GitHub_collection_pykan/tutorials/Example_10_relativity-addition.ipynb
kindxiaoming bfe4e84ec1 clean
2024-04-29 12:35:18 -04:00

433 lines
65 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "5d904dee",
"metadata": {},
"source": [
"# Example 10: Use of lock for Relativity Addition"
]
},
{
"cell_type": "markdown",
"id": "6465ec94",
"metadata": {},
"source": [
"In this example, we will symbolically regress $f(u,v)=\\frac{u+v}{1+uv}$. In relavitity, we know the rapidity trick $f(u,v)={\\rm tanh}({\\rm arctanh}\\ u+{\\rm arctanh}\\ v)$. Can we rediscover rapidity trick with KAN?"
]
},
{
"cell_type": "markdown",
"id": "94056ef6",
"metadata": {},
"source": [
"Intialize model and create dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0a59179d",
"metadata": {},
"outputs": [],
"source": [
"from kan import KAN, create_dataset\n",
"\n",
"# initialize KAN with G=3\n",
"model = KAN(width=[2,1,1], grid=10, k=3)\n",
"\n",
"# create dataset\n",
"f = lambda x: (x[:,[0]]+x[:,[1]])/(1+x[:,[0]]*x[:,[1]])\n",
"dataset = create_dataset(f, n_var=2, ranges=[-0.9,0.9])"
]
},
{
"cell_type": "markdown",
"id": "cb1f817e",
"metadata": {},
"source": [
"Train KAN and plot"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a87b97b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.28e-04 | test loss: 6.37e-04 | reg: 2.73e+00 : 100%|██| 20/20 [00:03<00:00, 5.41it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f1cfc9d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAApsklEQVR4nO3daVBVZ54G8OdcLuBFNrmCbIosF6ImcQtGTTJitKXTKC60JppMYoKdrumq2JOaLzPTk6qpSn+YmZpJ6UzXLJpMlI4dg1GDBpVW42jHmKBGxbgAAoIgqMgu673nPx/MPc1xBT1wt+dX5Yf7HsHXhL/Pec+7HEVEBERERAYyuboDRETkfRguRERkOIYLEREZjuFCRESGY7gQEZHhGC5ERGQ4hgsRERmO4UJERIZjuBARkeEYLkREZDiGCxERGY7hQkREhmO4EBGR4RguRERkOIYLEREZzuzqDhB5AhHBzZs30dHRgeDgYFitViiK4upuEbktjlyIHqClpQXr16+HzWZDZGQkEhMTERkZCZvNhvXr16OlpcXVXSRySwrfREl0b0VFRcjJyUFnZyeA26MXJ+eoJSgoCNu3b0dmZqZL+kjkrhguRPdQVFSErKwsiAhUVb3v7zOZTFAUBYWFhQwYon4YLkR3aGlpQXx8PLq6uh4YLE4mkwkWiwW1tbUIDw8f+g4SeQDOuRDdYfPmzejs7BxQsACAqqro7OxEXl7eEPeMyHNw5ELUj4jAZrOhsrISgykNRVGQlJSE8vJyriIjAsOFSKexsRGRkZGP9fVWq9XAHhF5Jj4WI+qno6Pjsb6+vb3doJ4QeTaGC1E/wcHBj/X1ISEhBvWEyLMxXIj6sVqtSE5OHvS8iaIoSE5ORkRExBD1jMizMFyI+lEUBe+8884jfe3atWs5mU/0I07oE92B+1yIHh9HLkR3CA8Px/bt26EoCkymB5eIc4f+jh07GCxE/TBciO4hMzMThYWFsFgsUBTlrsddzjaLxYI9e/ZgwYIFLuopkXtiuBDdR2ZmJmpra7Fu3TokJSXpriUlJWHdunWoq6tjsBDdA+dciAZARHDo0CHMmzcPBw8exNy5czl5T/QAHLkQDYCiKNqcSnh4OIOF6CEYLkREZDiGCxERGY7hQkREhmO4EBGR4RguRERkOIYLEREZjuFCRESGY7gQEZHhGC5ERGQ4hgsRERmO4UJERIZjuBARkeEYLkREZDgeuU80QCICEbnny8OISI8jF6JBYKgQDYzZ1R0g8hQMFqKB48iFiIgMx5ELuQ2Hw4Fjx46hra3N1V3xeNOnT8eYMWNc3Q3yYZzQJ7fR1dWFF154AV1dXQgJCXF1dzySiKC8vBybNm1Cdna2q7tDPowjF3IbIgKTyYQPPvgAGRkZru6O2xMROBwO9Pb24tatW7Db7bBarcjKygLvGcnVGC7kdsxmMwIDA13dDZcSEaiqilu3buHmzZtoaGjA1atXce3aNVy/fl371djYiBs3bqC5uRmTJk3Cjh07uPCA3ALDhcgNOPfQNDY2oqSkBF9//TVOnDiB8vJy3LhxQxuZPGhEEhISgu7u7mHsNdH9MVyIXEhE0NbWhiNHjiA/Px9HjhxBfX097Hb7oL9Xa2sr2tvbh6CXRIPHcCFyARFBU1MTtm7dio0bN+L8+fMDChQ/Pz/tsWFQUBBCQkIwevRoxMTEICkpyecfJ5L7YLgQDSMRQV9fH3bv3o3f/va3OHv2LFRVvev3+fn5ISIiAklJSUhLS0NKSgoSEhIQExODiIgIhIeHIzg4GBaLBYGBgTCbzVAUBX19fS74WxHdjeFCNExEBA0NDXjvvfewZcsW9PT06K4HBAQgLS0NmZmZmDdvHp588kmMHj0aAQEBAHhCAHkWhgvRMBARnD59Gr/85S9x4sQJ3bXg4GC89NJLWLNmDWbOnIng4GAADBPybAwXoiEmIjhy5AjeeustVFVVae1msxnz58/Hb37zG8yYMUN7tEXkDRguREPIGSyvvfYa6urqtHar1Yr33nsPubm5CAoKYqiQ12G4EA0REUFJSQlyc3N1wZKcnIwNGzZgzpw5MJl4dix5J4YL0RAQEdTX1+Ptt99GZWWl1j5p0iT8/ve/x+TJkzlaIa/G2yaiIdDd3Y2/+7u/w/Hjx7W25ORkBgv5DIYLkcFEBHl5edi6davWZrVasWHDBgYL+QyGC5GBRAQXLlzAb3/7W21DY0BAAN5//33MmTOHwUI+g+FCZKDe3l68//77ugn8l19+GatXr+bkPfkU/rQTGUREsHfvXnzxxRdaW3JyMv7xH/+RZ36Rz2G4EBmktbUV//RP/6Qd6+Lv74+///u/x/jx4/k4jHwOw4XIACKC/Px8nDx5UmubO3cuVqxYwWAhn8RwITLAzZs38bvf/Q4OhwMAMHLkSPzt3/4tgoKCXNwzItdguBA9JhHBtm3bcP78ea1t0aJFmD17Nkct5LMYLkSPqaWlBf/zP/+jvZclJCQEv/71r+Hv7+/inhG5DsOF6DGICHbv3o1z585pbVlZWZg2bRpHLeTTGC5Ej6GzsxMffvihNtcSFBSEv/qrv4LZzGP7yLcxXIgekYjgT3/6k+78sLlz52LGjBkctZDPY7gQPSK73Y6PP/5Yt6/lF7/4hfZaYiJfxnAhekQXL17E/v37tc+TJ0/G3LlzOWohAsOF6JGICD799FO0tLQAuP2++9dffx3BwcGu7RiRm2C4ED2CxsZGfP7559rnuLg4LF68mKMWoh8xXIgGSUSwf/9+VFVVaW1LlixBXFycC3tF5F4YLkSD1Nvbi08++US3/HjVqlUctRD1w3AhGqQLFy7gm2++0T7PmDEDU6ZMYbgQ9cNwIRoEEcHnn3+OtrY2AIDJZMKqVav4vhaiOzBciAahpaVF9zKw2NhY/PSnP+WohegODBeiARIRHD16FOXl5Vrbz372M8TGxrqwV0TuieFCNECqqiI/Px99fX0AgMDAQCxfvpyjFqJ7YLgQDVBtbS0OHjyofZ44cSLS09MZLkT3wHAhGgARwb59+9DQ0KC1LV26FCEhIS7sFZH7YrgQDUBPTw+2bdsGEQEAhIWFYcmSJRy1EN0Hw4VoAM6fP687Wn/WrFlITU11YY+I3BvDheghRAQ7d+5Ee3s7gNt7W1asWMHXGBM9AMOF6CFaW1tRUFCgfY6NjcVPfvITPhIjegCGC9EDiAiOHTuG0tJSrS0zMxMxMTEu7BWR+2O4ED2Aqqr47LPPtL0tAQEBWLFiBUctRA/BcCF6gLq6Ohw4cED7PGHCBDz77LMMF6KHYLgQ3Ydzb0t9fb3WtmTJEu5tIRoAhgvRffT09OCzzz7T9raEhoZi6dKlHLUQDQDDheg+zp49i+LiYu3zc889hyeeeMKFPSLyHAwXontwHlJ569YtALf3trzyyivc20I0QAwXontobGzUvbclISGBe1uIBoHhQnQHEUFRURGqqqq0tuzsbERFRbmwV0SeheFCdIeenh7k5eVBVVUAwMiRI/HKK69w1EI0CAwXon5EBCdPnsSxY8e0ttmzZ2Py5MkMF6JBYLgQ9aOqKjZt2oTOzk4AgJ+fH15//XUEBga6uGdEnoXhQtRPRUUFdu3apX1OTU3FT3/6U45aiAaJ4UL0I1VVkZeXhxs3bgAAFEXBa6+9hoiICBf3jMjzMFyIflRXV4dPPvlE+xwTE4OVK1dy1EL0CBguRLg9kZ+Xl4crV65obS+//DLGjRvnwl4ReS6GCxGA2tpafPTRR9o5YqNHj0Zubi5HLUSPiOFCPk9VVXz00Ueorq7W2lasWIG0tDSGC9EjYriQTxMRlJeXY+PGjdqoJTIyEr/61a9gMrE8iB4Vq4d8mt1ux7/927/p3tnyxhtvcNRC9JgYLuSzRASHDx/G1q1btbaEhAT86le/gp+fnwt7RuT5GC7kk0QETU1NeO+999DR0QHg9m78X//610hISHBx74g8H8OFfJLD4cC//uu/6l4GNmvWLLz55pt8HEZkAIYL+RwRwe7du/G73/1Om8QPCwvD+++/j9DQUBf3jsg7MFzIp4gITp06hXfffVf3lsm1a9fi+eef56iFyCAMF/IZzmXHubm5qKmp0drnz5+Pd999l5P4RAZiuJBPEBFcuHABq1atwpkzZ7T21NRUrF+/HmFhYS7sHZH3YbiQ11NVFYcOHcKyZcvw/fffa+3R0dHYsGEDUlNT+TiMyGBmV3eAaKiICNra2vDf//3f+Jd/+Rc0Nzdr16KiorBx40a88MILDBaiIcBwIa8jIujq6sKBAwfwz//8z/juu++gqqp2fezYsdi4cSPmz5/PYCEaIgwX8goiAofDgfr6ehQVFSEvLw/FxcXo7e3Vfo+iKEhPT8d//ud/YurUqQwWoiHEcCG35Nx/AkAXAs52VVXR09ODpqYmVFRU4Pjx4zh8+DCOHz+OGzdu6L4eAEaOHInVq1fjH/7hHxAVFcVgIRpiDBdyS9XV1fjggw8wbtw4xMbGwmw2o62tDQ0NDbh69Spqampw5coV1NfXo6WlBXa7/Z7fx9/fH7Nnz8ZvfvMbZGRkwM/Pj8FCNAwYLuSWLl68iP/6r/+Cw+EAcHv0cudo5EGCg4Mxe/Zs/PKXv8SCBQsQFBTEUCEaRgwXckuXLl3SggXAQ4PFZDIhLCwMEyZMQGZmJhYtWoQJEyYgICCAoULkAgwXckuNjY0ICQlBZ2enttJLURSYzWYEBgYiNDQUUVFRSEpKwqRJkzB9+nQ8+eSTiI2NZaAQuQGGC7mlv/mbv8Grr76K2tpabU4lODgYVqsVERERCA8PR3BwsC5IGChE7oPhQm5FVVUcO3ZMO1QSuP2eFT8/P/T19aGhoQENDQ0u7KF7s9vtaGpqcnU3iBgu5D5MJhMmTZqEvXv3Yu/eva7ujscKDQ3FqFGjXN0N8nGKDGYJDtEQEhHY7fZBrQqjezObzTCZeHQguQ7DhYiIDMdbGyIiMhzDhYiIDMdwISIiwzFciIjIcFyKTDRA9zupmYjuxpEL0QCdOnUKJpMJp06dcnVXiNwew4WIiAzHcCEiIsMxXIiIyHAMFyIiMhzDhYiIDMdwISIiwzFciIjIcAwXIiIyHMOFiIgMx3AhIiLDMVyIiMhwDBciIjIcw4WIiAzHcCEiIsMxXIgGQETQ3NwMAGhubta924WI7sZwIXqAlpYWrF+/HjabDfPnzwcAzJ8/HzabDevXr0dLS4trO0jkphThLRjRPRUVFSEnJwednZ0A7v0myqCgIGzfvh2ZmZku6SORu2K4EN1DUVERsrKyICJQVfW+v89kMkFRFBQWFjJgiPphuBDdoaWlBfHx8ejq6npgsDiZTCZYLBbU1tYiPDx86DtI5AE450J0h82bN6Ozs3NAwQIAqqqis7MTeXl5Q9wzIs/BkQtRPyICm82GysrKQa0IUxQFSUlJKC8v1+ZjiHwZw4Won8bGRkRGRj7W11utVgN7ROSZ+FiMqJ/W1tbH+vr29naDekLk2cyu7gCRK/X29qKqqgqlpaUoLS3FuXPnHuv7ffnll5g2bRpSU1NhtVr5iIx8FsOFfEpXVxfKy8tRVlaGsrIyVFVVwW63IygoCDabDStWrMChQ4dQU1Mz6O8dFRWF69ev46OPPgIAjBo1CqmpqUhLS4PNZkNMTAzDhnwG51zIq7W3t6OsrAylpaUoKytDTU0NRARhYWFIS0tDamoqUlNTER8fr/3Dv379erz77ruDntBft24d1q5di1u3buHSpUvan1ldXQ1VVREcHIzU1FTYbDakpaVh7NixMJn4ZJq8E8OFvMrNmze1UUlpaSnq6+sBAJGRkdooIjU1FVFRUfcdRRi9z6WnpwcVFRVavyorK9HX14cRI0YgJSVFC7jExESYzXyYQN6B4UIeS0Rw7do1bb6krKwMN2/eBADExsYiLS1NC5NRo0YN6nsPdof+nj17sGDBggF9b7vdjsuXL2thU15eju7ubpjNZiQlJWlhk5ycjBEjRgyq30TuguFCHkNVVVy5ckX7R7msrAxtbW1QFAUJCQm6kUlwcPBj/3kDPVtsx44dAw6We1FVFbW1tdpoq7y8HO3t7TCZTEhISIDNZtMepxnx9yIaDgwXclvOO3znqKS8vBxdXV3aHb4zSFJSUobsDr+lpQV5eXn493//d1RUVGjtycnJWLt2Ld544w2EhYUZ+meKCBoaGrS/c2lpKZqamgD8eUTmDJzBjsiIhgvDhdxG/7mJ0tJSVFRUoK+vD4GBgdo/pmlpaUhMTIS/v/+w9k1E0NTUhPb2doSEhCAiImJYV371n0sqKytDQ0MDgD/PJTl/RUZGckUauQWGC7nMrVu3tGXBpaWluHz5sm5VlTNMxo0bx1VVd2hra9Mtqb5y5Yq2Cq5/2MTFxTFsyCUYLjRsWlpadHfftbW1EBHdfpDU1FTExsbyH8RB6uzsxKVLl7TAqaqqgsPh0PbvOMMmISEBfn5+ru4u+QCGCw0JEUFjY6M2KiktLcX169cBAGPGjNHtMRk9ejTDxGD9Tx4oKytDRUUFent7ERAQgOTkZN0jxoCAAFd3l7wQw4UMISK4evWqbsNic3MzFEVBfHy8bqc633ky/BwOB6qrq7UFAuXl5ejs7ISfnx8SExO11Wg2mw0Wi8XV3SUvwHChR6KqKmpqarRRSXl5OTo6OmAymbR/rNLS0pCSkoKRI0e6urt0BxFBXV2d7jFla2srFEXB2LFjtVGlzWZDaGioq7tLHojhQgPS19eHyspK3ca/np4e+Pv763aZJycnIzAw0NXdpUESEdy4cUO31+bGjRsAgOjoaN0iAb5SgAaC4UL31N3drZ2PVVpaqh3waLFYtLOxUlNTMX78eB5Z4qWam5t1e22uXr0KAIiIiNDttYmOjuacGd2F4UIAgI6ODm2uxHnYooggNDRUtyw4Pj6ey4J9VEdHh7Yarby8XDuQMyQkRLvhsNlsPJCTADBcfFZTU5Nu8t15V2q1WnVnco0ZM4Z3pXRP3d3ddx3I6Rzd9n9UytGtb2K4+ADnAY/9TwtubGwEcPs4ET5PJyP09fWhqqpKN7pxzsvdeSAn5+W8H8PFC6mqirq6Om1UUlpaqh3wOG7cOG1UwpVANJScB432PxvOuaLQedCo8+eQKwq9D8PFC9x5hHtZWZl2wOOdy4K5h4FcRURQX1+vLRBw7oUCgPj4eN1JAtwL5fkYLh6ot7cXFRUVd+2+DgwM1J51c/c1uTsRuetAzmvXrgG4/cpo56iGB3J6JoaLhygpKdEtC1ZVFSNHjtTNl/DcKPJ0ra2tulMEnOfPhYeHaz/nU6ZM4asGPADDxUOcOnUKvb29CAsLQ1hYGEJDQxEUFMS7OfJqdrsdra2taGtrQ2trK9rb2zFx4kQuPPEADBcP0dfXB7PZzDAhn+ZwOKAoCvfReACGCxERGY7xT0REhuO22R+pqopLly6hq6vL1V3xeOPHjzf8vfI0PFRVRUVFBevAAAkJCT5dBwyXH9ntdvzhD39Ab28vRowY4erueKxr165hzZo1mDp1qqu7Qo/Abrfj008/RV9fH+vgEYkIrl+/jrfeegtTpkxxdXdchuHyIxGBoihYuXIlnnjiCVd3xyN0dXWhuroa48aNQ1BQEBwOBz744ANwGs9ziQhMJhNefvllpKWlubo7HqG7uxvV1dUYO3asVgfr1q3z+TpguNzBz88P/v7+ru6G2xMRlJWVYcOGDYiIiMDkyZORkZHB1WxegnUwMM46+PDDDzFq1Cg8/fTTmDNnDusAnNCnRyQiOHHiBPr6+nDt2jUcOnQIra2tru4W0bASEZw8eRJ9fX24fv06Dh8+zDr4EcOFHklTUxMuXryofY6Li8PYsWNd2COi4dfc3IzS0lLtM+vgzxguNGgigrNnz6KtrU1rmzp1Ks8xI58iIigpKdGNVKZMmcI6+BHDhQatr68P3333nTZhabFYMG3aND5nJp/S29uL4uJi7bPFYsHUqVNZBz9iuNCgiAguX76MyspKrc1msyE6OtqFvSIaXiKCqqoqXL58WWuz2WwYM2aM6zrlZhguNCiqquLIkSPo6+sDAJhMJsyePZunMZNPUVUVf/rTn1gHD8BwoQETEdTV1eHUqVNaW0xMDCZNmsRHAeQzRARXrlxBSUmJ1hYXF4eJEyeyDvphuNCAqaqKgwcPorOzEwCgKApeeOEFvqKWfIqqqvjqq690R+Q8//zzfMvrHRguNCDOuZbjx49rbZGRkXj22Wd5t0Y+Q0RQWVmJ77//XmuLjo7GM888wzq4A8OFBqS3txe7d+/W7tYURcHcuXN9+mA+8j29vb348ssv0d3dDeB2HWRkZCA0NNTFPXM/DBd6KBFBcXExfvjhB60tLi4Os2fP5t0a+QwRwXfffYcLFy5obWPHjsXMmTNZB/fAcKEHEhHU19ejoKAADocDAGA2m7Fo0SKEhIS4uHdEw8NZB19++SVUVQUA+Pv7Y+HChZxzvA+GCz1Qd3c3tm7dips3b2pt06dP52Yx8ildXV3Iz89HU1OT1jZ9+nQ89dRTrIP7YLjQfdntdhQWFuoeh0VGRmLJkiUwm3mgNvkGu92Offv24dy5c1pbZGQksrOzWQcPwHChe1JVFd9++y3++Mc/ase8BAQEYPny5RgzZgzv1sgnqKqK4uJi7N+/X6uDwMBA/PznP0dkZCTr4AEYLnQXEcH58+fx2WefaTuQFUXBggULeIYY+QwRwcWLF5Gfn6+rg3nz5mHKlCmsg4dguJCO88ykTZs2oaOjQ2ufOnUqfvazn8Fk4o8MeT8RQXV1NTZv3qyrgylTpuCll15iHQwA/wuRRkRQU1ODjRs36ibwExMT8eqrr2LEiBG8WyOvJyKora3Fhx9+qKuD8ePHY+XKlayDAeJsFAH4853ahg0b0NDQoLVHR0cjNzcXo0aNYkGR13MGy8aNG3V1EBUVhdWrV7MOBoHhQhARXLp0CR9++CGuX7+utVutVqxZswaxsbEsKPJ6zqNd/vd//xfXrl3T2iMiIvDWW28hLi6OdTAIDBcfp6oqSkpKkJeXh+bmZq3darXi7bffRlJSEguKvJ6qqjh37txddTBq1Cjk5uYiOTmZdTBIDBcfZrfbcfToUeTn52snHQPA6NGj8fbbbyMlJYUFRV7Pbrfj22+/xbZt23Dr1i2t3Wq1Ijc3FzabjXXwCBguPkhE0N3djd27d+PAgQPaMkvg9plha9asQUJCAguKvJqIoKenB3v27MH+/ft1deCcaxw/fjzr4BExXHyMiODmzZv4wx/+gNOnT2sbw4Dbr2l98803ER0dzYIiryYiaGpqQn5+Pk6dOqWdFwYASUlJWL16NWJiYlgHj4Hh4kNUVcWFCxewZcsW1NfXa+2KouCZZ57BqlWrEBYWxoIir6aqKsrKyvDpp5+irq5Oa1cUBVOnTsXKlSsRHh7OOnhMDBcf4Bz+HzhwAHv27NG9Qc/f3x8LFizAwoULERgYyIIiryUi6O3txf/93/+hsLBQN8/o7++PF198EQsXLuQ+FoMwXLyc86jw/Px8nD17Vjf8DwsLw4oVKzBjxgwewEdeTURw/fp1fP755zhz5oyuDkJDQ5GTk4Nnn32WdWAg/pf0UiICu92O7777Djt37tQdFa4oCpKSkvDaa69x4p68mojA4XDg5MmT2LFjh27HPXD79IlVq1Zx4n4IMFy8kIjgxo0b2LlzJ06cOAG73a5d8/f3x5w5c5CdnY3g4GAWFHkt56R9QUEBiouLdXVgNpvx3HPPITs7G6GhoayDIcBw8SLO0crx48exc+dONDY26q6PHj0ay5cvx7Rp0+Dn58eCIq/krINTp07hiy++0J06Adzecb906VKkp6ezDoYQw8VLiAgaGhqwc+dOnDp1SneX5ufnh2nTpiEnJwdRUVEsJvJazlF7QUEBTp48qasDk8mEp59+Gjk5OVxuPwwYLh7OuRLs6NGjKCws1B1dAdw+viI7OxuzZ8+Gv78/C4q8knMl2LfffovCwkLdHCNwe/FKVlYWnnvuOQQEBLAOhgHDxYOpqoqqqirs2LEDFy9e1K2A8fPzw+TJk5GTk8PNYOTVVFVFTU0NvvjiC5w/f15XByaTCU8++SRycnJ4AOswY7h4IBFBW1sb9u/fj0OHDunW6wO351ays7Px7LPPcrRCXktE0NHRgYMHD+LQoUO6c8GA23MrWVlZmDlzJkcrLsBw8SDOicozZ86goKAAdXV1uuNb/P39MWPGDGRnZ/P93uS1nMuLf/jhBxQUFODKlSu662azGdOnT0d2djbnGF2I4eIhnJshd+3adddEpaIoiIuLw7Jly/DUU09xBQx5LRHBtWvX8OWXX+LkyZO6wyYBIDY2FosXL8bTTz8Ns9nMOnAhhoubExF0dXXh8OHDKCoqQmtrq+56UFAQXnzxRfzkJz9BSEgIi4m8kvMk76NHj2Lfvn1oaWnRXbdYLPiLv/gLLFiwgPtW3ATDxU2JCFRVRWlpKXbu3ImKigrdIzCTyYSJEydi6dKlGD9+PEwmkwt7SzR0VFXFpUuX8MUXX+DSpUt3TdinpaVh8eLFSEpKYh24EYaLG3LuLN6zZw+OHj2Knp4e3fXRo0dj4cKFnKgkryYiaG1txb59+/D111+ju7tbdz0iIgIvvfQSZs2axUNX3RDDxY2ICPr6+nDixAkUFBTctbM4MDAQM2fOxMKFC2G1WllM5JX677DftWsXGhoadNedC1eysrK4cMWNMVzchKqquHr1Knbu3IkzZ87A4XBo15wHTS5duhRPPPEETCYTC4q8kvOkiYKCApw+fVq3cAUAEhISsHjxYkycOJELV9wcw8XFnBOVR44cwZ49e9DW1qa7HhoaiszMTGRkZMBisbCYyCs5T5r45ptvsGfPnrsm7IODgzFv3jzMnTsXI0eOZB14AIaLC6mqiurqanz++ef33GE/ZcoULFmyBLGxsZyoJK+lqipqa2uxY8eO++6wX7JkCeLj41kHHoTh4gLO0cpXX32Fffv2oaOjQ3c9KioKixcvRnp6Otfqk9dyjlaOHDmCvXv3or29XXfdarVi0aJFSE9P58IVD8RwGWaqquLKlSvIz8/HhQsX7tphP3v2bCxatAgREREsJvJaIoKrV69i27Ztd41W/P39kZ6ejkWLFmH06NGsAw/FcBkmzpVgX3/9NQoKCu6aW4mLi8PPf/5zPPnkk5yoJK/lXAl27Ngx7Nq16665lejoaCxduhSTJ09mHXg4hsswEBHcvHkT27Ztw4kTJ3R3aQEBAXjhhRewcOFChIWFsZjIa4kImpubsXPnThQXF+tWRPr7+2PWrFlYuHAhRo0axTrwAgyXIaaqKi5cuIAtW7agvr5edy0mJgbLly/neWDk9VRVRVlZGT799FPU1dXprkVFRSEnJ4ejFS/DcBkizsdgBw8exO7du9HV1aVdM5vNmDFjBpYtW8a5FfJqzsdghw8fxq5du3Svh/Dz88P06dOxdOlSzq14IYbLEBARtLe347PPPsO3336rewwWFhaGZcuWYdasWVwJRl7N+b6V7du349ixY7rHYCEhIcjOzsbzzz/POvBSDBeDOXcYb9q0CeXl5dpqMEVRkJycjFdffRUJCQksJvJqIoLr168jLy8PpaWlumuJiYlYuXIlD1z1cgwXA4kIysvL8fHHH+vOQ/Lz88Pzzz+PZcuW8Vh88noigsrKSmzatEk3z+jn54eZM2di2bJlPBbfBzBcDKKqKkpKSrBp0ybdO1csFguWLl2KjIwMDv/J66mqih9++AF5eXm6ZcYjRozAokWL8OKLL7IOfATDxQCqqqK4uBiffPKJ7j3eVqsVf/mXf4mnnnqKw3/yeqqq4uTJk3fVQXh4OF599VVMnjyZdeBDGC6PSVVVHDt2DFu2bNGtCIuPj0dubi7nV8gnqKqK48ePY8uWLboVYbGxsVi9ejUSExNZBz6G4fIY+hdU/2BJSUlBbm4uxowZw4Iir6eqKr7//vu7giUxMZF14MMYLo9IRHD27Fn8/ve/1wXLhAkTkJuby/0r5BNEBOfPn8cnn3yiC5bU1FTWgY9juDwCEUFVVRU2b96se7b8xBNP4Be/+AXCw8NZUOT1RATV1dXIy8vTneydmpqKNWvW8BgXH8fZtUFynhP28ccfo7m5WWtPTk7GmjVrGCzkE0QETU1NyMvLQ1NTk9aemJiIt956i8FCDJfB6u7uxpYtW1BbW6u1xcbG8hEA+ZSenh5s3boVNTU1Wlt0dDTefPNNWK1W1gExXAZDVVXs27cPZ86c0drCwsKwevVqREdHs6DIJ6iqiv379+vqICQkBK+//jpiYmJYBwSA4TJgIoJz587hj3/8o3akS0BAAF555RWkpKSwoMgniAguXLiAoqIi7cw8f39/LF++nHVAOgyXARARtLW1IT8/H93d3QBunxU2b948PPPMMywo8gnOA1m3b9+u1QEAzJ07FzNmzOAGSdLhT8MAiAiKiop076FITU1FVlYWzGYuuCPfICI4cOCAbp4lJSUFWVlZ8PPzc2HPyB0xXB7Cudzy8OHD2uOw4OBgrFixAkFBQS7uHdHwEBFcuXIFhw8f1tpGjhyJ5cuXIygoiKN3ugvD5SEcDgf27t2rbRBzPg4bP348C4p8hsPhQFFRkW5fV0ZGBo91oftiuDyA8wj9/qtiYmNj8eKLL/L5MvkMEUFFRQXrgAaFPxkP4HA4cODAAfT29gIATCYTMjMzERIS4uKeEQ0fh8OBr776Cj09PQBu18H8+fMRGhrq4p6RO2O43IeIoKamBufOndPaxo0bh+nTp/MxAPkM51xL/zqIj4/nKkl6KIbLfYgIvv76a+1uTVEUZGRkwGKxuLhnRMNHRPDNN9/oluDPmTOHdUAPxXC5j+bmZpw+fVr7PGbMGEydOpV3a+RTWlpadHUQGRnJOqABYbjcg4igpKRE95rWGTNmcK6FfIqzDvof0Jqens46oAFhuNyD3W5HcXGxtq/FYrEgPT2dd2vkU+x2O06cOKF9Zh3QYDBc7qG+vh6XL1/WPqekpCA6Otp1HSJygTvrIDk5mXVAA8ZwuYfTp0/rJjDT09N5vAX5nDNnzujq4JlnnmEd0IAxXO6gqipaW1u11TChoaGYOHEiHwWQT1FVFW1tbRgxYgSA23UwYcIE1gENGE9dvIPJZMLKlSsxd+5cnD59Gr29vRg1apSru0U0rEwmE15++WVkZGTgzJkzrAMaNIZLPyKCS5cuaXtb4uLiICK6Yy/o/hwOh+7sKfJMqqqioqJCq4PY2Fht5Rg9HOvgNobLj0wmE+Li4lBSUsIiegwjRozAyJEjXd0NekTOOjh79izOnj3r6u54LIvF4vN1oIhzva2PExE4HA5Xd8MrmEwmHmjooVgHxvH1OmC4EBGR4Xw3VomIaMgwXIiIyHAMFyIiMhzDhYiIDMdw8RAOhwMdHR1cyUM+zeFwoL29nXXgARguHqK2thbvvPMOamtrXd0VIpepra3FX//1X7MOPADDhYiIDMdwISIiwzFciIjIcAwXIiIyHMOFiIgMx3AhIiLDMVyIiMhwDBciIjIcw4WIiAzHcCEiIsMxXIiIyHAMFyIiMhzDhYiIDMdwISIiwzFcPICIoKmpCR0dHWhqaoKIuLpLRMOOdeBZFOH/IbfV0tKCzZs34z/+4z9QUVGhtScnJ+Odd97BG2+8gfDwcNd1kGgYsA48E8PFTRUVFSEnJwednZ0AoLtLUxQFABAUFITt27cjMzPTJX0kGmqsA8/FcHFDRUVFyMrKgohAVdX7/j6TyQRFUVBYWMjCIq/DOvBsDBc309LSgvj4eHR1dT2woJxMJhMsFgtqa2v5aIC8BuvA83FC381s3rwZnZ2dAyooAFBVFZ2dncjLyxvinhENH9aB5+PIxY2ICGw2GyorKwe1EkZRFCQlJaG8vFx7Dk3kqVgH3oHh4kYaGxsRGRn5WF9vtVoN7BHR8GMdeAc+FnMjHR0dj/X17e3tBvWEyHVYB96B4eJGgoODH+vrQ0JCDOoJkeuwDrwDw8WNWK1WJCcnD/p5saIoSE5ORkRExBD1jGj4sA68A8PFjSiKgnfeeeeRvnbt2rWcxCSvwDrwDpzQdzNc30/EOvAGHLm4mfDwcGzfvh2KosBkevD/HufO5B07drCgyKuwDjwfw8UNZWZmorCwEBaLBYqi3DXMd7ZZLBbs2bMHCxYscFFPiYYO68CzMVzcVGZmJmpra7Fu3TokJSXpriUlJWHdunWoq6tjQZFXYx14Ls65eADneyza29sREhKCiIgITlqSz2EdeBaGCxERGY6PxYiIyHAMFyIiMhzDhYiIDMdwISIiwzFciIjIcAwXIiIyHMOFiIgMx3AhIiLDMVyIiMhwDBciIjIcw4WIiAzHcCEiIsMxXIiIyHAMFyIiMtz/A/OS5qZqBe6FAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "2795dfc8",
"metadata": {},
"source": [
"We notice that the two functions in the first layer look the same. Let's try to lock them!"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "17b6b983",
"metadata": {},
"outputs": [],
"source": [
"model.lock(0,[[0,0],[1,0]])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "eb976f5a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtrklEQVR4nO3daVRUZ54/8O8tqoACCpASZFNkKYiaxBW3JK1GI51GcWtNNJnEBDuZ6XNiJv95Mz3dOWfOSb+YmTOTozN9Zno0mSiJHYP7LlFjazQmLlExLoCAIAgoS7HIUst9/i+wbriuoBdq+37O8UU9Remj8uN7n/ssVxJCCBAREWlI5+4OEBGR72G4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOb27O0DkDYQQaGhoQFtbG8LCwmA2myFJkru7ReSxOHIhegir1YrVq1fDYrEgOjoaycnJiI6OhsViwerVq2G1Wt3dRSKPJPFJlET3V1BQgEWLFqG9vR1A9+jFxTVqCQkJwZYtW5CVleWWPhJ5KoYL0X0UFBQgOzsbQgjIsvzAr9PpdJAkCXv27GHAEPXAcCG6i9VqRWJiIjo6Oh4aLC46nQ5GoxFVVVWIjIzs/w4SeQHOuRDdZf369Whvb+9VsACALMtob29HXl5eP/eMyHtw5ELUgxACFosFZWVl6EtpSJKElJQUlJSUcBUZERguRCr19fWIjo5+os+bzWYNe0TknXhbjKiHtra2J/p8a2urRj0h8m4MF6IewsLCnujzJpNJo54QeTeGC1EPZrMZqampfZ43kSQJqampiIqK6qeeEXkXhgtRD5Ik4b333nusz65cuZKT+UR3cEKf6C7c50L05DhyIbpLZGQktmzZAkmSoNM9vERcO/S3bt3KYCHqgeFCdB9ZWVnYs2cPjEYjJEm653aXq81oNGLv3r2YPXu2m3pK5JkYLkQPkJWVhaqqKqxatQopKSmq91JSUrBq1SpUV1czWIjug3MuRL0ghMDhw4cxc+ZMHDp0CDNmzODkPdFDcORC1AuSJClzKpGRkQwWokdguBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeZ45D5RLwkhIIS478PDiEiNIxeiPmCoEPWO3t0dIPIWDBai3uPIhYiINMeRC3kMp9OJEydOoKWlxd1d8Xrjx4/HkCFD3N0N8mOc0CeP0dHRgRdeeAEdHR0wmUzu7o5XEkKgpKQE69atQ05Ojru7Q36MIxfyGEII6HQ6fPzxx5g+fbq7u+PxhBBwOp2w2Wy4ffs2HA4HzGYzsrOzwWtGcjeGC3kcvV6PoKAgd3fDrYQQkGUZt2/fRkNDA2pra3Hjxg3U1dXh5s2byq/6+nrcunULTU1NGDVqFLZu3cqFB+QRGC5EHsC1h6a+vh6FhYU4duwYTp8+jZKSEty6dUsZmTxsRGIymdDZ2TmAvSZ6MIYLkRsJIdDS0oKjR48iPz8fR48eRU1NDRwOR59/r+bmZrS2tvZDL4n6juFC5AZCCDQ2NmLjxo1Yu3YtLl261KtACQgIUG4bhoSEwGQyYfDgwYiLi0NKSorf304kz8FwIRpAQgjY7Xbs2rULf/zjH3HhwgXIsnzP1wUEBCAqKgopKSnIyMhAWloakpKSEBcXh6ioKERGRiIsLAxGoxFBQUHQ6/WQJAl2u90NfyuiezFciAaIEAK1tbX48MMPsWHDBnR1daneDwwMREZGBrKysjBz5kw8/fTTGDx4MAIDAwHwhADyLgwXogEghMC5c+fw7rvv4vTp06r3wsLC8PLLL2PFihWYPHkywsLCADBMyLsxXIj6mRACR48exdtvv43y8nKlXa/XY9asWfj973+PiRMnKre2iHwBw4WoH7mC5fXXX0d1dbXSbjab8eGHHyI3NxchISEMFfI5DBeifiKEQGFhIXJzc1XBkpqaijVr1mDatGnQ6Xh2LPkmhgtRPxBCoKamBu+88w7KysqU9lGjRuHzzz/H6NGjOVohn8bLJqJ+0NnZid/97nc4deqU0paamspgIb/BcCHSmBACeXl52Lhxo9JmNpuxZs0aBgv5DYYLkYaEELh8+TL++Mc/KhsaAwMD8dFHH2HatGkMFvIbDBciDdlsNnz00UeqCfxXXnkFy5cv5+Q9+RV+txNpRAiBffv2Yfv27Upbamoq/vmf/5lnfpHfYbgQaaS5uRn/8i//ohzrYjAY8E//9E8YPnw4b4eR32G4EGlACIH8/HycOXNGaZsxYwaWLFnCYCG/xHAh0kBDQwP+9Kc/wel0AgBCQ0Pxj//4jwgJCXFzz4jcg+FC9ISEENi0aRMuXbqktM2dOxdTp07lqIX8FsOF6AlZrVb87//+r/JcFpPJhPfffx8Gg8HNPSNyH4YL0RMQQmDXrl24ePGi0padnY1x48Zx1EJ+jeFC9ATa29vxySefKHMtISEh+Lu/+zvo9Ty2j/wbw4XoMQkh8O2336rOD5sxYwYmTpzIUQv5PYYL0WNyOBz47LPPVPtafvOb3yiPJSbyZwwXosd05coVHDhwQHk9evRozJgxg6MWIjBciB6LEAJffvklrFYrgO7n3b/xxhsICwtzb8eIPATDhegx1NfXY/PmzcrrhIQEzJs3j6MWojsYLkR9JITAgQMHUF5errTNnz8fCQkJbuwVkWdhuBD1kc1mwxdffKFafrxs2TKOWoh6YLgQ9dHly5fx3XffKa8nTpyIMWPGMFyIemC4EPWBEAKbN29GS0sLAECn02HZsmV8XgvRXRguRH1gtVpVDwOLj4/HL3/5S45aiO7CcCHqJSEEjh8/jpKSEqXtV7/6FeLj493YKyLPxHAh6iVZlpGfnw+73Q4ACAoKwuLFizlqIboPhgtRL1VVVeHQoUPK65EjRyIzM5PhQnQfDBeiXhBCYP/+/aitrVXaFixYAJPJ5MZeEXkuhgtRL3R1dWHTpk0QQgAAIiIiMH/+fI5aiB6A4ULUC5cuXVIdrT9lyhSkp6e7sUdEno3hQvQIQghs27YNra2tALr3tixZsoSPMSZ6CIYL0SM0Nzdjx44dyuv4+Hi89NJLvCVG9BAMF6KHEELgxIkTKCoqUtqysrIQFxfnxl4ReT6GC9FDyLKMr776StnbEhgYiCVLlnDUQvQIDBeih6iursbBgweV1yNGjMCkSZMYLkSPwHAhegDX3paamhqlbf78+dzbQtQLDBeiB+jq6sJXX32l7G0JDw/HggULOGoh6gWGC9EDXLhwASdPnlReP/fcc3jqqafc2CMi78FwIboP1yGVt2/fBtC9t+XVV1/l3haiXmK4EN1HfX296rktSUlJ3NtC1AcMF6K7CCFQUFCA8vJypS0nJwcxMTFu7BWRd2G4EN2lq6sLeXl5kGUZABAaGopXX32VoxaiPmC4EPUghMCZM2dw4sQJpW3q1KkYPXo0w4WoDxguRD3Isox169ahvb0dABAQEIA33ngDQUFBbu4ZkXdhuBD1UFpaip07dyqv09PT8ctf/pKjFqI+YrgQ3SHLMvLy8nDr1i0AgCRJeP311xEVFeXmnhF5H4YL0R3V1dX44osvlNdxcXFYunQpRy1Ej4HhQoTuify8vDxcv35daXvllVcwbNgwN/aKyHsxXIgAVFVV4dNPP1XOERs8eDByc3M5aiF6TAwX8nuyLOPTTz9FRUWF0rZkyRJkZGQwXIgeE8OF/JoQAiUlJVi7dq0yaomOjsZvf/tb6HQsD6LHxeohv+ZwOPAf//Efqme2vPnmmxy1ED0hhgv5LSEEjhw5go0bNyptSUlJ+O1vf4uAgAA39ozI+zFcyC8JIdDY2IgPP/wQbW1tALp347///vtISkpyc++IvB/DhfyS0+nEv//7v6seBjZlyhS89dZbvB1GpAGGC/kdIQR27dqFP/3pT8okfkREBD766COEh4e7uXdEvoHhQn5FCIGzZ8/igw8+UD1lcuXKlXj++ec5aiHSCMOF/IZr2XFubi4qKyuV9lmzZuGDDz7gJD6Rhhgu5BeEELh8+TKWLVuG8+fPK+3p6elYvXo1IiIi3Ng7It/DcCGfJ8syDh8+jIULF+LHH39U2mNjY7FmzRqkp6fzdhiRxvTu7gBRfxFCoKWlBX/+85/xb//2b2hqalLei4mJwdq1a/HCCy8wWIj6AcOFfI4QAh0dHTh48CD+9V//FT/88ANkWVbeHzp0KNauXYtZs2YxWIj6CcOFfIIQAk6nEzU1NSgoKEBeXh5OnjwJm82mfI0kScjMzMR///d/Y+zYsQwWon7EcCGP5Np/AkAVAq52WZbR1dWFxsZGlJaW4tSpUzhy5AhOnTqFW7duqT4PAKGhoVi+fDn+8Ic/ICYmhsFC1M8YLuSRKioq8PHHH2PYsGGIj4+HXq9HS0sLamtrcePGDVRWVuL69euoqamB1WqFw+G47+9jMBgwdepU/P73v8f06dMREBDAYCEaAAwX8khXrlzB//zP/8DpdALoHr3cPRp5mLCwMEydOhXvvvsuZs+ejZCQEIYK0QBiuJBHunr1qhIsAB4ZLDqdDhERERgxYgSysrIwd+5cjBgxAoGBgQwVIjdguJBHqq+vh8lkQnt7u7LSS5Ik6PV6BAUFITw8HDExMUhJScGoUaMwfvx4PP3004iPj2egEHkAhgt5pH/4h3/Aa6+9hqqqKmVOJSwsDGazGVFRUYiMjERYWJgqSBgoRJ6D4UIeRZZlnDhxQjlUEuh+zkpAQADsdjtqa2tRW1vrxh56NofDgcbGRnd3g4jhQp5Dp9Nh1KhR2LdvH/bt2+fu7nit8PBwDBo0yN3dID8nib4swSHqR0IIOByOPq0Ko/vT6/XQ6Xh0ILkPw4WIiDTHSxsiItIcw4WIiDTHcCEiIs0xXIiISHNcikzUSw86qZmI7sWRC1EvnT17FjqdDmfPnnV3V4g8HsOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIekEIgaamJgBAU1OT6tkuRHQvhgvRQ1itVqxevRoWiwWzZs0CAMyaNQsWiwWrV6+G1Wp1bweJPJQkeAlGdF8FBQVYtGgR2tvbAdz/SZQhISHYsmULsrKy3NJHIk/FcCG6j4KCAmRnZ0MIAVmWH/h1Op0OkiRhz549DBiiHhguRHexWq1ITExER0fHQ4PFRafTwWg0oqqqCpGRkf3fQSIvwDkXorusX78e7e3tvQoWAJBlGe3t7cjLy+vnnhF5D45ciHoQQsBisaCsrKxPK8IkSUJKSgpKSkqU+Rgif8ZwIeqhvr4e0dHRT/R5s9msYY+IvBNvixH10Nzc/ESfb21t1agnRN5N7+4OELmTzWZDeXk5ioqKUFRUhIsXLz7R77d7926MGzcO6enpMJvNvEVGfovhQn6lo6MDJSUlKC4uRnFxMcrLy+FwOBASEgKLxYIlS5bg8OHDqKys7PPvHRMTg5s3b+LTTz8FAAwaNAjp6enIyMiAxWJBXFwcw4b8BudcyKe1traiuLgYRUVFKC4uRmVlJYQQiIiIQEZGBtLT05Geno7ExETlB//q1avxwQcf9HlCf9WqVVi5ciVu376Nq1evKn9mRUUFZFlGWFgY0tPTYbFYkJGRgaFDh0Kn451p8k0MF/IpDQ0NyqikqKgINTU1AIDo6GhlFJGeno6YmJgHjiK03ufS1dWF0tJSpV9lZWWw2+0IDg5GWlqaEnDJycnQ63kzgXwDw4W8lhACdXV1ynxJcXExGhoaAADx8fHIyMhQwmTQoEF9+r37ukN/7969mD17dq9+b4fDgWvXrilhU1JSgs7OTuj1eqSkpChhk5qaiuDg4D71m8hTMFzIa8iyjOvXrys/lIuLi9HS0gJJkpCUlKQamYSFhT3xn9fbs8W2bt3a62C5H1mWUVVVpYy2SkpK0NraCp1Oh6SkJFgsFuV2mhZ/L6KBwHAhj+W6wneNSkpKStDR0aFc4buCJC0trd+u8K1WK/Ly8vCf//mfKC0tVdpTU1OxcuVKvPnmm4iIiND0zxRCoLa2Vvk7FxUVobGxEcDPIzJX4PR1REY0UBgu5DF6zk0UFRWhtLQUdrsdQUFByg/TjIwMJCcnw2AwDGjfhBBobGxEa2srTCYToqKiBnTlV8+5pOLiYtTW1gL4eS7J9Ss6Opor0sgjMFzIbW7fvq0sCy4qKsK1a9dUq6pcYTJs2DCuqrpLS0uLakn19evXlVVwPcMmISGBYUNuwXChAWO1WlVX31VVVRBCqPaDpKenIz4+nj8Q+6i9vR1Xr15VAqe8vBxOp1PZv+MKm6SkJAQEBLi7u+QHGC7UL4QQqK+vV0YlRUVFuHnzJgBgyJAhqj0mgwcPZphorOfJA8XFxSgtLYXNZkNgYCBSU1NVtxgDAwPd3V3yQQwX0oQQAjdu3FBtWGxqaoIkSUhMTFTtVOczTwae0+lERUWFskCgpKQE7e3tCAgIQHJysrIazWKxwGg0uru75AMYLvRYZFlGZWWlMiopKSlBW1sbdDqd8sMqIyMDaWlpCA0NdXd36S5CCFRXV6tuUzY3N0OSJAwdOlQZVVosFoSHh7u7u+SFGC7UK3a7HWVlZaqNf11dXTAYDKpd5qmpqQgKCnJ3d6mPhBC4deuWaq/NrVu3AACxsbGqRQJ8pAD1BsOF7quzs1M5H6uoqEg54NFoNCpnY6Wnp2P48OE8ssRHNTU1qfba3LhxAwAQFRWl2msTGxvLOTO6B8OFAABtbW3KXInrsEUhBMLDw1XLghMTE7ks2E+1tbUpq9FKSkqUAzlNJpNywWGxWHggJwFguPitxsZG1eS766rUbDarzuQaMmQIr0rpvjo7O+85kNM1uu15q5SjW//EcPEDrgMee54WXF9fD6D7OBHeTyct2O12lJeXq0Y3rnm5uw/k5Lyc72O4+CBZllFdXa2MSoqKipQDHocNG6aMSrgSiPqT66DRnmfDuVYUug4adX0fckWh72G4+IC7j3AvLi5WDni8e1kw9zCQuwghUFNToywQcO2FAoDExETVSQLcC+X9GC5eyGazobS09J7d10FBQcq9bu6+Jk8nhLjnQM66ujoA3Y+Mdo1qeCCnd2K4eInCwkLVsmBZlhEaGqqaL+G5UeTtmpubVacIuM6fi4yMVL7Px4wZw0cNeAGGi5c4e/YsbDYbIiIiEBERgfDwcISEhPBqjnyaw+FAc3MzWlpa0NzcjNbWVowcOZILT7wAw8VL2O126PV6hgn5NafTCUmSuI/GCzBciIhIc4x/IiLSHLfN3iHLMq5evYqOjg53d8XrDR8+XPPnytPAkGUZpaWlrAMNJCUl+XUdMFzucDgc+Mtf/gKbzYbg4GB3d8dr1dXVYcWKFRg7dqy7u0KPweFw4Msvv4TdbmcdPCYhBG7evIm3334bY8aMcXd33IbhcocQApIkYenSpXjqqafc3R2v0NHRgYqKCgwbNgwhISFwOp34+OOPwWk87yWEgE6nwyuvvIKMjAx3d8crdHZ2oqKiAkOHDlXqYNWqVX5fBwyXuwQEBMBgMDz0a3p+0/jr6i0hBIqLi7FmzRpERUVh9OjRmD59ut/+e/ga1kHvuOrgk08+waBBg/Dss89i2rRpfvvv0RPDpZeEEJCFQN2telRW16DLZkOMOQrDExMQHBzkd99MQgicPn0adrsddXV1OHz4MG+F+QGlDuobcL1HHSQlxiM4yD/r4MyZM7Db7bh58yaOHDni17fCemK49JLD6cR3p8/i3MUrsDscd1olDBkchaxpzyHaHOVXhdXY2IgrV64orxMSEjB06FA39ogGgsPpxIkz53D+UpFSBxIkxAyOwku/mIroqEF+VQdNTU0oKipSXrMOfsalyL0ghMCFK8U4U3gRBoMBk8Y+ixefm4Sh8UNQV9+IA9+eQKfN5u5uDhghBC5cuICWlhalbezYsTzHzMcJIfBTUQl+vHAJBoMeE8c8gxlTJyIxbghu1jfi4LET6PKzOigsLERzc7PSNmbMGNbBHRy59ILd4cBPV65Cp9Nh9i+mIjWp+8pkRFoqthccxI26W6i6UYu04cP84qrNbrfjhx9+UO65G41GjBs3zi/+7v7M7nDgYlF3Hbz0/BSkuOogNQXbv/4GNTdvoaqmDqlJQ/3ie8Fms+HkyZPKa6PRiLFjx/rF3703OHLpha4uG1pv30ZoaAgS4rqfzChJEoKDAjE8MQFCCDRam1WfkWUZra2tuHbtGkpLS+F0Ot3Ue20JIXDt2jWUlZUpbRaLBbGxsW7sFQ2ELtudOggxquogKCgQwxPjH1oHFRUVKCsr86k6KC8vx7Vr15Q2i8WCIUOGuK9THoYjl14QEIAQ0EkSdHddlbhOIb572WFzczP+/Oc/o729HbIs4w9/+INPPEtFlmUcPXoUdrsdAKDT6TB16lQEBAT4zA8Ouj8hAIju//O7r84fVgdr1qxBe3s7hBD43e9+5zN18O2337IOHoIjl0cRApJwwmiwwWiwAc5WwNHS/cvZCoPUjhCDDXqd+l6zyWTC3/7t3+LNN9/0mUP2hBCorq7G2bNnlba4uDiMGjWKtwJ8nRCQhKO7DvRdgLMNcLR2/3K2PbQO3n33Xbzxxhs+VQfXr19HYWGh0paQkICRI0eyDnrgyOVRnK0Ita7Hq6OLIElAYNVpoMc30NNBdqRn2mEIKgNaJcA0HpAk6PV6DBo0SDXp7e1kWcahQ4fQ3t4OoHtvwwsvvMBH1PoDZytCrZ9jybPFkAAEVv8I9Pg5OirQAcsEGwyBZUCbBISNU+ogMjLS5+rgm2++UR2R8/zzz/vEiExLDJdHaS+Grv0nhBi6h/vCaYNNDoDDqUOwwQ6DBBgCAYhGwPpXIGwMIPneP6trruXUqVNKW3R0NCZNmsSrNX/QUQJdx8X71IGEYIMDBt2dOkATYD0ChI722TooKyvDjz/+qLTFxsZiwoQJrIO7+N7/vtaEA8DP95FbOoOw97IFzV1BmJxUhdFxdT8PZIRd9bW+xGazYdeuXcrVmiRJmDFjhl8fzOdX5HvrYN+VNLR0BWHSsGo8q6oD9df6EpvNht27d6OzsxNAdx1Mnz4d4eHhbu6Z5/GNm6ADqMIageoWE9q6AvFTzRA4Zd+/WhFC4OTJk/jpp5+UtoSEBEydOpVXa36q0hqBG3fq4GJtjN/UwQ8//IDLly8rbUOHDsXkyZNZB/fBcOmjIWG3ERpohwRgaGQzdLr7X6HJsoympiZYrVY4nU40NDSgpaXF6w6zE0KgpqYGO3bsUFbB6PV6zJ07FyaTyc29I3fpWQeJEb2vg8bGRrS2tnptHezevRuyLAMADAYD5syZwznHB2C49FFM2G2kmhsRrHdgTHwtHnS9YrfbsXPnThw6dAgmkwmbN2/Gt99+63VF1dnZiY0bN6KhoUFpGz9+PDeL+bnosNtIMTchWO/A6Pi6h9bB7t278c0338BkMmHLli04duyY19VBR0cH8vPz0djYqLSNHz8ezzzzDOvgATjn0kcCQIfdAKeQ0OXQQ5K67vt1gYGBeP3111Vtrk1n3sLhcGDPnj2q22HR0dGYP38+9Hp+6/gzIYAOW+/qYNmyZao2b6yD/fv34+LFi0pbdHQ0cnJyWAcPwX+ZR5EkdK+57L7SkgA8l3wdE4beQGRI511fq4NrfaYkScrGMm8kyzK+//57fP3118pVZmBgIBYvXowhQ4Z41Q8H0sDddSABzyVXYrxDf28dwLfq4OTJkzhw4IBSB0FBQfj1r3+N6Oho1sFDMFweJTgZCIwD7HUAuotqcFjPq7Q7hSMZgNCxgOS9heQihMClS5fw1VdfKTuQJUnC7NmzeYaYvwoe3l0Htp/rwBxqA+DaNNmjDsJG+0wdXLlyBfn5+ao6mDlzJsaMGcM6eASGy6MYooGh/w+wN+Khyyt1wYAhZsC61V9cZyatW7cObW1tSvvYsWPxq1/9ymd2WVMfGaKBxPfv1MFD6IJ8pg4qKiqwfv16VR2MGTMGL7/8MuugFxgujyJJgD6i+5ePE0KgsrISa9euVU3gJycn47XXXkNwcDCv1vyVn9VBVVUVPvnkE1UdDB8+HEuXLmUd9BLDhQD8fKW2Zs0a1NbWKu2xsbHIzc3FoEH+9RAo8k+uYFm7dq2qDmJiYrB8+XLWQR8wXAhCCFy9ehWffPIJbt68qbSbzWasWLEC8fHxLCjyea6jXf7v//4PdXV1SntUVBTefvttJCQksA76gOHi52RZRmFhIfLy8tDU1KS0m81mvPPOO0hJSWFBkc+TZRkXL168pw4GDRqE3NxcpKamsg76iOHixxwOB44fP478/HzlpGMAGDx4MN555x2kpaWxoMjnORwOfP/999i0aRNu376ttJvNZuTm5sJisbAOHgPDxQ8JIdDZ2Yldu3bh4MGDyjJLoPvMsBUrViApKYkFRT5NCIGuri7s3bsXBw4cUNWBa65x+PDhrIPHxHDxM0IINDQ04C9/+QvOnTunOobDYrHgrbfeQmxsLAuKfJoQAo2NjcjPz8fZs2eV88IAICUlBcuXL0dcXBzr4AkwXPyILMu4fPkyNmzYgJqaGqVdkiRMmDABy5YtQ0REBAuKfJosyyguLsaXX36J6upqpV2SJIwdOxZLly5FZGQk6+AJMVz8gGv4f/DgQezdu1f1BD2DwYDZs2djzpw5CAoKYkGRzxJCwGaz4a9//Sv27Nmjmmc0GAx48cUXMWfOHO5j0QjDxce5jgrPz8/HhQsXVMP/iIgILFmyBBMnTuQBfOTThBC4efMmNm/ejPPnz6vqIDw8HIsWLcKkSZNYBxriv6SPEkLA4XDghx9+wLZt21RHhUuShJSUFLz++uucuCefJoSA0+nEmTNnsHXrVtWOe6D79Illy5Zx4r4fMFx8kBACt27dwrZt23D69Gk4HA7lPYPBgGnTpiEnJwdhYWEsKPJZrkn7HTt24OTJk6o60Ov1eO6555CTk4Pw8HDWQT9guPgQ12jl1KlT2LZtG+rr61XvDx48GIsXL8a4ceMQEBDAgiKf5KqDs2fPYvv27apTJ4DuHfcLFixAZmYm66AfMVx8hBACtbW12LZtG86ePau6SgsICMC4ceOwaNEixMTEsJjIZ7lG7Tt27MCZM2dUdaDT6fDss89i0aJFXG4/ABguXs61Euz48ePYs2eP6ugKoPv4ipycHEydOhUGg4EFRT7JtRLs+++/x549e1RzjED34pXs7Gw899xzCAwMZB0MAIaLF5NlGeXl5di6dSuuXLmiWgETEBCA0aNHY9GiRdwMRj5NlmVUVlZi+/btuHTpkqoOdDodnn76aSxatIgHsA4whosXEkKgpaUFBw4cwOHDh1Xr9YHuuZWcnBxMmjSJoxXyWUIItLW14dChQzh8+LDqXDCge24lOzsbkydP5mjFDRguXsQ1UXn+/Hns2LED1dXVquNbDAYDJk6ciJycHD7fm3yWa3nxTz/9hB07duD69euq9/V6PcaPH4+cnBzOMboRw8VLuDZD7ty5856JSkmSkJCQgIULF+KZZ57hChjyWUII1NXVYffu3Thz5ozqsEkAiI+Px7x58/Dss89Cr9ezDtyI4eLhhBDo6OjAkSNHUFBQgObmZtX7ISEhePHFF/HSSy/BZDKxmMgnuU7yPn78OPbv3w+r1ap632g04he/+AVmz57NfSseguHioYQQkGUZRUVF2LZtG0pLS1W3wHQ6HUaOHIkFCxZg+PDh0Ol0buwtUf+RZRlXr17F9u3bcfXq1Xsm7DMyMjBv3jykpKSwDjwIw8UDuXYW7927F8ePH0dXV5fq/cGDB2POnDmcqCSfJoRAc3Mz9u/fj2PHjqGzs1P1flRUFF5++WVMmTKFh656IIaLBxFCwG634/Tp09ixY8c9O4uDgoIwefJkzJkzB2azmcVEPqnnDvudO3eitrZW9b5r4Up2djYXrngwhouHkGUZN27cwLZt23D+/Hk4nU7lPddBkwsWLMBTTz0FnU7HgiKf5DppYseOHTh37pxq4QoAJCUlYd68eRg5ciQXrng4houbuSYqjx49ir1796KlpUX1fnh4OLKysjB9+nQYjUYWE/kk10kT3333Hfbu3XvPhH1YWBhmzpyJGTNmIDQ0lHXgBRgubiTLMioqKrB58+b77rAfM2YM5s+fj/j4eE5Uks+SZRlVVVXYunXrA3fYz58/H4mJiawDL8JwcQPXaOWbb77B/v370dbWpno/JiYG8+bNQ2ZmJtfqk89yjVaOHj2Kffv2obW1VfW+2WzG3LlzkZmZyYUrXojhMsBkWcb169eRn5+Py5cv37PDfurUqZg7dy6ioqJYTOSzhBC4ceMGNm3adM9oxWAwIDMzE3PnzsXgwYNZB16K4TJAXCvBjh07hh07dtwzt5KQkIBf//rXePrppzlRST7LtRLsxIkT2Llz5z1zK7GxsViwYAFGjx7NOvByDJcBIIRAQ0MDNm3ahNOnT6uu0gIDA/HCCy9gzpw5iIiIYDGRzxJCoKmpCdu2bcPJkydVKyINBgOmTJmCOXPmYNCgQawDH8Bw6WeyLOPy5cvYsGEDampqVO/FxcVh8eLFPA+MfJ4syyguLsaXX36J6upq1XsxMTFYtGgRRys+huHST1y3wQ4dOoRdu3aho6NDeU+v12PixIlYuHAh51bIp7lugx05cgQ7d+5UPR4iICAA48ePx4IFCzi34oMYLv1ACIHW1lZ89dVX+P7771W3wSIiIrBw4UJMmTKFK8HIp7met7JlyxacOHFCdRvMZDIhJycHzz//POvARzFcNObaYbxu3TqUlJQoq8EkSUJqaipee+01JCUlsZjIpwkhcPPmTeTl5aGoqEj1XnJyMpYuXcoDV30cw0VDQgiUlJTgs88+U52HFBAQgOeffx4LFy7ksfjk84QQKCsrw7p161TzjAEBAZg8eTIWLlzIY/H9AMNFI7Iso7CwEOvWrVM9c8VoNGLBggWYPn06h//k82RZxk8//YS8vDzVMuPg4GDMnTsXL774IuvATzBcNCDLMk6ePIkvvvhC9Rxvs9mMv/mbv8EzzzzD4T/5PFmWcebMmXvqIDIyEq+99hpGjx7NOvAjDJcnJMsyTpw4gQ0bNqhWhCUmJiI3N5fzK+QXZFnGqVOnsGHDBtWKsPj4eCxfvhzJycmsAz/DcHkCPQuqZ7CkpaUhNzcXQ4YMYUGRz5NlGT/++OM9wZKcnMw68GMMl8ckhMCFCxfw+eefq4JlxIgRyM3N5f4V8gtCCFy6dAlffPGFKljS09NZB36O4fIYhBAoLy/H+vXrVfeWn3rqKfzmN79BZGQkC4p8nhACFRUVyMvLU53snZ6ejhUrVvAYFz/H2bU+cp0T9tlnn6GpqUlpT01NxYoVKxgs5BeEEGhsbEReXh4aGxuV9uTkZLz99tsMFmK49FVnZyc2bNiAqqoqpS0+Pp63AMivdHV1YePGjaisrFTaYmNj8dZbb8FsNrMOiOHSF7IsY//+/Th//rzSFhERgeXLlyM2NpYFRX5BlmUcOHBAVQcmkwlvvPEG4uLiWAcEgOHSa0IIXLx4EV9//bVypEtgYCBeffVVpKWlsaDILwghcPnyZRQUFChn5hkMBixevJh1QCoMl14QQqClpQX5+fno7OwE0H1W2MyZMzFhwgQWFPkF14GsW7ZsUeoAAGbMmIGJEydygySp8LuhF4QQKCgoUD2HIj09HdnZ2dDrueCO/IMQAgcPHlTNs6SlpSE7OxsBAQFu7Bl5IobLI7iWWx45ckS5HRYWFoYlS5YgJCTEzb0jGhhCCFy/fh1HjhxR2kJDQ7F48WKEhIRw9E73YLg8gtPpxL59+5QNYq7bYcOHD2dBkd9wOp0oKChQ7euaPn06j3WhB2K4PITrCP2eq2Li4+Px4osv8v4y+Q0hBEpLS1kH1Cf8zngIp9OJgwcPwmazAQB0Oh2ysrJgMpnc3DOigeN0OvHNN9+gq6sLQHcdzJo1C+Hh4W7uGXkyhssDCCFQWVmJixcvKm3Dhg3D+PHjeRuA/IZrrqVnHSQmJnKVJD0Sw+UBhBA4duyYcrUmSRKmT58Oo9Ho5p4RDRwhBL777jvVEvxp06axDuiRGC4P0NTUhHPnzimvhwwZgrFjx/JqjfyK1WpV1UF0dDTrgHqF4XIfQggUFhaqHtM6ceJEzrWQX3HVQc8DWjMzM1kH1CsMl/twOBw4efKksq/FaDQiMzOTV2vkVxwOB06fPq28Zh1QXzBc7qOmpgbXrl1TXqelpSE2NtZ9HSJyg7vrIDU1lXVAvcZwuY9z586pJjAzMzN5vAX5nfPnz6vqYMKECawD6jWGy11kWUZzc7OyGiY8PBwjR47krQDyK7Iso6WlBcHBwQC662DEiBGsA+o1nrp4F51Oh6VLl2LGjBk4d+4cbDYbBg0a5O5uEQ0onU6HV155BdOnT8f58+dZB9RnDJcehBC4evWqsrclISEBQgjVsRf0YE6nU3X2FHknWZZRWlqq1EF8fLyycowejXXQjeFyh06nQ0JCAgoLC1lETyA4OBihoaHu7gY9JlcdXLhwARcuXHB3d7yW0Wj0+zqQhGu9rZ8TQsDpdLq7Gz5Bp9PxQEMvxTrQjr/XAcOFiIg057+xSkRE/YbhQkREmmO4EBGR5hguRESkOYaLl3A6nWhra+NKHvJrTqcTra2trAMvwHDxElVVVXjvvfdQVVXl7q4QuU1VVRX+/u//nnXgBRguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuHiBYQQaGxsRFtbGxobGyGEcHeXiAYc68C7SIL/Qx7LarVi/fr1+K//+i+UlpYq7ampqXjvvffw5ptvIjIy0n0dJBoArAPvxHDxUAUFBVi0aBHa29sBQHWVJkkSACAkJARbtmxBVlaWW/pI1N9YB96L4eKBCgoKkJ2dDSEEZFl+4NfpdDpIkoQ9e/awsMjnsA68G8PFw1itViQmJqKjo+OhBeWi0+lgNBpRVVXFWwPkM1gH3o8T+h5m/fr1aG9v71VBAYAsy2hvb0deXl4/94xo4LAOvB9HLh5ECAGLxYKysrI+rYSRJAkpKSkoKSlR7kMTeSvWgW9guHiQ+vp6REdHP9HnzWazhj0iGnisA9/A22IepK2t7Yk+39raqlFPiNyHdeAbGC4eJCws7Ik+bzKZNOoJkfuwDnwDw8WDmM1mpKam9vl+sSRJSE1NRVRUVD/1jGjgsA58A8PFg0iShPfee++xPrty5UpOYpJPYB34Bk7oexiu7ydiHfgCjlw8TGRkJLZs2QJJkqDTPfy/x7UzeevWrSwo8imsA+/HcPFAWVlZ2LNnD4xGIyRJumeY72ozGo3Yu3cvZs+e7aaeEvUf1oF3Y7h4qKysLFRVVWHVqlVISUlRvZeSkoJVq1ahurqaBUU+jXXgvTjn4gVcz7FobW2FyWRCVFQUJy3J77AOvAvDhYiINMfbYkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERae7/A6671DuQXeOIAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "8214259e",
"metadata": {},
"source": [
"Now there are lock symbols in their top left corners!"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0298d20a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.13e-04 | test loss: 6.00e-04 | reg: 2.73e+00 : 100%|██| 20/20 [00:03<00:00, 5.68it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "markdown",
"id": "5ca6421a",
"metadata": {},
"source": [
"Retrain the model, the loss remains similar, meaning that the locking does not degrade model behavior, justifying our hypothesis that these two activation functions are the same. Let's now determine what this function is using $\\texttt{suggest_symbolic}$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2ccb7048",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"function , r2\n",
"arctanh , 0.9999993678015309\n",
"tan , 0.9998485210873531\n",
"arcsin , 0.998865199664262\n",
"sqrt , 0.9830640000050016\n",
"x^2 , 0.9830517375289431\n"
]
},
{
"data": {
"text/plain": [
"('arctanh',\n",
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
" 0.9999993678015309)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(0,1,0)"
]
},
{
"cell_type": "markdown",
"id": "0092be41",
"metadata": {},
"source": [
"We can see that ${\\rm arctanh}$ is at the top of the suggestion list! So we can set both to arctanh, retrain the model, and plot it."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1bb96fe1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999992221865773\n",
"r2 is 0.9999993678015309\n"
]
},
{
"data": {
"text/plain": [
"tensor(1.0000)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'arctanh')\n",
"model.fix_symbolic(0,1,0,'arctanh')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "83b852a3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.39e-04 | test loss: 2.54e-03 | reg: 2.73e+00 : 100%|██| 20/20 [00:03<00:00, 6.33it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20, update_grid=False);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9ccd0923",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArGElEQVR4nO3deVRV57038O8+HIaDHGQUVEQFj+KYqNFUjRmNpNKYVGNibNWmNfPVrnY1b7v63vYmXat3NdeVvpjkJmmaWCVO0YAZqgZTE81kY0Y1CojgACIqw5FROOfs5/3jcZ8BEUE3nGF/P2u5CBtIHgw/vueZFSGEABERkY5M/m4AERGFHoYLERHpjuFCRES6Y7gQEZHuGC5ERKQ7hgsREemO4UJERLpjuBARke4YLkREpDuGCxER6Y7hQkREumO4EBGR7hguRESkO4YLERHpjuFCRES6M/u7AUTBQAiB2tpaNDU1ISYmBomJiVAUxd/NIgpY7LkQdcFut2PVqlWw2WxITk7G8OHDkZycDJvNhlWrVsFut/u7iUQBSeFNlESdKywsxPz589HS0gJA9l40Wq8lOjoa+fn5yM7O9ksbiQIVw4WoE4WFhcjJyYEQAqqqXvbzTCYTFEXBtm3bGDBEXhguRB3Y7XakpaWhtbW1y2DRmEwmWCwWVFZWIi4urvcbSBQEOOdC1MHatWvR0tLSrWABAFVV0dLSgry8vF5uGVHwYM+FyIsQAjabDeXl5ehJaSiKgoyMDJSWlnIVGREYLkQ+ampqkJycfE1fn5iYqGOLiIITh8WIvDQ1NV3T1zc2NurUEqLgxnAh8hITE3NNX2+1WnVqCVFwY7gQeUlMTERmZmaP500URUFmZiYSEhJ6qWVEwYXhQuRFURQsX778qr52xYoVnMwnuogT+kQdcJ8L0bVjz4Wog7i4OOTn50NRFJhMXZeItkO/oKCAwULkheFC1Ins7Gxs27YNFosFiqJcMtylPbNYLNi+fTtmz57tp5YSBSaGC9FlZGdno7KyErm5ucjIyPD5WEZGBnJzc3Hq1CkGC1EnOOdC1A1CCHz00Ue44447sGvXLtx2222cvCfqAnsuRN2gKIp7TiUuLo7BQnQFDBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3DBciItIdj9wn6iYhBIQQnV4eRkS+2HMh6gGGClH3mP3dAKJgwWAh6j72XIiISHfsuVDAcLlc2Lt3LxoaGvzdlKA3efJkpKSk+LsZZGCc0KeA0draipkzZ6K1tRVWq9XfzQlKQgiUlpZizZo1mDt3rr+bQwbGngsFDCEETCYT/vrXv+LWW2/1d3MClhACLpcLTqcTDocDqqoCkHNCUVFRyMnJAV8zkr8xXCjgmM1mREZG+rsZfqMteb5w4QJqampQWVmJY8eO4fjx46ioqEB1dTXq6urQ0NCApqYmuFwuAEBaWhoKCgq48IACAsOFKAAIIdDW1oaysjJ8/vnn+PTTT3HgwAFUVFSgsbERDofjiv8ORVHcQUPkbwwXIj/ReignT57E22+/jfz8fBw4cACNjY3+bhrRNWO4EPmBqqooLy/HSy+9hE2bNqG6urrLz1cUBWazGRaLBTExMbBarejfvz/69++P8PBwKIqClJQUhIeH99F3QNQ1hgtRHxJCoLGxEa+++ipyc3NRVVXV6ef169cPQ4cOxZgxYzBu3DiMHDkSQ4cOxYABAxAbGwuLxYLw8HCYzZ4SVhQFTqezr74Voi4xXIj6iBAChw8fxq9+9St8+OGH7lVemri4OMyYMQN33303ZsyYgaFDhyI6OppnmVFQYrgQ9QFVVbFz50488cQTOH78uM/HBg0ahCVLlmDx4sUYMWIEzGYzw4SCHsOFqJepqoq33noLTz75JGpra93Po6OjsWTJEvzmN7/BsGHDYDLxNCYKHQwXol6kqiq2bt2Kxx9/HPX19e7nGRkZeO6555CTk4OwsDD2VCjkMFyIeokQAp988gmefPJJn2CZOnUqXn/9dYwZM4ahQiGL4ULUC4QQKCsrw+OPP46zZ8+6n8+cORN5eXlIT09nsFBI4yAvUS9oamrCr3/9axQXF7ufTZo0CWvWrGGwkCEwXIh0pqoqXnrpJezYscP9LC0tDa+++iqGDRvGYCFDYLgQ6UgIgW+++QbPPfec+5wvi8WClStXYuLEiQwWMgyGC5GOWlpa8PTTT6OmpgaA3DW/bNkyzJs3j8FChsJwIdKJEAIFBQX44IMP3M/Gjh2L3/3udz7HtBAZAcOFSCc1NTVYuXKl+3j8yMhI/Nd//RdSU1PZayHDYbgQ6UAIgby8PBw6dMj9LCcnBzk5OQwWMiSGC5EOTp8+jVdeecV9vXBcXBx++9vfGvpGTTI2hgvRNRJCYN26dSgvL3c/W7BgAVeHkaExXIiu0blz57B69Wp3ryUhIQFPPvkkwsLC/NwyIv9huBBdAyEE8vPzcfToUfezefPmYezYsey1kKExXIiuQWNjI1avXu2++Cs2NhaPPPIIj88nw2MFEF0lIQR27dqFAwcOuJ9lZ2fjuuuuY6+FDI/hQnSVHA4HVq9e7bOvZdmyZdwwSQSGC9FVEULgwIED+Pjjj93PpkyZghkzZrDXQgSGC9FVEUJg/fr1aGxsBACYTCYsXboUFovFzy0jCgwMF6KrcObMGbzzzjvu94cNG8bd+EReGC5EPSSEwI4dO3Dy5En3s3nz5mHAgAF+bBVRYGG4EPVQW1sbNm7c6F5+bLVa8cADD7DXQuSF4ULUA0IIHDx4EF988YX72bRp0zBu3DiGC5EXhgtRD7311ltoamoCICfyH3zwQURERPi5VUSBheFC1AN1dXV499133e+np6dj9uzZ7LUQdcBwIeomIQQ+/fRTlJWVuZ/NmTMHKSkpfmwVUWBiuBB1k6qq2Lx5M5xOJwAgKioK9913H3stRJ1guBB108mTJ/Hhhx+63x87diwmT57McCHqBMOFqBuEEHj//fdx9uxZ97N58+YhJibGj60iClwMF6JuaGtrw5YtW3yuMZ47dy57LUSXwXAh6oZDhw7hq6++cr8/ffp02Gw2P7aIKLAxXIiuQAiBgoICn70t999/P8LDw/3cMqLAxXAhuoLz58/77G0ZPHgwZs2axSExoi4wXIi6IITA559/jiNHjrifZWdnIzU11Y+tIgp8DBeiLmh7W7TbJiMiIrBgwQL2WoiugOFC1IVTp07hgw8+cL8/evRo3HjjjQwXoitguBBdhra3pbq62v3sxz/+MaxWqx9bRRQcGC5El9HW1oY333zTvbelf//+uPfee9lrIeoGhgvRZRw8eBD79u1zvz9t2jRkZWX5sUVEwYPhQtQJVVWxZcsWNDc3A/Dc28K9LUTdw3Ah6kRNTQ22bt3qfj89PR133nknh8SIuonhQtSBEAKFhYU4fvy4+9k999yDAQMG+K9RREGG4ULUQVtbG9544w24XC4AQL9+/bBw4UL2Woh6gOFC5EUIgW+++QZ79+51P5s2bRquu+46hgtRDzBciLyoqoo1a9a4J/LDwsKwZMkSREZG+rllRMGF4ULkpby83OeQSpvNhrvuuou9FqIeYrgQXSSEwPr1631um1y0aBESExP92Cqi4MRwIbqoqqoKb7zxhvv91NRULFq0iL0WoqvAcCGC7LVs2LDBZ/nxggULMGzYML+1iSiYMVyIAFRXV+O1115znyOWkJCAX/ziF+y1EF0lhgsZnhACeXl5OHr0qPvZvHnzMGbMGIYL0VViuJDhnThxAq+88opPr+WJJ55AWFiYn1tGFLwYLmRoLpcLzz//PE6cOOF+tnDhQowfP569FqJrwHAhwxJCYN++fVizZo372cCBA7F8+XKYTCwNomvBCiLDamxsxB//+EfY7XYAgKIoePzxx2Gz2dhrIbpGDBcyJJfLhZdeegm7d+92P5s4cSIeffRR9lqIdMAqIsMRQmD37t1YuXKl++Tj6OhoPPPMM0hKSvJz64hCA8OFDEUIgdLSUixfvhz19fUA5HDYsmXLMHv2bA6HEemE4UKGIYRAVVUVHnnkERQXF7uf/+AHP8Dvf/97mM1mP7aOKLQwXMgQhBA4deoUHnroIXz88cfu52lpaXjhhReQnJzMXguRjhguFPKEEDh06BAeeOAB/Otf/3I/j4uLw4svvoiJEycyWIh0xnChkCWEQHt7OzZv3oy7777b53bJ2NhY5Obm4kc/+hGDhagXcJCZApYQAm1tbYiMjOxRAAghoKoqvv/+ezz77LPYunUr2tra3B9PSEhAbm4uHnzwQS47JuolrCwKSNru+fvvvx9FRUXuc7+6+nwhBJqamrBnzx4sW7YMt912GzZt2uQTLBkZGVi3bh0WLVrEs8OIehF7LhSQ9u7di8WLF+PYsWOoqKjA66+/juuvv97d09DC5MKFC6iursb333+PPXv2YNeuXSgpKfEJFAAwm8246667sHLlSowcOZJDYUS9jOFCAaehoQF/+MMfcOzYMQDA/v37MWfOHDz88MMYP348mpubUV5ejqKiIpSWlqKiogINDQ1QVfWSf5eiKMjMzMRTTz2FRYsWITo6msFC1AcYLhRwYmNj8cwzz2Dp0qWorq4GAJw9exZ//vOfoSjKFYfIACAsLAw2mw1Lly7FkiVLkJqaylAh6kMMFwpIs2bNwtq1a/HLX/7SZ8NjV8ESHh6OQYMG4aabbsK8efNwyy23ID4+nqFC5AcMFwpIiqJg1qxZ2LlzJ5599lls3rwZdXV17qGv8PBwxMbGYtCgQcjKysINN9yAqVOnYsyYMUhISICiKAwVIj9iuFDAUhQFgwcPRm5uLp566imUlJSgvr4eUVFRGDhwIAYNGoT4+HhERUUxTIgCDMOFAoqqqti7dy+am5s7/XhUVBQA4PTp0zh9+nRfNi0oOJ1O1NXV+bsZRAwXChwmkwljx47Fjh07sGPHDn83J2jFxsYiPj7e380gg1NEd5beEPUBIQScTme3VoNR18xmM08fIL9iuBARke740oaIiHTHcCEiIt0xXIiISHcMFyIi0h2XIhN1k/faF27YJOoaey5E3fTtt9/CZDLh22+/9XdTiAIew4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIi6QQiB+vp6AEB9fb3P3S5EdCmGC1EX7HY7Vq1aBZvNhlmzZgEAZs2aBZvNhlWrVsFut/u3gUQBShF8CUbUqcLCQsyfPx8tLS0AOr+JMjo6Gvn5+cjOzvZLG4kCFcOFqBOFhYXIycmBEAKqql7280wmExRFwbZt2xgwRF4YLkQd2O12pKWlobW1tctg0ZhMJlgsFlRWViIuLq73G0gUBDjnQtTB2rVr0dLS0q1gAQBVVdHS0oK8vLxebhlR8GDPhciLEAI2mw3l5eU9WhGmKAoyMjJQWlrqno8hMjKGC5GXmpoaJCcn+zwbDGA0gDsB/BbASAAVAFov8/WJiYm93UyigGf2dwOIAknT+fPIAJB18c8oAP0AqADqL37OUgDDAJwEUAyg6OLbBgCNjY0MFyIwXMjo2tuBsjKguBgoLkbqwYP4E4B2AKUA3ocMjqMAHBe/5P8BGAMZPtcB0NaInQaQVFAA3HADMHo0kJQEcIiMDIrDYmQsLS1Aaak7TFBeDjidQHQ0MGoUxKhRmPUf/4GPT56Es5v/yjjIYbOZycn4y9KlUCor5QcSEmTIjB4NjBoFDB7MsCHDYM+FQltDA1BS4gmTEycAIYC4OCArC5g2Tb4dMgRQFCgA5paX46Nf/Up+XjfYAfxbUbDwP/8TyooVQFMTcOQIUFQk/+zdC6gqYLXK/1ZWlgyc9HQgLKwXv3ki/2HPhUJLba0nSIqLgaoq+Tw52fNLfdQoICXlsr0I3fe5XLjg21sqLQUcDsBiAWw2T+8mIwMID7+Gb54ocDBcKHgJAVRX+4ZJTY382ODBnl5CVpYcouqBnu7Q3759O2bPnt29f7nDIYfjtDaXlACtrTJYRozwtNlmkwFEFIQYLhQ8VBU4edJ3mKuhQfZAhg+XPZKsLPnWar3m/1x3zxYrKCjofrB0Rvu+iop8vy+TSX5f3iEZE3NN3xNRX2G4UOByOj2v8EtKfF/hZ2TIoSTtFX5UVK80wW63Iy8vD88//zzKysrczzMzM7FixQosXboU/fv31/c/KgRw+rRv2Gg9srQ0z/d9FT0yor7CcKHA0dYGHD0qf5kWFcl/djhkcNhsnjkTP8xNCCFQV1eHxsZGWK1WJCQk9O1O/Joaz9+L91zSgAG+YdPFXBJRX2K4kP80N3t6JMXFwLFjgMslh3604a2sLGDoUK6q6uj8efn3poVNx1Vw2iKBtDSGDfkFw4X6jt3uO4ldUSF/IcbHe159cz/I1Wlpkcuftd5NWZkM6n79PCE9erScw2FQUx9guFDvEAI4d8731fWZM/Jjqam+k9Tcya6/tjYZMNrffWmpfBYZ6TvEOGIEEBHh79ZSCGK4kD6EAE6d8vRKioqA+noZGkOGeIJk1Cg5dEN9y+WSw47e8zYtLbIXk5np6TmOHClPKyC6RgwXujoulxzn9x7mamqSv6y8l8+OHCmHZiiwCAFUVnpOESgulsOWiiLnuLSeTVYWEBvr79ZSEGK4UPc4HD4HPKK0VO48Dw/3DLNkZclhlshIf7eWekoIOWzp3bM5e1Z+bNAg37BJSvJvWykoMFyocxcueCaIi4tlsGgHPI4c6QmT4cMBM4+oC0l1dZ7//0VFsqcDyHDxPiNt4EDOmdElGC4kNTb67nw/fly+mo2N9Z18HzJE7hwn42lq8j1q59gxebpAx5+RoUP5M0IMF8OqrfUNk1On5HPvV6VZWXJlF1+VUmdaW30P5NQ2vVosnuXPWVk8kNOgGC5GoB3w6B0m587Jjw0e7PuLgLco0tXynpcrKpLDqtq83IgRfXJcDwUOhksoUlXPSiAtUM6f910JpC0L5kog6i0ul+dATm2RQFOT50BO742zPJAz5DBcQoHTKce/tSApKZF7GMxmOSThvSyYR7iTv2h7obR9UEVFctEA4NkLpQVOfLx/20rXjOESjNrbPWPdJSXyn9vb5RLgkSM9w1yZmdx9TYFLCHkgp9arKSqSw7eAPIBTu9ht9Gh5QCfn/oIKwyVY7N8PHD7se8Cj97lRWVnAsGE8N4qCm/f5c8XFclhNO39O69lMnsyrBoIAwyVYfPednDCNjZV/+veXQ1x8NUehzOmUF6dpf5qb5QsqhkvAY7gEC6dT9koYJmRkLpesAe6jCXgMFyIi0h3jn4iIdMdDoS5SVRVVpaVwtbb6uylBL2H4cFj1vlee+oSqqjh15AjrQAeJw4fDauDrJRguFzmdTtTk5cHR3g4T94JcHSFgrq4GHnsM1smT/d0augpaHTjb22HiLvqrIgCYT58GnnjC0HXAcNEIAZeiYMDixRg8erS/WxMcWlrkSQBpaUB0NJwuFw79z/+Ak3hB7GIdJC9ejLQxY/zdmuDQ3CzrYMgQTx385S8w+nQ2w6UDk8kE8xUO2fP+oVGMunpLCLnv5m9/k+eRjRsHzJ4Ng/5thJywsDDWQXcIITd/vvyyPPR1/Hhg9myu6gTDpduEEFCFQPXZGhyvPIW2tnakJidi2JA0WKIijVdcQgDffCP33lRXy2M8Zszwd6uol3nq4ByOV1Shrb0NKclJGG7kOvj6a1kHp0/L08anT/d3qwICw6WbnC4XPvnia3x76DDaHU4A8tVaSlIi5tx2MwYkJRirsOx2+YpNM3iw/EMhTauDrw8egsN5sQ6gICU5ETm338I6SEtjHVzEpcjdIITAgcMl+HL/9wgPD8f0yddj9szpSB80EGdqavH+nk9xoa3N383sO0IAhw7JwtJMnMhzzEKcEAL7Dxdj33cHERERjumTJ2L2zdORPnggzpyrxY7dnxizDurrPc+uv57XfF/Enks3OBxOHCg6ApPJhB/eOhMjhqUDAMaMzMRb2z9AVfVZVFRVwzZ8qDFetTmdwOefy+IC5N0ckyZxnDnEORxOHDjsqQPb8KEAgLEjR2DLtp04VX0WFVWnYRs+zBh14HL51kFkpDz3zAjfezew59INbe3taGhqgrWfBWkDU6AoChRFQVRkJDKGDIYqVNTW232+RlVVNDY24tixYzh69ChcLpd/Gq83IeRhgkeOeJ5lZnIowAAutLehoakJMf0sGDIo1bcO0tMghIra+vM+X6OqKhoaGkK3DoqLPc8yM+WwGAFgz6VbVCEgIFeSmTqcaRR28RRitcOyQ7vdjhdffBEtLS1QVRVPP/00oqOj+6rJvUcIYPduecMgIF+lzZghbxsMlV8c1CkhAAFxmTqQ76vqpXXwwgsvoLm5Gaqq4k9/+lPo1MGePayDLjBcrkQImOBEdHg7LOEAnI2AybNEM1xpQb8IByJM7T5fFhsbi+XLl6OmpgarV6/u40b3EiGAqirgyy89z1JS5DgzhwJCmxBQhBPR4Q5EhSuX1EGE0op+Ee0I76QOVqxYgXPnzuH111/v61b3DiHkyrB//9vzbMAADg13wHC5ElcD+tWvxk+uL4KiKIio/BLemznGRzqRdWM7wiOOAo0KYJ0CKArMZjPi4+Nx/vz5y/+7g42qAjt3Ao2N8n1FAW6+mVclG4GrATH1r+OnEw9DgYKIyn3wLoTxkQ6MurEdEZFHgUYA1qk+dWD3XvwR7IQAPvhAXgGguflmeQ0GuTFcrqSlBKbmg+gXrgIAhKsN7a4wOFQTLOEORChARDgAUQvU7wJiJgFKCP61CgGUlwN793qeJSUBN93EV2tG0FIMU8sBTx0429CuhsHh8qqDCHjVweTQrYMTJ4DPPvM8S0oCZs5kHXQQgv/3dSYcAFT3u+dbo/De4ZE4fyES04dVYOLgas/rN+EEQvXwk/Z24O235ZEvgCykO+7gpU1GoTrhUwcXovDuoVE4fyESM4ZXYOKg057fraoDIVsHDgewdSvQ1CTfVxTgtttkwJAPrhbroeP2/qg8H4vGtkgcPJ0Kl2qAv0Ih5JLLgwc9z4YMAW65ha/WDOp4vacO9p9KMU4d7NsHfPut59mgQcDtt7MOOmGAnwh9pcY0ISaiHSZFYEicHSaT2unnqaqKuro61NfXw+l0ora2Fg0NDcF3mJ02ib91q2cVjNkM3HsvYLX6tWnkP6kxzYiJaIcCgaHx569YB3a7HS6XK7jr4OxZ4K235D4vQN4Me889nGu5DIZLD6VYmzEiqQ6RZicmeQ+JdeBwOLB161bs3LkTsbGx2LRpE/bs2RN8RdXaCmzYIM9M0kydypUxBpdibcKIpFpEmZ2YlHa6yzooKChAYWEhrFYrNm7ciN27dwdfHbS1ARs3AmfOeJ5NmgTceCPr4DI459JDQgCtjnCoqgltrrDLFlVERAR+9rOf+TzTNp0FDZcL2L4d2L/f82zAAGD+fNl7IcMSAFodEVCFgjan+bK/XyMiIvDQQw/5PAu6OtBWSX71ledZYiJw//1yXwt1ir8hrsgEueRSvtJSFGBm5glMTT+FeMsF309VtM+VBaRtsAxKqipXhm3b5jneIiICWLhQ7m0Jpl8OdO2UDnUA4ObM45iabkZ8dOtlPjcE6kA79bjjsPD998tTKVgHl8VwuRJLBhA5GGivBiCgAEju1wZAO6DvYuEoEReXIQdxIWm0u1rWrZOrxABZRNnZPDvJqDrWgQIkx1ymDqyTQ6cOSkuBNWt8d+Lffrs8Vp910CWGy5WEDwDS/w/QXtP155migIiUvmlTb9L2s7z2mmezJCBPPZ47V05ikvGEDwDSfwu016LLZcYmS+jUQWWlvAyvrs7zfOxY4L77WAfdwHC5EkUBzP3ln1CnbRB7+WXg3DnP8+HDgaVLAYuFr9aMSlEAc5z8E+q0FZL/+7/yrSYtDfj5z4GYGNZBNzBcSBICOH4ceOkleW6SJjUVeOQROYHJgqJQpwXLiy/KF1qaxETg0UdlPbAOuoXhQrKgjhyRQwDeSy2TkoDHHpMbJllQFOq0F1gvvwxUVHiex8XJYMnMZB30AMPF6FQV+O47YPVq3xv1EhOBxx8HRoxgQVHoE0JeV/y3v8nNkpr+/WWwjBvHOughhotRCSGXVu7eDbz5pufMMABITpbBMnIkC4pCn8slj8/Py/M96Tg+XgbLhAmsg6vAcDEiIeTO+61b5eYw7TgLQE5aPvYYMGwYC4pCmxByqf22bcC778pd+JoBA2QdZGWxDq4Sw8VohJArwfLy5HCY9zEco0YBDz/MSUsKfUIAdjuwfr3stXjfHjl0qAyWoUNZB9eA4WIkqgp8/70MFu8VYSaTPC9s8WI5xsyColCmqp7NkcePe54rCnDddXK5cVIS6+AaMVyMQBsGe/99eVZYq9dxHRERwJw5wN13A5GRLCgKXULI+1g+/BDIz/fcyQLII11uvx1YsADo1491oAOGS6jTdhpv2CDvY/EeBouPBxYtkie7cscxhTLtyPyNG+UBlN7DYDExwAMPyPuJzGYGi04YLqFKe5X26adAQYHvMmNFkSvBli4F0tNZTBS6tFWRX34pV0V67+MC5MKVpUtlPZh4A4meGC6hSAg5p7JlizzR1ftVWkSEvJ74nnt4jAWFNiHkuWD5+fJFlsPh+ZjZDNx0kzzdOC6OddALGC6hROutfPaZXGbsfcEXII/Kf/BBeQhlWBgLikKT1lv56itg82bfxSsAkJAg51ZmzOAwWC9iuIQKVZVzK1u2yCXGqte1s2azXA22YIHcIMliolDlfR3xv//tu4fLZJKrwR58UO7nYh30KoZLsNNWgn30kdwMdv6878eTkuTNkT/4gbw1jwVFoUjbEKn12ms6XJHRv7+8MuL227kqso8wXIKZyyUPnNyyRb71XglmNgNTpsi7J3hzJIUyVZX7VbZskSsivecYw8KA8ePlDapDhnDSvg8xXIKRNlG5bRuwZ4/nljxNaqrsrUyZwjFlCl1CyAvtCgvlMUbe+1YAObfy4x/LiXv2VvocwyWYaBP2+/bJrn91te/HIyNlIc2dy/tXKHRpE/bffSfnVk6evLTXPnWq7LXzKCO/YbgEC1WVd0y89Rawf79v119R5G2RCxbIa1hNJhYUhSZtmX1BgXyR5b28WFGAQYNkHUyaxF67nzFcAp0Qsrv/wQey6+99rz0AWK3AXXcBs2bx2AoKXdrCld27gX/+03dTMABER8v9W3Pm8Hy8AMFwCVRa1//AAbkJ7MQJ365/WJjcrzJvHicqKXQJIXvtJSVyz0rHhSsmEzBmjOytjBjBOgggDJdAJIQ8puLtt+Vafe+uPwAMHCgnKqdM4fJiCl3awpV33wU+/vjShStJSfKkCU7YBySGSyARQl5Y9MknwHvvXbrD3mIBbr0VyMnhkRUUurSFK198IedWOi5ciYiQu+vvvZebggMYwyVQqCpw7Jjs+h8+7LvD3mSSN+Lddx9gs8liYkFRKPI+aeLbby9duJKR4Vm4wiOMAhrDxd+EAJqb5V0rO3fKf/aWmCjvWpk5k11/Cl3ahP2HH8oJ+44nTcTGysn6O+7gwpUgwXDxJ1UFiouBTZuA8nLficrwcHlky733coc9hTZVlT//GzcCRUWXLlyZNEn22tPSOGEfRBgu/qAtL962TS4x9p6oVBRZRPfdB1x/Pbv+FLq8b0jdsePSHfbaSRNTp3LhShBiuPQ1VZXLKTdsuLS3EhUlu/05OXIYgMVEoUqbY1y/XvbevesgIkIOA997L0+aCGIMl74ihOyh7Nwpx5RbWjwf0yYqFy6UE/fs+lOo0k4v3rVLLrX33hSsKHLP1sKF8rBJ9tqDGsOlLwgBVFUB69Zdeo99dDRw551yspITlRTKtLtW1q8HvvnGdyVYZKQ8Dn/uXO6wDxEMl96m3Yi3YYPvHROKIu/vXrSIvRUKfaoqD5p8441L960MGSIv8JowQfZWKCQwXHqLNln5z3/Kycr2ds/HIiKA226TY8pWK1+lUejSNgZv3y43BnsvXjGb5dzKggXcFByCGC69QQjZS8nLkxvBvIfBkpNlb2XSJI4pU2gTArDbZW/liy98NwYnJMi5lWnTWAchiuGiNyHkKpi//13eM6FRFDlJuXixPBuMxUShTAj58//3vwNlZZ7nigKMHg0sWQKkp7MOQhjDRU+qKu9aWb1aHriniYgAsrPlZKXFwoKi0KaqwPffA6+9Bpw753luNsurIebP5+IVA2C46EVVgc8+k0MA3ke4xMbKYbBp02RxEYUyVZVDYP/4h+8y45gYOQx2yy0cBjMI/rbTg8sl1+2/+abvhOWgQcAvfgGMGsViotDncgF79sgl962tnucpKcCyZfLeFa6KNAyGy7VyueQRLm++6bsizGYDHnmE8ytkDNoLrA0b5OowTWYm8Oij8kgj1oGhMFyuharKU1y9g0VR5Hr9ZcuA+HgWFIU+VZU9lo7BMm6cfIGVlMQ6MCCGy9XS5lg2bfINlhtuAH7+c+5fIWNQVXlb6rp1nmBRFHkF98MPc7e9gTFcroYQcjXMunWeORZFASZPlnMsMTEsKAp9QsiL7dau9Z1juf562WPh4auGxtm1nhJC3pS3erXvEeETJjBYyDiEAE6dksuNGxo8z7WhMAaL4TFcekK7h+Uf//Bdvz9ihAwWDoWREXjXwZkznueZmTJYOBRGYLj0jMsFFBTI+1g0KSlybDkhgQVFxuByAVu3ylsjNcnJsg44eU8XMVy6Swjg66+B3bs9Z4VFR8tjLAYPZkGRMQghj8vftctTBxYL8LOf8TgX8sFw6Q7tIMrNmz0rw0wmeZzLhAksKDIGIeSxRt4rJE0m4O675SQ+64C8MFy6Q1Xl0fne91BMmCAv+eKOYzIKVQXefRc4fdrzbMIE4K67WAd0Cf5EXIkQco7l0089z+LigAcekLfnERmBEEBpKfDxx55ncXHyvLCoKL81iwIXw+VK2tvlqzXv/Sxz5vA4CzIWhwN45x3fOvjhD+UtkqwD6gTDpSvaZsnDhz3PMjKAW2/lMAAZhxDAwYOyFjTDh8s771kHdBn8yeiKwwEUFgJOp3zfbAZ+9CO5SozIKBwO4P33PXUQFiYn8fv182+7KKAxXC5HCKCkxHdPi83G1WFkLFodlJR4no0cydVhdEUMl8vRTnp1OOT7YWFydRgn8clIVFXu7fKug9mzWQd0RQyXyzlzRo4za4YMAcaP56s1MpazZ33rID2dvXfqFoZLZ7RdyNrBlIoC3HST3IlMZBRCAF995bmuWFGA6dNZB9QtDJfOtLfLotLExACTJvHVGhlLezvw5Zee961WYMoU1gF1C8OlM1VVQEWF5/2sLHkgH5GRdKyDUaNYB9RtDJfOHDx46SVgXM9PRrN/v28dTJnCOqBu409KR0IAtbWeIy1iY2XPhUMBZCSqKg+pZB3QVeI1xx0pCvCTnwB33CF7MM3N8q4WIiMxmYCf/hSYNQs4cIB1QD3GcPGiCIGa0lK42trkg9RU+fa77/zWpmDiUlW4mpvB17bBzSQEao4cgVOrg4EDZY+eddAtqssFtanJ8HXAcLlIURQgLQ1i/37U7d/v7+YErXCLBWE8HidoKYoCkZYGfPcd6hgmV81ssSDM4MfjKEJo18kZmxACLpfL380ICSaTCSZO/AYl1oF+jF4HDBciItKdcWOViIh6DcOFiIh0x3AhIiLdMVyIiEh3DJdg4XLJU5q5koeMjHUQNBguwaKiAnj0Ud+DBImM5uRJ4OGH5VsKaAwXIiLSHcOFiIh0x3AhIiLdMVyIiEh3DBciItIdw4WIiHTHcCEiIt0xXIiISHcMFyIi0h3DhYiIdMdwISIi3TFciIhIdwwXIiLSHcOFiIh0x3AJAkII1NXVoam5GXV1dRBC+LtJRH2OdRBcGC4BzG63Y9WqVbDZbJg0eTLeeecdTJo8GTabDatWrYLdbvd3E4l6nXcdTL7hBrz33nuYfMMNrIMApwjGf0AqLCzE/Pnz0dLSAgBIFwJ/BvB/AZxUFABAdHQ08vPzkZ2d7b+GEvWijnUwVAj8N4DfAzjBOgho7LkEoMLCQuTk5KC1tRVCiEu6/9qz1tZW5OTkoLCw0E8tJeo9rIPgxnAJMHa7HfPnz4cQAqqqdvm5qqpCCIH58+dzaIBCCusg+DFcAszatWvR0tJyxYLSqKqKlpYW5OXl9XLLiPoO6yD4MVwCiBACL7zwwlV97fPPP8/VMxQSWAehgeESQGpra1FWVtbj4hBCoKysDHV1db3UMqK+wzoIDQyXANLU1HRNX9/Y2KhTS4j8h3UQGhguASQmJuayH6uCXIZc1cXXW61WvZtE1OeuVAe/B+sgGDBcAkhiYiIyMzOhXFy/780B4MTFtx0pioLMzEwkJCT0dhOJel1XddAO4PjFtx2xDgILwyWAKIqC5cuXX9XXrlixotNiJAo2rIPQwB36AcZutyMtLQ2tra3dWoZpMplgsVhQWVmJuLi43m8gUR9gHQQ/9lwCTFxcHPLz86EoCkymrv/3mEwmKIqCgoICFhSFFNZB8GO4BKDs7Gxs27YNFosFiqJc0s3XnlksFmzfvh2zZ8/2U0uJeg/rILgxXAJUdnY2KisrkZubi4yMDJ+PZWRkIDc3F6dOnWJBUUhjHQQvzrkEAe0ei8bGRlitViQkJHDSkgyHdRBcGC5ERKQ7DosREZHuGC5ERKQ7hgsREemO4UJERLpjuBARke4YLkREpDuGCxER6Y7hQkREumO4EBGR7hguRESkO4YLERHpjuFCRES6Y7gQEZHuGC5ERKS7/w9YAKJDKoqiQgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "4b98a727",
"metadata": {},
"source": [
"We will see that ${\\rm tanh}$ is at the top of the suggestion list (${\\rm sigmoid}$ is equivalent to tanh given input/ouput affine transformations)! So we can set it to ${\\rm tanh}$, retrain the model to machine precision, plot it and finally get the symbolic formula."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "99ad38b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"function , r2\n",
"tanh , 0.9999837308133379\n",
"sigmoid , 0.9999837287987492\n",
"arctan , 0.9995498634842791\n",
"sin , 0.996256989539414\n",
"gaussian , 0.9938095927784649\n"
]
},
{
"data": {
"text/plain": [
"('tanh',\n",
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
" 0.9999837308133379)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(1,0,0)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "af24c80d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999837308133379\n"
]
},
{
"data": {
"text/plain": [
"tensor(1.0000, grad_fn=<SelectBackward0>)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(1,0,0,'tanh')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "01936f17",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.69e-11 | test loss: 5.76e-12 | reg: 2.69e+00 : 100%|██| 20/20 [00:00<00:00, 21.70it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "76bcc188",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkJUlEQVR4nO3de3Bc91338c85u7rsRRdblp3Y8kWSlQskLWE6belQ2hmSKMVpwlMXpjTDlOl/DMRMh1Jm+IMC/zDDrXZK4S+mYwPNwGBzaZ0iEtp52j60TCE8JQ+EiSJbtiXZliV55ZV3dT2/54+fTvbsaiVL8pF295z3a2ZH2SPJ+TnRV5/zux7HGGMEAECI3Fo3AAAQPYQLACB0hAsAIHSECwAgdIQLACB0hAsAIHSECwAgdIQLACB0hAsAIHSECwAgdIQLACB0hAsAIHSECwAgdIQLACB0hAsAIHTJWjcAaATGGE1PT2tubk7ZbFZdXV1yHKfWzQLqFj0XYAO5XE5nzpzRwMCAuru71dvbq+7ubg0MDOjMmTPK5XK1biJQlxyeRAlUNzQ0pJMnT6pQKEiyvRef32tJp9M6f/68BgcHa9JGoF4RLkAVQ0NDOnHihIwx8jxv3a9zXVeO4+jixYsEDBBAuAAVcrmcenp6VCwWNwwWn+u6SqVSGhsbU2dn5843EGgAzLkAFc6ePatCobCpYJEkz/NUKBR07ty5HW4Z0DjouQABxhgNDAzo0qVL2kppOI6jvr4+DQ8Ps4oMEOEClJmamlJ3d3fVz7VKGpA0LGl+g+/v6uraodYBjYNhMSBgbm5u3c8NSHpt9eN68vl82E0CGhLhAgRks9n7+v62traQWgI0NsIFCOjq6lJ/f/+W500cx1F/f7/27t27Qy0DGgvhAgQ4jqMXX3xxW9976tQpJvOBVUzoAxXW2+fyuOycy5OS3gh8PftcgLXouQAVOjs7df78eTmOI9fduET8HfoXLlwgWIAAwgWoYnBwUBcvXlQqlZLjOGuGu/xrqVRKr7zyip5++ukatRSoT4QLsI7BwUGNjY3p9OnT6uvrK/tcX1+fTp8+rfHxcYIFqII5F2ATjDGa/fa31faxjyl/4YI6PvhBJu+BDfCwMGATHMexcyquaz8SLMCGGBYDAISOcAEAhI5wAQCEjnABAISOcAEAhI5wAQCEjnABAISOcAEAhI5wAQCEjnABAISOcAEAhI5wAQCEjnABAISOI/eBzVpaknI5qbNTamqqdWuAuka4AJtljLS8LCWTHLkP3APhAgAIHXMuAIDQ8SRK1I2VlRUN/9u/aTmfr3VTGt7Bd79be7u7a90MxBjDYqgbxWJR/+e557Q4Py83k6l1cxpW86VL2v/SS3rsmWdq3RTEGD0X1A9jtOy6OvI7v6OHPvCBWremvnmefa2s2EUGS0uSMVrKZvXtT35S3DGi1ggX1J1EIqHmlpZaN6N2/MGExUVpdla6eVMaG7OviQnp1i27JHp2VsrnpWLRfu2RI9LZs3JYyYY6QLgA9cAY2/uYmJBef1363vekN96wgXLnjrSwUAqd9TQ12T8DqAOEC1ArxtjX5KT06qvSP/yD9J//aXsk25kKXVwkXFA3CBegFoyxvZJz56Tz522P5V6BkkhIzc1SKiWl0/aVyUjZrP3nBx7g5ADUDcIF2E3GSIWC9PLL0p/+qQ2YalpbpQMHpIEB6eGHpf5+6fBhqbtb6uiwAdPcbMPEde3Lcei5oG4QLsBuMUZ6+23pt35L+uY37UqvoExGeuIJ6amnpPe/X+rttb0Sd3WvMxP1aCCEC7AbPE/6znekX/s1aXS0/HP79kk//dPSJz4hPfSQ7ZEQJGhwhAuw0zxPunhR+vVfl6anS9dbW22o/PIv22Evl9OYEB2EC7CTjJG+/nXps5+1e1N8hw9Lv/mb0kc+winLiCTCBdgpxkjf/a7tsQSD5Ud+RPrCF6RHHyVUEFmEC7ATjLFzK5/7nN1R73vve6Uvfcn2XAgWRBiDvMBOuHvXDnsND5euPf649Md/TLAgFggXIGyeJ335y9I//3Pp2sGD0h/9kT3/i2BBDBAuQJiMsUe4/MmflPaxpFLS5z9vey4EC2KCcAHCVChIv/d70syMfe840gsvSCdOECyIFcIFCIsxdj/Lt75VuvbII9KpU3a5MRAjhAsQlqkpO2Hvn+/V0mJ35O/fT68FsUO4AGEwRvqrv5Leeqt07amnpCefJFgQS4QLEIbr16WzZ0vH5nd2Si++aM8JA2KIcAHulzHSX/+1dO1a6drzz7M6DLFGuAD36+ZN6StfKfVaurqkT3+agygRa/z0A/fDGPt44qtXS9eee84+5IteC2KMcAHux+xsea+ls1P6+Z+n14LYowKA7TJG+sY31q4Qe/hhei2IPcIF2K6FBenll8uPeXnhBSmRqG27gDpAuADb4Z8h9v3vl6695z3SE0/QawFEuADb4y8/Lhbt+0RC+uQn7a58AIQLsC0TE9Krr5be9/VJH/oQvRZgFeECbJUxNlhu3ixde+45ae/e2rUJqDOEC7BVxaJ04UJp+XFHh/TRj9JrAQIIF2ArjJHeeMO+fD/2Y9Lx47VrE1CHCBdgK4yR/v7vyyfyT57keS1ABcIF2IqpKemf/qn0/uhR6QMfYEgMqEC4AJtljPSd70jj46VrzzxjD6oEUIZwATZreVn6u7+TPM++T6elZ5+taZOAekW4AJt1+bL0r/9aev+ud0k/9EMMiQFVEC7AZhgjDQ1JuZx97zh2b0tra02bBdQrwgXYjEJB+trXSu+7uqQnn6TXAqyDcAHuxRjpBz+Q/ud/Std+/MelQ4dq1yagzhEuwL0YI331q9L8vH2fTErPP8/R+sAGCBfgXqanpddeK70/elR63/sYEgM2QLgAG6m2t+XppzmkErgHwgXYyPKyPe4l+LTJj360tm0CGgDhAmxkdFT63vdK7x97THr0UYbEgHsgXID1GCN9/evS7dv2vePYifxUqrbtAhoA4QKs5+7dtXtbnnqKXguwCYQLUI0x0uuvS2++Wbr2wQ9KPT21axPQQAgXoBrPk86flxYX7fumJuljH2NvC7BJhAtQzdiY9I1vlN4fP87eFmALCBegkjHSP/6jdOtW6drzz0vt7bVrE9BgCBeg0tyc9Dd/Y0NGshsmn32WXguwBYQLEGSM9C//Uj6R/xM/IfX21q5NQAMiXICgpSXpK1+xHyWpuVn6xCeYyAe2iHABfMZI//Vf9iwx37veJb33vQyJAVtEuAA+z5P+8i/tnIskua70cz8npdO1bRfQgAgXQLK9lrffll55pXStt1caHKTXAmwD4QJINlz+4i/ss1skGyg/+7PSvn21bRfQoAgXwO+1XLhQunbokPTxj9NrAbaJcAE8T/qzP5Ompux7x7ErxA4erG27gAZGuCDejJF+8APpb/+2dK2nx07ku5QHsF1UD+Jtfl46fVqanbXvHUf6hV+ww2IAto1wQXwZY1eHffObpWuPPGKHxJhrAe4L4YJ4Mka6dk36gz8oHavf3CydOmUfCgbgvhAuiKf5eel3f1e6dKl07amnpI98hF4LEALCBfHjedKf/7n01a+Wrj34oPS5z0mtrbVrFxAhhAvixRjp1Vel3//98sMpP/tZ6eGH6bUAISFcEB/GSN/9ru2h3LljrzmO3SzJhkkgVIQL4sHzpG99S/qlX5Ju3Chdf//7pd/4DamlpXZtAyKIcEG0GWNXg738svSLvyhNTJQ+9+ij0h/+oT0/jF4LEKpkrRsA7Bh/ufEXviCdPy8tLJQ+99BD0pe+JPX1ESzADiBcEC3G2NfkpA2UL3/ZBkzQE0/YXflM4AM7hnBB4/MD5e5d6c03pa99ze68Hxuz133JpPRTPyX99m/bpccEC7BjCBc0hmBIrKzYIa47d6TxcRso3/++9Prr0tWr5cNfvgMHpBdflF54QUqlCBZghxEuqE/j49LwsJTP20Mlczn7IK/JSenmTfuanraf949vqaa9XXr2WbtKrL+fUAF2CeGC+nTxovT5z5eGvLbCde2zWJ55xh6d/8gjUiJBsAC7iHBBfcpk7N6UzXAcO9TV0yO95z3ST/6k9L732QMoHYdQAWqAcEF9amsrhYLj2J5HMmlDJJu1wXH4sHT8uN2v8tBDNlzSaQIFqAOEC+rTu99tj8PPZu28SXt76Z8zGXvAZFMTQQLUKcIFdcX1PI39+79ruVCwO+clO2E/NVV6xj3WtbyyopXbt0XcotYIF9QN13VlHn5Yeu01Tbz2Wq2b07Ca29qU7OiodTMQc44xW12KA+wMY4yWl5fFj+T9SyaTcl2ODkTtEC4AgNBxawMACB3hAgAIHeECAAgd4QIACB3hAmyWMdLS0tbPOgNiiHABNuuNN6RDh+xHABsiXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1yATTDG6Pbt21rxPN2+fVuGZ7oAGyJcgA3kcjmdOXNGAwMD+tCHP6zp6Wl96MMf1sDAgM6cOaNcLlfrJgJ1yTHcggFVDQ0N6eTJkyoUCpKkx4zRa5KelPT/HEeSlE6ndf78eQ0ODtauoUAdoucCVDE0NKQTJ06oWCzKGLNmGMy/ViwWdeLECQ0NDdWopUB9oucCVMjlcurp6VGxWJTnee9cf1x6p+cSfNCx67pKpVIaGxtTZ2fn7jYWqFP0XIAKZ8+eVaFQKAuWjXiep0KhoHPnzu1wy4DGQbgAAcYYffGLX9zW97700kusIgNWES5AwPT0tEZGRspCIiXpgKQPSEpKapeUqPg+Y4xGRkY0MzOza20F6lmy1g0A6slcPq+MbID4r6QkI6lJkiOpb/X9XUl3JM2uflyWlM/n1dXVVYOWA/WFcEG8eZ40NyfduSPNzqpzbEzvkuRJyku6LhsceUn/V9L/lnRVUrOkDkl7JD24+kcVJXXcuiWl01JHh9TSsst/GaB+EC6Il5UVGyT5vDQ7a4PFGCmZlNra1PHDP6y7vb164/JlVc6ezKt8ldit1Y9NskHz2JEj6kwkpOFh+4nmZhsy7e32lU7v9N8OqBuEC6JtacmGif+6e9deb2qyv/j37bMfV3/xO5I+/Su/os985jM2dDbzr5A07Tj6X7/6q3J+9Eel5eXSv292Vpqasn9WU1MpaNrbpUxGWt2MCUQN+1wQLQsL5WFSLNrrra3lv9hbW9f9I9bb57Kee+5zWVmxPSW/Tfm8HY5LJKS2tlLvJpuVXNbYIBoIFzS2YrE8TBYW7PV0ujxMmpu39Mf6O/SNMRsGjOu6chxHr7zyip5++unN/eHBeR7/tbJigyWbLYVNW5sNIKABES5oHMZIhUL5L+WlJfu5bLYUJG1tdgjqPlWeLRYsFSdwttiFCxc2HyzV+H+v2dnyv5fj2KGz9vZS4CQZyUZjIFxQv4zZ+A4/GCY7dIefy+V07tw5vfTSSxoZGXnnen9/v06dOqVPfepT6ujoCP9fXCyWh021HllHx5Z7ZMBuIVxQP1ZWysOkcm7C/6Vag7kJY4xmZmaUz+fV1tamvXv3vtN72RX+XJIfONXmkjo6NpxLAnYT4YLa8VdV+ZPdwWXBrKramL8Kzg8bfxVcc3N52LD8GTVCuGD3LC6WD3GtzmWU/UJkP8j2LC+XQrpy/07wv202S1BjVxAu2Dnz8+VhMj9vr7e22rtqf6iLoZzweV758uc7d+w11y0Pm7Y2lj9jRxAuCE/lSq7FRXvdPw7F/2XGJPTuq7Y4YnnZ9mKCy5/b21n+jFAQLtgeY+w4/3q/rIJ3xiyfrU+Vy5/9m4HK5c8hLOtG/BAu2JzKjX/5fGlZcOVKLu58G9P8fHnY+MOYqVR52HAgJzaBcEF11Y4sMcYGBxPE8bC4WB42/gKMlpbyFWmpVG3bibpEuMBaWioPk7k5e73ysMV0mjCJK3/peHD5MwdyYh2ES1wtLJSOnc/nq9+VtrdzV4r1+b1bP2zm5kqbXit7t6xIix3CJS7WO+DRH0/3X4ynY7sqHry27rwcB3LGAuESRRsd8OivBPJfrATCTqlcUTg7W1pRmMmUL39mRWHkEC5RsN4Bj5V7GLhjRK0Fb3pmZ6vvhdrGIxJQfwiXRlS5+7qODngEtmRhoXxFWuWBnH7gcIpDwyFcGsXt2+UruTjgEVFUef5ctQM5u7ro2TQAwqVRDA/b8epMpvTibg5Rt7xsh9Lu3rWvYlE6etSGDOoa4dIolpeZ9AT8uUSGe+se4QIACB3xDwAIHeMsq4wxKubzMisrtW5Kw2vOZNTEhGtDog7CE/c6IFxWeZ6nhdFReZ4nh70g22OMnPl56fhxNe3dW+vWYBs8z9PC5cvUwX1yikVpYCDWdUC4BBhJrceOKcVKlM3xV/JkMlIiIWOMZt98s9atwn16pw46OmrdlMZQrQ7++78V98lswqWC4zhy77ESJbgGwonrvhJj7D6Et9+255F1dso7cKDWrUJIqINN8utgeLhUBw88UOtW1QXCZZP8QirOL+husShvxVNra7My6bQSrhvP4rp92xbX/Lw0OWk3tyHSyuqgUJDneWptaVEmE+M6mJkpr4N9+2rdorpAuGySMUa3pm/r9uysvMAdW2tLiw4e6FZLc3O8CsvfSe1Lpez5UIg0Wwczmpm9I8/zVq86SrU26+D+/WppoQ6oA4ulyJtgjFHuTl4zuZxc19W+PZ16oHuf0qmU5hcWdH3yVqDQYsAfCvAPHZSkPXvY2BZxfh1M356V6zjat3ePHtzfrUy6VcX5RU3EsQ6Ch29K1EEAPZdN8FaLSo6jB/d3K5uxdyYdbVldm7ihwvy8CsV5ZTPpeNy1GSPdulV677q2qBBpnjG6PXtHjiMdPLC/rA6uTlxXsWiHjNsymfjUwdRU6b3rSjFeHVaJiN0Eb8XT0vKympIJpVOtchznnQnPzGoXeCF49yJ7l7e0tKS5uTnl83lF5iAE/xkd+XzpWjbLUEAMeCuelpeXlUwm19RBNp2WkdHi4lLZ90S6DgoF6mAD9Fw2xdgfpipZ7LjvfEWZxcVFDQ8Pa2VlRcYYPfbYY0pG5WywW7fsEf++7m573lNUfnFgHUZG1VeG+dcqw2NxcVFvvfXWO3Xw+OOPR6cOJiftWWc+6qAMPZd7MUaSUcL1lHA9ySxL3pJ9mWW5WlHS9eSqfKy5qalJDz30kPr6+qIzRGCMPZV2erp0rbXVDolF5e+I6lbrIOl6Sjgr69eBs7YOHn74YfX391MHMRORW4gdZJaVXLykY3vsihC3OF326U7XqL3Lk+PmpWVJyb3S6lBBc3OzlpaWqvyhDezGDbtpzNfdzWnNcfBOHcxKktzCtBT4PdqZWK2DxJy07Kypg8WKYeOGd/Nm6dHhkq0DHhleht8K97JyR87KrJKu7eoaz9OKHBnjKOF4ch3JdSRpUVq8KSX3qKzqosJ/lHJwArOlpTQUgGhbuSNnOaeks1oHxtOKt1oHridXq4ukzIK0eCPadXD3bvmCluZm6qAKwuVejKfgjMqSl9DEbJsWV1x1ZwrqTM0HfqYivAzT86SxsfIx5gMHeCJgXFTWwUpC47NtWvIS2pcpaE+qWIqSKM85GGPrINh7P3DA3mihDHMuW3R3sUmFpSYtewnliq0yUbw7q+QvuZydLV1Lp6X9+7lbiym/DpZW3HjVwfS0PZnCl0rZcKEO1iBctqg1uayka+/i0s1LctY5ns4Yo4WFBS0uLsoYo8XFRS0tLTXeUkx/8nJsrHRH6jhSTw9zLTHW2mTrwJGUaV6MRx0sLEjXrq2tA+ZaqiJctqg1uaxsy4ISjtGe9Py6X+d5nsbHx3Xjxg0lk0lduXJFk5OTu9jSkKysSFeulO9C7upiZUzMtSaX1dayKNf1tCe1cR2MjY3pxo0bampq0pUrV3Tz5s1dbGlIPM/WwXzg77pnj900SR1Uxa3nNqx4rowceZ4jZ51HXriuq97e3t1tWNiMkSYmpFyudK2lxd6tUVCxt+I5knHkGWfdQTHXddXX17er7QqdMdL16+VLj5ubpSNHOOplA4TLvVT5Jbo/e1crXlFNycqn9ZW+tuHX9PvzLNevl665rnT0qF3T3+h/P2xNlf/f3W0FdXlFNSUq6iDwpZGog5kZaXy8dM1xbLCkUtTBBgiXe3GzkpuSPNsddhypJemptDLM/+Fyo7P80j+YcnS0fCf+Aw8wHBZXblZKpCWvKMn+lLeW3VwF62CvIlMHc3PS5ctrV0nu20cd3APhci9ui5R+VPIWtfaQlwAnIbmtu9asHeOv4x8ZKV9uuWePdOgQBRVX79TBwsZfF6U6KBbtw/CC840dHdLhw9TBJhAu9+I4ktMsuTHYz+Efxjc8bFfG+DIZqbdXSiQoqriKWx3Mz0tvvWUDxpdKSX19dpUkdXBPhAssv8cyPFy+Iqa1VervtxOYFBSizu+xDA/bGy1fc7N0/DjzjVtAuMAWVD5vhwCCPRa/oNJpCgrRF+y5B4OlqcnWQTZLHWwB4RJ3xtgdx5culR/E19wsDQxQUIgHfxFL5Q2WHywdHdTBFhEucWVWn1EzOSldvVq+GqalxRZUWxsFhejzj3W5fLn8BssPls5O6mAbCJc4MsaGydiYPUI/eBRHKmULKpOhoBBtxtil9hMTdh9LcNm9f4PV3k4dbBPhEjf+GUmjo+UH8Em2p9Lfz6Qlos8Y20sZHS1/jIRk5xgHBphrvE+ES5wYY082vny5fEWYZM8LO3bMDgVQUIgyfwHL5ct2hWRQZ6ddbtzSQh3cJ8IlDvxhsOvX7Ss4v+K60oMP2g2SrktBIbr8YbDJSXu6cXCTsOPYnfdHjrCfKySES9T56/ZHR8ufxyLZFWFHj9peC8WEKPOHg69cKT+AUrKbIo8cKT2fiFoIBeESVf5d2tSUvUsLroKR7PxKby/jyog2f1XkzIxdFVk5HJxO22EwVkaGjnCJIv/4iqtX7aR9cDWY69ru/6FDHGOBaDPGngt27Zp95n2wDhzHPvf+yBHmGXcI4RIlwd7K2Fj5gXuSXQV25EjpZGMKClF0r95Kc7M9fLK7mzrYQYRLVPhHV1y7tnaJsePYeZXDh1kFg2gLPo54aqq8tyLZG6ujR3kWyy4gXBqdvxJsctJuBqucW/Hv0vznT1BQiKLKXvtCxaMBmprsUPCBA6yK3CWESyPzz0O6ds2u2w9yHPt878OH2RSJaPNP9L56tfyR3L7OTttbYfHKriJcGpE/UTkxYXsswWMrJBsmhw/bcKG3gqgyxu5VuX7dHmMU3Lci2V57T4+dW6G3susIl0bid/1nZmzXv3Ki0nVtIR06xPNXEF3+hH0uZ3srwePxpfI5RnrtNUO4NIrghH0ut3aiMpOxK8E6Oux7CgpR5C+zv3bNboasrINUil57nSBc6p3f9b9xo3rXP5m0x7ccOMC+FURXcOHK+PjahSuJhK2BgwfZt1InCJd6Fez6X7tWvevf2Wnv0pioRFT5PZM7d+wQWOXCFcn21o8c4cF2dYZwqUd+139srHrXv7XVTlT6Z4JRUIgif+HK+Hj1hSstLXZ+kQn7ukS41BN/wv7WLVtQlTvsEwl7uB5df0RZcOHKtWvVF67s22dvsNgUXLcIl3phjDQ3Z4up8vRiyT4R7/Bhe8CeREEhmipPmqjstWeztg78Z9pTB3WLcKm14IR95bNWJLukmK4/os6fsL95s/pJE01NduHKAw/wvJUGQbjUkr/D/upV22sJchzb9T90iLX6iDa/137liq2HIMex54GxcKXhEC614PdWJiZsj6VyojKdtsXU2UnXH9FV+YTUymX2/kkTLFxpSITLbvOf333lytreiv+sFSbsEXUb9Vb8kyZ6ejhpooERLrvFv0u7ccP2WCrnVrJZu1a/vZ1iQnT5K8Fu3rRL7St7K+m0rQN67Q2PcNkNGz3HPpGwk5QPPsgOe0Sb/6yV0VG7zDgo+IRUeu2RQLjsNP+JeKOja/etZDL2KHB6K4g6Y+zS4tHRtftWUinp2LFSbwWRQLjsFH8YbGLCTlYGJ+1d126G7Omht4Jo84fBJibWDgf7z7E/fJi5lQgiXHaCf2zF5ctrHznc0mJ7KzzHHlFnjN2vcvmyPcYoqLnZzq3whNTIIlzC5j8Vb2Rk7WGTnZ22+8++FUSdv9N+ZGTtqsj2dqm3l30rEUe4hMk/xfjSpfL5Fde1k/aHDrG7GNFnjF24MjJS/ix7x7F1wHBwLBAuYTHGHjg5Olo+rpxM2t5KV5cNGSDKjLFDYJculS8zTibtMNj+/QyDxQThEgZj7P6Vq1fLJ+5bW6X+fnvYJMWEqDPGHo1feYPV2ir19ZUOm0QsEC73yxi7GuzatfJgaWuzwcL8CuLAGLsxcnS0vA6yWen4cbvcmDqIFcLlfvgFVRksnZ32To3llYiDYI8lWAcdHfYGi2euxBLhsl3GSFNT9mykYEHt3WuDhQlLxIE/x1IZLHv22GBht31sES7b4a+GqSwoggVx4j8y4tKl8jmWPXvsUBh1EGssX9oqf/1+5WoYfyiMgkIc+OfljYyU14E/FEYdxB7hshX+c1guXy5fv5/NEiyID78OLl0qPycsk2EoDO8gXLbCGDt5n8+XrrW02IJi8h5xYYw9Lj/4HJaWFjsUxuQ9VhEum+Wf6jo5WbqWSNhjLFhmibjw6+DmzdI1vw44zgUBhMtm+AdRXr1q/9l36BDHhCM+/DqoXCF56FDpIFZgFeGyWePj5ePLe/bYc5IoKMRJZR10dtoH3VEHqEC43Iu/3PLWrdK1piZ7ThJnhSEujLFzjZV1cPQodYCq+Km4F8+zd2vBYYCDB5lnQbx4np3ED+5nOXiQeRasi3DZiL9ZMrgqJpstnewKxIFfB7OzpWuZDHWADREuG/E8eyilP4nvOPZuLZGobbuA3VStDg4dsvu6gHUQLuvxx5iDe1ra2lgdhnjx6yDYe29rY3UY7olw2cjkZPnd2gMPMHmJ+AnWgUQdYFP4CVnP/Hz5GHM6zcOOED/z8/bR3b5Mht47NoVwqcYYaWam/EC+ffuYa0G8UAe4D4RLNZ5ni8qXTNrj9LlbQ5xUq4OuLuoAm0K4VFMs2mP1fe3t9kA+IE6KRenu3dJ76gBbQLhUMzu79iFgQNzkctQBto1wqWSMfVaLvxqmqckuvWQoAHFTWQft7dQBNo1dUJUcRzp2TDpwwPZglpcZCkA89fbaZce5nD32hTrAFhAuAY6khXxexj8/qbXVfrx9u2ZtaiTGGJngyiI0rPl8Xp4/JJZK2Y/Uwab4dRD3Ph7hsspxHJlUSs7t21qgiLbNTSTkcCxIw3Icx+7pyuW0GNzfgi1xk0k5MV+y7RgT3HobX8YY8Z8iHI7j2F9SaDjUQXjiXgeECwAgdKwWAwCEjnABAISOcAEAhI5wAQCEjnBpFJ5nz3oKHscBxI3n2XP/qIO6R7g0imJR+o//sB+BuCoUbB0ED5ZFXSJcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXAAAoSNcAAChI1wAAKEjXBqAMUZTU1O6fv26pqamZIypdZOAXUcdNBbCpY7lcjmdOXNGAwMDOnrsmE5+/OM6euyYBgYGdObMGeVyuVo3EdhxwTo41turj//Mz+hYby91UOccQ/zXpaGhIZ08eVKF1YcipYzRuyX9QFLRcSRJ6XRa58+f1+DgYO0aCuygyjpIB+qgQB3UNXoudWhoaEgnTpxQsViUMWZN99+/ViwWdeLECQ0NDdWopcDOoQ4aGz2XOpPL5dTT06NisSgv8JzwtFS6Ywt8veu6SqVSGhsbU2dn5+42Ftgh69VBRqU6uBv4euqg/tBzqTNnz55VoVAoK6iNeJ6nQqGgc+fO7XDLgN1DHTQ+ei51xBijgYEBXbp0ac0QwHo9F0lyHEd9fX0aHh6WszoODTSqjepgvZ6LRB3UG3oudWR6elojIyNbXmJpjNHIyIhmZmZ2qGXA7qEOooFwqSNzc3P39f35fD6klgC1Qx1EA+FSR7LZ7LqfK2p1GfIG39/W1hZ2k4Bdt1EdFFR9aDiIOqgPhEsd6erqUn9/f9XxYiNbUNUGChzHUX9/v/bu3bvTTQR23L3q4K6og0ZAuNQRx3H04osvbut7T506xSQmIoE6iAZWi9WZ9db3r4f1/Ygi6qDx0XOpM52dnTp//rwcx5Hrbvy/x3VdOY6jCxcuUFCIFOqg8REudWhwcFAXL15UKpWS4zhruvn+tVQqpVdeeUVPP/10jVoK7BzqoLERLnVqcHBQY2NjOn36tPr6+so+19fXp9OnT2t8fJyCQqRRB42LOZcGYIzRzMyM8vm82tratHfvXiYtETvUQWMhXAAAoWNYDAAQOsIFABA6wgUAEDrCBQAQOsIFABA6wgUAEDrCBQAQOsIFABA6wgUAEDrCBQAQOsIFABA6wgUAEDrCBQAQOsIFABC6/w8iXfBQjVosqQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b62b0246",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 \\tanh{\\left(1.0 \\operatorname{atanh}{\\left(1.0 x_{1} \\right)} + 1.0 \\operatorname{atanh}{\\left(1.0 x_{2} \\right)} \\right)}$"
],
"text/plain": [
"1.0*tanh(1.0*atanh(1.0*x_1) + 1.0*atanh(1.0*x_2))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.symbolic_formula()[0][0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}