{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "fd0d2987", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint directory created: ./model\n", "saving model version 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "| train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it" ] }, { "name": "stdout", "output_type": "stream", "text": [ "saving model version 0.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from kan import *\n", "from kan.utils import batch_jacobian, create_dataset_from_data\n", "import numpy as np\n", "\n", "model = KAN(width=[2,1], seed=42)\n", "\n", "# the model learns the Hamiltonian H = 1/2 * (x**2 + p**2)\n", "x = torch.rand(1000,2) * 2 - 1\n", "flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1)\n", "\n", "def pred_fn(model, x):\n", " grad = batch_jacobian(model, x, create_graph=True)\n", " grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True)\n", " return grad_normalized\n", "\n", "loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2)\n", "\n", "\n", "dataset = create_dataset_from_data(x, flow)\n", "model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn);" ] }, { "cell_type": "code", "execution_count": 3, "id": "60c88d7f", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc30lEQVR4nO3de3BU9fkG8Ocsu0l2CeRGMCYBZEMIIIEQLrlwN0IoKUiLQJUKVnAcLzC1HXuJrSOOZdBpLSgFpNqKMgW1IKhBqYFw29wlCQQUEoFAAkmAzZVNsrfv7w/c/SUICubsnt3N85nJMAJ7zovw5jnfyzlHEkIIEBERyUildAFEROR7GC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7tdIFEHkDIQSuXr2K1tZWBAYGIiwsDJIkKV0WkcfiyIXoezQ2NmLdunWIjY1FeHg4Bg8ejPDwcMTGxmLdunVobGxUukQijyTxZWFEN7d3717Mnz8fJpMJwPXRi4Nj1KLT6bBjxw6kp6crUiORp2K4EN3E3r17kZGRASEE7Hb7LX+fSqWCJEnIyspiwBB1wnAhukFjYyOio6PR1tb2vcHioFKpoNVqUV1djeDgYNcXSOQFuOZCdIMtW7bAZDLdVrAAgN1uh8lkwrvvvuviyoi8B0cuRJ0IIRAbG4szZ87gTlpDkiTo9XpUVFRwFxkRGC5EXVy5cgXh4eHd+nxYWJiMFRF5J06LEXXS2trarc+3tLTIVAmRd2O4EHUSGBjYrc/36dNHpkqIvBvDhaiTsLAwxMTE3PG6iSRJiImJQWhoqIsqI/IuDBeiTiRJwooVK37UZ1euXMnFfKJvcUGf6Aa8z4Wo+zhyIbpBcHAwduzYAUmSoFJ9f4s47tDfuXMng4WoE4YL0U2kp6cjKysLWq0WkiR9Z7rL8XNarRZ79uzBzJkzFaqUyDMxXIhuIT09HdXV1Vi7di30en2XX9Pr9Vi7di1qamoYLEQ3wTUXotsghEBOTg7S0tKwb98+TJ8+nYv3RN+DIxei2yBJknNNJTg4mMFC9AMYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGC9EPMJvNOHPmDL788ksAQHl5OS5fvgy+rYLo1vg+F6JbaG9vx+7du7Fp0yaUlZXBYrFAkiQIIRAcHIyZM2di5cqVGDVqFB/BT3QDjlyIbqK+vh6PPvooli1bBp1OhzfeeAN5eXkoLS3FoUOH8Pzzz+Orr75Ceno6Nm3aBKvVqnTJRB5FrXQBRJ6msbERS5cuxdGjR/Hmm2/iwQcfhNVqRWZmJoxGI4YOHYrMzEwsXrwYGzZswJ/+9CdYrVY888wzHMEQfYvTYkSdCCGQmZmJjRs3YuvWrcjIyIAkSTAajRg9ejSqq6sxadIk7N+/HxqNBjabDevXr8df/vIXfPLJJ0hKSlL6j0DkETgtRtRJZWUl3nrrLSxfvhw/+clPfnAk0qtXLzzxxBNITk7Gq6++CpvN5qZKiTwbw4Wok08//RRWqxWPP/44VCoVhBC33BXm+DV/f388+eSTOHLkCM6fP+/miok8E9dciL4lhEBeXh6GDh0Km82G5557Dna7HcD1nWONjY0AgDNnzuC5556DSnX92iwiIgKLFy+GRqPBiRMnMHjwYKX+CEQeg+FC9C273Y66ujpERkaitrYWa9euvek018WLF7Fu3Trnf997771Yvnw5QkJCUFtb686SiTwWw4WoE5vNho6ODkiSBD8/P2e4CCFgsVgAAJIkQaPROD/j+H0Wi6XLzxP1ZNwtRj2W1WpFeXk5cnNznV9VVVUYMmQIcnJycOnSJed6S3NzMx566CHU19cjISEBmzZtQq9evQAAWq0Wfn5+GDduHAYPHozZs2cjJSUFqampCAsLU/KPSKQYjlyox2hpaUFRURFyc3NhMBhQWFiI1tZWaDQajB07FgsXLgQAvP766zh58iTS0tKcu8WMRiP8/PwAAIGBgUhMTHSOUoQQ+Mc//gF/f3/ExsZi+/bteO211wAAQ4cORUpKCiZOnIiUlBQMGTKE98JQj8CRC/msixcvwmAwIDc3F3l5eSgrK4PdbkdISIhzZJGSkoKxY8dCq9UCuH4D5ZQpU3DXXXdh586d6NOnDwDc8j4XIQRqamqQlpaGRYsW4aWXXoIQAhcuXIDBYEBeXh7y8vJQXl4OIQT69evnPG9qairGjBnjDC0iX8JwIZ9gs9lw8uTJLlNcjm3Ber0eqampzq+4uDjnTq+b2bVrFx555BEsW7YMq1evhlarRUNDw3fCRa1W4+rVq1i+fDmqqqqwd+9e9O/f/6bHbGxsRGFhIfLy8mAwGFBUVIS2tjYEBARg7NixztqSk5MREhLikv9HRO7EcCGvZDKZUFxc7ByZFBQUoKmpCWq1GmPGjHGODFJTU3HXXXfd0bFtNhvWrl2LF198EbNmzcILL7yAwYMH47333kNLSwuio6Pxs5/9DEVFRXj++edRX1+Pbdu2ITEx8bbPYbFYUFZW5hxV5ebmoq6uDgAwfPjwLmF4zz33cCqNvA7DhbxCXV1dl1FJaWkprFYrgoKCkJSUhIkTJyI1NRXjxo2DTqfr9vlsNhs+/PBDvPDCC7hy5QpSUlKQmJiI4OBg1NXVoaCgAMeOHcOkSZPw17/+FXFxcd06nxACZ8+edQZNbm4uvvrqKwDAXXfd5QyalJQUjB49mrvSyOMxXMjj2O12nDp1yrlmkZubizNnzgAABg4c6FwcT01NxYgRI5y7tlyhrq4Ou3fvxp49e/DNN9+go6MDISEhGDNmDObPn48pU6bA39/fJec2Go0oKChwhk1xcTE6Ojqg0+kwfvx4Z9gkJSUhKCjIJTUQ/VgMF1Jce3s7iouLnesR+fn5aGhogEqlwujRo527rVJTUxEZGalIjXa7HWazGTabDRqNRpFF+I6ODpSWljqn0gwGA65evQpJkjBy5MguU4EDBgzgVBopiuFCbnf58mXk5+c7RyZffvklLBYLAgMDkZSU5PwGOX78eOduLfouIQQqKyudQZOXl4fTp08DACIjI50jvJSUFMTHx0Ot5p0H5D4MF3IpIQQqKiq6XG1XVFQAAKKiopCamur8Jjhy5Eh+A+ymK1euOLc/GwwGHD161BncEyZMYHCT2zBcSFYdHR0oKSlxrhPk5+fj8uXLkCQJ8fHxznWCiRMnYsCAAUqX6/Pa29tx9OhR58gmPz8fRqOxy5Sj4+9DqSlH8k0MF+oWo9GI/Px858ikqKioy6KzY61kwoQJXHT2AI7NEp13pXXeLNF5V5qrN0uQb2O40G1zbJftvCXYsV02IiKiy70Zo0aN4nZZL1FXV+cMm7y8PJSUlHTZ5u3YKDBu3Dj07t1b6XLJSzBc6JYsFgtKS0ud33gMBgPq6+sBACNGjOiyJZg3+vkOxw2qjrDJz8933qCakJDQ5fE1d3qDKvUcDBdyampqQkFBgfOu986PKBk/frxzbj4pKYmPKOlB7HY7Tp482WUqraqqCsD1R+t03gL9Q4/WoZ6D4dJD3fhwRYPBgBMnTkAIgfDw8C5P8uXDFelGFy9edI5sHA8FtdlsCAkJQXJysjNsxo4di4CAAKXLJQUwXHoIq9WK48ePO4MkNzcXFy9eBADExcV12cUVExPDKS66I62trc7XGTie9eZ4nUFiYmKXjQL9+vVTulxyA4aLj2ppaUFhYaGz2R3vLvHz8/vOU3jZ7CQ3x4vYOt9zU1NTAwCIjY3tEjaxsbG8mPFBDBcfUVNT41x0z8vLw7Fjx2C32xEaGtplTjwxMZHTFKSICxcuOC928vLycPz4cQghEBYW1mXknJCQ4LLntZH7MFy8kOPdJZ3XSy5cuAAAiImJ6bKLa+jQoVxgJY/U1NTkfMeNY3RtMpng7++PsWPHOv8dJycnIzQ0VOly6Q4xXLzEoUOHYDAYYDAYUFBQgObmZqjVaiQmJnZZfL/Vy6qIPJ3FYsGxY8e67Eqrra0FcP0dN44LppkzZ/LfuRdguHgJs9kMIQRUKhUkSXL+yLlq8lVCCAghYLfbnV9CCPj5+fHJAV6A4eIlHH9NDBPqydgH3oOPoJVZW1sbPv30U5jNZqVLuW1CCIwbNw7Dhg1TuhTyEW1tbcjKyvK6Phg7diz7QCYMF5k1NTVh3759WLx4sWzHtNvtaGlpQXNzM/z8/BASEiLrTY2VlZU4ePAgm4pk4+o+0Gg0CAkJkXVXWWVlJQ4dOsQ+kAnDxQUiIiIwadKkbg3dhRCw2Ww4cuQI/vnPf6KwsBBNTU3QaDQYMmQIFi1ahF/84hcICQnp9hRBYGAgioqKunUMohtFRERg4sSJ3f73abPZUFBQgLffftv5llK1Wg29Xo+FCxfi4YcflqUPevfujeLi4m4dg/4f96h6ICEEjEYjnn32WcybNw///e9/ceXKFfTu3Rt2ux2FhYX4zW9+g9mzZyM/Px9cNiNfJIRAU1MT/vjHP2Lu3LnYvn076urqnE9m/vLLL/G73/0Os2bNwuHDh2G32xWumDpjuHgYIQQuXbqExYsXY/PmzfD398dvf/tbHD58GMXFxSgqKsLbb7+NUaNGoaSkBAsWLMAnn3zCgCGfIoTAlStXsGzZMqxfvx4qlQorVqzA/v37UVhYiIKCAvz73/9GQkICysvL8dBDD2Hnzp0MGA/CaTEP4hixLF++HDk5OYiNjcWbb76JlJSULjdCPvTQQ5gxYwYyMzOxdetWPPHEE9Bqtbj//vu5i4Z8QktLC55++mlkZWVh4MCB2LhxI6ZNm9alDxYsWID7778fL774Iv71r3/hmWeegVarxezZs9kHHoAjFw9iNpuRmZmJffv2ITY2Fh988AFSU1O/c4e9JEno168f1q1bh+XLl6OhoQFPP/00Tp8+zREMeT2LxYKXX34Zn376KQYMGID//Oc/mD59+k37IDQ0FK+++iqefPJJtLS0YOXKlc6ne5OyGC4eQgiB9957D1u3bkV4eDjeeustDB8+/JZXYJIkQafTYfXq1Zg9ezaqqqrw7LPP4tq1a26unEg+Qgjs2rULmzdvRt++fbFp0yYkJiZ+70gkICAAL774IubNm4eLFy/i17/+NZqbm91YNd0Mw8UDCCFQWVmJl19+GQCwatUqJCUl3dbQPjAwEK+99hr0ej1ycnKwefNmXrWRV3K8Y+jPf/4zLBYL/vCHP2DatGm31Qc6nQ6vvPIK4uLikJubiw0bNnD9RWEMFw9gsViwatUqXLp0CRkZGfjlL39523PGkiRh0KBBWL16NdRqNf7+97/j66+/ZsCQ17HZbHjllVdQVVWF++67D48//vhtP3RVkiRERUVh9erV8Pf3x/r161FeXs4+UBDDRWFCCOzfvx8ff/wx+vfvj5deeumOb5CUJAkZGRmYN28e6uvrsWbNGthsNhdVTCQ/IQQKCwvx/vvvIzg4GKtWrYJOp7ujY0iShJkzZ+LBBx+E0WjEmjVrYLFYXFQx/RCGi8JMJhPWrFkDs9mMp556CnFxcT9qp4tGo0FmZibCwsKwe/du5OXl8aqNvIbZbMYrr7yCa9euYenSpUhISPhRfdCrVy/8/ve/R//+/bFnzx4cPnyYfaAQhouChBDYs2cPCgsLERMTg2XLlv3oLZSSJCEuLg5Lly5FW1sb/va3v/GqjbyCEAIHDx7EgQMHEBkZiaeffvpHv4NIkiRnL3V0dOC1115DR0eHzBXT7WC4KMhkMuH111+H3W7HU089hfDw8G4dT6VS4YknnsDdd9+NnJwcjl7IK5jNZrzxxhswm81YtmwZoqOju3U8SZLw2GOPITo6GkeOHIHBYGAfKIDhohAhBPbt24ejR48iJiYGixYtkuXGr0GDBmHx4sVob2/Hhg0buPZCHk0Igfz8fBw+fBiRkZFYsmSJLH0QFRWFJUuWwGw2Y+PGjbBarTJUS3eC4aIQi8WCzZs3w2q14rHHHkNYWJgsx5UkCb/61a8QFhaG7Oxs53vKiTyRzWbD5s2b0dHRgYcffhhRUVGyHFeSJCxZsgT9+/dHTk4OSkpK2AduxnBRgBACpaWlOHLkCCIiIrBw4UJZH1eh1+vxwAMPoLW1FVu2bGFTkcc6ffo0srOzERwcjEceeUTWPhgwYAB+/vOfw2Qy4Z133mEfuBnDRQFCCLzzzjtoa2vDvHnzZLtac3CMXrRaLXbt2oVLly7JenwiOQghsG3bNjQ3NyM9PR0xMTGyHl+lUmHJkiXo3bs3srKyUFNTI+vx6fsxXBRw8eJFZGVlQafTYenSpT96Z8ytSJKEhIQEJCUloba2lk9NJo9kNBqxY8cO+Pn54dFHH5W9DwBg5MiRSE1NRX19PT766CP2gRsxXNxMCIGsrCzU1dUhJSUFI0eOdMl5NBoNlixZAgDYtm0bt2OSRxFC4MCBA6iqqkJ8fDwmTJjgkicZq9VqLFmyBCqVCu+//z7a2tpkPwfdHMPFzTo6OrB9+3YAwMMPPwyNRuOS80iShBkzZmDAgAEoLS1FWVmZS85D9GPYbDZs27YNdrsdCxYsgFardcl5JEnC9OnTMXDgQJSXl6O0tJSjFzdhuLjZiRMnUFJSgsjISMyYMcOl750IDw/H7Nmz0d7ejg8//JBNRR7j7NmzOHLkCEJDQzF37lyX9kFoaCjmzJkDs9mMDz74wGXnoa4YLm4khMBHH32EtrY2pKend/umyR8iSRIWLFgAPz8/7NmzB01NTS49H9HtcEwNNzU1YfLkyRg4cKBLzydJEubPnw9/f398/vnnaGhocOn56DqGixu1trbik08+gVqtxoMPPuiWt+WNGTMGw4YNw7lz55Cfn8/RCymuo6MDu3btgkqlwoIFC1yykH+j+Ph4jBgxAtXV1cjNzWUfuAHDxY1KSkrwzTffQK/XY/z48W4JF51Ohzlz5sBms3G3DHmEU6dO4fjx44iIiMCkSZPc0gdarRZz586FzWbDjh072AduwHBxE8cb9iwWC2bNmoU+ffq45bySJOGnP/0pAgICsH//fhiNRrecl+hmhBD47LPPYDKZMG3aNJdPDTs4XkvRu3dvHDp0CJcvX3bLeXsyhoubtLS04IsvvoBGo3H5AuaNhg8fjhEjRqCmpoZTY6Sojo4OZGVlQaVS4YEHHnBrHwwdOhT33nsvamtr+VBXN2C4uElZWRnOnj0LvV6PhIQEt547ICAAGRkZsNlsvKGSFFVRUYGTJ0/i7rvvRnJyslvDxc/PDxkZGbDb7ewDN2C4uIHjvS0WiwX3338/AgMD3Xp+SZIwa9YsBAQE4MCBA9w1RooQQiA7OxsmkwkTJ05Ev3793Hp+x5sqtVotDh06xCliF2O4uEFbWxuys7OhVqsxe/ZsRWoYPnw4hgwZgurqapSUlChSA/VsVqsVn332GSRJwpw5c9w6anGIi4tDXFwcLl26hKNHj7r9/D0Jw8UNTp8+jYqKCkRFRSExMVGRptLpdEhLS4PVasXnn3/OKQFyuwsXLuDYsWMICwtz+5SYQ0BAANLS0mCz2dgHLsZwcTEhBPbv34+2tjZMnDgRwcHBitThmBpTq9XOeojcRQgBg8GA5uZmJCYmIiIiQpE6JElCeno6NBoNDhw4gGvXrilSR0/AcHExq9WK//3vf85v7kpcrTkkJCQgKioKlZWVqKysVKwO6nmEEM6RQnp6Onr16qVYLfHx8YiOjsbZs2dx+vRpxerwdQwXF7t06RLKysoQEhKi2FSAQ3BwMJKTk9HW1oaDBw9ySoDcxmg0orCwEDqdDlOnTlW0D/r27YvU1FS0t7ezD1yI4eJCQggUFhaioaEB8fHxiIyMVLQex24ZSZLwxRdfwGazKVoP9Rzl5eWora3FkCFDoNfrFa2lcx9kZ2ezD1yE4eJi2dnZEEIgLS0NarVa0VokSUJycjL69OmD0tJS3qVMbiGEQE5ODqxWKyZPnoyAgABF65EkCRMmTEBQUBCOHTuGuro6RevxVQwXF7p27Rpyc3Ph5+en+FSAw4ABAzB8+HBcvnwZpaWlSpdDPYDZbMaBAwegUqlw3333eUQfREZGYvjw4TAajTh69CinxlyA4eJClZWVqKqqQlRUFIYNG6Z0OQCu36U8ZcoU2O127Nu3j01FLldTU4NTp06hX79+bn86xa1oNBpMnToVdrsdOTk5SpfjkxguLnTkyBG0t7djwoQJ6Nu3r9LlALg+JTBt2jT06tULBoOBrz8mlysuLkZzczPi4+PRv39/pcsB8P99oFarYTAY0N7ernRJPofh4iKdr4imT5/uEVMBDvHx8QgPD0dFRQWqq6uVLod8XE5ODoQQmDp1qqJbkG80YsQI9O/fH9988w3Onz+vdDk+h+HiIo2NjSgrK0Pv3r0V34J8o379+mHUqFFobW1FUVGR0uWQDzOZTCgsLIRGo3Hbu1tuV2hoKBISEnDt2jX2gQswXFzk66+/Rm1tLe655x4MGjRI6XK6UKlUmDp1KoQQOHjwoNLlkA87f/48zp07h4iICMTFxSldThcqlQpTpkwBAPaBCzBcXMRgMMBisSA5ORlarVbpcrqQJAmpqanQaDQoKCiAyWRSuiTyUUVFRTCZTEhISEBQUJDS5XTh2Jrv5+eH4uJi9oHMGC4uYLPZUFhYCEmSMHXqVKXLualhw4YhIiIC586d43wzuUTnkfHkyZOhUnnet5uhQ4fi7rvvRlVVFc6dO6d0OT5F2bv6fJRKpcLq1asxb94857Db07b8BgUFYe7cuWhvb0dAQABaWlqULol8jN1uR0pKCoxGo0f3wYIFC2C1WtG3b1+OXmTEcJGZSqXCyZMnsX37dgDAmTNnFK7o1gIDAxEUFIScnByMGTNG6XLIh6hUKpw6dQoajQajR4/G7t278fHHHytd1k2p1Wr4+fkhOzubfSAjSXjapYSXs9lsqK+vh91uV7qUOxIUFOT2N2SS72IfEMOFiIhk53krbERE5PUYLl5CCAG73e5xC6JE7sQ+8B4MFy9RWloKrVbLJxlTj1ZaWgqdTsc+8AIMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DxQsIIdDQ0AAAaGho4OPGqUdy9EHnH8lzMVw8WGNjI9atW4fY2FikpaXBbDYjLS0NsbGxWLduHRobG5Uukcjl2Afeia859lB79+7F/PnzYTKZAKDLVZokSQAAnU6HHTt2ID09XZEaiVyNfeC9GC4eaO/evcjIyHC+de9WVCoVJElCVlYWG4t8DvvAuzFcPExjYyOio6PR1tb2vQ3loFKpoNVqUV1djeDgYNcXSOQG7APvxzUXD7NlyxaYTKbbaigAsNvtMJlMePfdd11cGZH7sA+8H0cuHkQIgdjYWJw5c+aOdsJIkgS9Xo+KigrnPDSRt2If+AaGiwe5cuUKwsPDu/X5sLAwGSsicj/2gW/gtJgHaW1t7dbnW1paZKqESDnsA9/AcPEggYGB3fp8nz59ZKqESDnsA9/AcPEgYWFhiImJueP5YkmSEBMTg9DQUBdVRuQ+7APfwHDxIJIkYcWKFT/qsytXruQiJvkE9oFv4IK+h+H+fiL2gS/gyMXDBAcHY8eOHZAkCSrV9//1OO5M3rlzJxuKfAr7wPsxXDxQeno6srKyoNVqIUnSd4b5jp/TarXYs2cPZs6cqVClRK7DPvBuDBcPlZ6ejurqaqxduxZ6vb7Lr+n1eqxduxY1NTVsKPJp7APvxTUXLyCEgNFoREtLC/r06YPQ0FAuWlKPwz7wLgwXIiKSHafFiIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdv8HKMKmRdQUW2cAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model.plot()" ] }, { "cell_type": "code", "execution_count": 4, "id": "e8cb9a2f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2\n", "fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2\n", "saving model version 0.2\n" ] } ], "source": [ "model.auto_symbolic()" ] }, { "cell_type": "code", "execution_count": 5, "id": "1b143bf8", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle - 1.191 x_{1}^{2} - 1.191 x_{2}^{2} + 2.329$" ], "text/plain": [ "-1.191*x_1**2 - 1.191*x_2**2 + 2.329" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from kan.utils import ex_round\n", "ex_round(model.symbolic_formula()[0][0], 3)" ] }, { "cell_type": "code", "execution_count": null, "id": "782f818f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }