GitHub_collection_pykan/hellokan.ipynb

541 lines
78 KiB
Plaintext
Raw Normal View History

2024-04-28 20:28:25 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Hello, KAN!"
]
},
2024-04-28 20:32:30 -04:00
{
"cell_type": "markdown",
"id": "59cf5cd0",
"metadata": {},
"source": [
"### Kolmogorov-Arnold representation theorem"
]
},
{
"cell_type": "markdown",
"id": "f88e5321",
"metadata": {},
"source": [
"Kolmogorov-Arnold representation theorem states that if $f$ is a multivariate continuous function\n",
"on a bounded domain, then it can be written as a finite composition of continuous functions of a\n",
"single variable and the binary operation of addition. More specifically, for a smooth $f : [0,1]^n \\to \\mathbb{R}$,\n",
"\n",
"\n",
"$$f(x) = f(x_1,...,x_n)=\\sum_{q=1}^{2n+1}\\Phi_q(\\sum_{p=1}^n \\phi_{q,p}(x_p))$$\n",
"\n",
"where $\\phi_{q,p}:[0,1]\\to\\mathbb{R}$ and $\\Phi_q:\\mathbb{R}\\to\\mathbb{R}$. In a sense, they showed that the only true multivariate function is addition, since every other function can be written using univariate functions and sum. However, this 2-Layer width-$(2n+1)$ Kolmogorov-Arnold representation may not be smooth due to its limited expressive power. We augment its expressive power by generalizing it to arbitrary depths and widths."
]
},
{
"cell_type": "markdown",
"id": "ebd8766a",
"metadata": {},
"source": [
"### Kolmogorov-Arnold Network (KAN)"
]
},
{
"cell_type": "markdown",
"id": "2cf3b1ee",
"metadata": {},
"source": [
"The Kolmogorov-Arnold representation can be written in matrix form\n",
"\n",
"$$f(x)={\\bf \\Phi}_{\\rm out}\\circ{\\bf \\Phi}_{\\rm in}\\circ {\\bf x}$$\n",
"\n",
"where \n",
"\n",
"$${\\bf \\Phi}_{\\rm in}= \\begin{pmatrix} \\phi_{1,1}(\\cdot) & \\cdots & \\phi_{1,n}(\\cdot) \\\\ \\vdots & & \\vdots \\\\ \\phi_{2n+1,1}(\\cdot) & \\cdots & \\phi_{2n+1,n}(\\cdot) \\end{pmatrix},\\quad {\\bf \\Phi}_{\\rm out}=\\begin{pmatrix} \\Phi_1(\\cdot) & \\cdots & \\Phi_{2n+1}(\\cdot)\\end{pmatrix}$$"
]
},
{
"cell_type": "markdown",
"id": "f6521452",
"metadata": {},
"source": [
"We notice that both ${\\bf \\Phi}_{\\rm in}$ and ${\\bf \\Phi}_{\\rm out}$ are special cases of the following function matrix ${\\bf \\Phi}$ (with $n_{\\rm in}$ inputs, and $n_{\\rm out}$ outputs), we call a Kolmogorov-Arnold layer:\n",
"\n",
"$${\\bf \\Phi}= \\begin{pmatrix} \\phi_{1,1}(\\cdot) & \\cdots & \\phi_{1,n_{\\rm in}}(\\cdot) \\\\ \\vdots & & \\vdots \\\\ \\phi_{n_{\\rm out},1}(\\cdot) & \\cdots & \\phi_{n_{\\rm out},n_{\\rm in}}(\\cdot) \\end{pmatrix}$$\n",
"\n",
"${\\bf \\Phi}_{\\rm in}$ corresponds to $n_{\\rm in}=n, n_{\\rm out}=2n+1$, and ${\\bf \\Phi}_{\\rm out}$ corresponds to $n_{\\rm in}=2n+1, n_{\\rm out}=1$."
]
},
{
"cell_type": "markdown",
"id": "1b410498",
"metadata": {},
"source": [
"After defining the layer, we can construct a Kolmogorov-Arnold network simply by stacking layers! Let's say we have $L$ layers, with the $l^{\\rm th}$ layer ${\\bf \\Phi}_l$ have shape $(n_{l+1}, n_{l})$. Then the whole network is\n",
"\n",
"$${\\rm KAN}({\\bf x})={\\bf \\Phi}_{L-1}\\circ\\cdots \\circ{\\bf \\Phi}_1\\circ{\\bf \\Phi}_0\\circ {\\bf x}$$"
]
},
{
"cell_type": "markdown",
"id": "54bbde9a",
"metadata": {},
"source": [
"In constrast, a Multi-Layer Perceptron is interleaved by linear layers ${\\bf W}_l$ and nonlinearities $\\sigma$:\n",
"\n",
"$${\\rm MLP}({\\bf x})={\\bf W}_{L-1}\\circ\\sigma\\circ\\cdots\\circ {\\bf W}_1\\circ\\sigma\\circ {\\bf W}_0\\circ {\\bf x}$$"
]
},
{
"cell_type": "markdown",
"id": "1c5f7795",
"metadata": {},
"source": [
"A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers. (2) Each KAN layer can be visualized as a fully-connected layer, with a 1D function placed on each edge. Let's see an example below."
]
},
{
"cell_type": "markdown",
"id": "adcb5f75",
"metadata": {},
"source": [
"### Get started with KANs"
]
},
2024-04-28 20:28:25 -04:00
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Initialize KAN"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
2024-07-21 19:17:50 -04:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
}
],
2024-04-28 20:28:25 -04:00
"source": [
"from kan import *\n",
2024-07-13 22:17:48 -04:00
"torch.set_default_dtype(torch.float64)\n",
2024-04-28 20:28:25 -04:00
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
2024-07-13 22:17:48 -04:00
"model = KAN(width=[2,5,1], grid=3, k=3, seed=42)"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "3d72e076",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Create dataset"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "46717e8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1000, 2]), torch.Size([1000, 1]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"from kan.utils import create_dataset\n",
2024-04-28 20:28:25 -04:00
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"dataset['train_input'].shape, dataset['train_label'].shape"
]
},
{
"cell_type": "markdown",
"id": "8c6add1d",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Plot KAN at initialization"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ac76f858",
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABfY0lEQVR4nO3deXgUVdo+/ru6O+l0OgnZEyCELIQ1IPuWBASBgI67wuiouKAz6Cuvy4wi+HtnLp0wrqOo41cFGRUXQHHGBQn7voPshJAQkpCdLJ2t967z+4Opmm7WdKe6q7rzfK7La3Sy9OmTqrr7OefUKY4xxkAIIYRISCV3AwghhAQeChdCCCGSo3AhhBAiOQoXQgghkqNwIYQQIjkKF0IIIZKjcCGEECI5ChdCCCGSo3AhhBAiOQoXQgghkqNwIYQQIjkKF0IIIZKjcCGEECI5ChdCCCGSo3AhhBAiOY3cDSDEHzDG0NDQgLa2NoSFhSEmJgYcx8ndLEIUiyoXQq7BYDBg8eLFyMjIQFxcHFJTUxEXF4eMjAwsXrwYBoNB7iYSokgcPYmSkCtbt24d7r77bhiNRgAXqxeBULWEhoZi9erVyM3NlaWNhCgVhQshV7Bu3TrccsstYIyB5/mrfp9KpQLHcVizZg0FDCFOKFwIuYTBYEBSUhJMJtM1g0WgUqmg0+lQUVGByMhI7zeQED9Acy6EXOLzzz+H0WjsULAAAM/zMBqN+OKLL7zcMkL8B1UuhDhhjCEjIwMlJSVw59TgOA5paWkoKiqiVWSEgMKFEBf19fWIi4vr1M/HxMRI2CJC/BMNixHipK2trVM/39raKlFLCPFvFC6EOAkLC+vUz4eHh0vUEkL8G4ULIU5iYmKQnp7u9rwJx3FIT09HdHS0l1pGiH+hcCHECcdxePrppz362Xnz5tFkPiH/QRP6hFyC7nMhpPOociHkEpGRkVi9ejU4joNKde1TRLhD//vvv6dgIcQJhQshV5Cbm4s1a9ZAp9OB47jLhruE/0+n0+GXX37BtGnTZGopIcpE4ULIVeTm5qKiogLvvvsu0tLSXL6WlpaGd999F5WVlRQshFwBzbkQ0gGMMWzZsgU33XQTNm3ahEmTJtHkPSHXQJULIR3AcZw4pxIZGUnBQsh1ULgQQgiRHIULIYQQyVG4EEIIkRyFCyGEEMlRuBBCCJEchQshhBDJUbgQQgiRHIULIYQQyVG4EEIIkRyFCyGEEMlRuBBCCJEchQshhBDJUbgQQgiRHIULIYQQyVG4EEIIkRyFCyGEEMlRuBByHTabDZWVlSgoKAAAnD17Fo2NjeB5XuaWEaJc9JhjQq7CYDBg9erV+Oqrr3Dy5Em0trbCarUiJCQEcXFxyMnJwWOPPYasrCxoNBq5m0uIolC4EHIFe/bswbPPPotjx45h1KhRuOWWWzBkyBCEhYXBYDDg0KFD+Omnn1BcXIxZs2bhr3/9K+Li4uRuNiGKQeFCyCXWr1+Phx9+GGFhYfjb3/6Gm2++GVarFStWrIDFYkFERAR++9vfwmazYcWKFfjLX/6CQYMGYfny5UhISJC7+YQoAoULIU7OnDmD6dOnQ6/XY8WKFRg4cCA4jkNJSQmGDx+O5uZmpKam4tChQ4iKigJjDDt37sT999+PG2+8EUuXLoVWq5X7bRAiO5rQJ+Q/HA4HFi1ahKamJnzwwQdisFwLx3HIzs7GG2+8gR9++AH5+fk+ai0hykbhQsh/FBcX46effsJdd92F7Ozs6waLgOM43HHHHRg7diyWLFkCu93u5ZYSony0xIWQ/9i9ezfa2tpw9913o7S0FO3t7eLXKioq4HA4AABWqxUnT55ERESE+PUePXrgrrvuwl/+8hfU1NQgKSnJ5+0nREkoXAj5j9OnTyM0NBRpaWn4/e9/j127dolfY4zBYrEAAKqqqjB16lTxaxzH4e2338bgwYNhNBpRVVVF4UK6PAoXQv7DZDJBo9FAq9XCYrHAbDZf8fsYY5d9zW63Q6fTuYQQIV0ZhQvp8sxmM9ra2hAUFASTyQSDwYAxY8ZAr9eL32MymbB7924xRMaPHy/eOMlxHJKTk1FXVwe73Y6TJ0/ihhtucBk2I6SroXAhXY4QJm1tbWhvb4fdbgfHcRg4cCCsViv279+P119/3eVnSkpKMGrUKDQ3NyMhIQErV65EZGSk+HWO47BgwQI4HA48//zzeOGFF3DDDTcgOzsbOTk5yMrKQrdu3Xz8TgmRD4ULCXgmkwnt7e1ioDgcDnAch9DQUMTExCAsLExcKWa32/HPf/4T9913n0vloVarxX/nOA4qlUr8/xhjqKqqwnfffYcnn3wS8+bNw44dO7Bz507861//wvvvvw+O4zBkyBDk5OQgOzsbWVlZiIqK8nlfEOIrdBMlCSjCfIhQlTiHiV6vh16vR1hYGEJDQ6FSqXD06FHk5eXhp59+QlpaGkaOHImVK1fi5Zdfxvz588Whr6vdRAlcrISeeeYZ/PTTT1i3bh369esHtVoNlUoFxhjKysqwY8cO8Z/z58+D4zhkZmYiJydHrGyio6Pl7DpCJEWVC/FrzmEiBIpzmMTFxYmh4nzfyuHDh5GXl4c1a9YgPT0dS5cuxaxZs8SJ/Ndffx2hoaGYO3cuQkJCAAAajQYajcalYmltbUVeXh5WrFiBd955B4MGDQLP83A4HHA4HFCr1UhJSUFKSgoefPBBAEBZWRl27tyJHTt2YM2aNfjwww8BAJmZmcjOzsaECROQlZWFmJgYH/cmIdKhyoX4FcaYyzDXpWESFhYmViZXugny119/RV5eHn755RdkZGRg/vz5mDlzpsuuxhcuXMBTTz2Fn3/+Gbm5uXj22WcxYMAAFBYWgud5BAcHo0+fPti/fz/eeustHDlyBK+88grmzp3rEjw8z4vb8guVzJWUl5dj586dYuCcO3cOADBw4ECxssnOzkZsbKzU3UmI11C4EEUTwsR5mIvneahUqsuGua51R/3BgweRl5eH/Px89O3bFy+99BLuvfdel7kUZ+3t7ViyZAnee+891NbWIi0tDRkZGQgPD0dTUxMKCwtRVVWFESNG4M9//jMmTpx4xfBwJ2QEFRUVYtDs2LEDJSUlAID+/fsjJydHrGzi4+M72o2E+ByFC1EU5zARAsU5TMLCwqDX668bJoIDBw4gLy9PnAtZsGAB7r777quGyqVqamqwadMmbNu2DWfOnEFtbS169+6NYcOGYdq0aRgzZgxCQ0M79L7cDRlBVVWVS9gUFxcDAPr16+dS2dCOzERJKFyIrBhjMBqNLsNcnQkTwb59+5CXl4cNGzZgwIABeOmll3DXXXd1OFSupKGhAVu2bMHkyZM9nnzvTMgIqqurxWG07du3o6ioCACQkZGBCRMmIDs7G9nZ2ejevbtHbSREChQuxKc6EiZhYWHQ6XRuhYlgz549yMvLw6ZNmzBw4EAsWLAAd955p9sX8CsxGAzYsmULJk2a5HKPiyekCBlBTU0Ndu3aJVY2hYWFAIA+ffq4VDY9evToVJsJcQeFC/EqIUyEYS6j0Qie56FWq13mTDwNE8Hu3buRl5eHzZs3Y9CgQVi4cCFuv/12SUJFIGW4CKQMGUFdXR127dqF7du3Y8eOHTh9+jQAIC0tzSVsaP8z4k0ULkRSl4ZJe3s7GGNimAjDXJ0NE8HOnTuRl5eHrVu3YvDgwViwYAFuu+02SUNF4I1wEXgjZAQXLlxwqWxOnToFAEhNTRWDJicnB7169ZLk9QgBKFxIJzHG0N7e7jLMdWmYhIWFISQkRJIwEWzfvh15eXnYvn07hgwZgoULF+I3v/mNV0JF4M1wEXgzZAQNDQ1iZbNz506cOHECANC7d2+XyqZ3796Svi7pWihciFt4nr9smEsIE6Eq8UaYABcvvEKo7NixAzfccIMYKlK/1pX4IlwEl4aM83YzUmtsbHSpbE6cOAHGGJKTk8WqJicnB7179/ZJP5PAQOFCrulqYaLRaC4b5vIWxhi2bt2KvLw87Nq1C0OHDsXChQtxyy23+PRi58twceZwOHwSMoKmpibs2rVLXP587NgxMMaQlJQkVjU
2024-04-28 20:28:25 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 22 Axes>"
2024-04-28 20:28:25 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot KAN at initialization\n",
"model(dataset['train_input']);\n",
2024-07-13 22:17:48 -04:00
"model.plot()"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "ddf67e30",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Train KAN with sparsity regularization"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97111d75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-21 19:17:50 -04:00
"| train_loss: 1.05e-02 | test_loss: 9.86e-03 | reg: 3.69e+00 | : 100%|█| 50/50 [00:23<00:00, 2.15it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"# train the model\n",
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50, lamb=0.001);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "2f30c3ab",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Plot trained KAN"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2024-07-13 22:17:48 -04:00
"id": "92a4f67a",
2024-04-28 20:28:25 -04:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3GklEQVR4nO3deVhU9f4H8PeZGUAEFUXUFBdA5MruboqiudBy6/qzMm9auVSGlfd2+/1azKWg5al7exLriluLWVftqmkuaS6BC26RggvigmhIaAqoyDYz5/v7w2YexlAZOMM5M/N+PQ9PySx8+HDOvOf7PWe+RxJCCBARESlIp3YBRETkehguRESkOIYLEREpjuFCRESKY7gQEZHiGC5ERKQ4hgsRESmO4UJERIpjuBARkeIYLkREpDiGCxERKY7hQkREimO4EBGR4hguRESkOIYLEREpzqB2AUTOQAiBy5cvo6ysDL6+vvD394ckSWqXRaRZHLkQ3UZpaSlSUlIQGhqKgIAABAUFISAgAKGhoUhJSUFpaanaJRJpksQrURLVbvPmzXj44YdRXl4O4MboxcIyamnatClWrVqFhIQEVWok0iqGC1EtNm/ejAceeABCCMiyfMv76XQ6SJKEDRs2MGCIamC4EN2ktLQUgYGBqKiouG2wWOh0Onh7e6OgoAB+fn6OL5DICfCYC9FNlixZgvLy8joFCwDIsozy8nJ8+eWXDq6MyHlw5EJUgxACoaGhyMvLgz27hiRJCA4OxsmTJ3kWGREYLkQ2Ll26hICAgAY93t/fX8GKiJwTp8WIaigrK2vQ469du6ZQJUTOjeFCVIOvr2+DHt+sWTOFKiFybgwXohr8/f0REhJi93ETSZIQEhKCVq1aOagyIufCcCGqQZIkvPjii/V67LRp03gwn+h3PKBPdBN+zoWo4ThyIbqJn58fVq1aBUmSoNPdfhexfEJ/9erVDBaiGhguRLVISEjAhg0b4O3tDUmS/jDdZfmet7c3Nm7ciJEjR6pUKZE2MVyIbiEhIQEFBQWYM2cOgoODbW4LDg7GnDlzcP78eQYLUS14zIWoDoQQ+PHHHzFs2DBs27YNQ4cO5cF7otvgyIWoDiRJsh5T8fPzY7AQ3QHDhYiIFMdwISIixTFciIhIcQwXIiJSHMOFiIgUx3AhIiLFMVyIiEhxDBciIlIcw4WIiBTHcCEiIsUxXIiISHEMFyIiUhzDhYiIFMdwISIixTFciIhIcQwXIiJSHMOF6A6MRiPOnz+PnJwcAMDp06dRXFwMWZZVroxIu3iZY6JbKC0txapVq/D111/j6NGjuHbtGqqrq9GkSRMEBARg0KBBmDx5MgYOHAiDwaB2uUSawnAhqsWePXvw0ksvITs7G3369MEDDzyA6Oho+Pr6orS0FJmZmVi3bh1OnTqFxx57DG+//TYCAgLULptIMxguRDf54YcfMGHCBPj6+uK9997D/fffj+rqaixfvhxVVVVo3rw5xo4dC6PRiOXLl+PNN99EREQEli5dirZt26pdPpEmMFyIajhx4gTuvfde+Pj4YPny5QgPD4ckScjLy0PPnj1x5coVBAUFITMzEy1btoQQArt27cLjjz+OIUOGYPHixfDy8lL71yBSHQ/oE/3ObDbj3XffRUlJCT755BNrsNyOJEmIi4vDBx98gLVr12LTpk2NVC2RtjFciH536tQprFu3DqNHj0ZcXNwdg8VCkiSMGjUK/fv3x6JFi2AymRxcKZH28RQXot9lZGSgrKwMDz/8MPLz83H9+nXrbQUFBTCbzQCA6upqHD16FM2bN7fe3r59e4wePRpvvvkmioqKEBgY2Oj1E2kJw4Xod8ePH0fTpk0RHByMKVOmYPfu3dbbhBCoqqoCABQWFmLEiBHW2yRJwocffoioqCiUl5ejsLCQ4UJuj+FC9LuKigoYDAZ4eXmhqqoKlZWVtd5PCPGH20wmE7y9vW1CiMidMVzIrcmyjGPHjiE9PR3bt29HeXk5SktL0a9fP/j4+FjvV1FRgYyMDGuIDBgwwPrBSUmS0KlTJ1y8eBE6nQ4tW7ZU69ch0gyGC7kVWZZx9OhRpKWlIT09HTt27MClS5fg4eGBrl27oqqqCvv378f7779v87i8vDz06dMHV65cQdu2bbFixQr4+flZb5ckCdOnT0e7du04JUYEhgu5OFmWcfjwYaSlpWHHjh1IT09HcXExPD090a9fPyQmJmLw4MG4++67UVFRgbi4OCxZsgR//etfbQ7Y6/V66/9LkgSdTmf9nhAChYWFWLlyJf785z+jRYsWjf57EmkNw4VcitlsRnZ2NtLT060jk5KSEnh5eaF///544YUXEB8fj/79+8Pb29vmsU2bNsULL7yAl19+GXPnzsVrr71WpzXDqqqqkJycjIqKCkyePBkmk8kmfIjcEcOFnJrZbEZWVpZ1mmvnzp0oLS2Fl5cX7r77bkybNg1DhgxBv3790KRJkzs+34QJE7Bjxw68//77aNq0KRITE62PMxgMMBgMNiOWa9eu4Z133sHy5cvx0UcfITIyEmazGWazGbIsM2TIbXH5F3IqJpMJhw4dsk5z7dy5E1euXEGTJk0wYMAAxMfHIz4+Hn379q1TmNTmt99+w/PPP4/169cjISEBL730Erp3747c3FzIsgxPT0907doV+/fvx7/+9S8cOnQISUlJSExMtAkeS8DcPI1G5A4YLqRpJpMJP//8s3Waa+fOnbh27Rq8vb0xcOBADB48GEOGDEGfPn0UXdPr+vXrWLRoEebOnYsLFy4gODgYoaGhaNasGUpKSpCbm4vCwkL06tULs2fPRnx8PHS6Py54UTNkgBvHbhgy5A4YLqQpRqMRmZmZ1jDZtWsXysrK0LRpU8TFxVnDpHfv3vD09HR4PUVFRdi2bRvS09ORl5eHyspKtGzZEpGRkRg5ciT69euHpk2b3vF5GDLkbhgupCqj0YiffvrJOs21a9cuXL9+HT4+PoiLi7NOc/Xu3RseHh6q1mo2myGEgE6nq3WUUhcMGXIXDBdqVNXV1Thw4IB1ZLJ7926Ul5fD19cXgwYNso5MevbsqXqYOBJDhlwdw4UcyvKhREuYZGRkoKKiAs2bN7cJkx49erjlpYIZMuSqGC6kqMrKSuzfv986zZWRkYHKyko0b94cgwcPtk5zxcbGumWY3IoQArIsW1deZsiQs2O4UINUVlZi79691pHJnj17UFVVBT8/PwwePNg6MomJieGLZR0wZMhVMFzILhUVFdizZw927NiBtLQ07Nu3D1VVVWjZsqV1ZDJkyBBERUXxRbEBGDLk7BgudFvl5eXYs2ePdWSyb98+VFdXw9/f32aaKyoqqt5nUNGtMWTIWTFcyMb169eRkZFhHZns378fRqMRrVu3Rnx8vHWaKyIigmHSiBgy5GwYLm6urKwMGRkZSE9PR1paGg4cOACTyYSAgADrqGTIkCHo3r07w0QDGDLkLBgububatWvYvXu3dZrrp59+gslkQtu2ba1hEh8fj+7du0OSJLXLpVtgyJDWMVxc3NWrV7Fr1y5rmGRmZsJsNqNdu3YYMmSIdZorLCyMYeKkLKswAwwZ0g6Gi4u5cuWKNUzS0tLw888/Q5ZltG/f3iZMQkNDGSYu5uaQ0el0/BuTahguTq60tBQ7d+60jkwOHjwIWZbRoUMHDBkyxDrN1bVrV77QuAmGDGkBw8XJFBcXY+fOndazuQ4dOgQhBDp27GgzMgkODuYLiptjyJCaGC4ad/nyZevIJC0tDdnZ2RBCoFOnThg6dKg1TLp06cIXDqoVQ4bUwHDRmEuXLmHHjh3Waa7s7GwAQFBQkM3ZXF26dFG3UHI6DBlqTAwXlV28eBE7duywTnMdOXIEABAcHGyd5oqPj0fnzp1VrpRcBUOGGgPDpZFduHAB6enp1jA5duwYAKBr1642YdKxY0eVKyVXx5AhR2K4OFhRUZF1iis9PR05OTkAgG7dutlMc3Xo0EHlSsld1QwZnU4HvV7PkKEGY7gorLCw0GZkkpubCwAICwuzGZm0b99e5UqJbDFkSEkMlwY6f/689cJYaWlpOHnyJAC
2024-04-28 20:28:25 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 22 Axes>"
2024-04-28 20:28:25 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
2024-04-28 20:32:30 -04:00
{
"cell_type": "markdown",
"id": "576856cf",
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"Prune KAN and replot"
2024-04-28 20:32:30 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 6,
2024-04-28 20:28:25 -04:00
"id": "7fe6fb12",
"metadata": {},
"outputs": [
2024-07-21 19:17:50 -04:00
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.2\n"
]
},
2024-04-28 20:28:25 -04:00
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAt/ElEQVR4nO3daVSUZ57+8euuKmSXTVyBCIgLqKhsRkSMG3Yw04nG0U6mk3Q0cTTRGTMzORknmdh2jMduHddM2jaTbpNJ2ngEdRTXuICCCqK4oCKKYgBFEQqBgqKKuv8vWuovxgXlKWrh+pzjG8sqfqgP37qfrYSUUoKIiEhBKmsPQEREjodxISIixTEuRESkOMaFiIgUx7gQEZHiGBciIlIc40JERIpjXIiISHGMCxERKY5xISIixTEuRESkOMaFiIgUx7gQEZHiGBciIlIc40JERIrTWHsAInsgpcSdO3dQW1sLDw8P+Pn5QQhh7bGIbBZXLkSPodVqsWrVKoSFhcHf3x/BwcHw9/dHWFgYVq1aBa1Wa+0RiWyS4CdREj3cnj17MGXKFOh0OgB/W700a161uLm5ISUlBUlJSVaZkchWMS5ED7Fnzx4kJydDSgmTyfTIP6dSqSCEQFpaGgNDdB/GhegBWq0WAQEBqK+vf2xYmqlUKri6uqKkpATe3t6WH5DIDvCYC9EDNmzYAJ1O16qwAIDJZIJOp8M333xj4cmI7AdXLkT3kVIiLCwMRUVFeJpNQwiBkJAQFBYW8iwyIjAuRC1UVFTA39+/Tc/38/NTcCIi+8TdYkT3qa2tbdPza2pqFJqEyL4xLkT38fDwaNPzPT09FZqEyL4xLkT38fPzQ2ho6FMfNxFCIDQ0FL6+vhaajMi+MC5E9xFCYO7cuc/03Hnz5vFgPtE9PKBP9ABe50LUdly5ED3A29sbKSkpEEJApXr8JtJ8hX5qairDQnQfxoXoIZKSkpCWlgZXV1cIIX62u6v591xdXbFz505MmDDBSpMS2SbGhegRkpKSUFJSgpUrVyIkJKTFYyEhIVi5ciVKS0sZFqKH4DEXolaQUuLgwYMYO3Ys9u/fjxdeeIEH74kegysXolYQQpiPqXh7ezMsRE/AuBARkeIYFyIiUhzjQkREimNciIhIcYwLEREpjnEhIiLFMS5ERKQ4xoWIiBTHuBARkeIYFyIiUhzjQkREimNciIhIcYwLEREpjnEhIiLFMS5ERKQ4xoWIiBTHuBA9gcFgQGlpKS5cuAAAuHLlCiorK2Eymaw8GZHt4sccEz2CVqtFSkoKvvvuO+Tn56OmpgaNjY1wcXGBv78/EhISMGPGDMTHx0Oj0Vh7XCKbwrgQPcTRo0cxf/58nDlzBjExMUhOTsbgwYPh4eEBrVaL3NxcbN++HZcvX8a0adPw2Wefwd/f39pjE9kMxoXoAXv37sVbb70FDw8PLFmyBC+++CIaGxuxceNG6PV6dO7cGdOnT4fBYMDGjRuxcOFCRERE4Ntvv0W3bt2sPT6RTWBciO5z6dIlTJw4Ee7u7ti4cSPCw8MhhEBRURGGDRuG6upqBAcHIzc3Fz4+PpBS4siRI3jttdcwevRofPXVV3B2drb2t0FkdTygT3RPU1MTPv/8c1RVVWHt2rXmsDyOEAIjR47E73//e2zbtg27d+9up2mJbBvjQnTP5cuXsX37dkyePBkjR458YliaCSHw8ssvY/jw4Vi/fj2MRqOFJyWyfTzFheierKws1NbWYsqUKbh27Rrq6urMj5WUlKCpqQkA0NjYiPz8fHTu3Nn8eM+ePTF58mQsXLgQN2/eREBAQLvPT2RLGBeiey5evAg3NzeEhIRg1qxZyMzMND8mpYRerwcAlJWVYfz48ebHhBBYvnw5Bg0aBJ1Oh7KyMsaFOjzGheie+vp6aDQaODs7Q6/Xo6Gh4aF/Tkr5s8eMRiNcXV1bRIioI2NciO7p2rUr6uvrodVqERcXB3d3d/Nj9fX1yMrKMkdkxIgR5gsnhRAICgrCrVu3oFKp4OPjY61vgchmMC5E90RFRcFgMCA7OxtLly5t8VhRURFiYmJQXV2Nbt264YcffoC3t7f5cSEEFixYgO7du3OXGBF4thiRWWxsLEJCQrBhwwbU1dVBrVa3+NVMCAGVSmX+fZVKhRs3bmDz5s1ITk6Gl5eXFb8LItvAuBDd4+fnh/fffx8nT57E6tWrW31KsV6vx+9+9zvU19dj1qxZrT6FmciRcbcY0X3eeustZGRkYOnSpXBzc8Ps2bPh4uICANBoNNBoNOZVjJQSNTU1WLx4MTZu3IgVK1agX79+1hyfyGbw9i9ED7h9+zbee+897NixA0lJSZg/fz4GDBiAgoICmEwmdOrUCX369EF2djaWLVuGvLw8LFq0CLNnz26x+4yoI2NciB6irq4O69evx+rVq1FeXo6QkBCEhYXB09MTVVVVKCgoQFlZGaKiovDpp58iMTERKhX3MhM1Y1yIHuPmzZvYv38/0tPTcfr0aWRnZyMhIQHx8fGYMGEC4uLi4ObmZu0xiWwO40LUSjk5OYiNjUVOTg6io6OtPQ6RTeM6nqiV1Gq1+TRkIno8biVERKQ4xoWIiBTHuBARkeIYFyIiUhzjQkREimNciIhIcYwLEREpjnEhIiLFMS5ERKQ4xoWIiBTHuBARkeIYFyIiUhzjQkREimNciIhIcfw8F6JWklLCZDJBpVJBCGHtcYhsGlcuRE+Bn+VC1Doaaw9ApBQpJQoLC3Hnzh1rj9ImKpUKAwcOhLu7u7VHIXpm3C1GDsNkMuG9995DYGAgXFxcUFdXBy8vL2uP9dQOHz6MTz75BIMHD7b2KETPjCsXcijOzs4YMmQI1q1bB4PBgM2bN8PNzc3aY7WalBK1tbXgez6yd4wLORyTyYSDBw9CSon8/HzExMRYeySiDodHJ8nhREREYNCgQdDpdNixYwdXAURWwLiQw3F1dcVLL70EAEhLS0Ntba2VJyLqeBgXcjhCCLz44ovo3Lkzzp8/j9zcXK5eiNoZ40IOqV+/foiNjYVer0dKSgrjQtTOGBdySJ06dcKrr74KIQR27dqF27dvW3skog6FcSGHJIRAUlISevTogeLiYuzfv5+rF6J2xLiQw+rVqxfGjx8Pk8mE77//HgaDwdojEXUYjAs5LJVKhddffx3Ozs44cuQIzp07x9ULUTthXMhhCSEQFxeHIUOGoKamBn/9618ZF6J2wriQQ3N3d8frr78OIQRSUlJw8+ZNa49E1CEwLuTQhBB4+eWXERgYiOvXr2PLli1cvRC1A8aFHF6PHj0wdepUSCnxl7/8BXfv3rX2SEQOj3EhhyeEwBtvvAFfX1+cOXMGu3fv5uqFyMIYF3J4Qgj0798fL730EoxGI7788kvodDprj0Xk0BgX6hDUajVmzZoFT09PHDt2DHv27OHqhciCGBfqEIQQGDZsGCZNmgSDwYDVq1ejrq7O2mMROSzGhToMjUaDefPmoXPnzjh69Ci2b9/O1QuRhTAu1GE0r16mTJkCo9GI5cuXo7Ky0tpjETkkxoU6FLVajfnz56Nr1644ffo0NmzYwNULkQUwLtShCCEwYMAAvPPOO5BSYtWqVSgsLGRgiBTGuFCHo1KpMGfOHISHh6OkpASLFy/mHZOJFMa4UIfUrVs3fPzxx3B2dsbmzZuxbds2rl6IFMS4UIckhMAvf/lLTJ48GQ0NDfjP//xPFBcXMzBECmFcqMPq1KkTFi5ciJCQEFy6dAkff/wxGhoarD0WkUNgXKjDEkIgNDQUixcvhqurKzZv3oz169fDZDJZezQiu8e4UIcmhMArr7yCd999F0ajEYsWLcKPP/7I3WNEbcS4UIen0Wjw8ccfY8yYMaiqqsL777+P/Px8BoaoDRgX6vCEEPDx8cGaNWvQv39/XLlyBTNnzsT169cZGKJnxLgQ4W+B6du3L9atW4cePXogJycHM2fOxI0bNxgYomfAuBDdI4RAfHw8vvzyS/j6+uLAgQN4++23UVZWxsAQPSXGheg+QggkJyfjv//7v+Hj44N9+/bh17/+Na5evcrAED0FxoXoASqVClOmTME
2024-04-28 20:28:25 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-28 20:28:25 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = model.prune()\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "bd08ad99",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Continue training and replot"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 7,
2024-04-28 20:28:25 -04:00
"id": "18a2db11",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-21 19:17:50 -04:00
"| train_loss: 1.78e-02 | test_loss: 1.71e-02 | reg: 5.11e+00 | : 100%|█| 50/50 [00:06<00:00, 7.45it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 8,
"id": "8768d56c",
2024-04-28 20:28:25 -04:00
"metadata": {},
2024-07-21 19:17:50 -04:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.4\n"
]
}
],
2024-07-13 22:17:48 -04:00
"source": [
"model = model.refine(10)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "46f73098",
"metadata": {
"scrolled": false
},
2024-04-28 20:28:25 -04:00
"outputs": [
{
2024-07-13 22:17:48 -04:00
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-21 19:17:50 -04:00
"| train_loss: 2.91e-04 | test_loss: 3.06e-04 | reg: 5.11e+00 | : 100%|█| 50/50 [00:06<00:00, 7.25it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-07-13 22:17:48 -04:00
]
2024-04-28 20:28:25 -04:00
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "cf35d505",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Automatically or manually set activation functions to be symbolic"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 10,
2024-04-28 20:28:25 -04:00
"id": "b3c0642b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-21 19:17:50 -04:00
"fixing (0,0,0) with sin, r2=0.9999999588837324, c=2\n",
"fixing (0,1,0) with x^2, r2=0.9999999973097223, c=2\n",
"fixing (1,0,0) with exp, r2=0.999999987160932, c=2\n",
"saving model version 0.6\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"mode = \"auto\" # \"manual\"\n",
"\n",
"if mode == \"manual\":\n",
" # manual mode\n",
" model.fix_symbolic(0,0,0,'sin');\n",
" model.fix_symbolic(0,1,0,'x^2');\n",
" model.fix_symbolic(1,0,0,'exp');\n",
"elif mode == \"auto\":\n",
" # automatic mode\n",
" lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']\n",
" model.auto_symbolic(lib=lib)"
]
},
{
"cell_type": "markdown",
"id": "821ba616",
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"Continue training till machine precision"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 11,
2024-04-28 20:28:25 -04:00
"id": "c0800415",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-21 19:17:50 -04:00
"| train_loss: 1.97e-15 | test_loss: 1.87e-15 | reg: 5.11e+00 | : 100%|█| 50/50 [00:03<00:00, 15.86it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.7\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "markdown",
"id": "e39da499",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Obtain the symbolic formula"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 12,
2024-04-28 20:28:25 -04:00
"id": "bf44f7e0",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
2024-07-13 22:17:48 -04:00
"$\\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \\sin{\\left(3.1416 x_{1} \\right)}}$"
2024-04-28 20:28:25 -04:00
],
"text/plain": [
2024-07-13 22:17:48 -04:00
"1.0*exp(1.0*x_2**2 + 1.0*sin(3.1416*x_1))"
2024-04-28 20:28:25 -04:00
]
},
2024-04-28 20:32:30 -04:00
"execution_count": 12,
2024-04-28 20:28:25 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"from kan.utils import ex_round\n",
"\n",
"ex_round(model.symbolic_formula()[0][0],4)"
2024-04-28 20:28:25 -04:00
]
2024-07-13 22:17:48 -04:00
},
{
"cell_type": "code",
"execution_count": null,
"id": "16e635f0",
"metadata": {},
"outputs": [],
"source": []
2024-04-28 20:28:25 -04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}