GitHub_collection_pykan/tutorials/physics_2A_conservation_law.ipynb
2024-08-11 13:15:16 -04:00

167 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Physics 2A: Conservation Laws"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fd0d2987",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"from kan.utils import batch_jacobian, create_dataset_from_data\n",
"import numpy as np\n",
"\n",
"model = KAN(width=[2,1], seed=42)\n",
"\n",
"# the model learns the Hamiltonian H = 1/2 * (x**2 + p**2)\n",
"x = torch.rand(1000,2) * 2 - 1\n",
"flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1)\n",
"\n",
"def pred_fn(model, x):\n",
" grad = batch_jacobian(model, x, create_graph=True)\n",
" grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True)\n",
" return grad_normalized\n",
"\n",
"loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2)\n",
"\n",
"\n",
"dataset = create_dataset_from_data(x, flow)\n",
"model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "60c88d7f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc30lEQVR4nO3de3BU9fkG8Ocsu0l2CeRGMCYBZEMIIIEQLrlwN0IoKUiLQJUKVnAcLzC1HXuJrSOOZdBpLSgFpNqKMgW1IKhBqYFw29wlCQQUEoFAAkmAzZVNsrfv7w/c/SUICubsnt3N85nJMAJ7zovw5jnfyzlHEkIIEBERyUildAFEROR7GC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7tdIFEHkDIQSuXr2K1tZWBAYGIiwsDJIkKV0WkcfiyIXoezQ2NmLdunWIjY1FeHg4Bg8ejPDwcMTGxmLdunVobGxUukQijyTxZWFEN7d3717Mnz8fJpMJwPXRi4Nj1KLT6bBjxw6kp6crUiORp2K4EN3E3r17kZGRASEE7Hb7LX+fSqWCJEnIyspiwBB1wnAhukFjYyOio6PR1tb2vcHioFKpoNVqUV1djeDgYNcXSOQFuOZCdIMtW7bAZDLdVrAAgN1uh8lkwrvvvuviyoi8B0cuRJ0IIRAbG4szZ87gTlpDkiTo9XpUVFRwFxkRGC5EXVy5cgXh4eHd+nxYWJiMFRF5J06LEXXS2trarc+3tLTIVAmRd2O4EHUSGBjYrc/36dNHpkqIvBvDhaiTsLAwxMTE3PG6iSRJiImJQWhoqIsqI/IuDBeiTiRJwooVK37UZ1euXMnFfKJvcUGf6Aa8z4Wo+zhyIbpBcHAwduzYAUmSoFJ9f4s47tDfuXMng4WoE4YL0U2kp6cjKysLWq0WkiR9Z7rL8XNarRZ79uzBzJkzFaqUyDMxXIhuIT09HdXV1Vi7di30en2XX9Pr9Vi7di1qamoYLEQ3wTUXotsghEBOTg7S0tKwb98+TJ8+nYv3RN+DIxei2yBJknNNJTg4mMFC9AMYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGC9EPMJvNOHPmDL788ksAQHl5OS5fvgy+rYLo1vg+F6JbaG9vx+7du7Fp0yaUlZXBYrFAkiQIIRAcHIyZM2di5cqVGDVqFB/BT3QDjlyIbqK+vh6PPvooli1bBp1OhzfeeAN5eXkoLS3FoUOH8Pzzz+Orr75Ceno6Nm3aBKvVqnTJRB5FrXQBRJ6msbERS5cuxdGjR/Hmm2/iwQcfhNVqRWZmJoxGI4YOHYrMzEwsXrwYGzZswJ/+9CdYrVY888wzHMEQfYvTYkSdCCGQmZmJjRs3YuvWrcjIyIAkSTAajRg9ejSqq6sxadIk7N+/HxqNBjabDevXr8df/vIXfPLJJ0hKSlL6j0DkETgtRtRJZWUl3nrrLSxfvhw/+clPfnAk0qtXLzzxxBNITk7Gq6++CpvN5qZKiTwbw4Wok08//RRWqxWPP/44VCoVhBC33BXm+DV/f388+eSTOHLkCM6fP+/miok8E9dciL4lhEBeXh6GDh0Km82G5557Dna7HcD1nWONjY0AgDNnzuC5556DSnX92iwiIgKLFy+GRqPBiRMnMHjwYKX+CEQeg+FC9C273Y66ujpERkaitrYWa9euvek018WLF7Fu3Trnf997771Yvnw5QkJCUFtb686SiTwWw4WoE5vNho6ODkiSBD8/P2e4CCFgsVgAAJIkQaPROD/j+H0Wi6XLzxP1ZNwtRj2W1WpFeXk5cnNznV9VVVUYMmQIcnJycOnSJed6S3NzMx566CHU19cjISEBmzZtQq9evQAAWq0Wfn5+GDduHAYPHozZs2cjJSUFqampCAsLU/KPSKQYjlyox2hpaUFRURFyc3NhMBhQWFiI1tZWaDQajB07FgsXLgQAvP766zh58iTS0tKcu8WMRiP8/PwAAIGBgUhMTHSOUoQQ+Mc//gF/f3/ExsZi+/bteO211wAAQ4cORUpKCiZOnIiUlBQMGTKE98JQj8CRC/msixcvwmAwIDc3F3l5eSgrK4PdbkdISIhzZJGSkoKxY8dCq9UCuH4D5ZQpU3DXXXdh586d6NOnDwDc8j4XIQRqamqQlpaGRYsW4aWXXoIQAhcuXIDBYEBeXh7y8vJQXl4OIQT69evnPG9qairGjBnjDC0iX8JwIZ9gs9lw8uTJLlNcjm3Ber0eqampzq+4uDjnTq+b2bVrFx555BEsW7YMq1evhlarRUNDw3fCRa1W4+rVq1i+fDmqqqqwd+9e9O/f/6bHbGxsRGFhIfLy8mAwGFBUVIS2tjYEBARg7NixztqSk5MREhLikv9HRO7EcCGvZDKZUFxc7ByZFBQUoKmpCWq1GmPGjHGODFJTU3HXXXfd0bFtNhvWrl2LF198EbNmzcILL7yAwYMH47333kNLSwuio6Pxs5/9DEVFRXj++edRX1+Pbdu2ITEx8bbPYbFYUFZW5hxV5ebmoq6uDgAwfPjwLmF4zz33cCqNvA7DhbxCXV1dl1FJaWkprFYrgoKCkJSUhIkTJyI1NRXjxo2DTqfr9vlsNhs+/PBDvPDCC7hy5QpSUlKQmJiI4OBg1NXVoaCgAMeOHcOkSZPw17/+FXFxcd06nxACZ8+edQZNbm4uvvrqKwDAXXfd5QyalJQUjB49mrvSyOMxXMjj2O12nDp1yrlmkZubizNnzgAABg4c6FwcT01NxYgRI5y7tlyhrq4Ou3fvxp49e/DNN9+go6MDISEhGDNmDObPn48pU6bA39/fJec2Go0oKChwhk1xcTE6Ojqg0+kwfvx4Z9gkJSUhKCjIJTUQ/VgMF1Jce3s7iouLnesR+fn5aGhogEqlwujRo527rVJTUxEZGalIjXa7HWazGTabDRqNRpFF+I6ODpSWljqn0gwGA65evQpJkjBy5MguU4EDBgzgVBopiuFCbnf58mXk5+c7RyZffvklLBYLAgMDkZSU5PwGOX78eOduLfouIQQqKyudQZOXl4fTp08DACIjI50jvJSUFMTHx0Ot5p0H5D4MF3IpIQQqKiq6XG1XVFQAAKKiopCamur8Jjhy5Eh+A+ymK1euOLc/GwwGHD161BncEyZMYHCT2zBcSFYdHR0oKSlxrhPk5+fj8uXLkCQJ8fHxznWCiRMnYsCAAUqX6/Pa29tx9OhR58gmPz8fRqOxy5Sj4+9DqSlH8k0MF+oWo9GI/Px858ikqKioy6KzY61kwoQJXHT2AI7NEp13pXXeLNF5V5qrN0uQb2O40G1zbJftvCXYsV02IiKiy70Zo0aN4nZZL1FXV+cMm7y8PJSUlHTZ5u3YKDBu3Dj07t1b6XLJSzBc6JYsFgtKS0ud33gMBgPq6+sBACNGjOiyJZg3+vkOxw2qjrDJz8933qCakJDQ5fE1d3qDKvUcDBdyampqQkFBgfOu986PKBk/frxzbj4pKYmPKOlB7HY7Tp482WUqraqqCsD1R+t03gL9Q4/WoZ6D4dJD3fhwRYPBgBMnTkAIgfDw8C5P8uXDFelGFy9edI5sHA8FtdlsCAkJQXJysjNsxo4di4CAAKXLJQUwXHoIq9WK48ePO4MkNzcXFy9eBADExcV12cUVExPDKS66I62trc7XGTie9eZ4nUFiYmKXjQL9+vVTulxyA4aLj2ppaUFhYaGz2R3vLvHz8/vOU3jZ7CQ3x4vYOt9zU1NTAwCIjY3tEjaxsbG8mPFBDBcfUVNT41x0z8vLw7Fjx2C32xEaGtplTjwxMZHTFKSICxcuOC928vLycPz4cQghEBYW1mXknJCQ4LLntZH7MFy8kOPdJZ3XSy5cuAAAiImJ6bKLa+jQoVxgJY/U1NTkfMeNY3RtMpng7++PsWPHOv8dJycnIzQ0VOly6Q4xXLzEoUOHYDAYYDAYUFBQgObmZqjVaiQmJnZZfL/Vy6qIPJ3FYsGxY8e67Eqrra0FcP0dN44LppkzZ/LfuRdguHgJs9kMIQRUKhUkSXL+yLlq8lVCCAghYLfbnV9CCPj5+fHJAV6A4eIlHH9NDBPqydgH3oOPoJVZW1sbPv30U5jNZqVLuW1CCIwbNw7Dhg1TuhTyEW1tbcjKyvK6Phg7diz7QCYMF5k1NTVh3759WLx4sWzHtNvtaGlpQXNzM/z8/BASEiLrTY2VlZU4ePAgm4pk4+o+0Gg0CAkJkXVXWWVlJQ4dOsQ+kAnDxQUiIiIwadKkbg3dhRCw2Ww4cuQI/vnPf6KwsBBNTU3QaDQYMmQIFi1ahF/84hcICQnp9hRBYGAgioqKunUMohtFRERg4sSJ3f73abPZUFBQgLffftv5llK1Wg29Xo+FCxfi4YcflqUPevfujeLi4m4dg/4f96h6ICEEjEYjnn32WcybNw///e9/ceXKFfTu3Rt2ux2FhYX4zW9+g9mzZyM/Px9cNiNfJIRAU1MT/vjHP2Lu3LnYvn076urqnE9m/vLLL/G73/0Os2bNwuHDh2G32xWumDpjuHgYIQQuXbqExYsXY/PmzfD398dvf/tbHD58GMXFxSgqKsLbb7+NUaNGoaSkBAsWLMAnn3zCgCGfIoTAlStXsGzZMqxfvx4qlQorVqzA/v37UVhYiIKCAvz73/9GQkICysvL8dBDD2Hnzp0MGA/CaTEP4hixLF++HDk5OYiNjcWbb76JlJSULjdCPvTQQ5gxYwYyMzOxdetWPPHEE9Bqtbj//vu5i4Z8QktLC55++mlkZWVh4MCB2LhxI6ZNm9alDxYsWID7778fL774Iv71r3/hmWeegVarxezZs9kHHoAjFw9iNpuRmZmJffv2ITY2Fh988AFSU1O/c4e9JEno168f1q1bh+XLl6OhoQFPP/00Tp8+zREMeT2LxYKXX34Zn376KQYMGID//Oc/mD59+k37IDQ0FK+++iqefPJJtLS0YOXKlc6ne5OyGC4eQgiB9957D1u3bkV4eDjeeustDB8+/JZXYJIkQafTYfXq1Zg9ezaqqqrw7LPP4tq1a26unEg+Qgjs2rULmzdvRt++fbFp0yYkJiZ+70gkICAAL774IubNm4eLFy/i17/+NZqbm91YNd0Mw8UDCCFQWVmJl19+GQCwatUqJCUl3dbQPjAwEK+99hr0ej1ycnKwefNmXrWRV3K8Y+jPf/4zLBYL/vCHP2DatGm31Qc6nQ6vvPIK4uLikJubiw0bNnD9RWEMFw9gsViwatUqXLp0CRkZGfjlL39523PGkiRh0KBBWL16NdRqNf7+97/j66+/ZsCQ17HZbHjllVdQVVWF++67D48//vhtP3RVkiRERUVh9erV8Pf3x/r161FeXs4+UBDDRWFCCOzfvx8ff/wx+vfvj5deeumOb5CUJAkZGRmYN28e6uvrsWbNGthsNhdVTCQ/IQQKCwvx/vvvIzg4GKtWrYJOp7ujY0iShJkzZ+LBBx+E0WjEmjVrYLFYXFQx/RCGi8JMJhPWrFkDs9mMp556CnFxcT9qp4tGo0FmZibCwsKwe/du5OXl8aqNvIbZbMYrr7yCa9euYenSpUhISPhRfdCrVy/8/ve/R//+/bFnzx4cPnyYfaAQhouChBDYs2cPCgsLERMTg2XLlv3oLZSSJCEuLg5Lly5FW1sb/va3v/GqjbyCEAIHDx7EgQMHEBkZiaeffvpHv4NIkiRnL3V0dOC1115DR0eHzBXT7WC4KMhkMuH111+H3W7HU089hfDw8G4dT6VS4YknnsDdd9+NnJwcjl7IK5jNZrzxxhswm81YtmwZoqOju3U8SZLw2GOPITo6GkeOHIHBYGAfKIDhohAhBPbt24ejR48iJiYGixYtkuXGr0GDBmHx4sVob2/Hhg0buPZCHk0Igfz8fBw+fBiRkZFYsmSJLH0QFRWFJUuWwGw2Y+PGjbBarTJUS3eC4aIQi8WCzZs3w2q14rHHHkNYWJgsx5UkCb/61a8QFhaG7Oxs53vKiTyRzWbD5s2b0dHRgYcffhhRUVGyHFeSJCxZsgT9+/dHTk4OSkpK2AduxnBRgBACpaWlOHLkCCIiIrBw4UJZH1eh1+vxwAMPoLW1FVu2bGFTkcc6ffo0srOzERwcjEceeUTWPhgwYAB+/vOfw2Qy4Z133mEfuBnDRQFCCLzzzjtoa2vDvHnzZLtac3CMXrRaLXbt2oVLly7JenwiOQghsG3bNjQ3NyM9PR0xMTGyHl+lUmHJkiXo3bs3srKyUFNTI+vx6fsxXBRw8eJFZGVlQafTYenSpT96Z8ytSJKEhIQEJCUloba2lk9NJo9kNBqxY8cO+Pn54dFHH5W9DwBg5MiRSE1NRX19PT766CP2gRsxXNxMCIGsrCzU1dUhJSUFI0eOdMl5NBoNlixZAgDYtm0bt2OSRxFC4MCBA6iqqkJ8fDwmTJjgkicZq9VqLFmyBCqVCu+//z7a2tpkPwfdHMPFzTo6OrB9+3YAwMMPPwyNRuOS80iShBkzZmDAgAEoLS1FWVmZS85D9GPYbDZs27YNdrsdCxYsgFardcl5JEnC9OnTMXDgQJSXl6O0tJSjFzdhuLjZiRMnUFJSgsjISMyYMcOl750IDw/H7Nmz0d7ejg8//JBNRR7j7NmzOHLkCEJDQzF37lyX9kFoaCjmzJkDs9mMDz74wGXnoa4YLm4khMBHH32EtrY2pKend/umyR8iSRIWLFgAPz8/7NmzB01NTS49H9HtcEwNNzU1YfLkyRg4cKBLzydJEubPnw9/f398/vnnaGhocOn56DqGixu1trbik08+gVqtxoMPPuiWt+WNGTMGw4YNw7lz55Cfn8/RCymuo6MDu3btgkqlwoIFC1yykH+j+Ph4jBgxAtXV1cjNzWUfuAHDxY1KSkrwzTffQK/XY/z48W4JF51Ohzlz5sBms3G3DHmEU6dO4fjx44iIiMCkSZPc0gdarRZz586FzWbDjh072AduwHBxE8cb9iwWC2bNmoU+ffq45bySJOGnP/0pAgICsH//fhiNRrecl+hmhBD47LPPYDKZMG3aNJdPDTs4XkvRu3dvHDp0CJcvX3bLeXsyhoubtLS04IsvvoBGo3H5AuaNhg8fjhEjRqCmpoZTY6Sojo4OZGVlQaVS4YEHHnBrHwwdOhT33nsvamtr+VBXN2C4uElZWRnOnj0LvV6PhIQEt547ICAAGRkZsNlsvKGSFFVRUYGTJ0/i7rvvRnJyslvDxc/PDxkZGbDb7ewDN2C4uIHjvS0WiwX3338/AgMD3Xp+SZIwa9YsBAQE4MCBA9w1RooQQiA7OxsmkwkTJ05Ev3793Hp+x5sqtVotDh06xCliF2O4uEFbWxuys7OhVqsxe/ZsRWoYPnw4hgwZgurqapSUlChSA/VsVqsVn332GSRJwpw5c9w6anGIi4tDXFwcLl26hKNHj7r9/D0Jw8UNTp8+jYqKCkRFRSExMVGRptLpdEhLS4PVasXnn3/OKQFyuwsXLuDYsWMICwtz+5SYQ0BAANLS0mCz2dgHLsZwcTEhBPbv34+2tjZMnDgRwcHBitThmBpTq9XOeojcRQgBg8GA5uZmJCYmIiIiQpE6JElCeno6NBoNDhw4gGvXrilSR0/AcHExq9WK//3vf85v7kpcrTkkJCQgKioKlZWVqKysVKwO6nmEEM6RQnp6Onr16qVYLfHx8YiOjsbZs2dx+vRpxerwdQwXF7t06RLKysoQEhKi2FSAQ3BwMJKTk9HW1oaDBw9ySoDcxmg0orCwEDqdDlOnTlW0D/r27YvU1FS0t7ezD1yI4eJCQggUFhaioaEB8fHxiIyMVLQex24ZSZLwxRdfwGazKVoP9Rzl5eWora3FkCFDoNfrFa2lcx9kZ2ezD1yE4eJi2dnZEEIgLS0NarVa0VokSUJycjL69OmD0tJS3qVMbiGEQE5ODqxWKyZPnoyAgABF65EkCRMmTEBQUBCOHTuGuro6RevxVQwXF7p27Rpyc3Ph5+en+FSAw4ABAzB8+HBcvnwZpaWlSpdDPYDZbMaBAwegUqlw3333eUQfREZGYvjw4TAajTh69CinxlyA4eJClZWVqKqqQlRUFIYNG6Z0OQCu36U8ZcoU2O127Nu3j01FLldTU4NTp06hX79+bn86xa1oNBpMnToVdrsdOTk5SpfjkxguLnTkyBG0t7djwoQJ6Nu3r9LlALg+JTBt2jT06tULBoOBrz8mlysuLkZzczPi4+PRv39/pcsB8P99oFarYTAY0N7ernRJPofh4iKdr4imT5/uEVMBDvHx8QgPD0dFRQWqq6uVLod8XE5ODoQQmDp1qqJbkG80YsQI9O/fH9988w3Onz+vdDk+h+HiIo2NjSgrK0Pv3r0V34J8o379+mHUqFFobW1FUVGR0uWQDzOZTCgsLIRGo3Hbu1tuV2hoKBISEnDt2jX2gQswXFzk66+/Rm1tLe655x4MGjRI6XK6UKlUmDp1KoQQOHjwoNLlkA87f/48zp07h4iICMTFxSldThcqlQpTpkwBAPaBCzBcXMRgMMBisSA5ORlarVbpcrqQJAmpqanQaDQoKCiAyWRSuiTyUUVFRTCZTEhISEBQUJDS5XTh2Jrv5+eH4uJi9oHMGC4uYLPZUFhYCEmSMHXqVKXLualhw4YhIiIC586d43wzuUTnkfHkyZOhUnnet5uhQ4fi7rvvRlVVFc6dO6d0OT5F2bv6fJRKpcLq1asxb94857Db07b8BgUFYe7cuWhvb0dAQABaWlqULol8jN1uR0pKCoxGo0f3wYIFC2C1WtG3b1+OXmTEcJGZSqXCyZMnsX37dgDAmTNnFK7o1gIDAxEUFIScnByMGTNG6XLIh6hUKpw6dQoajQajR4/G7t278fHHHytd1k2p1Wr4+fkhOzubfSAjSXjapYSXs9lsqK+vh91uV7qUOxIUFOT2N2SS72IfEMOFiIhk53krbERE5PUYLl5CCAG73e5xC6JE7sQ+8B4MFy9RWloKrVbLJxlTj1ZaWgqdTsc+8AIMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DxQsIIdDQ0AAAaGho4OPGqUdy9EHnH8lzMVw8WGNjI9atW4fY2FikpaXBbDYjLS0NsbGxWLduHRobG5Uukcjl2Afeia859lB79+7F/PnzYTKZAKDLVZokSQAAnU6HHTt2ID09XZEaiVyNfeC9GC4eaO/evcjIyHC+de9WVCoVJElCVlYWG4t8DvvAuzFcPExjYyOio6PR1tb2vQ3loFKpoNVqUV1djeDgYNcXSOQG7APvxzUXD7NlyxaYTKbbaigAsNvtMJlMePfdd11cGZH7sA+8H0cuHkQIgdjYWJw5c+aOdsJIkgS9Xo+KigrnPDSRt2If+AaGiwe5cuUKwsPDu/X5sLAwGSsicj/2gW/gtJgHaW1t7dbnW1paZKqESDnsA9/AcPEggYGB3fp8nz59ZKqESDnsA9/AcPEgYWFhiImJueP5YkmSEBMTg9DQUBdVRuQ+7APfwHDxIJIkYcWKFT/qsytXruQiJvkE9oFv4IK+h+H+fiL2gS/gyMXDBAcHY8eOHZAkCSrV9//1OO5M3rlzJxuKfAr7wPsxXDxQeno6srKyoNVqIUnSd4b5jp/TarXYs2cPZs6cqVClRK7DPvBuDBcPlZ6ejurqaqxduxZ6vb7Lr+n1eqxduxY1NTVsKPJp7APvxTUXLyCEgNFoREtLC/r06YPQ0FAuWlKPwz7wLgwXIiKSHafFiIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdv8HKMKmRdQUW2cAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 500x200 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e8cb9a2f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2\n",
"fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2\n",
"saving model version 0.2\n"
]
}
],
"source": [
"model.auto_symbolic()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1b143bf8",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - 1.191 x_{1}^{2} - 1.191 x_{2}^{2} + 2.329$"
],
"text/plain": [
"-1.191*x_1**2 - 1.191*x_2**2 + 2.329"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from kan.utils import ex_round\n",
"ex_round(model.symbolic_formula()[0][0], 3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "782f818f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}