2024-04-28 20:28:25 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Hello, KAN!"
]
},
2024-04-28 20:32:30 -04:00
{
"cell_type": "markdown",
"id": "59cf5cd0",
"metadata": {},
"source": [
"### Kolmogorov-Arnold representation theorem"
]
},
{
"cell_type": "markdown",
"id": "f88e5321",
"metadata": {},
"source": [
"Kolmogorov-Arnold representation theorem states that if $f$ is a multivariate continuous function\n",
"on a bounded domain, then it can be written as a finite composition of continuous functions of a\n",
"single variable and the binary operation of addition. More specifically, for a smooth $f : [0,1]^n \\to \\mathbb{R}$,\n",
"\n",
"\n",
"$$f(x) = f(x_1,...,x_n)=\\sum_{q=1}^{2n+1}\\Phi_q(\\sum_{p=1}^n \\phi_{q,p}(x_p))$$\n",
"\n",
"where $\\phi_{q,p}:[0,1]\\to\\mathbb{R}$ and $\\Phi_q:\\mathbb{R}\\to\\mathbb{R}$. In a sense, they showed that the only true multivariate function is addition, since every other function can be written using univariate functions and sum. However, this 2-Layer width-$(2n+1)$ Kolmogorov-Arnold representation may not be smooth due to its limited expressive power. We augment its expressive power by generalizing it to arbitrary depths and widths."
]
},
{
"cell_type": "markdown",
"id": "ebd8766a",
"metadata": {},
"source": [
"### Kolmogorov-Arnold Network (KAN)"
]
},
{
"cell_type": "markdown",
"id": "2cf3b1ee",
"metadata": {},
"source": [
"The Kolmogorov-Arnold representation can be written in matrix form\n",
"\n",
"$$f(x)={\\bf \\Phi}_{\\rm out}\\circ{\\bf \\Phi}_{\\rm in}\\circ {\\bf x}$$\n",
"\n",
"where \n",
"\n",
"$${\\bf \\Phi}_{\\rm in}= \\begin{pmatrix} \\phi_{1,1}(\\cdot) & \\cdots & \\phi_{1,n}(\\cdot) \\\\ \\vdots & & \\vdots \\\\ \\phi_{2n+1,1}(\\cdot) & \\cdots & \\phi_{2n+1,n}(\\cdot) \\end{pmatrix},\\quad {\\bf \\Phi}_{\\rm out}=\\begin{pmatrix} \\Phi_1(\\cdot) & \\cdots & \\Phi_{2n+1}(\\cdot)\\end{pmatrix}$$"
]
},
{
"cell_type": "markdown",
"id": "f6521452",
"metadata": {},
"source": [
"We notice that both ${\\bf \\Phi}_{\\rm in}$ and ${\\bf \\Phi}_{\\rm out}$ are special cases of the following function matrix ${\\bf \\Phi}$ (with $n_{\\rm in}$ inputs, and $n_{\\rm out}$ outputs), we call a Kolmogorov-Arnold layer:\n",
"\n",
"$${\\bf \\Phi}= \\begin{pmatrix} \\phi_{1,1}(\\cdot) & \\cdots & \\phi_{1,n_{\\rm in}}(\\cdot) \\\\ \\vdots & & \\vdots \\\\ \\phi_{n_{\\rm out},1}(\\cdot) & \\cdots & \\phi_{n_{\\rm out},n_{\\rm in}}(\\cdot) \\end{pmatrix}$$\n",
"\n",
"${\\bf \\Phi}_{\\rm in}$ corresponds to $n_{\\rm in}=n, n_{\\rm out}=2n+1$, and ${\\bf \\Phi}_{\\rm out}$ corresponds to $n_{\\rm in}=2n+1, n_{\\rm out}=1$."
]
},
{
"cell_type": "markdown",
"id": "1b410498",
"metadata": {},
"source": [
"After defining the layer, we can construct a Kolmogorov-Arnold network simply by stacking layers! Let's say we have $L$ layers, with the $l^{\\rm th}$ layer ${\\bf \\Phi}_l$ have shape $(n_{l+1}, n_{l})$. Then the whole network is\n",
"\n",
"$${\\rm KAN}({\\bf x})={\\bf \\Phi}_{L-1}\\circ\\cdots \\circ{\\bf \\Phi}_1\\circ{\\bf \\Phi}_0\\circ {\\bf x}$$"
]
},
{
"cell_type": "markdown",
"id": "54bbde9a",
"metadata": {},
"source": [
"In constrast, a Multi-Layer Perceptron is interleaved by linear layers ${\\bf W}_l$ and nonlinearities $\\sigma$:\n",
"\n",
"$${\\rm MLP}({\\bf x})={\\bf W}_{L-1}\\circ\\sigma\\circ\\cdots\\circ {\\bf W}_1\\circ\\sigma\\circ {\\bf W}_0\\circ {\\bf x}$$"
]
},
{
"cell_type": "markdown",
"id": "1c5f7795",
"metadata": {},
"source": [
"A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers. (2) Each KAN layer can be visualized as a fully-connected layer, with a 1D function placed on each edge. Let's see an example below."
]
},
{
"cell_type": "markdown",
"id": "adcb5f75",
"metadata": {},
"source": [
"### Get started with KANs"
]
},
2024-04-28 20:28:25 -04:00
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Initialize KAN"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
"outputs": [],
"source": [
"from kan import *\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,5,1], grid=5, k=3, seed=0)"
]
},
{
"cell_type": "markdown",
"id": "3d72e076",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Create dataset"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "46717e8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1000, 2]), torch.Size([1000, 1]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"dataset['train_input'].shape, dataset['train_label'].shape"
]
},
{
"cell_type": "markdown",
"id": "8c6add1d",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Plot KAN at initialization"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ac76f858",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB3NklEQVR4nO3dd1QTWfsH8CcJIYTeLIAV7L13RcUKYsECih3r2rEvuig21oq9995FXRXFjl1E1y72XlBAesl8f3/4S15ZG2WSmeD9nON5z9mXzDy5yc0zc+fe50oAgBiGYRiGR1KhA2AYhmFyH5ZcGIZhGN6x5MIwDMPwjiUXhmEYhncsuTAMwzC8Y8mFYRiG4R1LLgzDMAzvWHJhGIZheMeSC8MwDMM7llwYhmEY3rHkwjAMw/COJReGYRiGdyy5MAzDMLxjyYVhGIbhHUsuDMMwDO8MhA6AYfQBAPr48SPFx8eTqakp2djYkEQiEToshhEtdufCMD8RExND8+fPp+LFi1OePHmoaNGilCdPHipevDjNnz+fYmJihA6RYURJwnaiZJjvCwkJofbt21NiYiIRfbl7UVPftRgbG9Pu3bupefPmgsTIMGLFkgvDfEdISAi5ubkRAOI47od/J5VKSSKR0D///MMSDMN8hSUXhvmPmJgYKlCgACUlJf00sahJpVJSKpX08uVLsrS01H6ADKMH2DMXhvmP9evXU2JiYqYSCxERx3GUmJhIGzZs0HJkDKM/2J0Lw3wFABUvXpweP35MWekaEomEHB0dKTIyks0iYxhiyYVhMoiKiqI8efLk6PU2NjY8RsQw+okNizHMV+Lj43P0+ri4OJ4iYRj9xpILw3zl48ePOXq9mZkZT5EwjH5jyYX57UVHR9OKFSvI2dmZatSoke1nJjKZjCZPnkyXL1/O0vMahsmNWHJhfkspKSm0b98+at++Pdnb29OgQYPI2NiYNm7cSF5eXtk6ZqNGjejAgQPUoEEDKleuHE2dOpUePXrEc+QMox9YcmF+GwAoLCyMBg4cSA4ODtS+fXt69uwZzZgxg168eEE7d+6k48eP044dO0gmk2XpDkYikdCgQYPo4cOHdPjwYapduzYFBQVR2bJlydnZmZYvX06fPn3S4rtjGHFhs8WYXO/+/fu0efNm2rJlCz158oQKFixI3t7e1KVLFypbtiwREV27do28vb3p9evXtHDhQsqbNy+1atUqUyv0iYgqVqxI169fp9GjR9OUKVPI0NCQEhMT6cCBA7R161Y6duwYSaVSat68OXXp0oVcXV3JyMhIJ++fYQQBhsmF3r17h4ULF6JmzZqQSqWwtLREnz59cOrUKahUKs3fqVQqzJ49G0ZGRqhevToePHig+f+OHDkCExMTSCQSSCQSEJHmn/q/mZiYICQkBCqVCjNnzoSBgQGqVauGyMjIb+JZtGgR6tSpA4VCgbx582LAgAE4c+ZMhngYJrdgyYXJNRISErBt2za0atUKcrkchoaGaNOmDXbu3InExMRv/v7169do3rw5ZDIZxo4di5SUlG/+Jjo6GvPnz4eTk1OG5OLk5IT58+cjJiYmw99fuXIFxYoVg6mpKdavXw+O47455r179+Dv74/ixYtDoVCgWLFimDhxIu7evctfYzCMwFhyYfRaeno6Tpw4gV69esHCwgJSqRR16tTBkiVL8OHDhx++7tChQ8ifPz8cHBxw7NixX56H4zicPn0aVlZWOH369HeThtrnz5/Ro0cPEBG6dOmC2NjY7/6dSqVCWFgY/vjjD+TLlw8KhQK1atXCggUL8Pbt21+/eYYRMZZcGL3077//YuzYsShYsCCkUimKFy+OSZMmfTMc9V9JSUkYMWIEZDIZ3N3d8f79+0yf88aNG7CxscGNGzcy9fdbtmyBubk5ihYtigsXLvwyrr1796JTp04wNTWFUqmEu7s7tmzZgvj4+EzHyDBiwZILozdevXqF2bNno3LlypBKpbC1tcWgQYNw/vz5n95JqN25cweVK1eGUqnEggULMvWar2U1uQDA48ePUatWLchkMkybNg3p6em/fM3Hjx+xcuVKNGzYEAqFAtbW1ujVqxdCQ0Mz9XqGEQOWXBhR+/z5MzZs2ICmTZtCJpNBqVSiY8eO2L9//3efkXwPx3FYuXIlTE1NUa5cuSwlh69lJ7kAQGpqKvz8/CCRSNCwYUO8fPky0699/Pgxpk2bhrJly0KhUKBw4cIYO3Ysrl+/nuXkyDC6xJILIzppaWk4fPgwvL29YWJiAqlUikaNGmHVqlWIjo7O0rE+ffqEjh07QiaTYcCAAUhISMh2XNlNLmonT56Eg4MDrK2tsW/fviy9luM4XL58GSNGjIC9vT0UCgWqVKmCWbNmZSlZMYyusOTCiALHcbhy5QqGDx+O/PnzQyqVomzZspgxYwaePn2arWOePXsWRYoUga2tLfbs2ZPjGHOaXAAgKioKbdu2BRFh4MCB353F9iupqak4dOgQunbtCgsLCxgZGaFZs2ZYv379DycPMIyuseTCCOrJkyeYNm0aypQpA6lUCjs7O4wYMQLh4eHZHvZJS0uDv78/5HI5GjZsiOfPn/MSKx/JBfiSSJcuXQojIyOULVsW//77b7aPFRsbi/Xr16N58+YwMjKChYUFvL298c8//yA1NTVHcTJMTrDkwujcp0+fsHLlSjg7O0MqlcLU1BRdu3bFkSNHkJaWlqNjP3nyBPXq1YOhoSGmTJnC6wNwvpKL2q1bt1C+fHkoFAosWrQox89QXrx4gdmzZ6NKlSpQKBSwt7fH8OHDcfnyZfZ8htE5llwYnUhJScG+ffvQoUMHGBkZwcDAAM2aNcPGjRsRFxfHyzm2b98Oa2trODo64vz587wc82t8JxfgyxTkwYMHg4jg7u7+07U5mcVxHG7cuIGxY8eicOHCUCgUKFu2LKZOnYrHjx/zEDXD/BpLLozWcByHc+fOYeDAgbC1tYVUKkWVKlUwZ84cvHr1irfzxMXFwcfHBzKZDJ07d87yQ//M0kZyUdu/fz9sbGxgZ2eH0NBQ3o6bnp6O0NBQ9O7dG9bW1lAoFGjYsCFWrFiBjx8/8nYehvkvllwY3j148AD+/v4oVqwYpFIpChUqhHHjxuHmzZu8n+vatWsoXbo0zM3NsW7dOq0O/2gzuQBf1vG4uLhAIpFg3LhxvD8ziY+Px9atW+Hu7g6lUglTU1N07NgRe/fuRXJyMq/nYhiWXBhevH//HosWLULt2rUhlUphYWGB3r174+TJk1opzKhSqTB37lwYGRmhWrVquH//Pu/n+C9tJxfgy/sKDAyEgYEBatSogYcPH2rlPG/fvsWCBQtQu3ZtKBQK5MuXD3/88QfCwsJYIU2GFyy5MNmWmJiIHTt2wN3dHYaGhjA0NIS7uzu2b9+erSm2mfXmzRu0bNkSMpkMo0ePzvRiypzSRXJRu3TpEpycnGBmZoaNGzdq9Vx3797FX3/9hWLFikGhUKBEiRLw9/fHvXv3tHpeJndjyYXJEpVKhZMnT8LHxweWlpaQSqWoXbs2Fi1alKU6Xdl1+PBh2NnZwd7eHiEhIVo/39d0mVyAL9UJunXrBiKCt7e31tewqFQqnD17FgMGDEDevHmhUChQp04dLFy4UCefLZO7sOTCZMqtW7cwbtw4FCpUCFKpFMWKFYO/v79OhqMAIDk5Gb6+vpDJZHBzc8O7d+90ct6v6Tq5qG3atAlmZmZwdHTExYsXdXLOpKQk7NmzBx06dICpqSmMjY3RunVrbNu2LUdVDpjfB0suzA+9fv0ac+fORZUqVSCVSmFjY4OBAwfi3LlzOl03cffuXVSpUgVKpRLz588XbM2GUMkFAB49eoQaNWrAwMAA06dP12kBy48fP2L58uVo0KABFAoFbGxs4OPjg+PHj7NCmswPseTCZBAXF4eNGzeiefPmMDAwgJGRETp06IB9+/bpfEYRx3FYtWoVzMzMUKZMGVy/fl2n5/8vIZML8KXsy/jx4yGRSNCoUSNBaoo9evQIU6dORZkyZaBQKFCkSBGMGzcuR1UGmNyJJRcGaWlpCAkJQdeuXWFqagqpVApnZ2e
"text/plain": [
"<Figure size 500x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot KAN at initialization\n",
"model(dataset['train_input']);\n",
"model.plot(beta=100)"
]
},
{
"cell_type": "markdown",
"id": "ddf67e30",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Train KAN with sparsity regularization"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97111d75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:30 -04:00
"train loss: 1.57e-01 | test loss: 1.31e-01 | reg: 2.05e+01 : 100%|██| 20/20 [00:18<00:00, 1.06it/s]\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
]
},
{
"cell_type": "markdown",
"id": "2f30c3ab",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Plot trained KAN"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
2024-04-28 20:32:30 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIhklEQVR4nO3deXxU1fk/8M+dNSukhEVZJQtacStK/QkosksBRTYFslJaRSBSRCqy2H5FaXGpgEAVMCRsCoR9CZalIGBRiLKIIkQBQVkCJGSZfc7vD3qnwzBJZjkz99zJ83698kIhc+fMmXvvc+9zznmuxBhjIIQQQjjSKN0AQgghkYeCCyGEEO4ouBBCCOGOggshhBDuKLgQQgjhjoILIYQQ7ii4EEII4Y6CCyGEEO4ouBBCCOGOggshhBDuKLgQQgjhjoILIYQQ7ii4EEII4Y6CCyGEEO4ouBBCCOFOp3QDCFEDxhiuXLmCiooKxMXFITExEZIkKd0sQoRFdy6E1KC0tBSzZs1CamoqGjVqhNatW6NRo0ZITU3FrFmzUFpaqnQTCRGSRE+iJMS7bdu2YeDAgaiqqgJw4+5FJt+1xMTEoKCgAL169VKkjYSIioILIV5s27YNffr0AWMMTqez2t/TaDSQJAmbN2+mAEOIGwouhHgoLS1F8+bNYTKZagwsMo1Gg+joaJw7dw4JCQmhbyAhKkBjLoR4yMvLQ1VVlU+BBQCcTieqqqqQn58f4pYRoh5050KIG8YYUlNT8cMPP8CfQ0OSJCQlJeHkyZM0i4wQUHAh5CYlJSVo1KhRUK9PTEzk2CJC1InSYoS4qaioCOr15eXlnFpCiLpRcCHETVxcXFCvj4+P59QSQtSNggshbhITE5GcnOz3uIkkSUhOTkaDBg1C1DJC1IWCCyFuJEnC2LFj/X4dYww5OTk0mE/If1FwIcRDZmYmYmJioNH4fnjo9Xp06NAhhK0iRF0ouBDiISEhAQUFBZAkqdYAo9FooNFo8Nxzz2HZsmXYuHEjHA5HmFpKiLhoKjIh1fC1ttiaNWvQo0cP7Nq1C5s2bUKzZs2QmZmJhg0bKtJuQkRAwYWQGpSWliI/Px+zZ89GcXGx6++Tk5ORk5ODzMxM1K9f3/X3P/30E/Ly8lBeXo5BgwbhoYceonEYUidRcCHEB4wxXL16FeXl5YiPj0eDBg2qDRoWiwUFBQX44osv8OCDD2Lw4MGIiooKc4sJURYFF0JCpKioCCtXrkRsbCzS09Nxxx13KN0kQsKGggshIXTlyhUsWbIEZ8+eRe/evdGtWze/ZqERolYUXAgJMYfDgcLCQmzfvh3JyclIT0+/aZyGkEhEwYWQMDl16hSWLFkCm82GoUOH4t5771W6SYSEDAUXQsKosrISH3/8MY4ePYqOHTuif//+0Ov1SjeLEO4ouBASZowx7N+/H+vWrUNiYiIyMzNx++23K90sQrii4EKIQi5cuIC8vDxcvnwZ/fv3R8eOHWlNDIkYFFwIUZDNZsOGDRvw2WefoW3bthg2bBhiY2OVbhYhQaPgQogAvvnmGyxfvhxarRZpaWlo06aN0k0iJCgUXAgRRFlZGZYtW4aTJ0+iW7du6N27N7RardLNIiQgFFwIEQhjDDt37sTmzZvRvHlzZGRkUAFMokoUXAgR0NmzZ5Gfn4/y8nIMHjwYDz30kNJNIsQvFFwIEZTFYsHq1avx5ZdfUgFMojoUXAgR3KFDh7By5UrExcUhIyMDrVq1UrpJhNSKggshKnDlyhXk5eXh3Llz+N3vfoeuXbtSAUwiNAouhKiEw+HA1q1bsWPHDqSkpCAtLY0KYBJhUXAhRGVOnjyJpUuXwm63Y+jQobjnnnuUbhIht6DgQogKuRfA7NSpE5566ikqgEmEQsGFEJVijGHfvn1Yv349GjZsiIyMDCqASYRBwYUQlfvll1+Qn5+PkpISPPXUU1QAkwiBggshEcBms2H9+vXYu3cv7r33Xjz77LNUAJMoioILIRHk2LFjWLFiBXQ6HdLS0pCamqp0k0gdRcGFkAhTVlaGpUuX4tSpU1QAkyiGggshEcjpdGLnzp3YsmULWrRogYyMDCQmJirdLFKHUHAhJIKdPXsWeXl5qKiowJAhQ/Dggw8q3SRSR1BwISTCmc1mrF69GgcPHkT79u0xaNAgGI1GpZtFIhwFF0LqiIMHD2LVqlWIj49HRkYGWrZsqXSTSASj4EJIHVJSUoL8/HycO3cOffr0QdeuXWlNDAkJCi6E1DHuBTBTU1MxfPhwKoBJuKPgQkgd9f3332Pp0qVwOBwYNmwY2rZtq3STSASh4EJIHVZZWYnly5fjm2++waOPPoonn3ySCmASLii4EFLHyQUw161bh0aNGiEzMxO33Xab0s0iKkfBhRAC4EYBzLy8PFy5cgX9+/dHhw4daLCfBIyCCyHExWazYd26ddi3bx8VwCRBoeBCCLnF0aNHsWLFCuj1eqSnpyMlJUXpJhGVoeBCCPGqtLQUS5cuRXFxMbp3744nnniCCmASn1FwIYRUy+l0YseOHdi6dStatmyJ9PR0KoBJfELBhRBSq9OnT2PJkiWorKzEkCFD0K5dO6WbRARHwYUQ4hOz2YxVq1bh0KFD+O1vf4uBAwdSAUxSLQouhBCfMcZw8OBBrF69GvHx8cjMzESLFi2UbhYREAUXQojfSkpKkJeXh/Pnz6Nv377o0qULrYkhN6HgQggJiMPhwJYtW7Bjxw7ceeedGD58OOrVq6d0s4ggKLgQQoLiXgBz+PDhuPvuu5VuEhEABRdCSNAqKiqwfPlyHD9+HI899hj69etHBTDrOAouhBAuGGP47LPPsGHDBjRu3BiZmZlo0qSJ0s0iCqHgQgjh6ueff0ZeXh6uXr2Kp59+Go888ggN9tdBFFwIIdxZrVasW7cO+/fvx3333Ye0tDQYDAalm0XCiIILIcQv/pwySkpKcPHixYCeckl3O+pGwYUQ4pfCwkI0bdo0ZNu32Wwwm83o2LFjyN6DhJ5O6QYQQtTl4sWL6NmzZ0i2febMGaxevRoNGzak4KJyGqUbQAhRH41Gw/0HAD744AN07dpV4U9HeKDgQggRwrZt2xATE0MVlyMEBRdCiOIsFgs+/fRTjB07lgbyIwQFF0KIohhj+OCDD3D//ffjV7/6ldLNIZxQcCGEKOrixYs4ffo00tLSlG4K4YiCCyFEMYwxzJ07F08//TR0Opq8GkkouBBCFHPw4EHY7XZ06tRJ6aYQzii4EEIU4XA48PHHH2PUqFE0iB+BKLgQQhSxcuVKNG3alB6THKEouBBCwu769ev44osv8MILL9BdS4Si4EIICSvGGObMmYMuXbogOjpa6eaQEKHgQggJq4MHD6KiogJ9+/ZVuikkhCi4EELCxmQyYcWKFRg1apSrnhiJTPTtEkLCwul0Ys6cOWjfvj0N4tcBFFwIISHHGMO6detgsVjwzDPP0CB+HUDBhRASUowx7NmzB1988QXGjRtH6bA6guotEEJChjGG3bt3o7CwEC+++CLi4+OVbhIJEwouhJCQcDgc2LBhA4qKipCTk4Pbb79d6SaRMKLgQgjhrry8HEuXLkVFRQX+9Kc/oUGDBko3iYQZBRdCCDdOpxPffvstVq1ahZSUFIwYMQJGo1HpZhEFUHAhhPjtxIkTiI+PR0xMDHQ6HWw2Gy5cuIB9+/bhp59+Qp8+ffDQQw/R4H0dRsGFEOK3TZs2wWKxwOl0QqPRwOl0wmg0om3bthgyZAji4+NpunEdR8GFEOIXnU6H7t27w2q1wmw2w263Q6/XIy4uDgaDAefOnQv6PWiMRv0kxhhTuhGEEPWwWq0hfw+NRkNPplQ5Ci6EkJBxOByoqqpCTEwMtFqt0s0hYUSjbYSQkPnll18wdepU/PLLL0o3hYQZBRdCCCHcUXAhhBDCHQUXQggh3FFwIYQQwh0FF0IIIdxRcCGEEMIdBRdCCCHcUXAhhBDCHQUXQggh3FFwIYQQwh0FF0IIIdxRcCGEEMIdBRdCCCHcUXAhhBDCHQUXQggh3FFwIYQQwh0FF0IIIdxRcCGEEMIdBRdCCCHcUXA
2024-04-28 20:28:25 -04:00
"text/plain": [
"<Figure size 500x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
2024-04-28 20:32:30 -04:00
"id": "61d537b7",
2024-04-28 20:28:25 -04:00
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Prune KAN and replot (keep the original shape)"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
"execution_count": 6,
2024-04-28 20:32:30 -04:00
"id": "1269a698",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAo0ElEQVR4nO3de1xUdf4/8NcZYBABJQhNU0sGbMtS12x7rNruqii1RCIY3ri6fLcsQddcy9T2u5vuxWw3xEtm7TCIl1UhvEMrupW4W5mbly6mGCpeiIuDM4wwM8zn90df+XXxAsMZzsyc1/Px8B9ihpcfOb3m3N5HEkIIEBERyUijdAAiIvI+LBciIpIdy4WIiGTHciEiItmxXIiISHYsFyIikh3LhYiIZMdyISIi2bFciIhIdiwXIiKSHcuFiIhkx3IhIiLZsVyIiEh2LBciIpIdy4WIiGTnq3QAIk8ghEBdXR3MZjOCgoIQFhYGSZKUjkXktrjnQnQTRqMROTk5iIqKQnh4OPr374/w8HBERUUhJycHRqNR6YhEbknikyiJrq+0tBSJiYmwWCwAvtl7uebaXkvXrl1RWFiImJgYRTISuSuWC9F1lJaWIjY2FkIIOByOG36fRqOBJEnYtWsXC4boW1guRN9jNBrRp08fXL169abFco1Go0FAQACqqqoQEhLi+oBEHoDnXIi+x2AwwGKxtKlYAMDhcMBisSA/P9/FyYg8B/dciL5FCIGoqCicPn0a7dk0JElCREQETp48yavIiMByIfqO2tpahIeHd+j1YWFhMiYi8kw8LEb0LWazuUOvN5lMMiUh8mwsF6JvCQoK6tDrg4ODZUpC5NlYLkTfEhYWBp1O1+7zJpIkQafTITQ01EXJiDwLy4XoWyRJQlZWVrtfJ4RAdnY2T+YT/R+WC9H3pKWloWvXrtBo2r55+Pn5Yfjw4S5MReRZWC5E3xMSEoLCwkJIknTLgtFoNNBoNHjqqaewfv167NixAy0tLZ2UlMh98VJkohto62yxoqIijB07Fvv378fOnTtx5513Ii0tDbfffrsiuYncAcuF6CaMRiPy8/OxfPlyVFRUtH5dp9MhOzsbaWlp6N69e+vXz507B4PBAJPJhIkTJ2LYsGE8D0OqxHIhagMhBOrr62EymRAcHIzQ0NAblkZzczMKCwvx4Ycf4sEHH8STTz6JLl26dHJiImWxXIhc5PDhw9i8eTMCAwORkpKCu+++W+lIRJ2G5ULkQnV1dVi3bh3Onj2Lxx57DGPGjGnXVWhEnorlQuRiLS0tKCkpwd69e6HT6ZCSkvKd8zRE3ojlQtRJTp06hXXr1sFms2HKlCl44IEHlI5E5DIsF6JO1NjYiE2bNuHYsWMYMWIE4uPj4efnp3QsItmxXIg6mRACBw8eRHFxMcLCwpCWloZevXopHYtIViwXIoVcunQJBoMBNTU1iI+Px4gRI3hPDHkNlguRgmw2G7Zv3473338fAwcOxNSpUxEYGKh0LKIOY7kQuYFPP/0UGzZsgI+PD5KTkzFgwAClIxF1CMuFyE00NDRg/fr1OHnyJMaMGYPHHnsMPj4+SscicgrLhciNCCGwb98+7Nq1C3369EFqaioHYJJHYrkQuaGzZ88iPz8fJpMJTz75JIYNG6Z0JKJ2YbkQuanm5mZs3boVH330EQdgksdhuRC5uY8//hibN29GUFAQUlNTcddddykdieiWWC5EHqCurg4GgwFVVVX45S9/idGjR3MAJrk1lguRh2hpacGePXtQVlaGyMhIJCcncwAmuS2WC5GHOXnyJAoKCmC32zFlyhTcf//9Skci+gGWC5EH+vYAzJEjR2L8+PEcgEluheVC5KGEECgvL8e2bdtw++23IzU1lQMwyW2wXIg83MWLF5Gfn4/a2lqMHz+eAzDJLbBciLyAzWbDtm3bcODAATzwwAOYPHkyB2CSolguRF7k+PHj2LhxI3x9fZGcnIyoqCilI5FKsVyIvExDQwMKCgpw6tQpDsAkxbBciLyQw+HAvn37sHv3bvTt2xepqakICwtTOhapCMuFyIudPXsWBoMBZrMZSUlJePDBB5WORCrBciHyck1NTdi6dSsOHTqEhx56CBMnToS/v7/SscjLsVyIVOLQoUPYsmULgoODkZqain79+ikdibwYy4VIRWpra5Gfn4+qqirExsZi9OjRvCeGXILlQqQy3x6AGRUVhWnTpnEAJsmO5UKkUl9++SUKCgrQ0tKCqVOnYuDAgUpHIi/CciFSscbGRmzYsAGffvopHnnkETzxxBMcgEmyYLkQqdy1AZjFxcUIDw9HWloa7rjjDqVjkYdjuRARgG8GYBoMBtTV1SE+Ph7Dhw/nyX5yGsuFiFrZbDYUFxejvLycAzCpQ1guRPQDx44dw8aNG+Hn54eUlBRERkYqHYk8DMuFiK7LaDSioKAAFRUViI6OxqOPPsoBmNRmLBciuiGHw4GysjLs2bMH/fr1Q0pKCgdgUpuwXIjoliorK7Fu3To0NjYiKSkJQ4cOVToSuTmWCxG1SVNTE7Zs2YKPP/4YP/nJT5CYmMgBmHRDLBciajMhBA4dOoStW7ciODgYaWlp6Nu3r9KxyA2xXIio3Wpra2EwGHD+/Hk8/vjjGDVqFO+Joe9guRCRU1paWrB7926UlZXhnnvuwbRp09CtWzelY5GbYLkQUYd8ewDmtGnTcN999ykdidwAy4WIOsxsNmPDhg347LPP8LOf/QxxcXEcgKlyLBcikoUQAu+//z62b9+OHj16IC0tDT179lQ6FimE5UJEsrpw4QIMBgPq6+sxYcIE/PSnP+XJfhViuRCR7KxWK4qLi3Hw4EEMGjQIycnJ0Gq1SseiTsRyIaJ2ac//Mmpra1FdXe3UUy65t+PZWC5E1C4lJSXo3bu3y97fZrOhqakJI0aMcNnPINfzVToAEXmW6upqjBs3ziXvfebMGWzduhW33347y8XDaZQOQESeR6PRyP4HANasWYPRo0cr/LcjObBciMgtlJaWomvXrpy47CVYLkSkuObmZrzzzjvIysriiXwvwXIhIkUJIbBmzRoMHjwYt912m9JxSCYsFyJSVHV1NSorK5GcnKx0FJIRy4WIFCOEwMqVKzFhwgT4+vLiVW/CciEixRw6dAh2ux0jR45UOgrJjOVCRIpoaWnBpk2bMGPGDJ7E90IsFyJSxObNm9G7d28+JtlLsVyIqNNduXIFH374IZ555hnutXgplgsRdSohBHJzczFq1CgEBAQoHYdchOVCRJ3q0KFDMJvNePzxx5WOQi7EciGiTnP16lVs3LgRM2bMaJ0nRt6J/7pE1CkcDgdyc3Px0EMP8SS+CrBciMjlhBAoLi5Gc3MzJk2axJP4KsByISKXEkLgvffew4cffojZs2fzcJhKcN4CEbmMEALvvvsuSkpKMGvWLAQHBysdiToJy4WIXKKlpQXbt2/H4cOHkZ2djV69eikdiToRy4WIZGcymVBQUACz2Yzf/OY3CA0NVToSdTKWCxHJxuFw4PPPP8eWLVsQGRmJ6dOnw9/fX+lYpACWCxG124kTJxAcHIyuXbvC19cXNpsNly5dQnl5Oc6dO4fY2FgMGzaMJ+9VjOVCRO22c+dONDc3w+FwQKPRwOFwwN/fHwMHDkRSUhKCg4N5ubHKsVyIqF18fX0RHR0Nq9WKpqYm2O12+Pn5ISgoCFqtFlVVVR3+GTxH4/kkIYRQOgQReQ6r1eryn6HRaPhkSg/HciEil2lpaYHFYkHXrl3h4+OjdBzqRDzbRkQuc/HiRSxatAgXL15UOgp1MpYLERHJjuVCRESyY7kQEZHsWC5ERCQ7lgsREcmO5UJERLJjuRARkexYLkREJDuWCxERyY7lQkREsmO5EBGR7FguREQkO5YLERHJjuVCRESyY7kQEZHsWC5ERCQ7lgsREcmO5UJERLJjuRARkexYLkREJDuWCxERyY7lQkREsmO5EBGR7FguREQkO5YLERHJjuVCRESyY7kQEZHsWC5ERCQ7lgsREcmO5UJERLJjuRARkexYLkREJDuWCxERyY7lQkREsmO5EBGR7FguREQkO5Y
"text/plain": [
"<Figure size 500x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.prune()\n",
"model.plot(mask=True)"
]
},
{
"cell_type": "markdown",
"id": "576856cf",
"metadata": {},
"source": [
"Prune KAN and replot (get a smaller shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
2024-04-28 20:28:25 -04:00
"id": "7fe6fb12",
"metadata": {},
"outputs": [
{
"data": {
2024-04-28 20:32:30 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAyHklEQVR4nO3daXCUZYIH8P/Tnatzk5uQcOQgB6eBAOHSEUkQBA8cdRR1ZkfXqrFwd2Z3q3bnqHV2arbWDzuDY1njOFVSOOOFRgUEiYDggQQIIFcSkhACCeHI1eToTjrd/eyH2M/ychmSN3n7+P+q5sP7hO48Zvrt//vcQkopQUREpCOT0RUgIiL/w3AhIiLdMVyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3QUZXgMgXSCnR1taG7u5uREZGIj4+HkIIo6tF5LXYciG6BavVipdffhnZ2dlITEzEpEmTkJiYiOzsbLz88suwWq1GV5HIKwmeREl0Y2VlZVi9ejVsNhuAgdaLh6fVEh4ejtLSUpSUlBhSRyJvxXAhuoGysjKsWLECUkq43e6b/juTyQQhBLZu3cqAIboKw4XoGlarFWlpabDb7bcMFg+TyQSLxYKmpibExsaOfAWJfADHXIiusWHDBthstkEFCwC43W7YbDa8+eabI1wzIt/BlgvRVaSUyM7ORn19PW7n1hBCICMjA7W1tZxFRgSGC5FGa2srEhMTh/X6+Ph4HWtE5JvYLUZ0le7u7mG9vqurS6eaEPk2hgvRVSIjI4f1+qioKJ1qQuTbGC5EV4mPj0dmZuZtj5sIIZCZmYm4uLgRqhmRb2G4EF1FCIG1a9cO6bUvvPACB/OJvsMBfaJrcJ0L0fCx5UJ0jdjYWJSWlkIIAZPp1reIZ4X+hx9+yGAhugrDhegGSkpKsHXrVlgsFgghruvu8pRZLBZs27YNxcXFBtWUyDsxXIhuoqSkBE1NTVi3bh0yMjI0P8vIyMC6detw/vx5BgvRDXDMhWgQpJQ4ceIE/vCHP+AXv/gFpk6dysF7oltgy4VoEIQQGDNmDKKjozFmzBgGC9H3YLgQEZHuGC5ERKQ7hgsREemO4UJERLpjuBARke4YLkREpDuGCxER6Y7hQkREumO4EBGR7hguRESkO4YLERHpjuFCRES6Y7gQEZHuGC5EgxQTE4MHH3wQMTExRleFyOvxPBeiQXK5XLDZbAgPD4fZbDa6OkRejeFCRES6Y7cYERHpLsjoChB5uN1uNDQ0oLe31+iq+Lz09HRERUUZXQ0KYAwX8hpOpxMffvgh+vv7ERoaanR1fFZLSwueeOIJTJ061eiqUABjuJBXEULggQceQHZ2ttFV8QlutxvHjh1DWloaEhIS4HK58Je//AUcSiWjMVzI65jNZgQF8aP5faSUaGxsRGlpKUJDQ7FkyRIUFBQYXS0iAAwXIp/lcDiwfft22O122O12fPzxx5wiTV6Ds8WIfJCUEocPH0Ztba0qS0lJwZQpUwysFdH/Y7gQ+RgpJTo6OrBr1y643W4AQFBQEJYtW4aIiAiDa0c0gOFC5GPcbjd2796N1tZWVTZjxgzk5eUZWCsiLYYLkQ+RUqKhoQEHDx5UZdHR0Vi6dCnHW8irMFyIfIjD4UBZWZlaaCqEwOLFi5GUlAQhhMG1I/p/DBciHyGlxJEjR1BXV6fK0tPTUVRUBJOJtzJ5F34iiXyAlBJXrly5bhC/uLgY4eHhBteO6HoMFyIfIKXEl19+iZaWFlU2bdo05ObmsjuMvBLDhcjLSSnR1NSE8vJyVRYZGclBfPJqDBciL+d0OrFjxw7YbDZVtmDBAqSkpLDVQl6L4ULkxaSUqKysRFVVlSpLSUnBwoULOYhPXo2fTiIvZrPZsHPnTjidTgADm3rec889iIyMNLhmRLfGcCHyUlJKHDhwAE1NTaosOzsb06dPZ3cYeT2GC5EXklKivb0dX375pTqbJTQ0FMXFxQgODja4dkTfj+FC5IWklPjiiy/Q0dGhygoKCjBhwgS2WsgnMFyIvIznELBr9w/7wQ9+wEF88hn8pBJ5GafTiZ07d8Jut6uyRYsWISEhga0W8hkMFyIvIqVEVVWVZurx2LFjMW/ePLZayKfw00rkRex2+3VTj5csWcKpx+RzGC5EXkJKiYqKCjQ2NqqyrKwsTJs2jd1h5HMYLkRewmq14osvvtBMPb7nnnsQEhJicM2Ibh/DhcgLuN1ufPnll2hra1NlM2fOREZGBlst5JMYLkQGk1Li/Pnz2L9/vyqLiori1GPyafzkEhnM5XJh165d1+16zKOLyZcxXIgMJKXEqVOncOLECVWWnJyM+fPns9VCPo2fXiID9fb2YseOHWrqsclkwt13342oqCiDa0Y0PAwXIoNIKXHo0CGcO3dOlWVkZGDmzJnsDiOfx3AhMoCUEleuXMGePXvgdrsBACEhIVi6dCmnHpNfYLgQGUBKia+++gqtra2qbPr06cjKymKrhfwCw4VolEkp0dzcjPLyclUWGRmJJUuWcBCf/AY/yUSjzDP1uKenR5XNnz8fycnJbLWQ32C4EI0iKSVqamo0U4+TkpKwYMECtlrIr/DTTDSKent78dlnn6G/vx/A/089jo6ONrhmRPpiuBCNEs+ux5x6TIGA4UI0CqSUaG9vx+7duzVTj4uLixEaGmpw7Yj0x3AhGgVutxu7d+9Ge3u7Kps5cyYyMzPZaiG/xHAhGmFSSpw9exYVFRWqLDo6mlOPya/xk000whwOB7Zv347e3l4AgBACixcvRmJiIlst5LcYLkQjyDOIX1dXp8rS0tJQVFTEVgv5NX66iUaIZxB/165dahA/ODgYJSUlCA8PN7h2RCOL4UI0QtxuN3bs2KEZxJ8xYwZyc3PZHUZ+j+FCNAKklKiursahQ4dUWUxMDJYuXQqz2WxgzYhGB8OFSGdSSnR3d2Pbtm2alfhLlizh0cUUMBguRDrzrGk5f/68KsvMzMScOXMYLBQwGC5EOpJSoq6uDnv37lVl4eHhWLFiBVfiU0BhuBDpREqJrq4ubNmyBX19fQAG1rTceeedGD9+PFstFFAYLkQ6cbvd+Oyzz9DU1KTKJk6ciEWLFnFNCwUcfuKJdCClxNGjR7F//35VFhERgVWrVsFisRhYMyJjMFyIhklKiYsXL2LLli3XzQ6bMGECu8MoIDFciIZBSgmbzYbS0lJ0dHSo8vz8fJ4uSQGNn3yiYXA6ndi2bZtm77CEhATcf//9CAkJMbBmRMZiuBANkdvtxtdff43y8nJVFhYWhgcffBAJCQnsDqOAxnAhGgK3242jR49i+/btcLlcAAbGWZYuXYq8vDwGCwU8hgvRbZJSora2FqWlpWo9CwAUFhZy2jHRd3gXEN0GKSUaGhrw7rvvoru7W5Xn5ORg5cqVCA4ONrB2RN6D4UI0SFJKNDY24q233tLMDEtLS8MjjzyCiIgIdocRfSfI6AoQ+QIpJc6dO4e///3vaG1tVeVJSUl4/PHHERcXx2AhugrDheh7SClRX1+Pt956S3PwV3x8PNasWYOxY8cyWIiuwXAhugW3242qqiq899576OzsVOVxcXFYs2YN0tPTGSxEN8BwIboJl8uFiooKbNq0CTabTZUnJCRgzZo13NqF6BYYLkTXkFKiv78fu3fvxs6dO9V+YQAwduxYPPHEExg3bhyDhegWGC5EV/EcUbx582YcOnQIbrdb/WzixIl4/PHHkZiYyGAh+h4MF6LvSClx6dIlbNy4EfX19apcCIGpU6fi4YcfRnR0NIOFaBAYLkT4/4H70tJSzYwws9mM+fPnY/ny5QgLC2OwEA0Sw4UCmmd85euvv8Znn32G3t5e9bOwsDDce++9WLBgAYKCeKsQ3Q7eMRSwpJSwWq3YvHkzjh49qhlfiYuLw+rVq5GXl8e9woiGgOFCAUdKCbfbjVOnTmHTpk24dOmS5ueZmZl4+OGHkZKSwm4woiFiuFBA8cwG27VrF/bt26fZ1dhsNmPu3LlYvnw59wk
2024-04-28 20:28:25 -04:00
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = model.prune()\n",
"model(dataset['train_input'])\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "bd08ad99",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Continue training and replot"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 8,
2024-04-28 20:28:25 -04:00
"id": "18a2db11",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:30 -04:00
"train loss: 4.74e-03 | test loss: 4.80e-03 | reg: 2.98e+00 : 100%|██| 50/50 [00:07<00:00, 7.03it/s]\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 9,
2024-04-28 20:28:25 -04:00
"id": "af27aba7",
"metadata": {},
"outputs": [
{
"data": {
2024-04-28 20:32:30 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAySUlEQVR4nO3deVBVZ54+8OdcLvsOXgQCqCyCC6CogEsWjWs0W9uTTBbLTmWmp2aJPanqyWRmqrpmqmd+Pd0906OTms5UZspu7XTSE6PpJGqCC0bjAgZE9iiLC6CCys6Fy733vL8/zH2bY1xADpy7PJ+q/PEehfuV3MNz3+W8ryKEECAiItKRyegCiIjI+zBciIhIdwwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLSndnoAog8gRACN2/eRH9/P8LCwhAbGwtFUYwui8htsedCdA/d3d3Ytm0bMjIyYLFYMGPGDFgsFmRkZGDbtm3o7u42ukQit6TwJEqiOysqKsLGjRthtVoB3Oq9uLh6LSEhIdi9ezfWrFljSI1E7orhQnQHRUVFWL9+PYQQUFX1rn/PZDJBURTs27ePAUM0AsOF6Dbd3d1ISkrC4ODgPYPFxWQyITg4GK2trYiKipr4Aok8AOdciG6zY8cOWK3WUQULAKiqCqvVip07d05wZUSegz0XohGEEMjIyEBzczPGcmsoioLU1FQ0NDRwFRkRGC5EGjdu3IDFYhnX18fGxupYEZFn4rAY0Qj9/f3j+vq+vj6dKiHybAwXohHCwsLG9fXh4eE6VULk2RguRCPExsYiLS1tzPMmiqIgLS0NMTExE1QZkWdhuBCNoCgKXnvttQf62i1btnAyn+gbnNAnug2fcyEaP/ZciG4TFRWF3bt3Q1EUmEz3vkVcT+jv2bOHwUI0AsOF6A7WrFmDffv2ITg4GIqifGu4y3UtODgY+/fvx+rVqw2qlMg9MVyI7mLNmjVobW3F1q1bkZqaqvmz1NRUbN26FW1tbQwWojvgnAvRKAghUFVVhZ///Of4m7/5G+Tk5HDynuge2HMhGgVFURAdHY3w8HBER0czWIjug+FCRES6Y7gQEZHuGC5ERKQ7hgsREemO4UJERLpjuBARke4YLkREpDuGCxER6Y7hQkREumO4EBGR7hguRESkO4YLERHpjuFCRES6Y7gQjVJ0dDT++I//GNHR0UaXQuT2eJ4L0Sg5nU5YrVaEhITAz8/P6HKI3BrDhYiIdMdhMSIi0p3Z6AKIXFRVRXNzM4aGhowuxeOlpKQgIiLC6DLIhzFcyG04HA7s2rULdrsdQUFBRpfjkYQQ6OjowObNm5GTk2N0OeTDGC7kVhRFwXe/+11kZGQYXYpHGBoaQltbGxITExEcHAyn04n/+q//AqdSyWgMF3I7JpMJ/v7+Rpfh9oQQOH/+PHbu3Ino6GjMnj0bDz/8MBRFMbo0IoYLkacSQqCqqgp2ux0dHR3o7u5GXl6e0WURAeBqMSKP1dvbi4aGBtmOj49HYmKigRUR/QHDhcgDCSHQ2NiInp4eeW3u3LkcTiS3wXAh8kCuITHXxH1AQADmzJnD+RZyGwwXIg/U3d3NITFyawwXIg/jWiXW29srr+Xk5CAgIMDAqoi0GC5EHkZVVZw9e1YOiQUGBmLu3LkcEiO3wnAh8jA3b95Ec3OzbCcnJyM+Pt7Aioi+jeFC5EGEEKirq0N/f7+8lpuby1Vi5HYYLkQexG63o6KiQrZDQkK4SozcEsOFyINcuXIFly9flu3U1FRYLBYDKyK6M4YLkYcQQqCiogI2mw3ArU0+FyxYAJOJtzG5H74riTzEwMAAqqqqZDs6OhpZWVkcEiO3xHAh8gBCCDQ0NOD69evy2pw5cxAeHm5gVUR3x3Ah8gCqquL06dNQVRUA4O/vj4ULF7LXQm6L4ULkAdrb2zXbvSQlJSElJYXhQm6L4ULk5oQQOHPmDKxWq7y2aNEibvdCbo3hQuTmBgYGUF5eLtuRkZHIzs5mr4XcGsOFyI0JIVBfX6+ZyM/OzkZ0dLSBVRHdH8OFyI05HA6cOnVKTuQHBASgoKCAvRZyewwXIjclhMDFixfR1NQkr6WmpnIinzwCw4XITamqipMnT8JutwMATCYTlixZArPZbHBlRPfHcCFyQ0IItLe3o7q6Wl6Lj4/H7Nmz2Wshj8BwIXJDQgicOnVKs/y4sLAQwcHBBlZFNHoMFyI31NXVpVl+HB0djby8PPZayGMwXIjcjBACJSUl6O7ultfy8/MRFRVlWE1EY8VwIXIz3d3dKCkpke3w8HAUFhZya33yKHy3ErkRIQRKS0tx8+ZNeW3BggU8EIw8DsOFyI10d3fjxIkTsh0aGoply5ZxroU8DsOFyE24nmu5vdcSHx/PcCGPw3AhchOdnZ2aXktYWBgeeeQRBgt5JIYLkRtQVRVHjx791gox9lrIUzFciAwmhEBbW5tmhVhkZCR7LeTRGC5EBnM6nTh48CAGBgbktWXLlmHKlCkMF/JYDBciAwkhcO7cOVRWVsprU6dOxbJly/hcC3k0vnuJDDQ4OIj9+/drdj5euXIlIiIiDK6MaHwYLkQGUVUVJ06cwMWLF+W1tLQ0LFiwgMNh5PEYLkQGcG2pX1xcDCEEACAwMBBPPPEEAgMDDa6OaPwYLkQGsNvt2Lt3L3p6euS1/Px8pKens9dCXoHhQjTJhBAoLy9HVVWVvGaxWLB69WpO4pPX4DuZaBIJIdDR0YF9+/bB6XQCAPz8/LBu3TrExMSw10Jeg+FCNImGh4fx8ccfo7OzU16bN28eJ/HJ6zBciCaJa3XYyOGw2NhYPPnkkzCbzQZWRqQ/hgvRJBBCoLm5GZ999hlUVQUAmM1mPPnkk7BYLOy1kNdhuBBNMCEEenp68MEHH2i2eCkoKEBeXh6DhbwSw4Vogg0PD2P37t1obW2V11JSUrBhwwb4+fkZWBnRxGG4EE0gp9OJw4cPo6KiQl4LCwvDc889h4iICPZayGsxXIgmiKqqKC8vx4EDBzTzLE899RSmT5/OYCGvxnAhmgBCCDQ0NODDDz/E8PAwAEBRFDz88MMoLCzkw5Lk9fgOJ9KZEAItLS1499130d/fL6/PnTsX69ev5zwL+QSGC5GOXBtS7ty5Ezdv3pTXU1JS8PzzzyM4OJjDYeQTGC5EOhFC4Pr16/j1r3+NK1euyOsWiwWbNm1CdHQ0g4V8BsOFSAeuYPnVr36Fy5cvy+uRkZHYtGkTEhMTGSzkU7jnBNE4CSFw9epV/PrXv9Y8yxIeHo6XX34ZaWlpDBbyOQwXonEQQuDixYvYuXMn2tvb5fWwsDC8/PLLmDVrFoOFfBLDhegBqaqK2tpavP/+++ju7pbXXT2WOXPmcMkx+SyGC9EYCSHgdDpx4sQJfPLJJxgcHJR/Fh0djU2bNiEzM5M9FvJpDBeiMRBCwGq1Yt++ffjyyy/lgV8AEB8fj02bNvHpeyIwXIhGTQiBK1euYNeuXTh//rzmz9LT0/HSSy8hLi6OwUIEhgvRfQkh4HA4UFZWhk8//VQzv2IymbBgwQJs3LgR4eHhDBaibzBciO5BCIEbN25g7969OHPmjGYYLDAwEKtXr8aKFSsQEBDAYCEageFCdAdCCNhsNnz11VcoKirSnHkP3Hrq/rvf/S5XhBHdBcOFaAQhBFRVRUNDAz777DM0NTXJ7fKBW8Ngubm5ePbZZxEbG8veCtFdMFyI8IdQaW1txaFDh1BdXS23yneJiorCunXrUFBQAH9/fwYL0T0wXMinuUKlpaUFX3zxBaqqqjA0NKT5O/7+/pg3bx6eeOIJrgYjGiWGC/kkIQTsdjuam5tx7Ngx1NfXw2azaf6OoiiYNm0a1q5di9mzZ8PPz4/BQjRKDBfyGUIICCHQ19eH6upqlJSU4PLly3A4HN/6u1OnTsXy5cuxaNEiBAUFMVSIxojhQl5NCAEAsNlsaGlpQVlZGaqrqzXPqrgoigKLxYKHH34Y+fn5CAsLY6gQPSCGC3klIQS
2024-04-28 20:28:25 -04:00
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "cf35d505",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Automatically or manually set activation functions to be symbolic"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 10,
2024-04-28 20:28:25 -04:00
"id": "b3c0642b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-04-28 20:32:30 -04:00
"fixing (0,0,0) with sin, r2=0.999987252534279\n",
"fixing (0,1,0) with x^2, r2=0.9999996536741071\n",
"fixing (1,0,0) with exp, r2=0.9999988529417926\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"mode = \"auto\" # \"manual\"\n",
"\n",
"if mode == \"manual\":\n",
" # manual mode\n",
" model.fix_symbolic(0,0,0,'sin');\n",
" model.fix_symbolic(0,1,0,'x^2');\n",
" model.fix_symbolic(1,0,0,'exp');\n",
"elif mode == \"auto\":\n",
" # automatic mode\n",
" lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']\n",
" model.auto_symbolic(lib=lib)"
]
},
{
"cell_type": "markdown",
"id": "821ba616",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Continue training to almost machine precision"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 11,
2024-04-28 20:28:25 -04:00
"id": "c0800415",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-04-28 20:32:30 -04:00
"train loss: 2.02e-10 | test loss: 1.13e-10 | reg: 2.98e+00 : 100%|██| 50/50 [00:02<00:00, 22.59it/s]\n"
2024-04-28 20:28:25 -04:00
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
]
},
{
"cell_type": "markdown",
"id": "e39da499",
"metadata": {},
"source": [
2024-04-28 20:32:30 -04:00
"Obtain the symbolic formula"
2024-04-28 20:28:25 -04:00
]
},
{
"cell_type": "code",
2024-04-28 20:32:30 -04:00
"execution_count": 12,
2024-04-28 20:28:25 -04:00
"id": "bf44f7e0",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \\sin{\\left(3.14 x_{1} \\right)}}$"
],
"text/plain": [
"1.0*exp(1.0*x_2**2 + 1.0*sin(3.14*x_1))"
]
},
2024-04-28 20:32:30 -04:00
"execution_count": 12,
2024-04-28 20:28:25 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.symbolic_formula()[0][0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}