295 lines
93 KiB
Plaintext
Raw Normal View History

2024-04-29 12:35:18 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Demo 3: Grid"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"One important feature of KANs is that they embed splines to neural networks. However, splines are only valid for approximating functions in known bounded regions, while the range of activations in neural networks may be changing over training. So we have to update grids properly according to that. Let's first take a look at how we parametrize splines. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'B_i(x)')"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAG2CAYAAACDLKdOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAD95UlEQVR4nOydeXwU9f3/n7Nn7oRcJOSEcBMgQLgPoWoEqtSeVlFri1hq66/n11btXWv9fuu3xeOrrQe1WitotdaqRRHlkiNcEQLhDDnIRe5kk+w9vz8msyQhGzbJ7s7uZp6PxzySbGbm895Nduc97+sliKIooqKioqKioqISImiUNkBFRUVFRUVFxZuozo2KioqKiopKSKE6NyoqKioqKiohhercqKioqKioqIQUqnOjoqKioqKiElKozo2KioqKiopKSKE6NyoqKioqKiohhercqKioqKioqIQUqnOjoqKioqKiElKozo2KioqKiopKSBEQzs3TTz/N2LFjCQsLY86cOezevdvtvnfddReCIFyxTZs2zY8Wq6ioqKioqAQqijs3W7Zs4Xvf+x4PPfQQR48eZenSpaxatYqKiop+93/88cepqalxbZWVlcTHx/PlL3/Zz5arqKioqKioBCKC0sKZ8+fPZ/bs2TzzzDOux6ZMmcLNN9/M7373u6se/9Zbb/GFL3yBCxcukJWV5UtTVVRUVFRUVIIAnZKLW61WDh8+zE9+8pNejxcUFLB3716PzvHCCy9w3XXXDejYWCwWLBaL62en00lTUxMJCQkIgjA041VUVFRUVFT8iiiKtLe3M2bMGDQa98knRZ2bhoYGHA4Ho0eP7vX46NGjqa2tverxNTU1/Oc//+Hvf//7gPv97ne/41e/+tWwbFVRUVFRUVEJDCorK0lPT3f7e0WdG5m+0RNRFD2KqLz44ovExcVx8803D7jfAw88wA9+8APXz62trWRmZlJZWUlMTMyQbFZRUVFRUVHxL21tbWRkZBAdHT3gfoo6N4mJiWi12iuiNJcuXboimtMXURTZtGkTd9xxBwaDYcB9jUYjRqPxisdjYmJU50ZFRUVFRSXIuFoARNFuKYPBwJw5c9i2bVuvx7dt28aiRYsGPHbnzp2cO3eOdevW+dJEFRUVFRUVlSBD8bTUD37wA+644w7y8/NZuHAhzz77LBUVFWzYsAGQUkpVVVW89NJLvY574YUXmD9/Prm5uUqYraKioqKiohKgKO7c3HLLLTQ2NvLrX/+ampoacnNzee+991zdTzU1NVfMvGltbeWNN97g8ccfV8JkFRUVFRUVlQBG8Tk3StDW1kZsbCytra1qzY2KioqKikqQ4On1W/EJxSoqKioqKioq3kR1blRUVFRUVFRCCtW5UVFRUVFRUQkpVOdGRUVFRUVFJaRQnRsVFRUVFRWVkEJ1blRUVFRUVFRCCtW5UVFRUVFRUQkpVOdGRUVFRUVFJaRQnRtv0nAO2uvAYVfakpBFFEVaLa00dDVgtpuVNiekcZhM2C5dwtnVpbQpIY3N6qCj1YKlS/3c8CV2m42OlmY621oZgbNr/YboEHG0W7HVdypqh+LyCyHFC9dBVzNodDBmFkxaDbPugKgkpS0LauxOO9srtvOfC//hYO1B2qxtAGgEDWNjxnJNxjV8fvznyY7NVtbQIEcURToPHKD1rX/RsX8/9tpa1+/0Y8YQuXQpcV/8AuEzZihoZWjQcLGdkk9qKD/RSOuly85jeLSe9MnxTJw3mqxpCQiagZWPVQamreESxR9v49zB/TRUliM6nQAYwsNJn5LLxAVLmLRoGTq9XmFLgxtnp42OI5foOt6AtcoEdul1TvvtYgStMjEUVX7BW/ILTic8NgE6G4EeL6kuDBZ+B5b+AAyR3llrBPFxxcf8/tDvqWyvHHA/jaDhxnE38oM5PyAhPMFP1oUOXceLqXvkEbqOHu39C0GAPh8RkYsWMvqnP8U4bpwfLQwN2hq7+OT1c5QW1ff+hUCvjw2AhLRIlt4ykbSJo/xmX6hg6exgz+aXOPbh+zivEkmPik9g2W13MXnJcgRBdSYHg+hw0r6rivYdlYgWx+VfCCAYtaTePxdNhHcdR0+v36pz421tKacDWiuhdAcc/itUH5EeT5gAX9oEqepdryeY7WZ+s/83vH3+bQDijHF8eeKXWZGxggmjJmDUGmnoauDwpcO8fe5tdlftBmCUcRQPL3mYZenLlDQ/aBCdThr+72ka/vQncDgQjEZib76ZmFWrCMudhjYqCkdLC13HjtH27ru0vvse2O2g1zP6gZ8w6tZb1QuCh5w+UMvOv5/GZnGAADmzkpg4L4XUnFjCovTYLA4aLpooPVrPyU+qsZmli8XMazNY9IUcNArdAQcbF0uKefeJ32NqagQgY+p0cldcT/rU6UQnJOKw22mqquT84QMc2/YfTM1NAOTkz2fVt3+AMUK9CfUEe0MXjX8vwVbdAYBudARR81IwThyFLiHcZ1FH1bkZAL8JZ4oilPwb/nM/tNeAPgK+/CJMvMF3a4YAdR113PfRfZQ0laAVtNw57U42zNhAhD7C7THH64/zy32/5EzzGQQEfpj/Q+6ceqd64R0AZ2cn1T/+Ce3btgEQs3o1yT/5MfrkZLfHWCsrqf3Nb+jYJTmTsV/8Aqm//CWCGtZ3i+gU2f+vUo68Xw5Aak4s16ydRMKYKLfHmDts7H/rPCd2VwOQNjGOld+cTlik+joPRPGOD9n27FM4HXbiUlK5fv13yMyd6XZ/u83GoX+/yf43N+Ow2RiVOoab7/8F8WPS/Gh18GEpbaHxbyU4O+1oInTE3jiOiLxkv6RRVedmAPyuCt7ZBG/cDee3g6CFz/0f5N3q+3WDkBpTDd94/xtcNF0kPiye3y/7PfNS53l0rM1h45HCR/jHmX8AsC53Hd+d/V3VwekHZ0cHFfd8k67DhxH0elJ+82vibr7Zo2NFUaRp0yYu/e8fwOkk+vrrSfvfxxAMBt8aHYSITpGP/naKU3trAJizKot5N41D4+FFoPRoPR++eBKbxUFiRhSf++4swqJUB6c/Dr3zT3a+/AIAExcsYeW3voc+LMyjY2vPn+Xt/32E9sZ6IuNG8eWfPUJCeoYvzQ1auk430fjSSXCI6DOiSbxjCtoYo9/WV52bAfC7cwPgsMG/vwtFr4CggS//Faau8c/aQcKlzkvc+Z87qTJVkRaVxgs3vEBa1ODuoERR5KWTL/HYoccAuDfvXr4181u+MDdocZrNVH5zA50HDqCJiSHjT88QMXv2oM/T/vHHVP2/7yLabJKDs/GPCFqtDywOTkRRZNerZyjeVYWgEbj2zslMWpA66PM0Vpn418ajdLXbSEiP4vM/nI0xXO0F6UnR+++yfdMzAMy7+cssueUOBM3g0nidrS28/vBPaagoIyI2jq/+6r8ZlapGcHpiPtdMw4snwC4SPi2B+K9OQtD79z3v6fVbTeL6C61eitjMugNEJ7yxDsr2KG1VwNBp6+S+j+6jylRFRnQGL658cdCODYAgCHxt2tf4r/z/AuDpoqd58+yb3jY3aBFFkeqfPCA5NpGRZD737JAcG4DoFStIf/r/EAwG2rdt49L//N7L1gY3h/9TTvGuKhDgurumDMmxAUhIi+Lm788mPMZA40UT7z9XjNPh9LK1wcu5g/t7OzZfvXPQjg1ARGwcX/7Zb0nKHkdnawv//O9f0WVq97a5QYuttoPGl0rALhI2NYH42yb73bEZDKpz408EAW7cCFNuAocVXvsatFYpbZXiiKLIg3se5GTjSUYZR/Hn6/5MSmTKsM5557Q7XRGb3+z/DUcvHb3KESODxj/9ifatW0GvJ/3ppwmf6b4ewROili5lzKO/A6Dpr3+lefNmb5gZ9JQW1XPg7VIArrl1EhPnDe//OX5MJDd9ZyY6g4bKk03sfu2sN8wMehoqynjvqf8FYOb1qyXHZhhp6IiYWL74wK+ISUqmuaaat//3tzjs6vwhR4eNhpdOIlodGMfFknDbZMVavD0lsK0LRbQ6+PyzMHo6dDbAa3eA3aK0VYry0smX2F6xHb1Gz+OfeZyMGO/kur8181s
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from kan.spline import B_batch\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# consider a 1D example.\n",
"# Suppose we have grid in [-1,1] with G intervals, spline order k\n",
2024-07-12 01:11:15 +08:00
"G = 11\n",
2024-04-29 12:35:18 -04:00
"k = 3\n",
"grid = torch.linspace(-1,1,steps=G+1)[None,:]\n",
"\n",
"# and we have sample range in [-1,1]\n",
"x = torch.linspace(-1,1,steps=1001)[None,:]\n",
"\n",
"basis = B_batch(x, grid, k=k)\n",
"\n",
2024-07-12 01:11:15 +08:00
"for i in range(G-k):\n",
" plt.plot(x[0].detach().numpy(), basis[0,:,i].detach().numpy())\n",
2024-04-29 12:35:18 -04:00
" \n",
2024-07-12 01:11:15 +08:00
"plt.legend(['B_{}(x)'.format(i) for i in np.arange(G-k)])\n",
2024-04-29 12:35:18 -04:00
"plt.xlabel('x')\n",
"plt.ylabel('B_i(x)')"
]
},
{
"cell_type": "markdown",
"id": "75af662c",
"metadata": {},
"source": [
"There are $G+k$ B-spline basis. The function is a linear combination of these bases $${\\rm spline}(x)=\\sum_{i=0}^{G+k-1} c_i B_i(x).$$ We don't need worry about the implementation since it's already built in KAN. But let's check if KAN is indeed implementing this. We initialize a [1,1] KAN, which is simply a 1D spline."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4369a310",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.1382, grad_fn=<MeanBackward0>)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from kan import KAN\n",
"\n",
"model = KAN(width=[1,1], grid=G, k=k)\n",
"# obtain coefficients c_i\n",
"model.act_fun[0].coef\n",
"assert(model.act_fun[0].coef[0].shape[0] == G+k)\n",
"\n",
"# the model forward\n",
"model_output = model(x[0][:,None])\n",
"\n",
"# spline output\n",
"spline_output = torch.einsum('i,ij->j',model.act_fun[0].coef[0], basis[0])[:,None]\n",
"\n",
"torch.mean((model_output - spline_output)**2)"
]
},
{
"cell_type": "markdown",
"id": "82150587",
"metadata": {},
"source": [
"They are not the same, what's happening? We want to remind that we model the activation function to have two additive parts, a residual function $b$(x) plus the spline function, i.e., $$\\phi(x)={\\rm scale\\_base}*b(x)+{\\rm scale\\_sp}*{\\rm spline}(x),$$ and by default $b(x)={\\rm silu}(x)=x/(1+e^{-x})$."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7d76a3c4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0., grad_fn=<MeanBackward0>)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# residual output\n",
"residual_output = torch.nn.SiLU()(x[0][:,None])\n",
"scale_base = model.act_fun[0].scale_base\n",
"scale_sp = model.act_fun[0].scale_sp\n",
"torch.mean((model_output - (scale_base * residual_output + scale_sp * spline_output))**2)"
]
},
{
"cell_type": "markdown",
"id": "3d72e076",
"metadata": {},
"source": [
"What if my grid does not match my data? For example, my grid is in [-1,1], but my data is in [10,10] or [-0.5,0.5]. Use update_grid_from_sample to adjust grids to samples. This grid update applies to all splines in all layers."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "46717e8b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])\n",
"Parameter containing:\n",
"tensor([[-10.0100, -6.0060, -2.0020, 2.0020, 6.0060, 10.0100]])\n"
]
}
],
"source": [
"model = KAN(width=[1,1], grid=G, k=k)\n",
"print(model.act_fun[0].grid) # by default, the grid is in [-1,1]\n",
"x = torch.linspace(-10,10,steps = 1001)[:,None]\n",
"model.update_grid_from_samples(x)\n",
"print(model.act_fun[0].grid) # now the grid becomes in [-10,10]. We add a 0.01 margin in case x have zero variance"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "de04db15",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])\n",
"Parameter containing:\n",
"tensor([[-0.5100, -0.3060, -0.1020, 0.1020, 0.3060, 0.5100]])\n"
]
}
],
"source": [
"model = KAN(width=[1,1], grid=G, k=k)\n",
"print(model.act_fun[0].grid) # by default, the grid is in [-1,1]\n",
"x = torch.linspace(-0.5,0.5,steps = 1001)[:,None]\n",
"model.update_grid_from_samples(x)\n",
"print(model.act_fun[0].grid) # now the grid becomes in [-10,10]. We add a 0.01 margin in case x have zero variance"
]
},
{
"cell_type": "markdown",
"id": "e418ca2c",
"metadata": {},
"source": [
"Uniform grid or non-uniform? We consider two options: (1) uniform grid; (2) adaptive grid (based on sample distribution) such that there are (rougly) same number of samples in each interval. We provide a parameter grid_eps to interpolate between these two regimes. grid_eps = 1 gives (1), and grid_eps = 0 gives (0). By default we set grid_eps = 1 (uniform grid). There could be other options but it is out of our scope here."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d2c4f636",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])\n",
"Parameter containing:\n",
"tensor([[-3.4896, -2.1218, -0.7541, 0.6137, 1.9815, 3.3493]])\n"
]
}
],
"source": [
"# uniform grid\n",
"model = KAN(width=[1,1], grid=G, k=k)\n",
"print(model.act_fun[0].grid) # by default, the grid is in [-1,1]\n",
"x = torch.normal(0,1,size=(1000,1))\n",
"model.update_grid_from_samples(x)\n",
"print(model.act_fun[0].grid) # now the grid becomes in [-10,10]. We add a 0.01 margin in case x have zero variance"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b9b354c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])\n",
"Parameter containing:\n",
"tensor([[-3.4796, -0.8529, -0.2272, 0.2667, 0.8940, 3.3393]])\n"
]
}
],
"source": [
"# adaptive grid based on sample distribution\n",
"model = KAN(width=[1,1], grid=G, k=k, grid_eps = 0.)\n",
"print(model.act_fun[0].grid) # by default, the grid is in [-1,1]\n",
"x = torch.normal(0,1,size=(1000,1))\n",
"model.update_grid_from_samples(x)\n",
"print(model.act_fun[0].grid) # now the grid becomes in [-10,10]. We add a 0.01 margin in case x have zero variance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7b8f994",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}