2024-04-29 12:35:18 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "5d904dee",
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"# Example 10: Relativitistic Velocity Addition"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "markdown",
"id": "6465ec94",
"metadata": {},
"source": [
"In this example, we will symbolically regress $f(u,v)=\\frac{u+v}{1+uv}$. In relavitity, we know the rapidity trick $f(u,v)={\\rm tanh}({\\rm arctanh}\\ u+{\\rm arctanh}\\ v)$. Can we rediscover rapidity trick with KAN?"
]
},
{
"cell_type": "markdown",
"id": "94056ef6",
"metadata": {},
"source": [
"Intialize model and create dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0a59179d",
"metadata": {},
"outputs": [],
"source": [
2024-07-13 22:17:48 -04:00
"from kan import *\n",
2024-04-29 12:35:18 -04:00
"\n",
"# initialize KAN with G=3\n",
"model = KAN(width=[2,1,1], grid=10, k=3)\n",
"\n",
"# create dataset\n",
"f = lambda x: (x[:,[0]]+x[:,[1]])/(1+x[:,[0]]*x[:,[1]])\n",
"dataset = create_dataset(f, n_var=2, ranges=[-0.9,0.9])"
]
},
{
"cell_type": "markdown",
"id": "cb1f817e",
"metadata": {},
"source": [
"Train KAN and plot"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a87b97b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 6.27e-04 | test loss: 5.79e-04 | reg: 5.51e+00 : 100%|██| 20/20 [00:04<00:00, 4.76it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f1cfc9d",
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAngklEQVR4nO3deViU9f7G8fs7MyzDJsIhU1FxkJPmVoo7JiQ5FZUGHbXUNjMzzcSlxdxatFxKNGxRW7SNSkhDFDyHEHE5GWiS/kwx3ABJCUaBYZ/n98dx5gJz54Fnlvt1Xf3jMPIB+857nnWEJEkSiIiIZKRSegAiIrI/jAsREcmOcSEiItkxLkREJDvGhYiIZMe4EBGR7BgXIiKSHeNCRESyY1yIiEh2jAsREcmOcSEiItkxLkREJDvGhYiIZMe4EBGR7BgXIiKSnUbpAYhsgSRJ+Ouvv1BWVgYPDw/4+vpCCKH0WERWi1suRFdhMBiwYsUKBAUFwc/PDx07doSfnx+CgoKwYsUKGAwGpUckskqCn0RJdHkpKSmIioqC0WgE8L+tFzPzVoubmxvi4+Oh1+sVmZHIWjEuRJeRkpKCiIgISJIEk8l0xa9TqVQQQiApKYmBIaqHcSG6hMFggL+/PyoqKq4aFjOVSgWtVou8vDx4e3s3/YBENoDHXIgusW7dOhiNxusKCwCYTCYYjUasX7++iScjsh3cciGqR5IkBAUFITc3FzeyNIQQ0Ol0yMnJ4VlkRGBciBooKiqCn59fo57v6+sr40REtom7xYjqKSsra9TzS0tLZZqEyLYxLkT1eHh4NOr5np6eMk1CZNsYF6J6fH19ERgYeMPHTYQQCAwMhI+PTxNNRmRbGBeieoQQeOGFF27quVOnTuXBfKKLeECf6BK8zoWo8bjlQnQJb29vxMfHQwgBlerqS8R8hX5CQgLDQlQP40J0GXq9HklJSdBqtRBC/G13l/nPtFottmzZgmHDhik0KZF1YlyIrkCv1yMvLw8xMTHQ6XQNHtPpdIiJiUF+fj7DQnQZPOZCdB0kSUJaWhqGDh2K1NRUhIWF8eA90VVwy4XoOgghLMdUvL29GRaia2BciIhIdowLERHJjnEhIiLZMS5ERCQ7xoWIiGTHuBARkewYFyIikh3jQkREsmNciIhIdowLERHJjnEhIiLZMS5ERCQ7xoWIiGTHuBARkewYFyIikh3jQkREsmNciK6hpqYG+fn5OHz4MADgjz/+QHFxMUwmk8KTEVkvfswx0RUYDAbEx8fjq6++wqFDh1BaWorq6mq4urrCz88PgwcPxvjx4zFo0CBoNBqlxyWyKowL0WXs2bMH0dHRyM7ORp8+fRAREYEePXrAw8MDBoMBWVlZSExMxLFjxzBq1Ci89dZb8PPzU3psIqvBuBBdYtu2bXjyySfh4eGBt99+G/fffz+qq6sRFxeHqqoqeHl5YfTo0aipqUFcXBwWLFiArl274osvvkCrVq2UHp/IKjAuRPUcPXoU9957L9zd3REXF4fbb78dQgjk5uaiV69eOH/+PDp27IisrCy0bNkSkiRh586deOyxxxAaGoq1a9fCxcVF6R+DSHE8oE90UV1dHRYtWoSSkhLExsZawnI1QgiEhIRgyZIl2LRpE5KTk5tpWiLrxrgQXXTs2DEkJiYiMjISISEh1wyLmRACI0aMQP/+/bFmzRrU1tY28aRE1o+nuBBdtHv3bpSVlSEqKgonTpxAeXm55bG8vDzU1dUBAKqrq3Ho0CF4eXlZHm/Tpg0iIyOxYMECFBYWwt/fv9nnJ7ImjAvRRb///jvc3Nyg0+kwceJE7Nq1y/KYJEmoqqoCABQUFOCee+6xPCaEwLvvvovu3bvDaDSioKCAcSGHx7gQXVRRUQGNRgMXFxdUVVWhsrLysl8nSdLfHqutrYVWq20QISJHxrgQXXTLLbegoqICBoMB/fr1g7u7u+WxiooK7N692xKRgQMHWi6cFEKgffv2OHv2LFQqFVq2bKnUj0BkNRgXoot69+6Nmpoa7N27F4sXL27wWG5uLvr06YPz58+jVatW+Pbbb+Ht7W15XAiB2bNn49Zbb+UuMSLwbDEii759+0Kn02HdunUoLy+HWq1u8J+ZEAIqlcry5yqVCmfOnMGGDRsQERGBFi1aKPhTEFkHxoXoIl9fX0yZMgX79u3DypUrr/uU4qqqKrz55puoqKjAxIkTr/sUZiJ7xt1iRPU8+eST2LFjBxYvXgw3NzdMmjQJrq6uAACNRgONRmPZipEkCaWlpVi4cCHi4uKwfPly3HbbbUqOT2Q1ePsXokucO3cOkydPxubNm6HX6xEdHY0uXbrgyJEjMJlMcHZ2RqdOnbB3714sW7YMv/76K9544w1MmjSpwe4zIkfGuBBdRnl5OdasWYOVK1fizz//hE6nQ1BQEDw9PVFSUoIjR46goKAAvXv3xvz58zFkyBCoVNzLTGTGuBBdRWFhIVJTU5Geno4DBw5g7969GDx4MAYNGoRhw4ahX79+cHNzU3pMIqvDuBBdp19++QV9+/bFL7/8guDgYKXHIbJq3I4nuk7m4ync/UV0bVwlREQkO8aFiIhkx7gQEZHsGBciIpId40JERLJjXIiISHaMCxERyY5xISIi2TEuREQkO8aFiIhkx7gQEZHsGBciIpId40JERLJjXIiISHb8PBciIpIdt1yIiEh2GqUHIJKLJEnIycnBX3/9pfQojaJSqdCtWze4u7srPQrRTeNuMbIbJpMJzz//PNq2bQtnZ2c4OTlBpVJBCKH0aDdkx44dmD9/Pnr06KH0KEQ3jVsuZFdcXFwQHh6OmTNnQqvV4tZbb0Xbtm0REBCATp06ITAwEK1bt4arq6tVRkeSJJSVlYHv+cjWMS5kdwoLC5GVlYWqqqoGf65Wq+Hl5YXAwEAMHToUUVFR6NmzJ5ycnKwyNES2jHEhuxMUFIT169cjPz8fBQUFyM/Px8mTJ5GXl4ezZ88iMzMTmZmZiI2NRXh4OF5++WX07dsXarVa6dGJ7AbjQnbHz88PI0eOtOxakiQJtbW1uHDhAk6cOIGMjAxs3LgRe/fuxaZNm5Ceno5p06ZhxowZcHd351YMkQx4KjLZLSEEhBBQqVRwdnbGP/7xDwQHB2PatGlITk5GYmIi7r77bpSWluLNN9/ExIkTUVxczOMdRDJgXMjhCCGg1WoxdOhQbNy4EfPmzYOLiwu++eYbTJ48GaWlpQwMUSMxLuSwhBDw9PTE7Nmz8d5778HNzQ3ff/893njjDdTW1io9HpFNY1zI4Wk0GowfPx5z586FSqXCBx98gISEBG69EDUC40KE/wVm6tSpGDlyJCoqKjBnzhycOnWKgSG6SYwL0UWurq546623EBgYiGPHjmHp0qWoq6tTeiwim8S4EF0khEBAQABeffVVaDQafPnll8jKyuLWC9FNYFyI6hFCYOTIkRgwYADOnz+P5cuX8+A+0U1gXIgu4eHhgRkzZsDZ2RlJSUnceiG6CYwL0SWEEAgPD8eAAQNQVlaG1atXw2QyKT0WkU1hXIguw83NDc899xw0Gg0SExORk5Oj9EhENoVxIboMIQT0ej26dOmCoqIifPPNN9w1RnQDGBeiK/D29saYMWMAAN999x2Ki4sVnojIdjAuRFcghEBUVBT8/PyQk5OD1NRUbr0QXSfGhegqOnbsCL1ej7q6Onz55Zc8LZnoOjEuRFehUqkwZswYODs7IyMjA0eOHFF6JCKbwLgQXYUQAgMGDECXLl1gMBh4Q0ui68S4EF2Dl5cXHnnkEQBAQkICSktLFZ6IyPoxLkTXIITAww8/DG9vbxw+fBh79uzh1gvRNTAuRNchKCgIISEhqK6uRlxcHK/YJ7oGxoXoOjg5OeHRRx+FWq3Gtm3bkJ+fr/RIRFaNcSG6DkIIDB06FO3bt8eZM2ewdetW7hojugrGheg6+fn54YEHHoAkSYiLi0NlZaXSIxFZLcaF6DoJITBq1ChotVr88ssvOHDgALdeiK6AcSG6TkII9OrVC71790Z5eTlvZkl0FYwL0Q1wdXXFmDFjIITApk2bUFhYqPRIRFaJcSG
2024-04-29 12:35:18 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-29 12:35:18 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "5ca6421a",
"metadata": {},
"source": [
"Retrain the model, the loss remains similar, meaning that the locking does not degrade model behavior, justifying our hypothesis that these two activation functions are the same. Let's now determine what this function is using $\\texttt{suggest_symbolic}$"
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 4,
2024-04-29 12:35:18 -04:00
"id": "2ccb7048",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 arctanh 0.999999 -16.470396 4 4 -16.470396\n",
"1 tan 0.999842 -12.540685 3 3 -12.540685\n",
"2 arcsin 0.998866 -9.771875 4 4 -9.771875\n",
"3 arccos 0.998866 -9.771725 4 4 -9.771725\n",
"4 x^0.5 0.982258 -5.815842 2 2 -5.815842\n"
2024-04-29 12:35:18 -04:00
]
},
{
"data": {
"text/plain": [
"('arctanh',\n",
2024-07-13 22:17:48 -04:00
" (<function kan.utils.<lambda>(x)>,\n",
" <function kan.utils.<lambda>(x)>,\n",
" 4,\n",
" <function kan.utils.<lambda>(x, y_th)>),\n",
" 0.9999989867210388,\n",
" 4)"
2024-04-29 12:35:18 -04:00
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 4,
2024-04-29 12:35:18 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.suggest_symbolic(0,1,0,weight_simple=0.0)"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "markdown",
"id": "0092be41",
"metadata": {},
"source": [
"We can see that ${\\rm arctanh}$ is at the top of the suggestion list! So we can set both to arctanh, retrain the model, and plot it."
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 5,
2024-04-29 12:35:18 -04:00
"id": "1bb96fe1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"r2 is 0.9999993443489075\n",
"r2 is 0.9999989867210388\n"
2024-04-29 12:35:18 -04:00
]
},
{
"data": {
"text/plain": [
"tensor(1.0000)"
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 5,
2024-04-29 12:35:18 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'arctanh')\n",
"model.fix_symbolic(0,1,0,'arctanh')"
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 6,
2024-04-29 12:35:18 -04:00
"id": "83b852a3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 3.91e-04 | test loss: 4.70e-04 | reg: 5.54e+00 : 100%|██| 20/20 [00:02<00:00, 7.30it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20, update_grid=False);"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 7,
2024-04-29 12:35:18 -04:00
"id": "9ccd0923",
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlFklEQVR4nO3deXxU9b3/8fd3sk4IEEgjFiNKMCqIqCBLJQhcJVGprTd4XW8r1lakCAWXlqqtFAUvFBVS26rYKrgFr8EF4YIWURRs4QcKFhWBuIWIbBnMMmQh5/fHl8mC7JxwZnk9Hw8eD52TgU8C33mf73qM4ziOAABwkc/rAgAA0YdwAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALgu3usCgEjgOI527NihiooKpaamKj09XcYYr8sCwhY9F+AgAoGAZsyYoezsbGVkZKhz587KyMhQdna2ZsyYoUAg4HWJQFgyPIkS2L9FixZp2LBhqqqqkmR7LyGhXktKSoqKioqUl5fnSY1AuCJcgP1YtGiRhg4dKsdxVF9ff8Cv8/l8MsZo/vz5BAzQBOEC7CMQCCgzM1PBYPCgwRLi8/nk9/tVUlKitLS0li8QiADMuQD7mDVrlqqqqg4rWCSpvr5eVVVVmj17dgtXBkQOei5AE47jKDs7W8XFxTqSpmGMUVZWljZs2MAqMkCEC9DM9u3blZGRcUzvT09Pd7EiIDIxLAY0UVFRcUzvLy8vd6kSILIRLkATqampx/T+1q1bu1QJENkIF6CJ9PR0denS5YjnTYwx6tKli9q3b99ClQGRhXABmjDGaPTo0Uf13jFjxjCZD+zFhD6wD/a5AMeOnguwj7S0NBUVFckYI5/v4E0ktEN/7ty5BAvQBOEC7EdeXp7mz58vv98vY8x3hrtCr/n9fi1YsEC5ubkeVQqEJ8IFOIC8vDyVlJRo+vTpysrKanYtKytL06dP1+bNmwkWYD+YcwEOg+M4WrJkiS666CItXrxYgwcPZvIeOAh6LsBhMMY0zKmkpaURLMAhEC4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLsAh1NbWavPmzfr4448lSZs2bdLOnTtVX1/vcWVA+OIxx8ABBAIBFRUV6dlnn9W6detUXl6umpoaJScnKyMjQwMGDNBNN92k/v37Kz4+3utygbBCuAD78d5772ncuHFau3atevfuraFDh6pHjx5KTU1VIBDQqlWrNG/ePG3cuFFXX3217r//fmVkZHhdNhA2CBdgH6+//rqGDx+u1NRUPfDAA7rssstUU1OjwsJCVVdXq02bNrrmmmtUW1urwsJCTZgwQWeddZaefvppdejQwevygbBAuABNfPrpp7rkkkvUqlUrFRYWqlu3bjLGqLi4WD179tSuXbvUuXNnrVq1Su3atZPjOHr33Xd13XXXadCgQXriiSeUlJTk9bcBeI4JfWCvPXv2aPLkySorK9MjjzzSECwHY4xRTk6Opk6dqldeeUULFy48TtUC4Y1wAfbauHGj5s2bp/z8fOXk5BwyWEKMMbriiivUr18/zZw5U3V1dS1cKRD+WOIC7LV8+XJVVFRo2LBh+vzzz1VZWdlwraSkRHv27JEk1dTUaN26dWrTpk3D9Y4dOyo/P18TJkzQli1blJmZedzrB8IJ4QLs9cknnyglJUVZWVkaMWKEli1b1nDNcRxVV1dLkkpLSzVkyJCGa8YYPfjggzr77LNVVVWl0tJSwgUxj3AB9goGg4qPj1dSUpKqq6u1e/fu/X6d4zjfuVZXVye/398shIBYRrgAe51wwgkKBoMKBALq27evWrVq1XAtGAxq+fLlDSFywQUXNGycNMaoU6dO2rp1q3w+n9q1a+fVtwCEDcIF2KtXr16qra3VihUrNGXKlGbXiouL1bt3b+3atUsdOnTQnDlzlJaW1nDdGKO77rpLJ554IkNigFgtBjTo06ePsrKyNGvWLFVWViouLq7ZrxBjjHw+X8PrPp9PX3/9tV588UUNHTpUbdu29fC7AMID4QLslZ6erltvvVWrV69WQUHBYS8prq6u1n333adgMKgRI0Yc9hJmIJoxLAY0MXz4cC1dulRTpkxRSkqKRo4cqeTkZElSfHy84uPjG3oxjuOovLxckyZNUmFhoR5++GGdccYZXpYPhA2OfwH2sW3bNo0aNUqvvfaa8vLyNG7cOHXt2lXr169XfX29EhMTddppp2nFihWaNm2aPvjgA02cOFEjR45sNnwGxDLCBdiPyspKzZw5UwUFBfrmm2+UlZWl7OxstW7dWmVlZVq/fr1KS0vVq1cv3XvvvRo4cKB8PkaZgRDCBTiILVu2aPHixXr77be1Zs0arVixQgMGDFD//v2Vm5urvn37KiUlxesygbBDuACHaeXKlerTp49Wrlyp888/3+tygLBGPx44TKH5FIa/gEOjlQAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcx/NcAACuo+cCAHBdvNcFAG5xHEcbNmzQjh07vC7lmPh8PnXv3l2tWrXyuhTgqDEshqhRX1+vX/7ylzrppJOUkJCgxMTEhqdHRpKlS5fq3nvvVY8ePbwuBThq9FwQVZKSkjRkyBDddtttSk1N1cknn6xTTjlFp556qrKysnTqqacqIyNDiYmJMsZ4Xe53OI6jiooKcc+HSEe4IOqUlpZq9erVqq6ubvZ6QkKC0tLSdPrpp2vIkCG64oor1K1bN8XHx4dl0ACRjHBB1DnjjDP09NNP64svvtAXX3yhL7/8UiUlJSotLdXOnTu1bNkyLVu2TA899JAuvfRS3XnnnTrvvPPk87G+BXAL4YKo873vfU//9V//1TC0VF9fr5qaGgUCAW3atElLly7Vq6++qvfff19z5szRG2+8oTvvvFOjR49WSkoKvRjABdyqIWoZY2SMUVxcnPx+v77//e8rJydHv/3tb7V48WLNnTtXOTk5CgQCuueeezRq1CgFAgHmOwAXEC6IOcYYtWrVSpdddpnmzZun8ePHKzExUbNnz9aYMWNUUVHhdYlAxCNcELOMMUpLS9OECRM0depUJScn67nnntOkSZNUW1vrdXlARCNcEPMSEhJ0yy236O6775bP51NBQYFeeeUVhseAY0C4AJLi4+M1btw45efnKxgM6u6771ZJSQkBAxwlwgXYy+/3a/LkyercubM+/fRTTZs2TfX19V6XBUQkwgXYyxijrKwsjR8/XnFxcZo9e7ZWr15N7wU4CoQL0IQxRtdee6369eunQCCghx9+WHV1dV6XBUQcwgXYR2pqqm6//XYlJibqtddeo/cCHAXCBdiHMUa5ubnq16+fysvL9fjjjzP3AhwhwgXYj5SUFI0YMUJxcXF69dVXtXHjRq9LAiIK4QLshzFGl156qbp27art27fr+eefZ2gMOAKEC3AAaWlpuv766yVJc+bM0c6dOz2uCIgchAtwAMYYXXnllcrIyNCGDRv05ptv0nsBDhPhAhxE586dlZubqz179uiZZ55hWTJwmAgX4CB8Pp+uv/56JSYmaunSpVq/fr3XJQERgXABDsIYowsuuEBdu3ZVIBDQSy+9xNAYcBgIF+AQ2rRpo2HDhkmSioqKVF5e7nFFQPgjXIBDMMYoPz9faWlp+vjjj/Xee+/RewEOgXABDkN2drZycnJUU1Oj559/nh37wCEQLsBhSEhI0LXXXqu4uDi9/vrr2rx5s9clAWGNcAEOgzFGF110kTp16qQtW7ZowYIFDI0BB0G4AIcpIyNDl19+uRzH0Zw5c1RdXe11SUDYIlyAw2SM0VVXXSW/36+VK1dqzZo19F6AAyBcgMNkjNF5552nXr16qbKyksMsgYMgXIAj4Pf7dd1118kYo5dffllbtmzxuiQgLBEuwBEwxuj
2024-04-29 12:35:18 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-29 12:35:18 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "4b98a727",
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"We will see that ${\\rm tanh}$ is at the top of the suggestion list! So we can set it to ${\\rm tanh}$, retrain the model to machine precision, plot it and finally get the symbolic formula."
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 8,
2024-04-29 12:35:18 -04:00
"id": "99ad38b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 tanh 0.999974 -14.743149 3 3 -0.548630\n",
"1 x 0.945782 -4.204828 1 1 -0.040966\n",
"2 0 0.000000 0.000014 0 0 0.000003\n",
"3 cos 0.995867 -7.915010 2 2 0.016998\n",
"4 sin 0.995867 -7.915010 2 2 0.016998\n"
2024-04-29 12:35:18 -04:00
]
},
{
"data": {
"text/plain": [
"('tanh',\n",
2024-07-13 22:17:48 -04:00
" (<function kan.utils.<lambda>(x)>,\n",
" <function kan.utils.<lambda>(x)>,\n",
" 3,\n",
" <function kan.utils.<lambda>(x, y_th)>),\n",
" 0.9999735355377197,\n",
" 3)"
2024-04-29 12:35:18 -04:00
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 8,
2024-04-29 12:35:18 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(1,0,0)"
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 9,
2024-04-29 12:35:18 -04:00
"id": "af24c80d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"r2 is 0.9999735355377197\n"
2024-04-29 12:35:18 -04:00
]
},
{
"data": {
"text/plain": [
2024-07-13 22:17:48 -04:00
"tensor(1.0000)"
2024-04-29 12:35:18 -04:00
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 9,
2024-04-29 12:35:18 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(1,0,0,'tanh')"
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 10,
2024-04-29 12:35:18 -04:00
"id": "01936f17",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 6.55e-08 | test loss: 6.85e-08 | reg: 5.57e+00 : 100%|██| 20/20 [00:01<00:00, 17.11it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 11,
2024-04-29 12:35:18 -04:00
"id": "76bcc188",
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkUklEQVR4nO3deXiV5Z3/8c99si9AAkYUUynBlAoCKrLIYsSFWOO0Cu3oONZBaUcZlSm1dhymv4rihpfK4jI6aKfgWENtUESQOEVHVgdKWSxVBCKjIaIYcmKWQxJynt8fNycLhhDgSZ6zvF/XlQs8J6d8obnP59y7cRzHEQAALvJ5XQAAIPoQLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXxXtdABAJHMdReXm5qqurlZ6erl69eskY43VZQNii5wK0w+/3a968ecrNzVVWVpb69eunrKws5ebmat68efL7/V6XCIQlw02UQNuKi4s1adIk1dbWSrK9l5BQryU1NVVFRUXKz8/3pEYgXBEuQBuKi4tVUFAgx3EUDAaP+X0+n0/GGC1fvpyAAVogXICj+P1+ZWdnKxAItBssIT6fTykpKSotLVVGRkbnFwhEAOZcgKMsXLhQtbW1HQoWSQoGg6qtrdWiRYs6uTIgctBzAVpwHEe5ubkqKSnRiTQNY4xycnK0a9cuVpEBIlyAVr766itlZWWd0ut79erlYkVAZGJYDGihurr6lF5fVVXlUiVAZCNcgBbS09NP6fXdunVzqRIgshEuQAu9evVS//79T3jexBij/v37q2fPnp1UGRBZCBegBWOM7rrrrpN67bRp05jMB45gQh84CvtcgFNHzwU4SkZGhoqKimSMkc/XfhMJ7dBfsmQJwQK0QLgAbcjPz9fy5cuVkpIiY8w3hrtCj6WkpGjFihWaMGGCR5UC4YlwAY4hPz9fpaWlmjt3rnJyclo9l5OTo7lz52rfvn0EC9AG5lyADnAcR++++64uv/xyrVq1SuPHj2fyHmgHPRegA4wxTXMqGRkZBAtwHIQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAtwHA0NDdq3b58+/PBDSdKePXt08OBBBYNBjysDwhfXHAPH4Pf7VVRUpJdfflk7duxQVVWV6uvrlZycrKysLI0bN05TpkzRmDFjFB8f73W5QFghXIA2bNiwQdOnT9f27ds1fPhwFRQUaMiQIUpPT5ff79fmzZu1bNky7d69W9dff70efPBBZWVleV02EDYIF+Aob7/9tiZPnqz09HQ98sgjuvrqq1VfX6/CwkLV1dWpe/fuuuGGG9TQ0KDCwkLNnDlTgwYN0ksvvaTevXt7XT4QFggXoIWPP/5YV111ldLS0lRYWKiBAwfKGKOSkhJdeOGFqqysVL9+/bR582ZlZmbKcRytXbtWN954oy699FK98MILSkpK8vqvAXiOCX3giMbGRj388MOqqKjQ008/3RQs7THGaOzYsXrssce0dOlSrVy5souqBcIb4QIcsXv3bi1btkwTJ07U2LFjjxssIcYYXXvttRo1apQWLFigw4cPd3KlQPhjiQtwxPr161VdXa1JkyZp7969qqmpaXqutLRUjY2NkqT6+nrt2LFD3bt3b3q+T58+mjhxombOnKn9+/crOzu7y+sHwgnhAhzx0UcfKTU1VTk5Obrtttu0bt26puccx1FdXZ0kqaysTFdeeWXTc8YYPfHEExo8eLBqa2tVVlZGuCDmES7AEYFAQPHx8UpKSlJdXZ0OHTrU5vc5jvON5w4fPqyUlJRWIQTEMsIFOOL0009XIBCQ3+/XyJEjlZaW1vRcIBDQ+vXrm0Jk9OjRTRsnjTE6++yz9eWXX8rn8ykzM9OrvwIQNggX4Ihhw4apoaFBGzdu1OzZs1s9V1JSouHDh6uyslK9e/fW4sWLlZGR0fS8MUYzZszQGWecwZAYIFaLAU1GjBihnJwcLVy4UDU1NYqLi2v1FWKMkc/na3rc5/Pp888/1x/+8AcVFBSoR48eHv4tgPBAuABH9OrVS3feeaf+/Oc/a/78+R1eUlxXV6dZs2YpEAjotttu6/ASZiCaMSwGtDB58mStXr1as2fPVmpqqqZOnark5GRJUnx8vOLj45t6MY7jqKqqSg899JAKCws1Z84cDRgwwMvygbDB8S/AUQ4cOKA77rhDb775pvLz8zV9+nSde+652rlzp4LBoBITE3XOOedo48aNevzxx7V161Y98MADmjp1aqvhMyCWES5AG2pqarRgwQLNnz9fX3zxhXJycpSbm6tu3bqpoqJCO3fuVFlZmYYNG6b77rtPeXl58vkYZQZCCBegHfv379eqVav03nvvqWTbNh3auFGZ48bpvDFjNGHCBI0cOVKpqalelwmEHcIF6KDGTZvkjBgh36ZN8l10kdflAGGNCX2gg+Li4iRjJIa/gOOilQAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcx30uQEc5jhQM2iP3jfG6GiCs0XMBTgR3uQAdwmVhiBqO4+jArl2qKy/3upRTYnw+nXbeeUpOS/O6FOCkES6IGo7j6K9z5ijprLOUlJQkxcXZrwgbwjqwZo2G/L//pzOHDPG6FOCkES6IKiYpSQMvvVQ9ZsyQEhKkzEypd2+pb18pN1caMED61rek9PQjLwiv4HEcRxurq+38DhDBCBdEn4MHpT/9Saqra/24MVJamg2avDzphz+URoyQkpPDLmSASEe4IPrk5Ei/+Y0NmQMHpM8/lz79VPrkE6msTPrrX6UdO+z3jBsn3XOPDZsIHEIDwhXhguiTlSXdcIP9fWh4qbFRqqmRSkulDRuk116T1qyR/vu/7X/ffrs0Y4bUvTsBA7iAdZWIbsbYr/h4qUcPadAgacoUGy4rV0rXXCMdOiQ9+aR9/Msvme8AXEC4IPYYIyUmShdfLL3yivTww3Yu5vXXpalTJb+fgAFOEeGC2GWMlJoq/exn0lNP2RVkb7wh/frXUkOD19UBEY1wAeLipBtvlGbNssuXX3xRWryY3gtwCggXQLIB89Of2pCpq5Puv9+uLiNggJNCuAAhSUnSzJl2o+Unn0iPPWZXmQE4YYQLEGKMlJ0t/eu/2tVlixdLGzfSewFOAuECtGSMdN110tix0tdfS3PnMrkPnATCBThaWpo0fbodJlu50h4lQ+8FOCGEC3A0Y6TLLpNGj7a7+v/jP5h7AU4Q4QK0JSVFuu02O/eyYoX08cdeVwREFMIFaIsx0pVX2uNiysvtTn6GxoAOI1yAY+nRQ7rpJhs0r74qffWV1xUBEYNwAY4ltHLs9NOlkhJp1Sp6L0AHES5Ae84+W8rPtxP6L7/MsmSggwgXoD0+nz0SJjFRWrdO+ugjrysCIgLhArTHGGnkSDuxX1lp74FhaAw4LsIFOJ5u3aRJk+zvX3vN7twH0C7CBTgeY6Qf/EDKzLTDYhs20HsBjoNwAToiN1caN85O6BcWSsGg1xUBYY1wAToiPl66/np778sf/yiVlnpdERDWCBegI4yRxo+X+vaV9u+3B1oyNAYcE+ECdFRWllRQYEPl97+3N1YCaBPhAnSUMdKPfmQPtdy0Sdq+nd4LcAyEC9BRxkgXXGC/amrsTZWEC9AmwgU4ESkpdse+MdLrr0tffOF1RUBYIlyAE2GMdM010llnSZ9+Kr35Jr0XoA2EC3CizjpLuvZau9flt7+1Q2QAWiFcgBNljPQP/yB17y5t3iy98w69F+AohAtwooyRBg+WrrrK7th/9lnp0CGvqwLCCuECnIz4eGnqVCktTVq92u7ap/cCNCFcgJNhjDRqlHT11XYz5eOPS1VVXlcFhA3CBThZCQnS3XdLGRn2pOSXXqL3AhxBuAAnyxjpwgulW2+1K8dmz5Z
2024-04-29 12:35:18 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-29 12:35:18 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 12,
2024-04-29 12:35:18 -04:00
"id": "b62b0246",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
2024-07-13 22:17:48 -04:00
"$\\displaystyle \\tanh{\\left(\\operatorname{atanh}{\\left(x_{1} \\right)} + \\operatorname{atanh}{\\left(x_{2} \\right)} \\right)}$"
2024-04-29 12:35:18 -04:00
],
"text/plain": [
2024-07-13 22:17:48 -04:00
"tanh(atanh(x_1) + atanh(x_2))"
2024-04-29 12:35:18 -04:00
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 12,
2024-04-29 12:35:18 -04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"formula = model.symbolic_formula()[0][0]\n",
"nsimplify(ex_round(formula, 4))"
2024-04-29 12:35:18 -04:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}