kindxiaoming b5da1c01e4 update
2024-04-29 00:00:51 -04:00

225 lines
5.7 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

.. _hello-kan:
Hello, KAN!
===========
Kolmogorov-Arnold representation theorem
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Kolmogorov-Arnold representation theorem states that if :math:`f` is a
multivariate continuous function on a bounded domain, then it can be
written as a finite composition of continuous functions of a single
variable and the binary operation of addition. More specifically, for a
smooth :math:`f : [0,1]^n \to \mathbb{R}`,
.. math:: f(x) = f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q(\sum_{p=1}^n \phi_{q,p}(x_p))
where :math:`\phi_{q,p}:[0,1]\to\mathbb{R}` and
:math:`\Phi_q:\mathbb{R}\to\mathbb{R}`. In a sense, they showed that the
only true multivariate function is addition, since every other function
can be written using univariate functions and sum. However, this 2-Layer
width-:math:`(2n+1)` Kolmogorov-Arnold representation may not be smooth
due to its limited expressive power. We augment its expressive power by
generalizing it to arbitrary depths and widths.
Kolmogorov-Arnold Network (KAN)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Kolmogorov-Arnold representation can be written in matrix form
.. math:: f(x)={\bf \Phi}_{\rm out}\circ{\bf \Phi}_{\rm in}\circ {\bf x}
where
.. math:: {\bf \Phi}_{\rm in}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n}(\cdot) \\ \vdots & & \vdots \\ \phi_{2n+1,1}(\cdot) & \cdots & \phi_{2n+1,n}(\cdot) \end{pmatrix},\quad {\bf \Phi}_{\rm out}=\begin{pmatrix} \Phi_1(\cdot) & \cdots & \Phi_{2n+1}(\cdot)\end{pmatrix}
We notice that both :math:`{\bf \Phi}_{\rm in}` and
:math:`{\bf \Phi}_{\rm out}` are special cases of the following function
matrix :math:`{\bf \Phi}` (with :math:`n_{\rm in}` inputs, and
:math:`n_{\rm out}` outputs), we call a Kolmogorov-Arnold layer:
.. math:: {\bf \Phi}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n_{\rm in}}(\cdot) \\ \vdots & & \vdots \\ \phi_{n_{\rm out},1}(\cdot) & \cdots & \phi_{n_{\rm out},n_{\rm in}}(\cdot) \end{pmatrix}
:math:`{\bf \Phi}_{\rm in}` corresponds to
:math:`n_{\rm in}=n, n_{\rm out}=2n+1`, and :math:`{\bf \Phi}_{\rm out}`
corresponds to :math:`n_{\rm in}=2n+1, n_{\rm out}=1`.
After defining the layer, we can construct a Kolmogorov-Arnold network
simply by stacking layers! Lets say we have :math:`L` layers, with the
:math:`l^{\rm th}` layer :math:`{\bf \Phi}_l` have shape
:math:`(n_{l+1}, n_{l})`. Then the whole network is
.. math:: {\rm KAN}({\bf x})={\bf \Phi}_{L-1}\circ\cdots \circ{\bf \Phi}_1\circ{\bf \Phi}_0\circ {\bf x}
In constrast, a Multi-Layer Perceptron is interleaved by linear layers
:math:`{\bf W}_l` and nonlinearities :math:`\sigma`:
.. math:: {\rm MLP}({\bf x})={\bf W}_{L-1}\circ\sigma\circ\cdots\circ {\bf W}_1\circ\sigma\circ {\bf W}_0\circ {\bf x}
A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers.
(2) Each KAN layer can be visualized as a fully-connected layer, with a
1D function placed on each edge. Lets see an example below.
Get started with KANs
~~~~~~~~~~~~~~~~~~~~~
Initialize KAN
.. code:: ipython3
from kan import *
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
Create dataset
.. code:: ipython3
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
.. parsed-literal::
(torch.Size([1000, 2]), torch.Size([1000, 1]))
Plot KAN at initialization
.. code:: ipython3
# plot KAN at initialization
model(dataset['train_input']);
model.plot(beta=100)
.. image:: intro_files/intro_15_0.png
Train KAN with sparsity regularization
.. code:: ipython3
# train the model
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);
.. parsed-literal::
train loss: 1.57e-01 | test loss: 1.31e-01 | reg: 2.05e+01 : 100%|██| 20/20 [00:18<00:00, 1.06it/s]
Plot trained KAN
.. code:: ipython3
model.plot()
.. image:: intro_files/intro_19_0.png
Prune KAN and replot (keep the original shape)
.. code:: ipython3
model.prune()
model.plot(mask=True)
.. image:: intro_files/intro_21_0.png
Prune KAN and replot (get a smaller shape)
.. code:: ipython3
model = model.prune()
model(dataset['train_input'])
model.plot()
.. image:: intro_files/intro_23_0.png
Continue training and replot
.. code:: ipython3
model.train(dataset, opt="LBFGS", steps=50);
.. parsed-literal::
train loss: 4.74e-03 | test loss: 4.80e-03 | reg: 2.98e+00 : 100%|██| 50/50 [00:07<00:00, 7.03it/s]
.. code:: ipython3
model.plot()
.. image:: intro_files/intro_26_0.png
Automatically or manually set activation functions to be symbolic
.. code:: ipython3
mode = "auto" # "manual"
if mode == "manual":
# manual mode
model.fix_symbolic(0,0,0,'sin');
model.fix_symbolic(0,1,0,'x^2');
model.fix_symbolic(1,0,0,'exp');
elif mode == "auto":
# automatic mode
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
.. parsed-literal::
fixing (0,0,0) with sin, r2=0.999987252534279
fixing (0,1,0) with x^2, r2=0.9999996536741071
fixing (1,0,0) with exp, r2=0.9999988529417926
Continue training to almost machine precision
.. code:: ipython3
model.train(dataset, opt="LBFGS", steps=50);
.. parsed-literal::
train loss: 2.02e-10 | test loss: 1.13e-10 | reg: 2.98e+00 : 100%|██| 50/50 [00:02<00:00, 22.59it/s]
Obtain the symbolic formula
.. code:: ipython3
model.symbolic_formula()[0][0]
.. math::
\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \sin{\left(3.14 x_{1} \right)}}