2024-08-11 13:02:16 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
2024-08-11 13:04:01 -04:00
"# API 7: Pruning\n",
2024-08-11 13:02:16 -04:00
"\n",
"We usually use pruning to make neural networks sparser hence more efficient and more interpretable. KANs provide two ways of pruning: automatic pruning, and manual pruning."
]
},
{
"cell_type": "markdown",
"id": "7fd6a742",
"metadata": {},
"source": [
"## Pruning nodes"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-11 16:06:09 -04:00
"cuda\n",
2024-08-11 13:02:16 -04:00
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 16:06:09 -04:00
"| train_loss: 3.46e-02 | test_loss: 3.46e-02 | reg: 4.91e+00 | : 100%|█| 20/20 [00:05<00:00, 3.36it\n"
2024-08-11 13:02:16 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"data": {
2024-08-11 16:06:09 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAym0lEQVR4nO3deVhUZf8/8PcZkB1lVTM1AdEk9w00TU3TRy1Ncckl9xZTLPNJS7uyckvbQNsUS1FxKTXN3Po+lmguaG6YiqC4gogIgywDzHL//ijm55QLMxzmzAzv13XNdRWz8OHDHN/c933m3JIQQoCIiEhGKqULICIix8NwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIds5KF0BkD4QQuH37NgoKCuDl5QV/f39IkqR0WUQ2iyMXogdQq9WIiYlBaGgoAgMDERQUhMDAQISGhiImJgZqtVrpEolsksSdKInubffu3YiMjERRURGAv0YvZcpGLR4eHti0aRN69eqlSI1EtorhQnQPu3fvRt++fSGEgMFguO/jVCoVJEnC9u3bGTBEd2G4EP2DWq1G3bp1odFoHhgsZVQqFdzd3XH9+nX4+PhUfoFEdoBrLkT/EBcXh6KionIFCwAYDAYUFRVh1apVlVwZkf3gyIXoLkIIhIaGIi0tDeYcGpIkITg4GKmpqTyLjAgMFyIT2dnZCAwMrNDz/f39ZayIyD5xWozoLgUFBRV6fn5+vkyVENk3hgvRXby8vCr0fG9vb5kqIbJvDBeiu/j7+yMkJMTsdRNJkhASEgI/P79KqozIvjBciO4iSRKioqIseu6UKVO4mE/0Ny7oE/0DP+dCVHEcuRD9g4+PDzZt2gRJkqBSPfgQKfuE/ubNmxksRHdhuBDdQ69evbB9+3a4u7tDkqR/TXeVfc3d3R07duxAz549FaqUyDYxXIjuo1evXrh+/Tqio6MRHBxscl9wcDCio6ORnp7OYCG6B665EJWDEAK//fYbunfvjj179qBbt25cvCd6AI5ciMpBkiTjmoqPjw+DheghGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkQPodVqkZ6ejnPnzgEALl68iJycHBgMBoUrI7Jd3OaY6D7UajU2bdqE+Ph4nDlzBvn5+SgtLYWbmxsCAwPRuXNnjB8/Hk8++SScnZ2VLpfIpjBciO7h0KFDmDp1KpKSktCuXTv07dsXzZs3h5eXF9RqNY4dO4Zt27bhwoULGDp0KObOnYvAwEClyyayGQwXon/45ZdfMGbMGHh5eWHBggXo06cPSktLsX79epSUlKB69ep44YUXoNVqsX79erz//vt44oknsHr1atSqVUvp8olsAsOF6C4pKSn4z3/+A09PT6xfvx5hYWGQJAlpaWlo3bo18vLyEBQUhGPHjsHX1xdCCPz+++8YPnw4unbtiuXLl8PV1VXpH4NIcVzQJ/qbXq/H/PnzkZubiy+++MIYLA8iSRI6deqERYsWYevWrdi1a5eVqiWybQwXor9duHAB27Ztw8CBA9GpU6eHBksZSZLw/PPPIyIiArGxsdDpdJVcKZHt4ykuRH87ePAgCgoKEBkZicuXL6OwsNB43/Xr16HX6wEApaWlOHPmDKpXr268v06dOhg4cCDef/99ZGZmom7dulavn8iWMFyI/pacnAwPDw8EBwfjlVdewYEDB4z3CSFQUlICAMjIyMAzzzxjvE+SJHz66ado1qwZioqKkJGRwXChKo/hQvQ3jUYDZ2dnuLq6oqSkBMXFxfd8nBDiX/fpdDq4u7ubhBBRVcZwIfpbzZo1odFooFarER4eDk9PT+N9Go0GBw8eNIZIx44djR+clCQJ9evXR1ZWFnQ6HZKSktC8eXPUqFFDqR+FSHEMF6K/tWnTBlqtFkeOHMHChQtN7ktLS0O7du2Ql5eHWrVqYcOGDfDx8THeL0kSZs6cCZ1Oh8mTJ+P1119Hq1at0KVLF3Tp0gWdO3c2eTyRo+PZYkR/a9++PYKDgxEXF4fCwkI4OTmZ3MpIkgSVSmX8ukqlwo0bN7Bx40ZMnjwZqampWLp0KcLCwvDDDz+gf//+8Pf3R5s2bTBt2jRs3boVOTk5Cv6kRJWP4UL0N39/f0yePBnHjx/H4sWLy31KcUlJCebMmQONRoNXX30VDRs2xPjx4xEXF4fLly/j4sWLWL58OZo1a4ZNmzZhwIABCAwMRKtWrTB16lRs2bIFt2/fruSfjsi6OC1GdJcxY8Zg3759WLhwITw8PDBx4kS4ubkBAJydneHs7GwcxQghkJ+fj3nz5mH9+vX4/PPP0bhxY5PXkyQJQUFBCAoKwtixYwEAly9fRkJCAhISErB161bExMQAAJo3b26cRnvqqacQEBBgxZ+cSF68/AvRP9y6dQuTJk3Czz//jF69emHq1Klo0qQJzp8/D4PBABcXFzRs2BBHjhzBJ598gpMnT+LDDz/ExIkTTabPyuvKlStISEjAvn37sHfvXqSlpQEAmjZtiq5du+Kpp55Cly5deGFMsisMF6J7KCwsRGxsLBYvXoybN28iODgYoaGh8Pb2Rm5uLs6fP4+MjAy0adMGs2fPRpcuXaBSyTPLfO3aNePIJiEhARcuXAAAhIWFoWvXrsaRDS+SSbaM4UL0AJmZmdizZw8SEhKQlpaG4uJi+Pr6omnTpujZsyfCw8Ph4eFRqTWkp6ebhE1KSgoAoEmTJsZptC5duqB27dqVWgeRORguROWk1+shhIBKpZJtlGKJjIwMk2m08+fPAwAaN25sMo1Wp04dxWokYrgQ2bkbN25g3759xpFN2XbMjRo1MhnZPProowpXSlUJw4XIwdy8edMYNnv37sXZs2cBAA0bNjQJm3r16ilcKTkyhguRg8vKysK+ffuM02h//vknACA4ONg4jda1a1fUr19f4UrJkTBciKqY7Oxsk2m0pKQkAECDBg2MZ6N16dIFDRo0ULZQsmsMF6Iq7vbt29i/f79xGi0pKQlCCDz22GPGoOnatSsaNGhQ7g3UiBguRGQiJycH+/fvN06jnTx5EkII1KtXz2QaLTg4mGFD98VwIaIHys3Nxe+//26cRjtx4gQMBgMeffRRk2m0hg0bMmzIiOFCRGbJy8szhs3evXtx/PhxGAwG1KlTx2QaLTQ0lGFThTFciKhC7ty5g99//904jXbs2DHo9XrUrl3bZBqtcePGDJsqhOFCRLLKz8/HgQMHjNNoR48ehV6vR61atYzXRevatSuaNGnCsHFgDBciqlQFBQU4ePCgcRrt6NGj0Ol0CAwMNJlGCwsLY9g4EIYLEVlVYWEhDh48aJxGO3LkCLRaLQICAkxGNk888YSi13CjimG4EJGiioqKcOjQIeM0WmJiIkpLS+Hv72+8CGeXLl3QrFkzho0dYbgQkU3RaDQ4fPiwcRrt8OHDKC0tha+vrzFsunbtiubNmzNsbBjDhYhsmkajQWJionFkc+jQIZSUlMDHxwdPPfWUcRqtRYsWFu0ESpWD4UJEdqW4uBhHjhzB3r17sW/fPhw8eBDFxcWoUaMGOnfubJxGa9myJZydnZUut8piuBCRXSspKcHRo0eNYXPgwAFoNBpUr14dnTp1MoZN69atGTZWxHAhIodSWlqKo0ePGqfRDhw4gKKiInh7e6NTp07GabTWrVujWrVqSpfrsBguROTQSktLcezYMePI5vfff0dhYSE8PT1NRjZt27Zl2MiI4UJEVYpWq8Xx48eNYbN//34UFBTA09MTHTt2NIZNu3bt4OLionS5dovhQkRVmk6nw/Hjx43TaPv370d+fj7c3d3x5JNPGqfR2rVrB1dXV6XLtRsMFyKiu+h0Opw8edI4stm3bx/u3LkDNzc3k5FNeHg4w+YBGC5ERA+g1+tx6tQp7N271ziyUavVcHNzQ0REhDF
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 22 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from kan import *\n",
2024-08-11 16:06:09 -04:00
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)\n",
"\n",
2024-08-11 13:02:16 -04:00
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
2024-08-11 16:06:09 -04:00
"model = KAN(width=[2,5,1], grid=5, k=3, seed=1, device=device)\n",
2024-08-11 13:02:16 -04:00
"\n",
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
2024-08-11 16:06:09 -04:00
"dataset = create_dataset(f, n_var=2, device=device)\n",
2024-08-11 13:02:16 -04:00
"dataset['train_input'].shape, dataset['train_label'].shape\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20, lamb=0.01);\n",
"model(dataset['train_input'])\n",
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "280cc49f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.2\n"
]
},
{
"data": {
2024-08-11 16:06:09 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuMklEQVR4nO3de1xU9b7/8fd3zXAZ7orgJUUF8ZqKImiAikVg+dhuwzp5bO9ztDyZXTz16Ozfbqe7tKx2Vy+lXaxHxzqV7iN2wWuPzDsaF/GyUdBEcw8giDIIDMwws76/P5I5UqYoa5gL7+fj4T8OAx+UxWu+a61ZS0gpJYiIiDSkuHoAIiLyPowLERFpjnEhIiLNMS5ERKQ5xoWIiDTHuBARkeYYFyIi0hzjQkREmmNciIhIc4wLERFpjnEhIiLNMS5ERKQ5xoWIiDTHuBARkeYYFyIi0pze1QMQeQIpJS5cuID6+noEBQUhPDwcQghXj0XktrhyIboGk8mE5cuXIzY2FhEREejfvz8iIiIQGxuL5cuXw2QyuXpEIrckeCdKoqvbtm0bpk+fDrPZDODn1UuLllVLQEAAsrKykJGR4ZIZidwV40J0Fdu2bcOUKVMgpYSqqr/5cYqiQAiBTZs2MTBEV2BciH7BZDKhd+/eaGxsvGZYWiiKAoPBAKPRiLCwMOcPSOQBeMyF6BfWrFkDs9ncprAAgKqqMJvN+OSTT5w8GZHn4MqF6ApSSsTGxqK0tBQ3smkIIRAdHY2TJ0/yLDIiMC5ErVRXVyMiIqJdzw8PD9dwIiLPxN1iRFeor69v1/Pr6uo0moTIszEuRFcICgpq1/ODg4M1moTIszEuRFcIDw9HTEzMDR83EUIgJiYGXbt2ddJkRJ6FcSG6ghACTzzxxE09d/78+TyYT3QZD+gT/QLf50LUfly5EP1CWFgYsrKyIISAolx7E2l5h/6GDRsYFqIrMC5EV5GRkYFNmzbBYDBACPGr3V0tf2cwGLB582akp6e7aFIi98S4EP2GjIwMGI1GLFu2DNHR0a0ei46OxrJly1BWVsawEF0Fj7kQtYGUEjt27MAdd9yB7du3Y9KkSTx4T3QNXLkQtYEQwnFMJSwsjGEhug7GhYiINMe4EBGR5hgXIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnGhYiINMe4EBGR5hgXIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnGheg6mpubUVZWhuPHjwMATp06hYsXL0JVVRdPRuS+eJtjot9gMpmQlZWFzz77DEVFRairq4PVaoW/vz8iIiIwfvx4PPTQQ0hOToZer3f1uERuhXEhuor9+/fjqaeewpEjR5CQkIApU6ZgxIgRCAoKgslkQkFBAbKzs/Hjjz/i/vvvx5IlSxAREeHqsYncBuNC9AvffvstZs2ahaCgILzyyiu4++67YbVasXbtWlgsFoSEhGDGjBlobm7G2rVrsWjRIgwbNgyffvopunfv7urxidwC40J0hRMnTmDy5MkIDAzE2rVrMXToUAghUFpaitGjR6O2thb9+/dHQUEBunTpAikl9u7di5kzZyI1NRUffvgh/Pz8XP1tELkcD+gTXWa32/Hyyy+jpqYG77zzjiMs1yKEQEpKCl577TV8/fXX2Lp1awdNS+TeGBeiy3788UdkZ2cjMzMTKSkp1w1LCyEEpk2bhnHjxmH16tWw2WxOnpTI/fEUF6LLcnJyUF9fj+nTp+PMmTNoaGhwPGY0GmG32wEAVqsVRUVFCAkJcTzeq1cvZGZmYtGiRTh37hx69+7d4fMTuRPGheiy4uJiBAQEIDo6GnPnzsW+ffscj0kpYbFYAADl5eW48847HY8JIfDmm29i+PDhMJvNKC8vZ1yo02NciC5rbGyEXq+Hn58fLBYLmpqarvpxUspfPWaz2WAwGFpFiKgzY1yILouMjERjYyNMJhPGjh2LwMBAx2ONjY3IyclxRCQpKcnxxkkhBKKiolBVVQVFUdClSxdXfQtEboNxIbosPj4ezc3NyM3NxauvvtrqsdLSUiQkJKC2thbdu3fHunXrEBYW5nhcCIFnn30WPXr04C4xIvBsMSKHxMREREdHY82aNWhoaIBOp2v1p4UQAoqiOP5eURRUVFRg/fr1mDJlCkJDQ134XRC5B8aF6LLw8HA8/vjjOHjwIFasWNHmU4otFgtefPFFNDY2Yu7cuW0+hZnIm3G3GNEVZs2ahd27d+PVV19FQEAA5s2bB39/fwCAXq+HXq93rGKklKirq8NLL72EtWvXYunSpRg0aJArxydyG7z8C9EvnD9/Ho899hg2btyIjIwMPPXUUxgyZAhKSkqgqip8fX0xYMAA5Obm4o033sChQ4fwwgsvYN68ea12nxF1ZowL0VU0NDRg9erVWLFiBSorKxEdHY3Y2FgEBwejpqYGJSUlKC8vR3x8PJ5//nlMnDgRisK9zEQtGBeiazh37hy2b9+OXbt24fDhw8jNzcX48eORnJyM9PR0jB07FgEBAa4ek8jtMC5EbZSXl4fExETk5eVhzJgxrh6HyK1xHU/URjqdznEaMhFdG7cSIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnGhYiINMe4EBGR5hgXIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnez4WojaSUUFUViqJACOHqcYjcGlcuRDeA93Ihahu9qwcg0kJzczPOnj0LVVVdPUq7CSEQFRUFX19fV49CdNMYF/IKRqMRjz76KOLj49Hc3AydTuexq4yCggKsWrUKMTExrh6F6KYxLuQVpJQYPnw40tLSsGzZMsyaNQv33HOPq8e6Kc888wx4KJQ8HeNCXkNVVbz33nvYtGkTzp8/j7S0NISEhLh6rBvCqJC38Mz9BkRXodPpMH/+fAQGBqKgoADffPMNf1kTuQjjQl5l7NixuPvuu2G32/H222+jrq7O1SMRdUqMC3kVvV7vWL0UFhZiy5YtXL0QuQDjQl4nISEB6enpsNvtWLVqFRobG109ElGnw7iQ19Hr9Zg3bx78/PyQm5uLXbt2cfVC1MEYF/I6QggkJycjJSUFVqsV7733Hpqbm109FlGnwriQV/Lz88PcuXOh1+uxY8cOFBYWcvVC1IEYF/JKQgikp6dj+PDhaGhowIcffugVl4Yh8hSMC3mtoKAgPPTQQxBC4Ouvv0ZpaamrRyLqNBgX8lpCCGRmZqJfv364cOECPvnkE+4aI+ogjAt5tcjISMycORMA8Pnnn6OystLFExF1DowLeTUhBP74xz+iW7du+Omnn/Dll19y9ULUARgX8nrR0dGYOnUqpJT46KOPeEkYog7AuJDXUxQF//Ef/4HAwEAcPXoU3333HVcvRE7GuJDXE0IgLi4OkyZNgs1mwwcffACr1erqsYi8GuNCnYKPjw8efvhh+Pj4YM+ePcjJyeHqhciJGBfqFIQQmDRpEhITE9HU1IRVq1bBZrO5eiwir8W4UKdhMBjw2GOPQafTYdu2bcjLy+PqhchJGBfqNIQQuPvuuzF69Gg0NDRg+fLlXL0QOQnjQp1KUFAQ5s+fD51Oh82bN/PYC5GTMC7UqQghMHXqVCQmJsJsNuO1116DxWJx9VhEXodxoU4nMDAQf/rTn+Dr64vt27cjOzubqxcijTEu1OkIITB58mRMnjwZzc3NeOmll3DhwgVXj0XkVRgX6pR8fX2xYMECdOnSBf/4xz+wYsUK3u+FSEOMC3VKQgiMHj0ajzzyCABg5cqV+OGHH7h7jEgjjAt1Woqi4Mknn0RcXBxMJhP+3//7fzCZTK4ei8grMC7UqYWHh+Nvf/sbgoODsX//frzyyit87wuRBhgX6tRaLgvzn//5nxBCYNWqVcjKyuLuMaJ2Ylyo09PpdHj66aeRnp6OxsZGPP3007w0DFE7MS5EAIKDg7FixQoMGjQIFRUVmDNnDkpLSxkYopvEuBDh591j0dHReP/99xEZGYmioiLMmjULZWVlDAzRTWBciC4TQiAlJQUrV65ESEgIcnJy8O///u8oLy9nYIh
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mode = 'auto'\n",
"\n",
"if mode == 'auto':\n",
" # automatic\n",
" model = model.prune_node(threshold=1e-2) # by default the threshold is 1e-2\n",
" model.plot()\n",
"elif mode == 'manual':\n",
" # manual\n",
" model = model.prune_node(active_neurons_id=[[0]])"
]
},
{
"cell_type": "markdown",
"id": "cf7001ab",
"metadata": {},
"source": [
"## Pruning Edges"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b58417be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 16:06:09 -04:00
"| train_loss: 7.84e-02 | test_loss: 7.80e-02 | reg: 7.26e+00 | : 100%|█| 6/6 [00:01<00:00, 3.72it/s\n"
2024-08-11 13:02:16 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"data": {
2024-08-11 16:06:09 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABG+0lEQVR4nO3deXxM5/4H8M+ZrJPJRhJLKZWI2imStLbQIkG1di2KtqqUai2XXu3t5nbTukJRqkWoLor2UpUpfrWr2KmSIILYI/vsc87z+6M95yaaMDM5M2eW7/v1yuvVmsyc7zw5Zz7zPM85z+EYYwyEEEKIjFRKF0AIIcT7ULgQQgiRHYULIYQQ2VG4EEIIkR2FCyGEENlRuBBCCJEdhQshhBDZUbgQQgiRHYULIYQQ2VG4EEIIkR2FCyGEENlRuBBCCJEdhQshhBDZUbgQQgiRHYULIYQQ2fkrXQAhnoAxhtu3b6OsrAyhoaGIiooCx3FKl0WI26KeCyF3UVRUhPnz5yM+Ph4xMTFo1KgRYmJiEB8fj/nz56OoqEjpEglxSxzdiZKQymm1WgwaNAh6vR7An70XkdhrCQkJwfr165GSkqJIjYS4KwoXQiqh1WrRt29fMMYgCEKVv6dSqcBxHDZv3kwBQ0g5FC6E3KGoqAj169eHwWC4a7CIVCoV1Go18vLyEBkZ6fwCCfEANOdCyB3S09Oh1+ttChYAEAQBer0eq1atcnJlhHgO6rkQUg5jDPHx8cjJyYE9hwbHcYiNjcXZs2fpLDJCQOFCSAX5+fmIiYmp1vOjoqJkrIgQz0TDYoSUU1ZWVq3nl5aWylQJIZ6NwoWQckJDQ6v1/LCwMJkqIcSzUbgQUk5UVBTi4uLsnjfhOA5xcXGoWbOmkyojxLNQuBBSDsdxePnllx167uTJk2kyn5C/0IQ+IXeg61wIqT7quRByh8jISKxfvx4cx0GluvshIl6hv2HDBgoWQsqhcCGkEikpKdi8eTPUajU4jvvbcJf4b2q1Gj///DN69eqlUKWEuCcKF0KqkJKSgry8PKSlpSE2NrbCY7GxsUhLS8OVK1coWAipBM25EGIDxhh+/fVXPPbYY9i+fTu6d+9Ok/eE3AX1XAixAcdx0pxKZGQkBQsh90DhQgghRHYULoQQQmRH4UIIIUR2FC6EEEJkR+FCCCFEdhQuhBBCZEfhQgghRHYULoQQQmRH4UIIIUR2FC6EEEJkR+FCCCFEdhQuhBBCZEfhQgghRHYULoQQQmRH4UIIIUR2FC6EEEJkR+FCyD1YLBZcuXIFp0+fBgCcP38eBQUFEARB4coIcV90m2NCqlBUVIT169djzZo1OHXqFEpLS2E2mxEcHIyYmBh06dIFzz//PDp16gR/f3+lyyXErVC4EFKJ/fv3Y8qUKThx4gQSEhLQt29ftG7dGqGhoSgqKsLhw4exadMmnDt3DsOGDcO///1vxMTEKF02IW6DwoWQO/zyyy8YM2YMQkND8cEHH6BPnz4wm8349ttvYTKZEB4ejqeeegoWiwXffvst3n77bbRo0QKrV69G7dq1lS6fELdA4UJIOdnZ2UhNTYVGo8G3336L5s2bg+M45OTkoF27diguLkajRo1w+PBh1KhRA4wx7NmzB8OHD0e3bt3wxRdfICgoSOm3QYjiaEKfkL/wPI/3338fhYWFWLhwoRQsd8NxHDp37ow5c+bgv//9LzIyMlxULSHujcKFkL+cO3cOmzZtwsCBA9G5c+d7BouI4zj0798fDz/8MJYtWwar1erkSglxf3SKCyF/2bdvH8rKyjBo0CDk5uZCp9NJj+Xl5YHneQCA2WzGqVOnEB4eLj1+3333YeDAgXj77bdx/fp11K9f3+X1E+JOKFwI+cuZM2cQEhKC2NhYvPjii9i7d6/0GGMMJpMJAHD16lX07NlTeozjOMydOxetWrWCXq/H1atXKVyIz6NwIeQvBoMB/v7+CAoKgslkgtForPT3GGN/e8xqtUKtVlcIIUJ8GYULIX+pVasWDAYDioqKkJSUBI1GIz1mMBiwb98+KUQ6duwoXTjJcRwaNGiAmzdvwmq14uTJk2jdujUiIiKUeiuEKI5ORSbkL1u2bEH//v2xYMECjB07tsJjOTk5SEhIQHFxMR544AEcOnQIkZGR0uMcx2HWrFmYO3cuAMDPzw9t27ZF165dkZycjE6dOlX4fUK8HZ0tRshfEhMTERsbi/T0dOh0Ovj5+VX4EXEcB5VKJf27SqXCtWvXsG7dOkyYMAGnT5/GZ599hmbNmmHdunUYMGAAateujcTERPzjH//Apk2bUFBQoOA7JcT5KFwI+UtUVBQmTZqEI0eOYMGCBTafUmwymTB79mzo9XoMHz4cUVFRePrpp7F8+XKcP38e2dnZ+Pzzz9GyZUts2LABgwYNQt26ddGhQwdMmzYNGzduxO3bt5387ghxLRoWI6QcnU6H5557Dj///DPeeecdTJgwAcHBwbhw4QISExOlYbHMzExERkaitLQU7733HpYuXYp58+Zh1KhRMBgMMJlMUKlUUKvVCAoKqnDNzMWLF7Fr1y7s3LkTO3fuxMWLFwEArVq1kobROnfujOjoaKWagZBqo3Ah5A63bt3CxIkT8dNPPyElJQVTpkxBs2bNkJWVBUEQEBgYiMaNGyMzMxOffPIJjh07hnfffRcTJkyQhs94nq8QMsHBwQgODq70wsxLly5h165dUuBcuHABANCiRQskJyeja9eu6NKlCy2MSTwKhQshldDpdFi2bBkWLFiAGzduIDY2FvHx8QgLC0NhYSGysrJw9epVtG/fHm+99RaSk5OhUv19lFkQBOj1eptCRpSXlycFza5du3D+/HkAQLNmzZCcnCz1bGiRTOLOKFwIuYvr169j+/bt2LlzJ06dOoXTp0+jbdu2SEpKQq9evZCUlISQkJB7vo4gCDAYDDAajTaHjOjq1asVwubs2bMAgKZNm0rDaF26dEGdOnWq/X4JkQuFCyE2ys7OxtSpUzFv3jzEx8c79BrlQ4bjOKjVaptDRnTt2rUKw2jZ2dkAgCZNmkjDaF27dkXdunUdqpEQOdBFlITYSDztuLLhL1upVCpoNBqo1WoYDAbo9XoYDAa7QqZu3boYNmwYhg0bBuDPsNmzZ4/Us1m2bBkAID4+XurZdO3aFffdd5/DdRNiLwoXQhQgR8iI6tatiyFDhmDIkCEAgBs3bkhhs3PnTnz55ZcAgLi4uAphQ+ufEWeicCFEQXeGjPjjSMiIateujUGDBmHQoEEAgJs3b2LPnj3SMNqKFSsAAI0aNZKCJjk5Gffff7+s7434NgoXQtyAM0JGVKtWLQwcOBADBw4EAOTn51cYRlu5ciUAoGHDhtLZaF27dkXDhg3leGvER1G4EOJGqgqZ4OBgqNXqaoWMKDo6Gv3790f//v0BALdv38bevXulYbTVq1eDMYYGDRpIvZrk5GQ0bNhQlu0T30DhQogbujNkjEYjjEajrCEjioqKwhNPPIEnnngCAFBQUIC9e/dKw2hr1qwBYwz169evMIzWqFEjChtSJQoXQtxY+ZARA8ZZISOqWbMm+vXrh379+gEACgsLsW/fPmkY7ZtvvoEgCLjvvvsqDKPFxcVR2BAJhQshHkClUiEkJATBwcEuCxlRjRo10LdvX/Tt2xcAUFxcXKFn891330EQBNStW7fCMFrjxo0pbHwYhQshHkTJkBFFRESgT58+6NOnDwCgpKQEe/fuxe7du7Fz506sW7cOPM+jTp060gWdycnJaNKkCYWND6FwIcQDiSFT2ZxMcHBwtS70tFd4eDh69+6N3r17AwBKS0uxb98+aRWBV155BTzPo3bt2ujSpYsUNk2bNqWw8WIULoR4MI7j3CZkRGFhYUhJSUFKSgoAoKysDPv375eG0aZOnQqr1YqYmBh06dJFGkZr1qwZhY0XoXAhxAu4Y8iIQkND0bNnT/Ts2RPAnytO79+/XxpGmz59OiwWC6Kjoyv0bJo3b65o3aR6KFwI8SJVhUxQUBDUarVbfFhrNBr06NEDPXr0AADo9Xr89ttv0jDazJkzYTabERUVhc6dO0t
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 22 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from kan import *\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
2024-08-11 16:06:09 -04:00
"model = KAN(width=[2,5,1], grid=5, k=3, seed=1, device=device)\n",
2024-08-11 13:02:16 -04:00
"\n",
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
2024-08-11 16:06:09 -04:00
"dataset = create_dataset(f, n_var=2, device=device)\n",
2024-08-11 13:02:16 -04:00
"dataset['train_input'].shape, dataset['train_label'].shape\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=6, lamb=0.01);\n",
"model(dataset['train_input'])\n",
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4d57cbfe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.2\n"
]
}
],
"source": [
"model.prune_edge()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3a23aed",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 16:06:09 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA8LklEQVR4nO3deVxU5f4H8M8ZkG1AUMB9BXHfUoFcEC0V1Cx3yyX0Zqa5lNrVrnnLm1ctyyuaa5SCZqmpleaVcbkpbom7ZgoKYuKGbLINDMw8vz+K+Um5MHCYMzN83q8Xr5cyC9/5cg6feZ5z5jySEEKAiIhIRiqlCyAiItvDcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHb2ShdAZA2EEEhLS0NOTg5cXV3h6ekJSZKULovIYnHkQvQEmZmZWLZsGfz8/ODt7Y3GjRvD29sbfn5+WLZsGTIzM5UukcgiSVyJkujRNBoNhgwZgry8PAC/j16KFY9aXFxcsH37doSEhChSI5GlYrgQPYJGo0H//v0hhIDBYHjs/VQqFSRJwu7duxkwRA9huBD9SWZmJurVqwetVvvEYCmmUqng7OyM5ORkeHh4VHyBRFaAx1yI/iQqKgp5eXmlChYAMBgMyMvLw4YNGyq4MiLrwZEL0UOEEPDz80NiYiJM2TUkSYKPjw+uXr3Ks8iIwHAhKiE1NRXe3t7lerynp6eMFRFZJ06LET0kJyenXI/Pzs6WqRIi68ZwIXqIq6truR7v5uYmUyVE1o3hQvQQT09P+Pr6mnzcRJIk+Pr6onr16hVUGZF1YbgQPUSSJEydOrVMj502bRoP5hP9gQf0if6En3MhKj+OXIj+xMPDA9u3b4ckSVCpnryLFH9Cf8eOHQwWoocwXIgeISQkBLt374azszMkSfrLdFfx95ydnfHf//4Xffr0UahSIsvEcCF6jJCQECQnJyM8PBw+Pj4lbvPx8UF4eDhu3brFYCF6BB5zISoFIQR++uknPP/88zhw4AB69uzJg/dET8CRC1EpSJJkPKbi4eHBYCF6CoYLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgvRUxQWFuLWrVu4fPkyACAhIQHp6ekwGAwKV0ZkubjMMdFjZGZmYvv27di0aRMuXbqE7Oxs6HQ6ODk5wdvbG0FBQXjttdfQtWtX2NvbK10ukUVhuBA9wvHjxzF9+nRcuHAB/v7+6N+/P9q2bQtXV1dkZmbi9OnT2LVrF65du4YRI0bg3//+N7y9vZUum8hiMFyI/mTv3r0YO3YsXF1dsWjRIvTr1w86nQ6bN29GQUEBqlatipdffhmFhYXYvHkz5s2bh1atWmHjxo2oWbOm0uUTWQSGC9FD4uPjERoaCrVajc2bN6Nly5aQJAmJiYno0KEDHjx4gMaNG+P06dOoVq0ahBA4cuQIRo4ciR49euCLL76Ao6Oj0i+DSHE8oE/0B71ej4ULFyIjIwMrVqwwBsuTSJKEbt26YfHixfjhhx8QHR1tpmqJLBvDhegP165dw65duzB48GB069btqcFSTJIkDBw4EM8++ywiIiJQVFRUwZUSWT6e4kL0h2PHjiEnJwdDhgxBUlIScnNzjbclJydDr9cDAHQ6HS5duoSqVasab69Tpw4GDx6MefPm4e7du6hXr57Z6yeyJAwXoj9cuXIFLi4u8PHxwRtvvIGjR48abxNCoKCgAABw+/Zt9O7d23ibJElYsmQJ2rRpg7y8PNy+fZvhQpUew4XoD1qtFvb29nB0dERBQQHy8/MfeT8hxF9uKyoqgrOzc4kQIqrMGC5Ef6hRowa0Wi0yMzMRGBgItVptvE2r1eLYsWPGEOnSpYvxg5OSJKFBgwZISUlBUVERLl68iLZt28Ld3V2pl0KkOJ6KTPSHPXv2YODAgVi+fDnGjx9f4rbExET4+/vjwYMHaNSoEU6dOgUPDw/j7ZIkYc6cOViyZAkAwM7ODu3bt0f37t0RHByMrl27lrg/ka3j2WJEfwgICICPjw+ioqKQm5sLOzu7El/FJEmCSqUyfl+lUuHOnTvYtm0bpkyZgitXrmD16tVo0aIFtm3bhkGDBqFmzZoICAjA3//+d+zatQvp6ekKvlKiisdwIfqDp6cnpkyZgjNnzmD58uWlPqW4oKAA8+fPh1arxcSJE+Hr64tx48Zh/fr1SEhIQHx8PD7//HO0bt0aO3bswJAhQ1C7dm106tQJM2fOxM6dO5GWllbBr47IvHjMheghY8eORUxMDD7++GO4uLhg0qRJcHJyAgDY29vD3t7eOIoRQiA7OxsLFizA5s2bsXTpUjRr1qzE80mShEaNGqFRo0YICwsDANy4cQMxMTE4dOgQdu7cic8++wwA0KZNG+M0Wrdu3eDl5WXGV04kLx5zIfqT+/fvY/Lkyfjxxx8REhKC6dOno0WLFoiLi4PBYICDgwOaNGmC2NhYfPrppzh37hw+/PBDTJo0qcT0WWn99ttviImJMQbO9evXAQCtWrVCcHAwunfvjqCgIF4Yk6wKw4XoEXJzcxEREYHly5fj3r178PHxgZ+fH9zc3JCRkYG4uDjcvn0bHTt2xAcffIDg4GCoVPLMMicnJxuDJiYmBgkJCQCAFi1aIDg42Diy4UUyyZIxXIie4O7duzhw4AAOHTqES5cu4fLly2jfvj0CAwPRp08fBAYGwsXFpUJruH37domwuXr1KgCgefPmxmm0oKAg1KpVq0LrIDIFw4WolOLj4zFjxgwsXboUfn5+itVx586dEtNo8fHxAICmTZsap9G6d++O2rVrK1YjEQ/oE5VS8WnHck1/lVXt2rUxYsQIjBgxAsDvYXPkyBHjyCYiIgIA4OfnZxzZdO/eHXXq1FGybKpkGC5EVq527doYNmwYhg0bBgC4d++eMWwOHTqEL7/8EgDg6+tbImx4/TOqSAwXIhtTs2ZNDBkyBEOGDAEApKSk4MiRI8ZptPXr1wMAGjdubAya4OBg1K9fX8myycYwXIhsXI0aNTB48GAMHjwYAJCamlpiGi0yMhIA0LBhQ+PZaN27d0fDhg0VrJqsHcOFqJLx8vLCwIEDMXDgQABAWloajh49apxG27hxI4QQaNCggXFUExwcjIYNG5Z6ATUihgtRJefp6YkXX3wRL774IgAgPT0dR48eNU6jbdq0CUII1KtXr8Q0WuPGjRk29FgMFyIqoXr16hgwYAAGDBgAAMjIyMCxY8eM02jffPMNDAYD6tSpU2IazdfXl2FDRgwXInqiatWqoX///ujfvz8A4MGDByVGNlu2bIHBYEDt2rVLTKM1adKEYVOJMVyIyCTu7u7o168f+vXrBwDIysrC0aNHcfjwYRw6dAjbtm2DXq9HrVq1jB/oDA4ORtOmTRk2lQjDhYjKpWrVqujbty/69u0LAMjOzsaxY8eMVxF46623oNfrUbNmTQQFBRnDpnnz5gwbG8ZwISJZubm5ISQkBCEhIQCAnJwcHD9+3DiNNmPGDBQVFcHb2xtBQUHGabQWLVowbGwIw4WIKpSrqyt69+6N3r17A/j9itPHjx83TqO98847KCwshJeXV4mRTcuWLRW/1A6VHcOFiMxKrVajV69e6NWrFwAgLy8PP//8s3Eabfbs2dDpdPD09ES3bt2MZ6O1bt2aYWNFGC5EpCgXFxc899xzeO655wAAWq0WJ06cME6jvfvuu9DpdKhWrZoxbIKDg9GmTRuGjQVjuBCRRXF2dkaPHj3Qo0cPAL+HTWxsrHFk895776GgoAAeHh7o1q2bcRqtbdu2ZVoJlCoGw4WILJqzs7NxtAIA+fn5OHnyJA4dOoTDhw/j/fffR35+Ptzd3Y1h0717d7Rr1w729vwTpxR2noisipOTE4KCghAUFAQAKCgowKlTp4xhM2/ePGi1WlStWhVdu3Y1hs0zzzz
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 22 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "1db74fbd",
"metadata": {},
"source": [
"## Prune nodes and edges together"
]
},
{
"cell_type": "markdown",
"id": "4e7e2c8a",
"metadata": {},
"source": [
"just use model.prune()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1ea08f0e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-08-11 16:06:09 -04:00
"| train_loss: 3.46e-02 | test_loss: 3.46e-02 | reg: 4.91e+00 | : 100%|█| 20/20 [00:05<00:00, 3.70it\n"
2024-08-11 13:02:16 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"data": {
2024-08-11 16:06:09 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAym0lEQVR4nO3deVhUZf8/8PcZkB1lVTM1AdEk9w00TU3TRy1Ncckl9xZTLPNJS7uyckvbQNsUS1FxKTXN3Po+lmguaG6YiqC4gogIgywDzHL//ijm55QLMxzmzAzv13XNdRWz8OHDHN/c933m3JIQQoCIiEhGKqULICIix8NwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIds5KF0BkD4QQuH37NgoKCuDl5QV/f39IkqR0WUQ2iyMXogdQq9WIiYlBaGgoAgMDERQUhMDAQISGhiImJgZqtVrpEolsksSdKInubffu3YiMjERRURGAv0YvZcpGLR4eHti0aRN69eqlSI1EtorhQnQPu3fvRt++fSGEgMFguO/jVCoVJEnC9u3bGTBEd2G4EP2DWq1G3bp1odFoHhgsZVQqFdzd3XH9+nX4+PhUfoFEdoBrLkT/EBcXh6KionIFCwAYDAYUFRVh1apVlVwZkf3gyIXoLkIIhIaGIi0tDeYcGpIkITg4GKmpqTyLjAgMFyIT2dnZCAwMrNDz/f39ZayIyD5xWozoLgUFBRV6fn5+vkyVENk3hgvRXby8vCr0fG9vb5kqIbJvDBeiu/j7+yMkJMTsdRNJkhASEgI/P79KqozIvjBciO4iSRKioqIseu6UKVO4mE/0Ny7oE/0DP+dCVHEcuRD9g4+PDzZt2gRJkqBSPfgQKfuE/ubNmxksRHdhuBDdQ69evbB9+3a4u7tDkqR/TXeVfc3d3R07duxAz549FaqUyDYxXIjuo1evXrh+/Tqio6MRHBxscl9wcDCio6ORnp7OYCG6B665EJWDEAK//fYbunfvjj179qBbt25cvCd6AI5ciMpBkiTjmoqPjw+DheghGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkQPodVqkZ6ejnPnzgEALl68iJycHBgMBoUrI7Jd3OaY6D7UajU2bdqE+Ph4nDlzBvn5+SgtLYWbmxsCAwPRuXNnjB8/Hk8++SScnZ2VLpfIpjBciO7h0KFDmDp1KpKSktCuXTv07dsXzZs3h5eXF9RqNY4dO4Zt27bhwoULGDp0KObOnYvAwEClyyayGQwXon/45ZdfMGbMGHh5eWHBggXo06cPSktLsX79epSUlKB69ep44YUXoNVqsX79erz//vt44oknsHr1atSqVUvp8olsAsOF6C4pKSn4z3/+A09PT6xfvx5hYWGQJAlpaWlo3bo18vLyEBQUhGPHjsHX1xdCCPz+++8YPnw4unbtiuXLl8PV1VXpH4NIcVzQJ/qbXq/H/PnzkZubiy+++MIYLA8iSRI6deqERYsWYevWrdi1a5eVqiWybQwXor9duHAB27Ztw8CBA9GpU6eHBksZSZLw/PPPIyIiArGxsdDpdJVcKZHt4ykuRH87ePAgCgoKEBkZicuXL6OwsNB43/Xr16HX6wEApaWlOHPmDKpXr268v06dOhg4cCDef/99ZGZmom7dulavn8iWMFyI/pacnAwPDw8EBwfjlVdewYEDB4z3CSFQUlICAMjIyMAzzzxjvE+SJHz66ado1qwZioqKkJGRwXChKo/hQvQ3jUYDZ2dnuLq6oqSkBMXFxfd8nBDiX/fpdDq4u7ubhBBRVcZwIfpbzZo1odFooFarER4eDk9PT+N9Go0GBw8eNIZIx44djR+clCQJ9evXR1ZWFnQ6HZKSktC8eXPUqFFDqR+FSHEMF6K/tWnTBlqtFkeOHMHChQtN7ktLS0O7du2Ql5eHWrVqYcOGDfDx8THeL0kSZs6cCZ1Oh8mTJ+P1119Hq1at0KVLF3Tp0gWdO3c2eTyRo+PZYkR/a9++PYKDgxEXF4fCwkI4OTmZ3MpIkgSVSmX8ukqlwo0bN7Bx40ZMnjwZqampWLp0KcLCwvDDDz+gf//+8Pf3R5s2bTBt2jRs3boVOTk5Cv6kRJWP4UL0N39/f0yePBnHjx/H4sWLy31KcUlJCebMmQONRoNXX30VDRs2xPjx4xEXF4fLly/j4sWLWL58OZo1a4ZNmzZhwIABCAwMRKtWrTB16lRs2bIFt2/fruSfjsi6OC1GdJcxY8Zg3759WLhwITw8PDBx4kS4ubkBAJydneHs7GwcxQghkJ+fj3nz5mH9+vX4/PPP0bhxY5PXkyQJQUFBCAoKwtixYwEAly9fRkJCAhISErB161bExMQAAJo3b26cRnvqqacQEBBgxZ+cSF68/AvRP9y6dQuTJk3Czz//jF69emHq1Klo0qQJzp8/D4PBABcXFzRs2BBHjhzBJ598gpMnT+LDDz/ExIkTTabPyuvKlStISEjAvn37sHfvXqSlpQEAmjZtiq5du+Kpp55Cly5deGFMsisMF6J7KCwsRGxsLBYvXoybN28iODgYoaGh8Pb2Rm5uLs6fP4+MjAy0adMGs2fPRpcuXaBSyTPLfO3aNePIJiEhARcuXAAAhIWFoWvXrsaRDS+SSbaM4UL0AJmZmdizZw8SEhKQlpaG4uJi+Pr6omnTpujZsyfCw8Ph4eFRqTWkp6ebhE1KSgoAoEmTJsZptC5duqB27dqVWgeRORguROWk1+shhIBKpZJtlGKJjIwMk2m08+fPAwAaN25sMo1Wp04dxWokYrgQ2bkbN25g3759xpFN2XbMjRo1MhnZPProowpXSlUJw4XIwdy8edMYNnv37sXZs2cBAA0bNjQJm3r16ilcKTkyhguRg8vKysK+ffuM02h//vknACA4ONg4jda1a1fUr19f4UrJkTBciKqY7Oxsk2m0pKQkAECDBg2MZ6N16dIFDRo0ULZQsmsMF6Iq7vbt29i/f79xGi0pKQlCCDz22GPGoOnatSsaNGhQ7g3UiBguRGQiJycH+/fvN06jnTx5EkII1KtXz2QaLTg4mGFD98VwIaIHys3Nxe+//26cRjtx4gQMBgMeffRRk2m0hg0bMmzIiOFCRGbJy8szhs3evXtx/PhxGAwG1KlTx2QaLTQ0lGFThTFciKhC7ty5g99//904jXbs2DHo9XrUrl3bZBqtcePGDJsqhOFCRLLKz8/HgQMHjNNoR48ehV6vR61atYzXRevatSuaNGnCsHFgDBciqlQFBQU4ePCgcRrt6NGj0Ol0CAwMNJlGCwsLY9g4EIYLEVlVYWEhDh48aJxGO3LkCLRaLQICAkxGNk888YSi13CjimG4EJGiioqKcOjQIeM0WmJiIkpLS+Hv72+8CGeXLl3QrFkzho0dYbgQkU3RaDQ4fPiwcRrt8OHDKC0tha+vrzFsunbtiubNmzNsbBjDhYhsmkajQWJionFkc+jQIZSUlMDHxwdPPfWUcRqtRYsWFu0ESpWD4UJEdqW4uBhHjhzB3r17sW/fPhw8eBDFxcWoUaMGOnfubJxGa9myJZydnZUut8piuBCRXSspKcHRo0eNYXPgwAFoNBpUr14dnTp1MoZN69atGTZWxHAhIodSWlqKo0ePGqfRDhw4gKKiInh7e6NTp07GabTWrVujWrVqSpfrsBguROTQSktLcezYMePI5vfff0dhYSE8PT1NRjZt27Zl2MiI4UJEVYpWq8Xx48eNYbN//34UFBTA09MTHTt2NIZNu3bt4OLionS5dovhQkRVmk6nw/Hjx43TaPv370d+fj7c3d3x5JNPGqfR2rVrB1dXV6XLtRsMFyKiu+h0Opw8edI4stm3bx/u3LkDNzc3k5FNeHg4w+YBGC5ERA+g1+tx6tQp7N271ziyUavVcHNzQ0REhDF
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 22 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from kan import *\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
2024-08-11 16:06:09 -04:00
"model = KAN(width=[2,5,1], grid=5, k=3, seed=1, device=device)\n",
2024-08-11 13:02:16 -04:00
"\n",
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
2024-08-11 16:06:09 -04:00
"dataset = create_dataset(f, n_var=2, device=device)\n",
2024-08-11 13:02:16 -04:00
"dataset['train_input'].shape, dataset['train_label'].shape\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20, lamb=0.01);\n",
"model(dataset['train_input'])\n",
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "4fc161de",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.2\n"
]
},
{
"data": {
2024-08-11 16:06:09 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAuMklEQVR4nO3de1xU9b7/8fd3zXAZ7orgJUUF8ZqKImiAikVg+dhuwzp5bO9ztDyZXTz16Ozfbqe7tKx2Vy+lXaxHxzqV7iN2wWuPzDsaF/GyUdBEcw8giDIIDMwws76/P5I5UqYoa5gL7+fj4T8OAx+UxWu+a61ZS0gpJYiIiDSkuHoAIiLyPowLERFpjnEhIiLNMS5ERKQ5xoWIiDTHuBARkeYYFyIi0hzjQkREmmNciIhIc4wLERFpjnEhIiLNMS5ERKQ5xoWIiDTHuBARkeYYFyIi0pze1QMQeQIpJS5cuID6+noEBQUhPDwcQghXj0XktrhyIboGk8mE5cuXIzY2FhEREejfvz8iIiIQGxuL5cuXw2QyuXpEIrckeCdKoqvbtm0bpk+fDrPZDODn1UuLllVLQEAAsrKykJGR4ZIZidwV40J0Fdu2bcOUKVMgpYSqqr/5cYqiQAiBTZs2MTBEV2BciH7BZDKhd+/eaGxsvGZYWiiKAoPBAKPRiLCwMOcPSOQBeMyF6BfWrFkDs9ncprAAgKqqMJvN+OSTT5w8GZHn4MqF6ApSSsTGxqK0tBQ3smkIIRAdHY2TJ0/yLDIiMC5ErVRXVyMiIqJdzw8PD9dwIiLPxN1iRFeor69v1/Pr6uo0moTIszEuRFcICgpq1/ODg4M1moTIszEuRFcIDw9HTEzMDR83EUIgJiYGXbt2ddJkRJ6FcSG6ghACTzzxxE09d/78+TyYT3QZD+gT/QLf50LUfly5EP1CWFgYsrKyIISAolx7E2l5h/6GDRsYFqIrMC5EV5GRkYFNmzbBYDBACPGr3V0tf2cwGLB582akp6e7aFIi98S4EP2GjIwMGI1GLFu2DNHR0a0ei46OxrJly1BWVsawEF0Fj7kQtYGUEjt27MAdd9yB7du3Y9KkSTx4T3QNXLkQtYEQwnFMJSwsjGEhug7GhYiINMe4EBGR5hgXIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnGhYiINMe4EBGR5hgXIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnGheg6mpubUVZWhuPHjwMATp06hYsXL0JVVRdPRuS+eJtjot9gMpmQlZWFzz77DEVFRairq4PVaoW/vz8iIiIwfvx4PPTQQ0hOToZer3f1uERuhXEhuor9+/fjqaeewpEjR5CQkIApU6ZgxIgRCAoKgslkQkFBAbKzs/Hjjz/i/vvvx5IlSxAREeHqsYncBuNC9AvffvstZs2ahaCgILzyyiu4++67YbVasXbtWlgsFoSEhGDGjBlobm7G2rVrsWjRIgwbNgyffvopunfv7urxidwC40J0hRMnTmDy5MkIDAzE2rVrMXToUAghUFpaitGjR6O2thb9+/dHQUEBunTpAikl9u7di5kzZyI1NRUffvgh/Pz8XP1tELkcD+gTXWa32/Hyyy+jpqYG77zzjiMs1yKEQEpKCl577TV8/fXX2Lp1awdNS+TeGBeiy3788UdkZ2cjMzMTKSkp1w1LCyEEpk2bhnHjxmH16tWw2WxOnpTI/fEUF6LLcnJyUF9fj+nTp+PMmTNoaGhwPGY0GmG32wEAVqsVRUVFCAkJcTzeq1cvZGZmYtGiRTh37hx69+7d4fMTuRPGheiy4uJiBAQEIDo6GnPnzsW+ffscj0kpYbFYAADl5eW48847HY8JIfDmm29i+PDhMJvNKC8vZ1yo02NciC5rbGyEXq+Hn58fLBYLmpqarvpxUspfPWaz2WAwGFpFiKgzY1yILouMjERjYyNMJhPGjh2LwMBAx2ONjY3IyclxRCQpKcnxxkkhBKKiolBVVQVFUdClSxdXfQtEboNxIbosPj4ezc3NyM3NxauvvtrqsdLSUiQkJKC2thbdu3fHunXrEBYW5nhcCIFnn30WPXr04C4xIvBsMSKHxMREREdHY82aNWhoaIBOp2v1p4UQAoqiOP5eURRUVFRg/fr1mDJlCkJDQ134XRC5B8aF6LLw8HA8/vjjOHjwIFasWNHmU4otFgtefPFFNDY2Yu7cuW0+hZnIm3G3GNEVZs2ahd27d+PVV19FQEAA5s2bB39/fwCAXq+HXq93rGKklKirq8NLL72EtWvXYunSpRg0aJArxydyG7z8C9EvnD9/Ho899hg2btyIjIwMPPXUUxgyZAhKSkqgqip8fX0xYMAA5Obm4o033sChQ4fwwgsvYN68ea12nxF1ZowL0VU0NDRg9erVWLFiBSorKxEdHY3Y2FgEBwejpqYGJSUlKC8vR3x8PJ5//nlMnDgRisK9zEQtGBeiazh37hy2b9+OXbt24fDhw8jNzcX48eORnJyM9PR0jB07FgEBAa4ek8jtMC5EbZSXl4fExETk5eVhzJgxrh6HyK1xHU/URjqdznEaMhFdG7cSIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnGhYiINMe4EBGR5hgXIiLSHONCRESaY1yIiEhzjAsREWmOcSEiIs0xLkREpDnez4WojaSUUFUViqJACOHqcYjcGlcuRDeA93Ihahu9qwcg0kJzczPOnj0LVVVdPUq7CSEQFRUFX19fV49CdNMYF/IKRqMRjz76KOLj49Hc3AydTuexq4yCggKsWrUKMTExrh6F6KYxLuQVpJQYPnw40tLSsGzZMsyaNQv33HOPq8e6Kc888wx4KJQ8HeNCXkNVVbz33nvYtGkTzp8/j7S0NISEhLh6rBvCqJC38Mz9BkRXodPpMH/+fAQGBqKgoADffPMNf1kTuQjjQl5l7NixuPvuu2G32/H222+jrq7O1SMRdUqMC3kVvV7vWL0UFhZiy5YtXL0QuQDjQl4nISEB6enpsNvtWLVqFRobG109ElGnw7iQ19Hr9Zg3bx78/PyQm5uLXbt2cfVC1MEYF/I6QggkJycjJSUFVqsV7733Hpqbm109FlGnwriQV/Lz88PcuXOh1+uxY8cOFBYWcvVC1IEYF/JKQgikp6dj+PDhaGhowIcffugVl4Yh8hSMC3mtoKAgPPTQQxBC4Ouvv0ZpaamrRyLqNBgX8lpCCGRmZqJfv364cOECPvnkE+4aI+ogjAt5tcjISMycORMA8Pnnn6OystLFExF1DowLeTUhBP74xz+iW7du+Omnn/Dll19y9ULUARgX8nrR0dGYOnUqpJT46KOPeEkYog7AuJDXUxQF//Ef/4HAwEAcPXoU3333HVcvRE7GuJDXE0IgLi4OkyZNgs1mwwcffACr1erqsYi8GuNCnYKPjw8efvhh+Pj4YM+ePcjJyeHqhciJGBfqFIQQmDRpEhITE9HU1IRVq1bBZrO5eiwir8W4UKdhMBjw2GOPQafTYdu2bcjLy+PqhchJGBfqNIQQuPvuuzF69Gg0NDRg+fLlXL0QOQnjQp1KUFAQ5s+fD51Oh82bN/PYC5GTMC7UqQghMHXqVCQmJsJsNuO1116DxWJx9VhEXodxoU4nMDAQf/rTn+Dr64vt27cjOzubqxcijTEu1OkIITB58mRMnjwZzc3NeOmll3DhwgVXj0XkVRgX6pR8fX2xYMECdOnSBf/4xz+wYsUK3u+FSEOMC3VKQgiMHj0ajzzyCABg5cqV+OGHH7h7jEgjjAt1Woqi4Mknn0RcXBxMJhP+3//7fzCZTK4ei8grMC7UqYWHh+Nvf/sbgoODsX//frzyyit87wuRBhgX6tRaLgvzn//5nxBCYNWqVcjKyuLuMaJ2Ylyo09PpdHj66aeRnp6OxsZGPP3007w0DFE7MS5EAIKDg7FixQoMGjQIFRUVmDNnDkpLSxkYopvEuBDh591j0dHReP/99xEZGYmioiLMmjULZWVlDAzRTWBciC4TQiAlJQUrV65ESEgIcnJy8O///u8oLy9nYIh
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = model.prune()\n",
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a8dd8a8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-08-11 16:06:09 -04:00
"version": "3.9.16"
2024-08-11 13:02:16 -04:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}