2024-04-28 20:28:25 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Hello, KAN!"
]
},
2024-04-28 20:32:30 -04:00
{
"cell_type": "markdown",
"id": "59cf5cd0",
"metadata": {},
"source": [
"### Kolmogorov-Arnold representation theorem"
]
},
{
"cell_type": "markdown",
"id": "f88e5321",
"metadata": {},
"source": [
"Kolmogorov-Arnold representation theorem states that if $f$ is a multivariate continuous function\n",
"on a bounded domain, then it can be written as a finite composition of continuous functions of a\n",
"single variable and the binary operation of addition. More specifically, for a smooth $f : [0,1]^n \\to \\mathbb{R}$,\n",
"\n",
"\n",
"$$f(x) = f(x_1,...,x_n)=\\sum_{q=1}^{2n+1}\\Phi_q(\\sum_{p=1}^n \\phi_{q,p}(x_p))$$\n",
"\n",
"where $\\phi_{q,p}:[0,1]\\to\\mathbb{R}$ and $\\Phi_q:\\mathbb{R}\\to\\mathbb{R}$. In a sense, they showed that the only true multivariate function is addition, since every other function can be written using univariate functions and sum. However, this 2-Layer width-$(2n+1)$ Kolmogorov-Arnold representation may not be smooth due to its limited expressive power. We augment its expressive power by generalizing it to arbitrary depths and widths."
]
},
{
"cell_type": "markdown",
"id": "ebd8766a",
"metadata": {},
"source": [
"### Kolmogorov-Arnold Network (KAN)"
]
},
{
"cell_type": "markdown",
"id": "2cf3b1ee",
"metadata": {},
"source": [
"The Kolmogorov-Arnold representation can be written in matrix form\n",
"\n",
"$$f(x)={\\bf \\Phi}_{\\rm out}\\circ{\\bf \\Phi}_{\\rm in}\\circ {\\bf x}$$\n",
"\n",
"where \n",
"\n",
"$${\\bf \\Phi}_{\\rm in}= \\begin{pmatrix} \\phi_{1,1}(\\cdot) & \\cdots & \\phi_{1,n}(\\cdot) \\\\ \\vdots & & \\vdots \\\\ \\phi_{2n+1,1}(\\cdot) & \\cdots & \\phi_{2n+1,n}(\\cdot) \\end{pmatrix},\\quad {\\bf \\Phi}_{\\rm out}=\\begin{pmatrix} \\Phi_1(\\cdot) & \\cdots & \\Phi_{2n+1}(\\cdot)\\end{pmatrix}$$"
]
},
{
"cell_type": "markdown",
"id": "f6521452",
"metadata": {},
"source": [
"We notice that both ${\\bf \\Phi}_{\\rm in}$ and ${\\bf \\Phi}_{\\rm out}$ are special cases of the following function matrix ${\\bf \\Phi}$ (with $n_{\\rm in}$ inputs, and $n_{\\rm out}$ outputs), we call a Kolmogorov-Arnold layer:\n",
"\n",
"$${\\bf \\Phi}= \\begin{pmatrix} \\phi_{1,1}(\\cdot) & \\cdots & \\phi_{1,n_{\\rm in}}(\\cdot) \\\\ \\vdots & & \\vdots \\\\ \\phi_{n_{\\rm out},1}(\\cdot) & \\cdots & \\phi_{n_{\\rm out},n_{\\rm in}}(\\cdot) \\end{pmatrix}$$\n",
"\n",
"${\\bf \\Phi}_{\\rm in}$ corresponds to $n_{\\rm in}=n, n_{\\rm out}=2n+1$, and ${\\bf \\Phi}_{\\rm out}$ corresponds to $n_{\\rm in}=2n+1, n_{\\rm out}=1$."
]
},
{
"cell_type": "markdown",
"id": "1b410498",
"metadata": {},
"source": [
"After defining the layer, we can construct a Kolmogorov-Arnold network simply by stacking layers! Let's say we have $L$ layers, with the $l^{\\rm th}$ layer ${\\bf \\Phi}_l$ have shape $(n_{l+1}, n_{l})$. Then the whole network is\n",
"\n",
"$${\\rm KAN}({\\bf x})={\\bf \\Phi}_{L-1}\\circ\\cdots \\circ{\\bf \\Phi}_1\\circ{\\bf \\Phi}_0\\circ {\\bf x}$$"
]
},
{
"cell_type": "markdown",
"id": "54bbde9a",
"metadata": {},
"source": [
"In constrast, a Multi-Layer Perceptron is interleaved by linear layers ${\\bf W}_l$ and nonlinearities $\\sigma$:\n",
"\n",
"$${\\rm MLP}({\\bf x})={\\bf W}_{L-1}\\circ\\sigma\\circ\\cdots\\circ {\\bf W}_1\\circ\\sigma\\circ {\\bf W}_0\\circ {\\bf x}$$"
]
},
{
"cell_type": "markdown",
"id": "1c5f7795",
"metadata": {},
"source": [
"A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers. (2) Each KAN layer can be visualized as a fully-connected layer, with a 1D function placed on each edge. Let's see an example below."
]
},
{
"cell_type": "markdown",
"id": "adcb5f75",
"metadata": {},
"source": [
"### Get started with KANs"
]
},
2024-04-28 20:28:25 -04:00
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Initialize KAN"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
2024-07-21 19:17:50 -04:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-11 18:23:03 -04:00
"cuda\n",
2024-07-21 19:17:50 -04:00
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
}
],
2024-04-28 20:28:25 -04:00
"source": [
"from kan import *\n",
2024-07-13 22:17:48 -04:00
"torch.set_default_dtype(torch.float64)\n",
2024-08-11 18:23:03 -04:00
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)\n",
"\n",
2024-04-28 20:28:25 -04:00
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
2024-08-11 18:23:03 -04:00
"model = KAN(width=[2,5,1], grid=3, k=3, seed=42, device=device)"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "3d72e076",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Create dataset"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "46717e8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1000, 2]), torch.Size([1000, 1]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"from kan.utils import create_dataset\n",
2024-04-28 20:28:25 -04:00
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
2024-08-11 18:23:03 -04:00
"dataset = create_dataset(f, n_var=2, device=device)\n",
2024-04-28 20:28:25 -04:00
"dataset['train_input'].shape, dataset['train_label'].shape"
]
},
{
"cell_type": "markdown",
"id": "8c6add1d",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Plot KAN at initialization"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ac76f858",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:23:03 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB7kklEQVR4nO3dd1gU19828HvoiEpR7BXsUqSJCqLY0Kix9y7+jCWaaOzG3mPsGhMVe43GqBGjWEFFmoBdLFhBVKQu7LJlvu8feXZeSExkYXdnwfO5Lq/n+QV258vZ3bn3nDlzDkdEBIZhGIbRIiOxC2AYhmFKHxYuDMMwjNaxcGEYhmG0joULwzAMo3UsXBiGYRitY+HCMAzDaB0LF4ZhGEbrWLgwDMMwWsfChWEYhtE6Fi4MwzCM1rFwYRiGYbSOhQvDMAyjdSxcGIZhGK1j4cIwDMNoHQsXhmEYRutMxC6AYUoCIsKHDx8gkUhQtmxZVKhQARzHiV0Wwxgs1nNhmP+QkZGBDRs2oH79+rC3t0fdunVhb2+P+vXrY8OGDcjIyBC7RIYxSBzbiZJhPu7cuXPo06cPcnNzAfzVe1FT91rKlCmD3377DQEBAaLUyDCGioULw3zEuXPn0LVrVxAReJ7/198zMjICx3EIDg5mAcMw+bBwYZi/ycjIQI0aNSCVSv8zWNSMjIxgaWmJ169fw8bGRvcFMkwJwK65MMzf7NmzB7m5uYUKFgDgeR65ubnYu3evjitjmJKD9VwYJh8iQv369ZGYmAhNPhocx8HBwQGPHz9ms8gYBixcGKaA1NRU2NvbF+vxFSpU0GJFDFMysWExhslHIpEU6/HZ2dlaqoRhSjYWLgyTT9myZYv1+HLlymmpEoYp2Vi4MEw+FSpUgKOjo8bXTTiOg6OjI+zs7HRUGcOULCxcGCYfjuMwadKkIj128uTJ7GI+w/wfdkGfYf6G3efCMMXHei4M8zc2Njb47bffwHEcjIz++yOivkP/+PHjLFgYJh8WLgzzEQEBAQgODoalpSU4jvvHcJf6v1laWuLMmTPo1KmTSJUyjGFi4cIw/yIgIACvX7/G+vXr4eDgUOBnDg4OWL9+PZKSkliwMMxHsGsuDFMIRITLly+jffv2uHjxIvz9/dnFe4b5D6znwjCFwHGccE3FxsaGBQvDfAILF4ZhGEbrWLgwDMMwWsfChWEYhtE6Fi4MwzCM1rFwYRiGYbSOhQvDMAyjdSxcGIZhGK1j4cIwDMNoHQsXhmEYRutYuDAMwzBax8KFYRiG0ToWLgzDMIzWsXBhGIZhtI6FC8MwDKN1LFwYhmEYrWPhwjAMw2gdCxeG+QSFQoGkpCQ8ePAAAPD06VOkpaWB53mRK2MYw8W2OWaYf5GRkYHffvsNBw4cwL1795CdnQ25XA4LCwvY29ujdevWCAwMhI+PD0xMTMQul2EMCgsXhvmIGzduYMqUKbh9+za8vLzQtWtXuLi4oGzZssjIyMDNmzfxxx9/4MmTJxgwYACWLl0Ke3t7sctmGIPBwoVh/iYkJAQjR45E2bJlsWLFCnzxxReQy+U4fPgw8vLyUL58eQwcOBAKhQKHDx/GwoUL0bRpU+zbtw+VK1cWu3yGMQgsXBgmn0ePHqFz586wsrLC4cOH0aRJE3Ach8TERLi7uyMzMxN169bFzZs3YWtrCyLCtWvXMHjwYLRt2xY7duyAubm52H8Gw4iOXdBnmP+jUqmwfPlypKenY/PmzUKw/BeO4+Dr64sffvgBJ0+exNmzZ/VULcMYNhYuDPN/njx5gj/++AO9e/eGr6/vJ4NFjeM49OzZEy1atMD27duhVCp1XCnDGD42xYVh/k94eDgkEgn69OmD58+fIycnR/jZ69evoVKpAAByuRz37t1D+fLlhZ9Xq1YNvXv3xsKFC5GSkoIaNWrovX6GMSQsXBjm/zx8+BBlypSBg4MDvvrqK1y/fl34GREhLy8PAJCcnIyOHTsKP+M4DmvWrIGzszNyc3ORnJzMwoX57LFwYZj/I5VKYWJiAnNzc+Tl5UEmk33094joHz9TKpWwtLQsEEIM8zlj4cJ81niex/379xEWFoZLly4hNzcXGRkZ8Pb2hpWVlfB7UqkU4eHhQoi0atVKuHGS4zjUqlUL7969g1KpxP379+Hq6lpg2IxhPjdsKjLzWeF5Hvfu3UNoaCiuXLmCq1evIjU1FaampqhXrx4ePnyILVu2YMyYMQUel5iYCC8vL2RmZqJOnTqIiYmBjY2N8HOO4zBnzhysW7cO5ubmMDY2houLC3x9feHr64tWrVrB2tpaz38tw4iHhQtTqvE8jzt37iA0NBShoaEICwtDWloazMzM4O3tjTZt2qBNmzawtLTE0qVLcebMGXh7eyMkJKRAz+Pf7nMB/homS05ORuvWrVG7dm1s2bIFMTExuH79Oq5evYqkpCRwHAdnZ2e0bt0aPj4+aNWqlfB4himN2LAYU6qoVCrcvn0bYWFhQs8kPT0d5ubmaNGiBSZOnIi2bdvC29sblpaWiIuLw6JFi3Dy5Ek4Ojpi6NChOHr0KDZu3IhZs2YVas2wvLw8LF68GG/fvsXr168REBCA6dOnY/369bCwsMDLly9x9epVXLt2DSdPnsSWLVvAcRyaNm2K1q1bCz0bOzs7PbQQw+gJMUwJplQq6ebNm7R27Vr68ssvyc7OjoyMjMjS0pL8/f1p0aJFdOXKFZJKpQUeFxMTQ927dycAVK9ePdqzZw8pFAqSSCTUv39/Klu2LK1Zs4Zyc3OJ53l6+vQpVahQgUxMTKhevXqUlpZGPM9TZmYmzZgxg6ytrWnnzp305MkTCgwMJHNzc6pevTqtW7eOcnJyChz7xYsXdODAARo/fjw5OTlR+fLlqXz58tSyZUuaPn06nTp1ilJTU/XZjAyjdSxcmBJFoVBQdHQ0/fjjj9S9e3eytbUlIyMjKlOmDLVv356WLFlCoaGh/wgTtejoaOrWrRsBoAYNGtDevXtJoVAU+J13795Rv379yNLSknr27EmhoaH07t07unr1KoWGhtKNGzfo/fv3FBwcTP7+/mRra0ubNm0ipVIpPMfTp09pzJgxZGFhQdWqVaO1a9eSRCL5aE0vX76kgwcP0sSJE8nFxUUImxYtWtC0adPoxIkT9P79e+01IsPoAbvmwhg0pVKJ2NhYYZjr2rVryM7OFmZsqa+ZeHl5/eeaXlFRUVi0aBHOnDmDhg0bYt68eRg4cCCMjY0/+vs5OTnYvn07Nm7ciLdv38LBwQH169dHuXLlkJ6ejoSEBCQnJ8PDwwMLFixAmzZtYGT0zwUvnj9/jhUrVmDv3r2wtbXFd999h3HjxhWYifZ3SUlJuHbtmvAvMTERANCoUSP4+vqidevWaNWqFSpVqqRhazKM/rBwYQyKQqHAzZs3hTC5fv06JBIJypQpAx8fHyFMPD09YWZm9snni4yMxKJFi/Dnn3+iUaNGmD9/Pvr37/+vofJ3KSkpuHjxIkJDQ3Hr1i1ER0fD19cXPj4+6NSpE7y9vVGmTJlPPs+LFy+wcuVK7N69G7a2tpg6dSrGjRuHsmXLfvKxycnJwuSA69ev48mTJwCAhg0bwsfHR5gkwFZkZgwJCxdGVAqFAjExMcJsruvXryMnJwdWVlbw8fFB27Zt4efnB09PT5iamhb6eW/cuIFFixbh3LlzaNKkCebNm4d+/foVOlQ+JiYmBi1btkRERAQ8PDyK9BwvXrzAqlWrsHv3blhbW2PKlCkYP348ypUrV+jnePPmDa5fvy70bB4/fgwAqF+/vjD12cfHB1WrVi1SjQyjDSxcGL2Sy+WIjo4Weibh4eHIzc1F2bJl4evrK/RM3N3dNQoTtevXr2PRokU4f/48mjZtivnz56Nv374fHbLSVFxcHLy9vREZGQk3N7diPdfLly/xww8/YNeuXShbtiymTJmCCRMmFOnGy5SUFISHhws9m4SEBABAvXr1CvRsqlWrVqyaGUYTLFwYncrLy0NUVJQQJjdu3IBUKkX58uULhImbm1uxtgq+du0aFi1ahAsXLsDJyQkLFixA7969tRIqatoMF7VXr17
2024-04-28 20:28:25 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 22 Axes>"
2024-04-28 20:28:25 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot KAN at initialization\n",
"model(dataset['train_input']);\n",
2024-07-13 22:17:48 -04:00
"model.plot()"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "ddf67e30",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Train KAN with sparsity regularization"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97111d75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 18:23:03 -04:00
"| train_loss: 1.85e-02 | test_loss: 1.77e-02 | reg: 6.93e+00 | : 100%|█| 50/50 [00:09<00:00, 5.13it"
2024-07-21 19:17:50 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"# train the model\n",
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50, lamb=0.001);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "2f30c3ab",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Plot trained KAN"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2024-07-13 22:17:48 -04:00
"id": "92a4f67a",
2024-04-28 20:28:25 -04:00
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:23:03 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAwtklEQVR4nO3deVxU9f4/8NcZNhEQEFFzDZBMTcs9l3JJRfJ+b6bdutes7GalIi4oLqSgWJq4MWyiVopWVl+tR3lVqNxLza9LmV5QETdccGNQYGaAmc/vj3J+uTPDgXNm5vV8PHg8ill48eYcXn7OmTlIQggBIiIiGWmUDkBERI6H5UJERLJjuRARkexYLkREJDuWCxERyY7lQkREsmO5EBGR7FguREQkO5YLERHJjuVCRESyY7kQEZHsWC5ERCQ7lgsREcmO5UJERLJjuRARkexclQ5AZA+EELh27RqKi4vh7e2NgIAASJKkdCwi1eLKhegBdDodtFotQkNDERgYiKCgIAQGBiI0NBRarRY6nU7piESqJPEvURLdW1ZWFoYOHYrS0lIAf6xebrm1aqlduzbWr1+PsLAwRTISqRXLhegesrKyMGjQIAghYDab73s/jUYDSZKwceNGFgzRX7BciO6g0+nQpEkT6PX6BxbLLRqNBp6ensjPz4efn1/1BySyAzznQnSHjIwMlJaWVqpYAMBsNqO0tBSrV6+u5mRE9oMrF6K/EEIgNDQUeXl5sGbXkCQJwcHBOHHiBF9FRgSWC9Ftrl69isDAwCo9PiAgQMZERPaJh8WI/qK4uLhKj79586ZMSYjsG8uF6C+8vb2r9HgfHx+ZkhDZN5YL0V8EBAQgJCTE6vMmkiQhJCQEdevWraZkRPaF5UL0F5IkITIy0qbHjhs3jifzif7EE/pEd+D7XIiqjisXojv4+flh/fr1kCQJGs2Dd5Fb79D/+uuvWSxEf8FyIbqHsLAwbNy4EZ6enpAk6a7DXbc+5+npiU2bNmHAgAEKJSVSJ5YL0X2EhYUhPz8fiYmJCA4Ovu224OBgJCYm4vz58ywWonvgOReiShBCYNu2bXjuueewZcsW9OnThyfviR6AKxeiSpAkyXJOxc/Pj8VC9BAsFyIikh3LhYiIZMdyISIi2bFciIhIdiwXIiKSHcuFiIhkx3IhIiLZsVyIiEh2LBciIpIdy4WIiGTHciEiItmxXIiISHYsFyIikh3LhYiIZMdyISIi2bFciIhIdiwXoocoLy/H+fPnkZ2dDQA4efIkrl+/DrPZrHAyIvXinzkmug+dTof169fjs88+w9GjR3Hz5k2UlZWhVq1aCAwMxDPPPIO33noLPXr0gKurq9JxiVSF5UJ0D3v27MHEiRNx+PBhdO7cGYMGDUK7du3g7e0NnU6HAwcOYMOGDcjNzcUrr7yC999/H4GBgUrHJlINlgvRHb7//nuMGDEC3t7emDdvHp5//nmUlZXhiy++gNFoRJ06dfDPf/4T5eXl+OKLLzBr1iy0adMGa9asQYMGDZSOT6QKLBeivzh+/DgGDhwILy8vfPHFF2jdujUkSUJeXh46dOiAoqIiBAUF4cCBA/D394cQAj/99BOGDRuG3r1746OPPoKHh4fS3waR4nhCn+hPJpMJc+fORWFhIVJSUizF8iCSJKFnz55ISEjAt99+i8zMzBpKS6RuLBeiP+Xm5mLDhg0YMmQIevbs+dBiuUWSJAwePBhPP/00VqxYgYqKimpOSqR+fIkL0Z92796N4uJiDB06FKdPn0ZJSYnltvz8fJhMJgBAWVkZjh49ijp16lhub9SoEYYMGYJZs2bh0qVLaNKkSY3nJ1ITlgvRn3JyclC7dm0EBwfj3Xffxc8//2y5TQgBo9EIALhw4QL69+9vuU2SJCxatAht27ZFaWkpLly4wHIhp8dyIfqTXq+Hq6srPDw8YDQaYTAY7nk/IcRdt1VUVMDT0/O2EiJyZiwXoj/Vr18fer0eOp0OXbt2hZeXl+U2vV6P3bt3W0qke/fuljdOSpKEZs2a4fLly9BoNPD391fqWyBSDZYL0Z86duyI8vJy7Nu3D/Pnz7/ttry8PHTu3BlFRUVo0KABvvzyS/j5+VlulyQJMTExaNiwIQ+JEYGvFiOy6NKlC4KDg5GRkYGSkhK4uLjc9nGLJEnQaDSWz2s0Gly8eBFfffUVGjVqBHd3dwW/CyJ1YLkQ/SkgIABjx47FwYMHkZSUVOmXFBuNRsTHx6OgoAC7du1CixYtkJiYCL1eX82JidSL5UL0FyNGjMALL7yA+fPnIykpCXq9HrcuYuHq6gpXV1fLKkYIgRs3biAuLg5ffvklUlJSkJOTg/DwcERHRyM4OBiLFy9GaWmpkt8SkSJ4+ReiO1y5cgURERH4z3/+g7CwMEycOBGtWrXCsWPHYDab4e7ujhYtWmDfvn1YuHAhfv31V8THx2P06NGW4jl58iTmzZuH1atXo27duoiOjsaoUaNue5EAkSNjuRDdQ0lJCVasWIGkpCQUFBQgODgYoaGh8PHxQWFhIY4dO4YLFy6gY8eOiIuLQ69evaDR3H0g4NSpU5g3bx5WrVoFf39/TJ48GWPGjGHJkMNjuRA9wKVLl7Blyxbs2LEDeXl5MBgM8Pf3xxNPPIEBAwaga9euqF279kOf5/Tp0/jwww/xySefwN/fH5MmTcKYMWPg7e1dA98FUc1juRBVkslkghACGo3mnquUyjhz5oylZHx9fREVFYWIiAj4+PjInJZIWSwXIgWcPXsW8+fPx8cffwxvb29ERUVh7Nixt12vjMiesVyIFHTu3DnMnz8fH330Eby9vTFx4kRERkayZMjusVyIVCA/Px8JCQlYsWIFateujQkTJmDcuHHw9fVVOhqRTVguRCpy4cIFJCQkYNmyZfD09MT48eMxfvz42y41Q2QPWC5EKnThwgUsWLAAy5Ytg4eHByZMmMCSIbvCciFSsYsXL2LhwoVIT0+Hm5ubZSVTt25dpaMRPRDLhcgOFBQUYOHChUhLS4OrqyvGjRuHiRMnsmRItVguRHakoKAAixYtQlpaGlxcXBAZGYmJEyciICBA6WhEt2G5ENmhy5cvY/HixUhNTYUkSRg7diyioqJQr149paMRAWC5ENm1q1evYvHixUhJSYEQAhEREZg0aRICAwOVjkZOjuVC5ACuXr2KJUuWIDk5GWaz2VIy9evXVzoaOSmWC5EDuXbtGhITE5GUlASTyYTRo0dj8uTJaNCggdLRyMmwXIgc0PXr16HVaqHValFeXo5Ro0YhOjoaDRs2VDoaOQmWC5EDKywstJSM0Wi0lMwjjzyidDRycCwXIieg0+ksJWMwGPDOO+9gypQpaNSokdLRyEGxXIicSFFREZKSkrBkyRLo9Xq8/fbbmDp1Kho3bqx0NHIwLBciJ1RUVITk5GQsWbIEJSUllpJp0qSJ0tHIQbBciJzYjRs3kJKSgsWLF6O4uBhvvfUWpk2bhqZNmyodjewcy4WIcPPmTaSmpmLRokW4ceOGpWSaNWumdDSyUywXIrIoLi62lExRURHefPNNTJ8+Hc2bN1c6GtkZlgsR3aW4uBhLly7FwoULUVhYiBEjRmD69OkICgpSOhrZCZYLEd1XSUmJpWSuX7+ON954A9OnT0dwcLDS0UjlWC5E9FClpaVIT0/HggULcPXqVbz++uuIiYlBSEiI0tFIpVguRFRppaWlWL58ORISEnDlyhUMHz4c7733Hlq0aKF0NFIZlgsRWU2v11tKpqCgwFIyoaGhSkcjlWC5EJHNDAYDVqxYgfnz5+PSpUsYNmwY3nvvPbRs2VLpaKQwjdIBiMh+1apVC5GRkcjNzYVWq8W2bdvQpk0bDB8+HDk5OUrHIwWxXIioymrVqoWIiAjk5uYiKSkJO3fuRJs2bfDqq68iOztb6XikAB4WIyLZGY1GrFy5EvPmzUNRURHOnz8PLy8vpWNRDWK5EFGlff311zCZTHB1da3Wr2M2mzF06NBq/RpUvap3CyEih7J3717s378fs2fPRrdu3art60ybNo3lYudYLkRkle7du2PEiBE4efKk0lFIxXhCn4i
2024-04-28 20:28:25 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 22 Axes>"
2024-04-28 20:28:25 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
2024-04-28 20:32:30 -04:00
{
"cell_type": "markdown",
"id": "576856cf",
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"Prune KAN and replot"
2024-04-28 20:32:30 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 6,
2024-04-28 20:28:25 -04:00
"id": "7fe6fb12",
"metadata": {},
"outputs": [
2024-07-21 19:17:50 -04:00
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.2\n"
]
},
2024-04-28 20:28:25 -04:00
{
"data": {
2024-08-11 18:23:03 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuY0lEQVR4nO3deVRUV54H8O+tKrZikSW4oq2FuItRBFyjxoUoSTTi6U7bPYlZJo5Jq21yJuM45nR01HQnnWnN5kmbOROTaJsF0walJR0juAcE3FBxwQ0QEaEIUFBV1Lvzh1JHjFGUV7yq4vs5x398vFc/lMv33eW9K6SUEkRERCrSaV0AERF5H4YLERGpjuFCRESqY7gQEZHqGC5ERKQ6hgsREamO4UJERKpjuBARkeoYLkREpDqGCxERqY7hQkREqmO4EBGR6hguRESkOoYLERGpjuFCRESqM2hdAJEnkFLi2rVrqK2tRVBQECIiIiCE0LosIrfFngvRHZjNZqxZswYxMTGIjIxEr169EBkZiZiYGKxZswZms1nrEonckuBOlES3l5GRgZSUFFgsFgDXey9NmnotRqMRqampSEpK0qRGInfFcCG6jYyMDCQnJ0NKCUVRfvbrdDodhBDYtm0bA4boJgwXoluYzWZERUWhvr7+jsHSRKfTISAgAMXFxQgNDXV9gUQegHMuRLdYv349LBZLi4IFABRFgcViwSeffOLiyog8B3suRDeRUiImJgZFRUW4l6YhhIDJZMLp06e5iowIDBeiZioqKhAZGdmq8yMiIlSsiMgzcViM6Ca1tbWtOr+mpkalSog8G8OF6CZBQUGtOj84OFilSog8G8OF6CYRERGIjo6+53kTIQSio6MRHh7uosqIPAvDhegmQgjMnz//vs5dsGABJ/OJbuCEPtEt+JwLUeux50J0i9DQUKSmpkIIAZ3uzk2k6Qn9zZs3M1iIbsJwIbqNpKQkbNu2DQEBARBC/GS4q+nvAgICkJ6ejilTpmhUKZF7YrgQ/YykpCQUFxdj9erVMJlMzY6ZTCasXr0aJSUlDBai2+CcC1ELSCmxc+dOTJw4ETt27MCECRM4eU90B+y5ELWAEMI5pxIaGspgIboLhgsREamO4UJERKpjuBARkeoYLkREpDqGCxERqY7hQkREqmO4EBGR6hguRESkOoYLERGpjuFCRESqY7gQEZHqGC5ERKQ6hgsREamO4UJERKpjuBARkeoYLkREpDqGC9Fd2O12lJSU4MSJEwCAs2fPorKyEoqiaFwZkfviNsdEP8NsNiM1NRUbNmxAQUEBampqYLPZ4O/vj8jISIwdOxbPPfccRo8eDYPBoHW5RG6F4UJ0G/v378eiRYtw5MgRxMfHIzk5GbGxsQgKCoLZbEZubi7S0tJw5swZ/OpXv8KKFSsQGRmpddlEboPhQnSLb7/9FnPmzEFQUBDeeOMNTJs2DTabDZs2bYLVakVISAiefPJJ2O12bNq0Ca+//joGDhyITz/9FJ06ddK6fCK3wHAhusmpU6fwyCOPIDAwEJs2bcKAAQMghEBRURGGDRuG6upq9OrVC7m5uQgLC4OUEnv27MHs2bMxfvx4fPTRR/Dz89P62yDSHCf0iW5wOBxYtWoVqqqq8N577zmD5U6EEBgzZgzefPNNbNmyBdu3b2+jaoncG8OF6IYzZ84gLS0NM2fOxJgxY+4aLE2EEJgxYwZGjBiBdevWobGx0cWVErk/LnEhumHfvn2ora1FSkoKzp8/j7q6Ouex4uJiOBwOAIDNZkNBQQFCQkKcx7t27YqZM2fi9ddfR1lZGaKiotq8fiJ3wnAhuuHkyZMwGo0wmUyYO3cu9u7d6zwmpYTVagUAlJaWYvLkyc5jQgi8/fbbGDx4MCwWC0pLSxku1O4xXIhuqK+vh8FggJ+fH6xWKxoaGm77dVLKnxxrbGxEQEBAsxAias8YLkQ3dOzYEfX19TCbzUhMTERgYKDzWH19Pfbt2+cMkVGjRjkfnBRCoEePHigvL4dOp0NYWJhW3wKR22C4EN0QFxcHu92O7Oxs/OlPf2p2rKioCPHx8aiurkanTp3w+eefIzQ01HlcCIElS5agc+fOHBIjAleLETklJCTAZDJh/fr1qKurg16vb/aniRACOp3O+fc6nQ6XL1/GV199heTkZHTo0EHD74LIPTBciG6IiIjA7373O+Tl5eGdd95p8ZJiq9WK//7v/0Z9fT3mzp3b4iXMRN6Mw2JEN5kzZw527dqFP/3pTzAajZg3bx78/f0BAAaDAQaDwdmLkVKipqYGK1euxKZNm/CXv/wFffv21bJ8IrfB178Q3eLq1at46aWXsHXrViQlJWHRokXo378/CgsLoSgKfH190bt3b2RnZ+PPf/4zDh06hOXLl2PevHnNhs+I2jOGC9Ft1NXVYd26dXjnnXdw5coVmEwmxMTEIDg4GFVVVSgsLERpaSni4uLwhz/8AePGjYNOx1FmoiYMF6I7KCsrw44dO5CVlYXDhw8jOzsbY8eOxejRozFlyhQkJibCaDRqXSaR22G4ELVQTk4OEhISkJOTg+HDh2tdDpFbYz+eqIX0er1zGTIR3RlbCRERqY7hQkREqmO4EBGR6hguRESkOoYLERGpjuFCRESqY7gQEZHqGC5ERKQ6hgsREamO4UJERKpjuBARkeoYLkREpDqGCxERqY7hQkREquN+LkQtJKWEoijQ6XQQQmhdDpFbY8+F6B5wLxeiljFoXQCRGux2Oy5evAhFUbQupdWEEOjRowd8fX21LoXovjFcyCsUFxfjxRdfRFxcnDNgPLWXkZubiw8++ADR0dFal0J03xgu5BWklBg8eDASExOxceNGTJs2DU899ZTWZd2XxYsXg1Oh5OkYLuRVtmzZgq+++gpmsxm//vWvPW5oiaFC3sIzxw2IbkMIgccffxx6vR4HDx7EuXPntC6JqN1iuJBXGTVqFLp16waz2YyMjAz2BIg0wnAhr/LAAw9g4sSJAICvv/4aNptN44qI2ieGC3kVIQRSUlJgMBiQl5eHkydPal0SUbvEcCGvM2LECJhMJtTW1mLLli0cGiPSAMOFvE5oaCgee+wxANeHxmprazWuiKj9YbiQ1xFCYNasWTAajThx4gQOHDjA3gtRG2O4kFeKjY1FXFwc7HY7NmzY4BWvhSHyJAwX8kp+fn6YPXs2hBDIyMjApUuXtC6JqF1huJBXEkIgOTkZXbt2RXl5OTZv3syhMaI2xHAhr9WlSxdMnz4dAPDZZ5+hpqZG44qI2g+GC3ktIQSeeuopBAYGoqCgADt27GDvhaiNMFzIawkhMGTIEIwbNw6NjY1Yt24dn9gnaiMMF/JqPj4+eOGFF+Dj44Ndu3bhhx9+YO+FqA0wXMirCSHw8MMPIy4uDvX19fjggw/gcDi0LovI6zFcyOsZjUa8+OKL0Ov1+Mc//oHc3Fz2XohcjOFCXk8IgUcffRSxsbGora3FmjVr0NjYqHVZRF6N4ULtQkhICBYuXAi9Xo+tW7dy7oXIxRgu1C4IITBjxgwMHz4cdXV1ePPNN2G1WrUui8hrMVyo3QgKCsK///u/w9fXF//85z+Rnp7O3guRizBcqN0QQmDq1KmYPHkybDYbVq5cicrKSq3LIvJKDBdqV/z8/LB06VJ06NABhw8fxgcffMA3JhO5AMOF2hUhBOLi4vCv//qvkFLinXfeQV5eHofHiFTGcKF2R6/X4+WXX8agQYNQWVmJxYsX86WWRCpjuFC71LFjR6xatQpGoxFZWVlYvXo1n9wnUhHDhdolIQSSkpIwd+5cSCnxP//zP/jnP//J4TEilTBcqN0yGAxYsmQJRo0ahZqaGixYsACnT59mwBCpgOFC7VpYWBjeffdddOvWDWfPnsW8efNw7do1BgxRKzFcqF0TQiA2NharV69GUFAQMjMzsXDhQtTV1WldGpFHY7hQuyeEwPTp07Fs2TL4+vriiy++wH/8x3/AYrFoXRqRx2K4EOH68uQXX3wRixYtghAC69atw3/+53/CYrFwiIzoPjBciG7w9fXFa6+9hpdeegkAsHbtWrz88suoqalhwBD
2024-04-28 20:28:25 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-28 20:28:25 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = model.prune()\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "bd08ad99",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Continue training and replot"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 7,
2024-04-28 20:28:25 -04:00
"id": "18a2db11",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 18:23:03 -04:00
"| train_loss: 1.79e-02 | test_loss: 1.72e-02 | reg: 7.66e+00 | : 100%|█| 50/50 [00:06<00:00, 7.21it"
2024-07-21 19:17:50 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 8,
"id": "8768d56c",
2024-04-28 20:28:25 -04:00
"metadata": {},
2024-07-21 19:17:50 -04:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.4\n"
]
}
],
2024-07-13 22:17:48 -04:00
"source": [
"model = model.refine(10)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "46f73098",
2024-08-11 18:23:03 -04:00
"metadata": {},
2024-04-28 20:28:25 -04:00
"outputs": [
{
2024-07-13 22:17:48 -04:00
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 18:23:03 -04:00
"| train_loss: 4.67e-04 | test_loss: 4.73e-04 | reg: 7.66e+00 | : 100%|█| 50/50 [00:06<00:00, 7.37it"
2024-07-21 19:17:50 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-07-13 22:17:48 -04:00
]
2024-04-28 20:28:25 -04:00
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "cf35d505",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Automatically or manually set activation functions to be symbolic"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 10,
2024-04-28 20:28:25 -04:00
"id": "b3c0642b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-11 18:23:03 -04:00
"fixing (0,0,0) with sin, r2=0.9999999188529035, c=2\n",
"fixing (0,1,0) with x^2, r2=0.9999999809840728, c=2\n",
"fixing (1,0,0) with exp, r2=0.9999999904907739, c=2\n",
2024-07-21 19:17:50 -04:00
"saving model version 0.6\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"mode = \"auto\" # \"manual\"\n",
"\n",
"if mode == \"manual\":\n",
" # manual mode\n",
" model.fix_symbolic(0,0,0,'sin');\n",
" model.fix_symbolic(0,1,0,'x^2');\n",
" model.fix_symbolic(1,0,0,'exp');\n",
"elif mode == \"auto\":\n",
" # automatic mode\n",
" lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']\n",
" model.auto_symbolic(lib=lib)"
]
},
{
"cell_type": "markdown",
"id": "821ba616",
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"Continue training till machine precision"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 11,
2024-04-28 20:28:25 -04:00
"id": "c0800415",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 18:23:03 -04:00
"| train_loss: 3.33e-10 | test_loss: 7.20e-11 | reg: 0.00e+00 | : 100%|█| 50/50 [00:02<00:00, 24.56it\n"
2024-07-21 19:17:50 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.7\n"
]
2024-04-28 20:28:25 -04:00
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "e39da499",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Obtain the symbolic formula"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 12,
2024-04-28 20:28:25 -04:00
"id": "bf44f7e0",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
2024-07-13 22:17:48 -04:00
"$\\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \\sin{\\left(3.1416 x_{1} \\right)}}$"
2024-04-28 20:28:25 -04:00
],
"text/plain": [
2024-07-13 22:17:48 -04:00
"1.0*exp(1.0*x_2**2 + 1.0*sin(3.1416*x_1))"
2024-04-28 20:28:25 -04:00
]
},
2024-04-28 20:32:30 -04:00
"execution_count": 12,
2024-04-28 20:28:25 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"from kan.utils import ex_round\n",
"\n",
"ex_round(model.symbolic_formula()[0][0],4)"
2024-04-28 20:28:25 -04:00
]
2024-07-13 22:17:48 -04:00
},
{
"cell_type": "code",
"execution_count": null,
"id": "16e635f0",
"metadata": {},
"outputs": [],
"source": []
2024-04-28 20:28:25 -04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-08-11 18:23:03 -04:00
"version": "3.9.16"
2024-04-28 20:28:25 -04:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}