358 lines
42 KiB
Plaintext
358 lines
42 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Example 5: Special functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2571d531",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's construct a dataset which contains special functions $f(x,y)={\\rm exp}(J_0(20x)+y^2)$, where $J_0(x)$ is the Bessel function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "2075ef56",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda\n",
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 5.15e-01 | test_loss: 5.86e-01 | reg: 5.84e+00 | : 100%|█| 20/20 [00:03<00:00, 5.89it\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"print(device)\n",
|
|
"\n",
|
|
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
|
|
"model = KAN(width=[2,1,1], grid=3, k=3, seed=2, device=device)\n",
|
|
"f = lambda x: torch.exp(torch.special.bessel_j0(20*x[:,[0]]) + x[:,[1]]**2)\n",
|
|
"dataset = create_dataset(f, n_var=2, device=device)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f30c3ab",
|
|
"metadata": {},
|
|
"source": [
|
|
"Plot trained KAN, the bessel function shows up in the bettom left"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3f95fcdd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 6 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "187d19f9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.2\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 1.54e-02 | test_loss: 4.73e-02 | reg: 7.50e+00 | : 100%|█| 20/20 [00:02<00:00, 6.93it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.3\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = model.refine(20)\n",
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "8d50bcef",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxPElEQVR4nO3deXSTZd4+8OtO0iVtWtKNpS1LUyqLgkgpSym0LEMd0VHBceHMjLjr4IZzjjrMT0XG5dWRERAdfdHXQXQOKOVVthFGhAJlKVRAdihla0v3plvaNM1z//4oed4WAUGe9Enb63MOR2mS5tuSO1fu9RFSSgkiIiINGfQugIiIOh6GCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmjPpXQBReyClRHl5OWpra2GxWBAREQEhhN5lEfks9lyILsNut2P+/PlISEhAVFQU4uLiEBUVhYSEBMyfPx92u13vEol8kuCVKIkubt26dZg6dSocDgeA5t6Lh6fXEhQUhIyMDKSnp+tSI5GvYrgQXcS6deswefJkSCmhKMol72cwGCCEwJo1axgwRC0wXIguYLfbERsbi/r6+ssGi4fBYIDZbEZ+fj6sVqv3CyRqBzjnQnSBxYsXw+FwXFGwAICiKHA4HPjss8+8XBlR+8GeC1ELUkokJCQgLy8PV9M0hBCw2Ww4fvw4V5ERgeFC1EpZWRmioqKu6fEREREaVkTUPnFYjKiF2traa3p8TU2NRpUQtW8MF6IWLBbLNT0+JCREo0qI2jeGC1ELERERiI+Pv+p5EyEE4uPjER4e7qXKiNoXhgtRC0IIPPXUU7/osU8//TQn84nO44Q+0QW4z4Xo2rHnQnQBq9WKjIwMCCFgMFy+iXh26K9YsYLBQtQCw4XoItLT07FmzRqYzWYIIX4y3OX5mtlsxtq1azFp0iSdKiXyTQwXoktIT09Hfn4+5s2bB5vN1uo2m82GefPmoaCggMFCdBGccyG6AlJKbNy4ERMmTMCGDRswbtw4Tt4TXQZ7LkRXQAihzqlYrVYGC9HPYLgQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBD9DJfLhYKCAhw+fBgAcOLECVRUVEBRFJ0rI/JdvMwx0SXY7XZkZGTgiy++wMGDB1FTU4PGxkYEBgYiKioKY8aMwUMPPYTRo0fDZDLpXS6RT2G4EF3E9u3bMXPmTPz4449ISkrC5MmTMXjwYFgsFtjtduTk5GDVqlXIzc3FPffcg9deew1RUVF6l03kMxguRBdYv349pk+fDovFgjfffBO33HILGhsbsXTpUjidToSGhuLee++Fy+XC0qVLMXv2bFx//fVYsmQJunXrpnf5RD6B4ULUwrFjx3DzzTcjODgYS5cuxcCBAyGEQF5eHoYOHYqqqirExcUhJycHYWFhkFJi69atmDZtGtLS0vDxxx8jICBA7x+DSHec0Cc6z+1244033kBlZSUWLlyoBsvlCCGQkpKCt99+G9988w2+/fbbNqqWyLcxXIjOy83NxapVqzBlyhSkpKT8bLB4CCFwxx13YOTIkVi0aBGampq8XCmR7+MSF6Lztm3bhtraWkydOhWnTp1CXV2delt+fj7cbjcAoLGxEQcPHkRoaKh6e3R0NKZMmYLZs2ejqKgIsbGxbV4/kS9huBCdd+TIEQQFBcFms+Gxxx5DVlaWepuUEk6nEwBQWFiIX/3qV+ptQgjMnTsXgwYNgsPhQGFhIcOFOj2GC9F59fX1MJlMCAgIgNPpRENDw0XvJ6X8yW1NTU0wm82tQoioM2O4EJ3XtWtX1NfXw263Y8SIEQgODlZvq6+vx7Zt29QQSU5OVjdOCiHQq1cvlJSUwGAwICwsTK8fgchnMFyIzktMTITL5UJ2djbeeuutVrfl5eUhKSkJVVVV6NatG5YtWwar1areLoTArFmz0L17dw6JEYGrxYhUw4cPh81mw+LFi1FXVwej0djqj4cQAgaDQf26wWDAuXPnsHz5ckyePBldunTR8acg8g0MF6LzIiIi8OSTT+KHH37AggULrnhJsdPpxF//+lfU19fjscceu+IlzEQdGYfFiFqYPn06Nm/ejLfeegtBQUF44oknEBgYCAAwmUwwmUxqL0ZKiZqaGrz++utYunQp3n33XfTr10/P8ol8Bo9/IbpAaWkpZsyYgdWrVyM9PR0zZ87EgAEDcPToUSiKAn9/f/Tt2xfZ2dl45513sHfvXsyZMwdPPPFEq+Ezos6M4UJ0EXV1dVi0aBEWLFiA4uJi2Gw2JCQkICQkBJWVlTh69CgKCwuRmJiIV155BampqTAYOMpM5MFwIbqMoqIibNiwAZmZmdi3bx+ys7MxZswYjB49GpMmTcKIESMQFBSkd5lEPofhQnSFdu3aheHDh2PXrl0YNmyY3uUQ+TT244mukNFoVJchE9HlsZUQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5Xs+F6ApJKaEoCgwGA4QQepdD5NPYcyG6CryWC9GVMeldAJEWXC4Xzpw5A0VR9C7lmgkh0KtXL/j7++tdCtEvxnChDiE/Px9//OMfkZiYqHcp1ywnJwcffPAB4uPj9S6F6BdjuFCHIKXE4MGD8frrr+tdylU7cuQI/vu//xt33XUXkpOT8eKLL4JTodTeMVyow2lPk+35+fm47777sH//fuzcuRP/+c9/9C6JSBOcnSTSiZQSH374Ifbv3w8A2LNnD7Kzs3WuikgbDBcinZw+fRqLFy8G0LwKrbGxEVu3btW5KiJtMFyIdCClxCeffILCwkLExMRg3LhxMBqNyMvL43wLdQiccyHSwZkzZ/DPf/4TAPDAAw/gzjvvRFVVFRISEjBv3jxdayPSAsOFqI0pioJFixapvZaHHnoIvXr1AtDco2lPCxKILoXDYkRt7PTp02qv5cEHH0TPnj31LYjICxguRG1IURR89NFHOHfuHHr27ImHHnqIPRXqkBguRG0oLy8Pn332GQDg4YcfRmxsrM4VEXkHw4WojSiKgg8//BDFxcXo3bs3HnjgAfZaqMNiuBC1ASkljh07hs8//xxCCDz66KOIjo7Wuywir2G4ELUBRVGwYMEClJaWwmaz4f7772evhTo0hguRl0kpsX//fixbtgxCCMyYMQPdu3fXuywir2K4EHlZU1MT5s6dC7vdjoEDB+J3v/sdey3U4TFciLxISont27fj66+/htFoxHPPPYfw8HC9yyLyOoYLkRc1NDTgv/7rv+BwODB8+HBMnTqVvRbqFBguRF4ipcTq1avx/fffIyAgAC+88AIsFoveZRG1CYYLkZdUVFTgjTfegMvlwq9//WtMmjSJvRbqNBguRF7gOeZl//79sFqtmDVrFvz9/fUui6jNMFyINCalxOHDh/Hee+9BSokHH3wQQ4YMYa+FOhWGC5HGnE4nXn31VZSUlCAhIQEzZ86E0WjUuyyiNsVwIdKQlBIZGRlYuXIl/Pz88PLLL6NHjx56l0XU5hguRBqRUuLUqVN45ZVX4HK58Jvf/AZTpkzhcBh1SgwXIo04nU7MmjULJ0+eRExMDObMmYOAgAC9yyLSBcOFSAOKouCTTz7BihUr4Ofnh1dffRX9+vVjr4U6LYYL0TWSUmLz5s2YPXs2mpqacO+992LatGkMFurUGC5E10BKiePHj+OJJ55ARUUFbrrpJrz55pvc00KdHsOF6BeSUqK4uBgPP/wwjh07hh49euAf//gHunfvzl4LdXoMF6JfQEoJu92Oxx9/HFlZWQgNDcV7772HYcOGMViIwHAhumpSStTW1uKpp57C6tWrERgYiLfffhu33347g4XoPIYL0VWQUqKurg4zZ87EsmXL4Ofnh1deeQUPPPAADAY2JyIPtgaiKySlRE1NDZ555hksXrwYBoMBzz//PJ599lmYTCa9yyPyKWwRRFdASomysjI8+eSTyMjIgMlkwp/+9Ceedkx0CQwXop8hpcSJEyfw2GOPYdOmTQgICMCLL76IF154gTvwiS6B4UJ0GYqi4LvvvsPTTz+N48ePw2Kx4LXXXsPjjz8OPz8/vcsj8lkMF6KLkFKioaEB77//Pl5//XVUV1cjOjoa8+fPx+23384j9Il+BsOF6AJSSuTl5eEvf/kLVqxYAUVRkJycjIULF2Lw4MFcbkx0BbhajOg8KSWcTif+53/+B+PHj8dXX30Fo9GIRx55BP/7v//LYCG6Cuy5UKcnpYTb7cbu3bvx97//HStXroTL5ULv3r3xyiuv4L777oOfnx+DhegqMFyoU1MUBSdOnMDcuXPxr3/9C3V1dfDz88O0adPw17/+Fb1792aoEP0CDBfqdKSU6lUjP/roI3z22WcoKSmBwWDA8OHD8eyzz+KOO+6Av78/g4XoF2K4UKchpYTL5cLevXvx6aef4uuvv0ZJSQmEEIiPj8fMmTPxu9/9DhaLhaFCdI0YLtSheeZT8vPz8d1332HZsmXYsWMHHA4HhBDo06cPHn30UUyfPh1du3ZlqBBphOFCHY6UEoqi4Ny5c/juu++QkZGBnTt3oqKiAlJK+Pn5ITExEQ888ADuvPNOdOvWjaFCpDGGC3Uoy5cvx+bNm3Hq1Cns3r0bpaWlkFLCYDAgOjoaEydOxL333ovk5GQEBwczVIi8hOFCHcrmzZvxwQcfAACEEOjevTvGjRuHKVOmIDk5GVFRURBCMFSIvIzhQh1KYmIifvvb36Jbt24YPXo0UlJSLnrZYSmlThUSdQ4MF+oQhBDYv38/goKC0K9fPwghcOjQIRw6dEjv0q7avn372LOidk9IfoSjDqCxsRF5eXlwu916l3LNDAYD4uPjeZ0YatcYLkREpDkOixFdoZafwzhsRXR5PBWZ6Art2bMHRqMRe/bs0bsUIp/HcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhegKSClRWVkJAKisrAQv4Ep0eQwXosuw2+2YP38+EhISMHHiREgpMXHiRCQkJGD+/Pmw2+16l0jkk4TkRzCii1q3bh2mTp0Kh8MB4OKXOQ4KCkJGRgbS09N1qZHIVzFciC5i3bp1mDx5MqSUUBTlkvczGAwQQmDNmjUMGKIWGC5EF7Db7YiNjUV9ff1lg8XDYDDAbDYjPz8fVqvV+wUStQOccyG6wOLFi+FwOK4oWABAURQ4HA589tlnXq6MqP1gz4WoBSklEhISkJeXd1UrwoQQsNlsOH78uDofQ9SZMVyIWigrK0NUVNQ1PT4iIkLDiojaJw6LEbVQW1t7TY+vqanRqBKi9o3hQtSCxWK5pseHhIRoVAlR+8ZwIWohIiIC8fHxVz1vIoRAfHw8wsPDvVQZUfvCcCFqQQiBp5566hc99umnn+ZkPtF5nNAnugD3uRBdO/ZciC5gtVqRkZEBIQQMhss3Ec8O/RUrVjBYiFpguBBdRHp6OtasWQOz2QwhxE+GuzxfM5vNWLt2LSZNmqRTpUS+ieFCdAnp6enIz8/HvHnzYLPZWt1ms9kwb948FBQUMFiILoJzLkRXQEqJjRs3YsKECdiwYQPGjRvHyXuiy2DPhegKCCHUORWr1cpgIfoZDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMF6Kf4XK5UFBQgMOHDwMATpw4gYqKCiiKonNlRL6LlzkmugS73Y6MjAx88cUXOHjwIGpqatDY2IjAwEBERUVhzJgxeOihhzB69GiYTCa9yyXyKQwXoovYvn07Zs6ciR9//BFJSUmYPHkyBg8eDIvFArvdjpycHKxatQq5ubm455578NprryEqKkrvsol8BsOF6ALr16/H9OnTYbFY8Oabb+KWW25BY2Mjli5dCqfTidDQUNx7771wuVxYunQpZs+ejeuvvx5LlixBt27d9C6fyCcwXIhaOHbsGG6++WYEBwdj6dKlGDhwIIQQyMvLw9ChQ1FVVYW4uDjk5OQgLCwMUkps3boV06ZNQ1paGj7++GMEBATo/WMQ6Y4T+kTnud1uvPHGG6isrMTChQvVYLkcIQRSUlLw9ttv45tvvsG3337bRtUS+TaGC9F5ubm5WLVqFaZMmYKUlJSfDRYPIQTuuOMOjBw5EosWLUJTU5OXKyXyfVziQnTetm3bUFtbi6lTp+LUqVOoq6tTb8vPz4fb7QYANDY24uDBgwgNDVVvj46OxpQpUzB79mwUFRUhNja2zesn8iUMF6Lzjhw5gqCgINhsNjz22GPIyspSb5NSwul0AgAKCwvxq1/9Sr1NCIG5c+di0KBBcDgcKCwsZLhQp8dwITqvvr4eJpMJAQEBcDqdaGhouOj9pJQ/ua2pqQlms7lVCBF1ZgwX6vROnTqFTZs2YevWrXA4HLDb7RgxYgSCg4PV+9TX12Pbtm1qiCQnJ6sbJ4UQ6NWrF0pKStDU1ITc3FwkJSUhMDBQrx+JSHdcikydztmzZ5GZmYnMzExs2rQJZ86cgRACvXv3xokTJ/D+++/j4YcfbvWYvLw8JCUloaqqCn369MHu3bthtVrV24UQmDVrFt555x0YjUYEBgZixIgRGDt2LNLS0pCUlMQlytSpMFyowyssLMSmTZuwefNmbNq0CSdPngQADB48WH3zT0lJgaIoSElJQVhYGL799ttWE/aX2ucCNA+TFRYWIjU1FbfddhumT5+OzZs3Y/PmzdiyZQvsdjvMZjNGjhypPl9iYiL8/f11+X0QtQWGC3U4xcXFrcIkNzcXAHD99derb+5jxoxBRETETx77/vvv409/+hP+3//7f3jxxRfVoa/LhUtDQwOeffZZrFq1Ct9//z369eunfj+3240ff/wRmzdvRmZmJrZu3Yrq6moEBQVh1KhRSE1NRVpaGm666Sb4+fm1wW+HqG0wXKjdKy0tRWZmphomR48eBQD0798fqampSE1NxdixY6/o7K+6ujo8+OCDWLt2LV599VU88cQTCAwMxMmTJzF8+HB1WCw7OxtWqxU1NTV4/fXX8dFHH+Hdd9/FAw88cNnv39TUhL1796r1ZmVloba2FhaLBcnJyWq9Q4YM4WGY1K4xXKjdKSsrw5YtW9R5k0OHDgEAEhIS1Dfn1NTUX3zOV2lpKWbMmIHVq1cjPT0dM2fOxIABA3D06FEoigJ/f3/07dsX2dnZeOedd7B3717MmTMHTzzxBIxG41U9l8vlwp49e9Se1rZt2+BwOBAaGorRo0erP8vgwYOv+nsT6YnhQj6vsrISW7ZsUd+A9+/fDwCw2WytwiQ6Olqz56yrq8OiRYuwYMECFBcXw2azISEhASEhIaisrMTRo0dRWFiIxMREvPLKK0hNTYXBcO0HXjQ2NiInJ0cNzu3bt6OhoQFWqxUpKSnqz3rDDTdo8nxE3sJwIZ9TVVWFrVu3qmGyb98+SCnRu3dvpKWlqfMmbbFRsaioCBs2bEBmZiby8vLQ0NCAsLAw3HDDDZg0aRJGjBiBoKAgrz2/0+nErl271LDZuXMnnE4nwsPDMWbMGDVsruQcNKK2xHAh3dXU1CArK0t9A92zZw8URUFMTAzS0tLUSe/evXvrWqfb7YaUEgaDQbdeQ0NDA3bu3KnOL2VnZ8PlciEyMhJjx45Vw6Zfv34MG9IVw4XaXG1tLbZv3672THJycuB2u9GjR49WPZO4uDi+Qf4Mh8OBHTt2qGGze/duNDU1oWvXrq2GDPv27cvfJbUphgt5necN0BMmu3btUt8AW4YJ3wCvXW1tLXbs2KFuEP3hhx/U4G4ZNgxu8jaGC2nOM3TjCZOdO3fC5XIhKioKY8eOVcOEQzfeV11dje3bt/9kyDE2Nlb9d0hNTdV9yJE6HoYLXTOn04ns7OyLTjq3nAfgpLP+7HY7srKy1E2dLRdLtAwbnupM14rhQletsbERu3fvVodeduzYoS6XbbmCictlfV9FRUWrxRSeZd5xcXHqv2NaWhp69Oihc6XU3jBc6Ge5XC788MMPaphs375d3eiXkpKizptwo1/7V15eftENqn379tVkgyp1HgwX+omWR5Rs2rRJvUKjxWJptWucR5R0fKWlpeoQWmZmpnq0Tr9+/VqFTWRkpM6Vkq9huJB6uKInTLKystTDFZOTk9WeydChQ3m4YidXVFTUKmw8h4IOHDiw1Tlu4eHhOldKemO4dEKKouDAgQNqmGzduhV2ux2BgYEYNWqUGibDhg3jsfB0WYWFhWrQZGZmqpczGDRokBo2nssYUOfCcOkEpJQ4dOiQGiZbtmxBRUUFAgICMGLECDVMhg8fzgta0TU5e/asuqEzMzNTvRDbjTfeqIbN6NGj0aVLF71LJS9juHRAUkocPXpUbeBbtmxBaWkp/Pz8MGLECLWRjxgxgpfiJa86depUq7ApKCiAwWDATTfdpL4Ok5OTERISoneppDGGSwcgpURubq7agDMzM1FSUgKTyYSkpCS1EY8cOdKrhywSXY6UEidPnmx1iemioiIYjUYMHTpU3WMzatQoBAcH610uXSOGSzvkaaQtw+TcuXMwGo1ITExU9yaMHDkSFotF73KJLsrzoajlnI3nQ9GwYcPUsPH2ydPkHQyXduL06dOtwiQ/P7/V8EJaWhpGjRrV6rrvRO2JZzi3ZdiUl5fD398fSUlJ6twgh3PbB4ZLOzFo0CAcP35cnRhNS0tDcnIyrFar3qUReYWiKDh8+LAaNJ6FKEuWLMFvf/tbvcujn8FwaScURYEQgmdzUaclpYSUku2gnWC4EBGR5nh2hwZcLhfOnj0LRVH0LuWaCSHQs2dPbp6kq8I2QBdiuGigoKAATz75JBITE/Uu5Zrl5ORg4cKFsNlsepdC7UjLNuAZvmqvJ2KzDWiD4aIBKSUGDx6MOXPmeO05jh49io8//hi33XYbxo4d67Xn+ctf/gKOlNLVklJi0KBBGDJkCFauXImxY8fiwQcf1LusK7Z+/XocPHgQ48ePR1NTE9uABhguGvPGRKPb7cbs2bPx9ddfY+PGjdi4caNXlhyzQdG1EELgu+++w7Jly2C323H//fe3i1OzFUXBkiVLsHz5ctx9992Ijo7Wu6QOoX32WzuZqqoq/PDDDwCaj9MoKCjQuSKiixszZgyEEDhw4ADKy8v1LueK2O127Nq1C0Bz/e11OM/X8LfYDpSUlKgNtb6+HmfPntW5IqKLS0xMhMViQUlJCY4cOaJ3OVfk0KFDOHfuHIKDgzFq1Ci9y+kwGC7tQFFRERoaGgA0d+HZcyFf1bNnT9hsNrhcLmRlZfn8UKuUEpmZmWhsbER8fDzi4uL0LqnDYLi0AwUFBXC73erfCwsLfb7RUudkNpuRlJQEAMjKymr1uvVFLpcLmzZtAgCMHj2aZ5hpiOHi46SUKCwsbPW14uJinaoh+nljx46FEAIHDx5EaWmp3uVcVmFhIQ4dOgSDwYDx48frXU6HwnBpB4qKilr9vaysjD0X8lmJiYno0qULSktLsX//fr3LuSQpJXbt2oWKigpERkbipptu0rukDoXh0g5c+OmvvLy8Q+yEpo4pJiYG1113HdxuNzIzM332g5CUEt999x2klLjxxhvRvXt3vUvqUBguPs7tdqOsrAwA1GuzVFVVoampSc+yiC4pICAAKSkpAIAtW7bA6XTqXNHF1dTUYPv27QCACRMmwGg06lxRx8Jw8XFNTU2w2+0AgL59+wIAamtr4XK5dKyK6NKEEBg3bhxMJhOOHDmCM2fO6F3SRR05cgSnT5+G2WxW54lIOwwXH9fY2IiqqioAQHx8PIQQcDgcPvtpkAiAOsxUU1ODbdu2+dzQmJQSGzduhNPphM1mw3XXXad3SR0Ow8XHNTQ0oK6uDgaDATabDUajEfX19aivr9e7NKJLioiIwPDhwyGlxLp163xujtDlcmHDhg0AmnflBwcH61xRx8Nw8XEOhwMNDQ0wGAzo2bMnTCYTnE4n6urq9C6N6JIMBgNuvvlmCCGwc+dOn1uSXFBQgAMHDsBoNGLixIl6l9MhMVx8XG1tLZxOJ/z8/NCjRw/4+fmhqakJtbW1epdGdElCCKSkpMBqtaKoqAjZ2dk+MzQmpURWVhbsdju6du2KxMREzrd4AcPFx1VXV8PlciEgIABRUVEICAiA2+1GTU2N3qURXVbPnj0xZMgQKIqCtWvX+lS4rF+/HlJKDBs2DF27dtW7pA6J4eLjqquroSgKAgMDERYWBrPZDEVRUF1drXdpRJdlMpnw61//GgCQmZmprnrUW3l5uboEOT09nacgewl/qz7ObrdDURSYzWaEhITAbDZDSqmuICPyVUIITJgwARaLBWfPnsWePXt0771IKZGTk4PCwkKEhoZyCbIXMVx8nOfTXnBwMIKCgtRVLb7yKZDocuLj4zFo0CA0NTVhzZo1epcDAFi7di3cbjcGDx6M3r17611Oh8Vw8WFSSlRUVAAAQkNDERAQ0GqXvt6fAol+TkBAANLT0wEA33//ve5zhXa7Hd9//z0A4Oabb4a/v7+u9XRkDBcf5+mhhIaGws/PD126dGn1dSJfJoTApEmTYDabkZeXh3379ulWi2dI7PTp0wgODsakSZM4JOZFDBcf5+m5WK1WGAwGNVw450LtRf/+/TFgwAA0NjbqumpMSomvv/4aLpcLN954I3flexnDxQdIKeFwOLB//37U1NSoja/lxL3VaoUQolW4tLxfdXU1srKyUFFRweEy8ilms1ldNbZu3TrdhsbKy8uxfv16AMBvfvMbBAQE6FJHZ8Fw8QFutxsvvPACxowZgxkzZqCxsRFA8yWNPcNf4eHhAJpDBmgOF8+RGk1NTZg5cybS09PxyCOPqJdEJvIFQgjccsstMJvNyM3N1WXVmJQSmzdvRn5+PqxWK2655RYOiXkZw6UNSSnVPy3l5eXhyy+/RENDA1avXo2jR48CaA4Nz6e8C3sutbW16iVkz5w5gzVr1qCpqQkbNmzw6Qs0Uec0cOBADBo0CI2NjVixYkWbh4vb7caXX34JRVEwatQoxMXFtenzd0YMF41dLDyA5gMoFy5ciFmzZqGoqKjVkNbOnTvVTZEOh0P9ZOdyudRjXsLCwgD8X8+lrq5OvabLgQMH1Mc3NDRgx44drWqoq6vD6tWrkZOTc9EDBC9VM5FWAgMDceeddwIA/v3vf6vXKGorp06dwubNm2EwGHD33XfDZDK16fN3RgwXDRUVFeH555/HO++8A4fDoX5dSolvvvkGf/7zn/Huu+9i9uzZrd7kd+/e3erN/eDBgwAAp9MJh8MBIYQaKl26dFGP3W9sbISUEocOHWr1/fbs2aP+v6IoeOONN3DPPffgtttuw/bt21s9l8vlwldffYV58+a1eYOnzkMIgdtuuw1hYWHIz89XrwDZFqSUWLlyJSorKxEbG4vx48dzSKwNML410tTUhJdeegmff/45jEYjIiMjMX36dAghoCgKli9frvY0vv32WxQVFSEmJgZNTU04dOgQgOaTZBVFQV5eHqSUqK+vV09E9gyHhYaGQgiBhoYGdW7lxIkTAJobsJQSx44dQ2NjIwICAlBUVIR//etfcLvdqKiowCeffIJRo0ap9/3mm2/wyCOPwOl0YufOnYiOjtbht0edQZ8+fTBu3DisWLECS5YswdSpU9tkUt3hcGD58uUAgMmTJyMqKsrrz0nsuWjGYDBg3LhxiImJgdvtxhdffKFOzFdVVbVa319aWqr+va6uDmfPngUA3HDDDQCAwsJCuFwu9aJgJpMJoaGhAJovdWw0GuF0OtHQ0AC3261e6a9fv34QQuDcuXPqMNm+fftQXFysPndWVhYqKysBNPdaFi1aBKfTibCwMNx1113cVEZeYzQa8fvf/x4mkwk7d+7E3r17vd57kVIiOzsbBw4cQGBgIO6++272WtoIw0UjBoMB06ZNw9///ncYjUYcOHAABQUFAICzZ8+itLRU3QSpKApycnIgpUR5eTkqKirg5+ennnNUVlYGh8OBmpoauFwu+Pv7qzvzLRYL/Pz84HK5UFdXB6fTiZKSEgBASkoK/P39YbfbUVZWBikldu/eDUVREB4eDn9/fxQVFSEvLw9A80KAvXv3wmg04m9/+xumTJnCQ/zIazzH8A8aNAgOhwOLFy/2ergoiqJ+0Bs6dCiGDBnCcGkjfCfRWFJSEiIjI1FVVaXOfRw/fhwNDQ3o1q0b0tLSAAA//vgjpJQoLCyEw+FAcHAwhgwZAqPRiJqaGvWP2+1GQEAAgoKCADSfMebv7w+Xy4Xq6mrU1taisrISBoMBQ4cORUhICBoaGlBQUAAppbpyLC0tDTExMXA6ndi/f7+6W7m6uhpdu3bFxIkT2ejI6ywWC/7whz9ACIFVq1bh1KlTXn2+M2fOYN26dRBCYNq0aQgMDPTq89H/YbhoLCoqCtddd53aHZdS4vDhw5BSomfPnkhOTgbQPE9SX1+PM2fOoKmpCWFhYejbty/8/f1RX18Pu92OyspK9URks9kMoHlDWmBgoHrsfnV1Nerq6mAymdC3b19ERERAURScPn0aDQ0N6nzMyJEj1R3JniE5T30DBgxAZGSkDr8t6myEELjjjjvQs2dPlJaW4osvvvBa70VKiS+//BKlpaWIjY3F5MmT+QGqDTFcNGYymZCYmAgA+OGHH9DY2IjDhw8DABISEnD99dfDaDSiuLgYFRUV6hBVdHQ0evTogaCgILhcLpSXl6OyshJSSlgsFnUuJDAwEEFBQVAUBZWVlSgrK4PT6URgYCCio6PVCfm8vDxUVlaiuLgYBoMB/fv3V+d0Dh06hPr6ejVkhg4dCqPR2Ka/J+q8unfvjnvvvRcA8MUXX7SaE9RSWVkZlixZAgCYOnUqunfv7pXnoYtjuHhBUlIShBDIzc1FUVGR2nsYMGAAevfuDbPZjJqaGhQWFuLkyZMAgF69eqFLly6wWCxwu90oKSlBeXk5gP87tBJAq/mXyspKlJSUoKmpCSEhIbBareoR4qdOnUJhYSFqampgNpvRu3dv3HDDDRBC4PTp08jPz8fJkychhMCQIUPa+DdEnZkQAr///e8RGRmJ06dPY/ny5Zr3XqSUyMjIwIkTJxAWFob7779f0+9PP4/hojEhBK6//noEBwejrKwMOTk5OHfunNp7iIyMRHh4OBobG3Hs2DGcPn0aAGCz2RAYGKjuZykpKVEPrezSpYvaszCZTOqy5PLycnVDZpcuXRAUFKTuPD579qy6JDk8PByRkZHo168f/P39UVZWhuzsbJSXlyMwMBD9+/fncAG1qfj4eNx5552QUmLRokXqa10r5eXl+PDDDyGlxB133IHrrruOr/E2xnDxgpiYGMTExMDlcmH16tWw2+0wm82Ii4uDxWJBbGwsAGDv3r0oLCwEAMTFxcHPzw8REREAmjdkejY1RkREqKu4DAaDulu/rKwM586dAwBERkbC399fDZeioiLs27cPUkrExMQgJCQEMTExCA8PR319PVauXImGhgZERUVxbwu1OSEEHn30UVitVhw/fhxfffWVZr0XKSU+//xzHD16FGFhYZgxYwZXQeqAv3EvCA4OxqBBgwAAK1euRGNjI6KiotC9e3eYTCbYbDYAwM6dO1FeXg6TyYQ+ffrAYDCga9euAJrDwTMs5gkcoLlRejaBlZaWqsudu3fvDqPRiJ49e8LPzw+VlZXYsWMHgOZekZ+fH8LCwtCnTx9IKbFu3ToAzZ8gPXtoiNqKEAIDBw7ElClToCgK3n//fZSWlmryvc+ePYsPPvgAUkrcd999GDBgAHstOmC4eIFnPT8A9WywhIQEda6kf//+AJpXbdXU1CA4OBgxMTEAgG7dugFo3kjpCZcLV3J5AqhluPTo0UP9b3BwMGpra9Wl0P369QMA+Pn5YfDgwQCaj5YBOJlP+jEYDHjyyScRERGB3NxcfPLJJxc9++5quN1uvPfeezhz5gyio6Px1FNP8fWtE4aLFwghMGbMGISEhKhfGzVqFIxGI4QQGDBgAAwGA1wuF6SUiIyMVHsnnhUthYWF6ubIrl27qp+8hBBquBQXF6vDap5wiYiIQGRkpHrwpef5PI/3HP0CNDfukSNHevvXQXRRQgj0798f999/P6SU+Mc//oEjR4784uExKSV27dqFxYsXQwiBP/7xj+jTp4+2RdMVY7h4SUJCAsaNGwegeUK+5fUjEhISEBwcrN7XZrMhODgYQgi153Lu3DmUlZW1GgbziIqKghACxcXFOHfuHIQQiI6OhhCi1aQ+AAQFBSEhIQFAc2MeNWqUGk7R0dEYNmwYhwxINwaDATNmzEDfvn1RUlKCl19+GfX19b/oe1VXV+Oll15CdXU1hgwZggcffJCvbR0xXLzEz88Pc+fOxUsvvYSPP/5YHY4Cmt/Ue/Xqpf49MTFRnXDs1q0bjEajuvveZDL9ZFgsKioKBoMBVVVVqKqqgtFoVHsuBoMBN910k3rf2NhYdcgNAHr27IlZs2Zh6NCheOmll7j2n3QXHR2Nl19+GQEBAfj3v/+NhQsXqtcqulJNTU2YO3cusrKyEBwcjDlz5qgLX0gfDBcvEUIgNjYWs2bNwq233tpqtUpwcDDS09MBNO+4b3n0SlRUlLqnBQACAgJaTeh77tPyNNnAwEC1NyKEwIQJE9Tbx40bp871AM3h8+ijj2Ljxo3qMRxEevLs2v/DH/4ARVHw1ltvYcWKFVc8/+J2u/HPf/4TCxYsgJQSjz/+OMaNG8fXts545L4XXe7F/cwzz0BRFMTFxWHEiBHq18PCwhAcHKwep2+xWNS9Lx7h4eGwWCzqNWNCQ0NbfUobOXIkZs2ahePHj+O55577SR1CCF4/nHyKn58fZs+ejePHj2PTpk149tlnYbFYcPPNN1+2HbndbixZsgQvvvgiGhoacOutt+KFF17gxcB8AP8FdOCZlH/zzTfVv3uEhobCarWqK8U8QdJSSEgIIiMjW034t1xOHBAQgOeffx5SSggh+AmOfJ4QAuHh4fjwww9x3333Yc+ePXjkkUewYMEC3H777T9Z8SWlhMPhwMKFC/HWW2/B4XAgNTUVCxcubLWQhvTDYTGdeN70L3zjN5vNrTY1xsbG/qSXERgY2GrSPj4+/ifXYRFCwGAwMFio3RBCoHfv3li8eDEGDx6MsrIyPPLII3jhhRdw6NAhdXWlw+HApk2bcPfdd2POnDmor6/HpEmT8Omnn6Jbt258zfsIhouPMZlM6unFANRlyy0ZDAYkJSWpf/ecZUbU3gkhkJCQgKVLlyItLU3tnaSlpeG2227D448/jokTJ+L222/Hhg0b4Ofnh0cffRSLFy9Gjx492A58CIfFfNCYMWPw6aefwmg0IjU19Se3CyFw1113qZduvf3229moqMMQQiAuLg7Lli3Dp59+ik8//RS5ubnYtGmTep/AwECkpqbiueeew/jx41stgiHfwHDRmBbnI916663485//jJCQEIwfP/6i3zc+Ph7/+c9/AABWq9XrV/QjulJavRZDQ0PxzDPPYPr06dixYwe+//57lJaWwmazYdKkSbjxxhvVIWO+/n0Pw0UDQgjs378fr732mqbft6amBn/72980/Z4/Z9++fewF0VXzVhtoKSQkRF3csn79eqxfv94rz8M2oA0hGfnXrLGxESdPnrzqjV++yGAwwGaz/WSBANHlsA3QhRguRESkOa4WayeklFAUhWPL1KmxHbQfDJd2Yu/evTCbzdi7d6/epRDpZu/evQgKCmI7aAcYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4tANSSlRWVgIAKisreaEk6pQ87aDlf8l3MVx8mN1ux/z585GQkIAJEyagsbEREyZMQEJCAubPnw+73a53iURex3bQPgnJ+PdJ69atw9SpU+FwOACg1ac0IQQAICgoCBkZGUhPT9elRiJvYztovxguPmjdunWYPHmyer3wSzEYDBBCYM2aNWxY1OGwHbRvDBcfY7fbERsbi/r6+ss2KA+DwQCz2Yz8/HxYrVbvF0jUBtgO2j/OufiYxYsXw+FwXFGDAgBFUeBwOPDZZ595uTKitsN20P6x5+JDpJRISEhAXl7eVa2EEULAZrPh+PHj6jg0UXvFdtAxMFx8SFlZGaKioq7p8RERERpWRNT22A46Bg6L+ZDa2tprenxNTY1GlRDph+2gY2C4+BCLxXJNjw8JCdGoEiL9sB10DAwXHxIREYH4+PirHi8WQiA+Ph7h4eFeqoyo7bAddAwMFx8ihMBTTz31ix779NNPcxKTOgS2g46BE/o+huv7idgOOgL2XHyM1WpFRkYGhBAwGC7/z+PZmbxixQo2KOpQ2A7aP4aLD0pPT8eaNWtgNpshhPhJN9/zNbPZjLVr12LSpEk6VUrkPWwH7RvDxUelp6cjPz8f8+bNg81ma3WbzWbDvHnzUFBQwAZFHRrbQfvFOZd2QEqJiooK1NTUICQkBOHh4Zy0pE6H7aB9YbgQEZHmOCxGRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHm/j8h+UhYj3axHwAAAABJRU5ErkJggg==",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 6 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "733a2a41",
|
|
"metadata": {},
|
|
"source": [
|
|
"suggest_symbolic does not return anything that matches with it, since Bessel function isn't included in the default SYMBOLIC_LIB. We want to add Bessel to it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "031db28f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 0 0.000000 0.000014 0 0 0.000003\n",
|
|
"1 x 0.001602 -0.002298 1 1 0.799540\n",
|
|
"2 sin 0.161428 -0.253977 2 2 1.549205\n",
|
|
"3 cos 0.161428 -0.253977 2 2 1.549205\n",
|
|
"4 1/x^2 0.099456 -0.151116 2 2 1.569777\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('0',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 0,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.0,\n",
|
|
" 0)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.suggest_symbolic(0,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "4b8549a2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"dict_keys(['x', 'x^2', 'x^3', 'x^4', 'x^5', '1/x', '1/x^2', '1/x^3', '1/x^4', '1/x^5', 'sqrt', 'x^0.5', 'x^1.5', '1/sqrt(x)', '1/x^0.5', 'exp', 'log', 'abs', 'sin', 'cos', 'tan', 'tanh', 'sgn', 'arcsin', 'arccos', 'arctan', 'arctanh', '0', 'gaussian'])"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"SYMBOLIC_LIB.keys()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5db9e7cf",
|
|
"metadata": {},
|
|
"source": [
|
|
"add bessel function J0 to the symbolic library. we should include a name and a pytorch implementation. c is the complexity assigned to J0."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "cbde1924",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"add_symbolic('J0', torch.special.bessel_j0, c=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bda24c6d",
|
|
"metadata": {},
|
|
"source": [
|
|
"After adding Bessel, we check suggest_symbolic again"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "83e5cfdd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 0 0.000000 0.000014 0 0 0.000003\n",
|
|
"1 J0 0.198505 -0.319216 1 1 0.736157\n",
|
|
"2 x 0.001602 -0.002298 1 1 0.799540\n",
|
|
"3 sin 0.161428 -0.253977 2 2 1.549205\n",
|
|
"4 cos 0.161428 -0.253977 2 2 1.549205\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('0',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 0,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.0,\n",
|
|
" 0)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# J0 fitting is not very good\n",
|
|
"model.suggest_symbolic(0,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4180de14",
|
|
"metadata": {},
|
|
"source": [
|
|
"The fitting r2 is still not high, this is because the ground truth is J0(20x) which involves 20 which is too large. our default search is in (-10,10). so we need to set the search range bigger in order to include 20. now J0 appears at the top of the list\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "e78f4674",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 J0 0.998912 -9.830484 1 1 -1.166097\n",
|
|
"1 0 0.000000 0.000014 0 0 0.000003\n",
|
|
"2 x 0.001602 -0.002298 1 1 0.799540\n",
|
|
"3 cos 0.583964 -1.265186 2 2 1.346963\n",
|
|
"4 sin 0.583964 -1.265186 2 2 1.346963\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('J0',\n",
|
|
" (<function torch._C._special.special_bessel_j0>,\n",
|
|
" J0,\n",
|
|
" 1,\n",
|
|
" <function torch._C._special.special_bessel_j0>),\n",
|
|
" 0.9989116787910461,\n",
|
|
" 1)"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.suggest_symbolic(0,0,0,a_range=(-40,40))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|