blealtan / efficient-kan Goto Github PK
View Code? Open in Web Editor NEWAn efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
License: MIT License
An efficient pure-PyTorch implementation of Kolmogorov-Arnold Network (KAN).
License: MIT License
Hey, I want to use your implementation, do you know how much slower the learning can be compared to nn.linear?
Hey, guys. Now I have this problem, the structure is like following:
y_pred = KAN_1(X1)*f1(X1) + KAN_2(X2)*f2(X2) + KAN_3(X3)*f3(X3),f1,f2,f3 is all fixed and known functions, how can I train KAN_i,i=1,2,3 simultaneously?
I am wondering how can I realize this training process.
Similar like what authors shown in official git repo, can use this efficient-kan model for continual learning settings. . For using efficient-kan for CL settings, I haven't found some attributes that need to be set given in official pykan;
######### cl code from pykan
model = KAN(width=[1,1], grid=200, k=3, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)
how can I set bias_trainable=False, sp_trainable=False, sb_trainable=False here, is there a way?
Getting this error when running test_simple_math.py
. Any idea how to resolve it?
File "/home/chansingh/test/kan.py", line 131, in curve2coeff
solution = torch.linalg.lstsq(
^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/native/BatchLinearAlgebra.cpp":1539, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.
(pytorch is up-to-date, version '2.3.0+cu121', python 3.11)
Hi, thank you for your work.
As for the title, I would like some ways to save the model for inference. I have tried pickle dump but it does not work.
Thanks
Hey I have a fork, with some not useful stuff on it yet. (Mostly just profiling showing forward passes suck due to b-splines and some comparisons to MLPs.)
Do you want folks to contribute to this?
Are you interested in making b-splines more efficient with something like: https://github.com/GistNoesis/FourierKAN/blob/main/fftKAN.py
Let me know what you think.
I don't quite understand KAN's code, is it possiable for KanLinear to do as Torch.nn.Linear: only the last dimension is subjected to derivation operations, allowing inputs greater than 2 dimensions?
For example, in multi head attention, our input is similar to [batch, nhead, dim]
However, this is not allowed in the current KAN ("assert x.dim() == 2 and x.size(1) == self.in_features")
Excuse me! I am very interested in exploring the application of KAN in attention
I wonder what the equation used in the KAN model, anybody knows?
from src.efficient_kan.kan import KAN
import torch
net = KAN([1152,1152*4,1152]).to("cuda")
x = torch.rand(size=(4096*4,1152)).to("cuda")
net(x)
I found that if the hidden layer is too large, the problem of CUDA out of memory will occur.
D:\Users\12719\anaconda3\python.exe D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py
20%|██ | 20/100 [00:01<00:06, 12.66it/s, mse_loss=nan, reg_loss=nan]
Intel oneMKL ERROR: Parameter 6 was incorrect on entry to SGELSY.
Intel oneMKL ERROR: Parameter 6 was incorrect on entry to SGELSY.
20%|██ | 20/100 [00:02<00:08, 9.82it/s, mse_loss=nan, reg_loss=nan]
Traceback (most recent call last):
File "D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py", line 35, in
test_mul()
File "D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py", line 29, in test_mul
optimizer.step(closure)
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\optim\optimizer.py", line 459, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\optim\lbfgs.py", line 320, in step
orig_loss = closure()
^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\tests\test_simple_math.py", line 18, in closure
y = kan(x, update_grid=(i % 20 == 0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 1541, in call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\src\efficient_kan\kan.py", line 272, in forward
layer.update_grid(x)
File "D:\Users\12719\anaconda3\Lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\src\efficient_kan\kan.py", line 210, in update_grid
self.spline_weight.data.copy(self.curve2coeff(x, unreduced_spline_output))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "D:\Users\12719\PycharmProjects\efficient-kan\src\efficient_kan\kan.py", line 131, in curve2coeff
solution = torch.linalg.lstsq(
^^^^^^^^^^^^^^^^^^^
RuntimeError: false INTERNAL ASSERT FAILED at "C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\BatchLinearAlgebra.cpp":1538, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.
This note shows equivalence of KAN to MLP in the piecewise linear approximation. I guess non-linearity of spline might help in some cases, but would be cool to have it as a baseline. Here's the reddit discussion
Hello:
if input tensor size is [64,28x28],hidden layers is [256,256,256,256],The memory usage of mlp and kan is similar,382M and 500M respectively.The results are consistent with the experimental results:
However,if the input tensor size is [36864,28x28],The memory usage of the two is huge different,844M and 14468M respectively.What is the reason for this?The initialization of the kan is consistent with that given in the example. And use a gpu.
I try to reproduce the experiments (example 4 in official KAN). With official KAN, I get the results as below (Ground-truth is at the top, and the predication is at the bottom):
But with the efficient-kan, I get the results as below:
It shows that previous peak will be higher when learning new peak.
The official model is create by: "model = KAN(width=[1, 1], grid=200, k=3, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)"
The efficient-kan model is created by: "model = KAN([1, 1], grid_size=200)"
It seems to be the same except for "bias_trainable=False, sp_trainable=False, sb_trainable=False".
Can anyone put kan(attention and mlp) into llama2.c?
Hello, thanks for your work. Are there plans in the near future to support fitting symbolic expressions/manual input and network visualizations as the original implementation do?
Thanks in advance
Hello author, I would like to know if the efficient implementation of MLP can replace the MLP module in transformer. What are the disadvantages and advantages?
Using your implementation on the data that has been transposed previously causes a
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
error.
Just replacing
efficient-kan/src/efficient_kan/kan.py
Line 156 in 605b8c2
Can anyone help me on how to get symbolic formula after training the efficient KAN model.
as title
I notice that most image classification tasks are based on efficient-kan instead of original kan. I want to know if it is possible to plot and prune the efficient-kan just like the examples in original kan.
I am pretty sure this is really just a dimensionally issue, but trying to use KANLinear to substitute for nn.Linear to try this approach out. I can use the tutorial to get it to work with MNIST just fine, but it doesn't work well outside the box, almost certainly because I am missing something.
I keep getting the error for forward:
assert x.dim() == 2 and x.size(1) == self.in_features
AssertationError
All I am doing is dropping KANLinear in for nn.Linear, and keeping in_features and out_features the same hidden size. Is there a way forward can be edited to allow non-image inputs?
Hi, I just want to share my experience when I developed KAN that the scale parameter seems quite important (but that was in the very beginning of the KAN project, so I could be hallucinating). Would love to hear your experimental results! Great initiative, would love to see a more efficient KAN implementation (with good features maintained).
I am confused about the principle of KAN. From this implementation, KAN has more learnable parameters?
It seems that the improvement of KAN lies in the learnable activation functions, thus achieving better accuracy. Does KAN have any advantage on computation and memory?
File "tests/test_simple_math.py", line 166, in curve2coeff
solution = torch.linalg.lstsq( # 使用最小二乘法求解线性方程组
RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/native/BatchLinearAlgebra.cpp":1462, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.
Hello,
Is there a rule of thumb or intuition for setting the layers_hidden parameter? I'm using it for time series, and I use [input_size, 10, horizon]. The 10 is arbitrary, and taken from the MNIST example, but do you have a suggestion on setting these for best performance?
This repos should have a license to protect its owner and potential users
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.