632 lines
82 KiB
Plaintext
632 lines
82 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Physics 1: Lagrangian neural network"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 94,
|
|
"id": "66865edb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"from kan.utils import batch_jacobian, batch_hessian\n",
|
|
"\n",
|
|
"torch.use_deterministic_algorithms(True)\n",
|
|
"torch.set_default_dtype(torch.float64)\n",
|
|
"\n",
|
|
"seed = 0\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"\n",
|
|
"#example = 'harmonic_oscillator'\n",
|
|
"#example = 'single_pendulum'\n",
|
|
"example = 'relativistic_mass'\n",
|
|
"\n",
|
|
"# three examples: harmonic oscillator, single pendulum, double pendulum\n",
|
|
"\n",
|
|
"# dimension of q\n",
|
|
"# Lagrangian: (q, qd) -> (qdd)\n",
|
|
"\n",
|
|
"if example == 'harmonic_oscillator':\n",
|
|
" n_sample = 1000\n",
|
|
" # harmonic oscillator\n",
|
|
" d = 1\n",
|
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qd = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qdd = - q\n",
|
|
" x = torch.cat([q, qd], dim=1)\n",
|
|
" \n",
|
|
"if example == 'single_pendulum':\n",
|
|
" n_sample = 1000\n",
|
|
" # harmonic oscillator\n",
|
|
" d = 1\n",
|
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qd = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qdd = - torch.sin(q)\n",
|
|
" x = torch.cat([q, qd], dim=1)\n",
|
|
" \n",
|
|
"if example == 'relativistic_mass':\n",
|
|
" n_sample = 10000\n",
|
|
" # harmonic oscillator\n",
|
|
" d = 1\n",
|
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" #qd = torch.rand(size=(n_sample,1)) * 1.998 - 0.999\n",
|
|
" #qd = 0.95 + torch.rand(size=(n_sample,1)) * 0.05\n",
|
|
" qd = torch.rand(size=(n_sample,1)) * 2 - 1\n",
|
|
" qdd = (1 - qd**2)**(3/2)\n",
|
|
" x = torch.cat([q, qd], dim=1)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 95,
|
|
"id": "ec549451",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.compiler import kanpiler\n",
|
|
"from sympy import *\n",
|
|
"\n",
|
|
"input_variables = symbol_x, symbol_vx = symbols('x v_x')\n",
|
|
"expr = symbol_vx ** 2\n",
|
|
"model = kanpiler(input_variables, expr, grid=20)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 96,
|
|
"id": "f9930812",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASgklEQVR4nO3de3BU5f3H8c/ZEGBjUkJiuCiKSQgoNyWigKZ2LMiK6c3iOPVHrYzgWFrBf0r/AMc6ttqp8kcyeMGOMw7gwLQ2jhTCNN6LhNQb4A2omFSYIBBCWEhMyGX3+f3xsCThEsCc3XM2vF8zTGJ2N/sNzpMP3/NcjmOMMQIAwEUBrwsAAPQ9hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHX9vC4ASAbGGB0+fFhNTU1KT09Xdna2HMfxuizAt+hcgB6Ew2GVlpaqoKBAOTk5ys3NVU5OjgoKClRaWqpwOOx1iYAvOdwsDDiziooKzZ49W83NzZJs9xIT61rS0tJUVlamUCjkSY2AXxEuwBlUVFSouLhYxhhFo9GzPi8QCMhxHJWXlxMwQBeEC3CKcDisESNGqKWlpcdgiQkEAgoGg6qtrVVmZmb8CwSSAHMuwClWrlyp5ubm8woWSYpGo2pubtaqVaviXBmQPOhcgC6MMSooKFBNTY0uZGg4jqO8vDzt3r2bVWSACBegm/r6euXk5PTq9dnZ2S5WBCQnLosBXTQ1NfXq9Y2NjS5VAiQ3wgXoIj09vVevz8jIcKkSILkRLkAX2dnZys/Pv+B5E8dxlJ+fr6ysrDhVBiQXwgXownEcLVy48Du9dtGiRUzmAycwoQ+cgn0uQO/RuQCnyMzMVFlZmRzHUSDQ8xCJ7dB/9dVXCRagC8IFOINQKKTy8nIFg0E5jnPa5a7Y14LBoDZu3KiZM2d6VCngT4QLcBahUEi1tbUqKSlRXl5et8fy8vJUUlKiffv2ESzAGTDnApwHY4zeeecdTZ8+XW+99ZZuvfVWJu+BHtC5AOfBcZyTcyqZmZkEC3AOhAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWEC3AObW1tqqmp0ccffyxJ+vzzz3Xo0CFxtwrg7LifC3AWx48f17p167RixQp98sknam9vl+M4MsYoMzNTM2fO1KJFizRx4kSO4AdOQecCnEFdXZ3mzp2refPmKS0tTcuXL1dVVZW2b9+uTZs2aenSpdq5c6dCoZBWrFihjo4Or0sGfKWf1wUAfhMOh3Xfffdp69ateuGFF3TXXXepo6NDS5YsUUNDg0aPHq0lS5Zozpw5eu655/TII4+oo6NDDz30EB0McAKXxYAujDFasmSJnn/+eb388ssqLi6W4zhqaGjQtddeq9raWhUVFentt99WamqqIpGInnnmGT3xxBNav369pkyZ4vWPAPgCl8WALr766iu9+OKLmj9/vmbNmnXOTiQlJUUPPvigpk6dqqeeekqRSCRBlQL+RrgAXWzYsEEdHR164IEHFAgEZIw566qw2GMDBgzQggULtHnzZu3duzfBFQP+xJwLcIIxRlVVVRo9erQikYgWL16saDQqya4cC4fDkqSamhotXrxYgYD9t9mwYcM0Z84cpaam6osvvlBubq5XPwLgG4QLcEI0GtXBgwd12WWX6cCBAyopKTnjZa5vvvlGpaWlJ/973Lhxmj9/vgYPHqwDBw4ksmTAtwgX4ATHcZSamqrW1lY5jqP+/fufDBdjjNrb27s9Lyb2vPb29m5fBy5mhAtwQiAQUH5+vjZt2qQxY8bovffeOznfcuzYMd1zzz2qq6vTtddeqxUrViglJUWSFAwGFQ6HdXj/fuUtWybt2CHddJN0883SpZd6+SMBniFcgC5uu+02rV69Wjt27ND06dNPrhZraGhQ//79JUnp6ekqLCw82aUYY/Tss89qUP/+GjdqlLRmjfT00/YbjhljQyb2p6BAYi8MLgKEC9DFjBkzNGrUKP3lL3/RlClTlJGR0ePzjTHat2+fli9frv9bsEBZjz8uGSPt3Stt3ixt2SJVVkovvWS/npNjQybW2Vx/vXQitIC+hE2UwClee+013XvvvZo3b56efPJJBYNBHTly5LRNlP369dPhw4c1f/587dmzRxUVFRoyZMiZv2k4LP3nPzZsNm+W3n9fammRBg6Ubrihs7O56SZp8OCE/rxAPBAuwCkikYhKSkr02GOP6fbbb9ejjz6q3NxcrV69Wo2NjRoxYoTuvPNOffjhh1q6dKnq6uq0du1aFRYWnv+btLdL27d3725iK83GjpWKimzQFBVJublcSkPSIVyAM4hEInrllVf06KOPqr6+XtOmTVNhYaEyMzN18OBBvf/++/r0009VVFSkZcuWacyYMb17Q2OkmhobMrHuZscO+9iwYd3nba67TmJVGnyOcAF6cPDgQa1bt04bN25UdXW1WltbNXjwYE2aNEmzZ8/WLbfcogEDBsTnzRsapKoqGziVldIHH0itrVJamjRlSudltGnTpEGD4lMD8B0RLsB5iEajamtrUyQSUWpq6smVYwnV2ipt3dq9u6mvt5fMJkzo3t1ceSWX0uApwgVIVsZIu3d3djaVldJ//2sfu/zyznmbm2+WJk6U+rE4FIlDuAB9yaFDtquJdTYffWQXD6SnS1OndnY2U6ZI51hmDfQG4QL0ZceP24CJrUrbssXO5QQCdmFAbN6mqMh2O4BLCBfgYhKNSrt2dZ+3qa62j40c2X3eZtw46cQRN8CFIlyAi92BA517bSor7aKBjg67Am3atM55mxtvlC65xOtqkSQIFwDdNTfbZc+x7mbLFunoUbsgYNKk7gsFhg3zulr4FOECoGfRqPTFF91XpX39tX0sP7/7vM3VV9v5HFz0CBcAF27fvs7OprLSHmUTidhz0WJdzc0323PTBg70ulp4gHAB0HtNTfYwzlhnU1Vlv5aaKk2e3P1gzpwcr6tFAhAuANzX0SF99ln3VWm1tfax0aO7z9uMHs1pAn0Q4QIgMfbu7T5v8+mn9pSBSy/tPm9TWCjF67w2JAzhAsAbR4923uOmstJ+3txsg+WGGzq7m5tukrKyvK4WF4hwAeAP7e3SJ590727277ePjR3b2d3cfrs0dKi3teKcCBcA/hT71XTqx0CAOZokwDGpAPwpFiBdg4R/CycNwgWA61pbWlSzfr2ibW1el3LeHGOUPXmyhl5zjdel9AmECwDXHT96VPtff12jfv5zac8eeyhmerrXZZ3Zl19K3/uejhw9qgP//jfh4hLCBUBcDBw6VFesWSPnn/+U1q6VbrnF65JO19oqLV0q1dQo9Q9/UB1zOa7hECAA8REISFdcIX37rbRpkz/nS/bvl3butEfXcD8bVxEuAOLDcWy3EgjYZcXt7V5XdLpt26Rw2C515lgaVxEuAOJn4kS7AXLXrs49K35hjPTuu/bj979vbykA1xAuAOJnyBBp/HjbHXz8sdfVdNfaak8HSEnx53xQkiNcAMRPv372F7cx0ttv+2veZe9eafduG4ATJnhdTZ9DuACIrx/8wIZMZaV0/LjX1VjG2FsENDVJ110nZWd7XVGfQ7gAiK/x4+3tkKurpf/9z+tqOr35pg2ZH/6Qu2fGAX+jAOJr8GB7w7Bvv7Xdix8ujTU22lOYBwywk/nsb3Ed4QIgvgIBacYM+/kbb/gjXHbutHMuV10lXX2119X0SYQLgPhyHNsdpKXZeY6GBm/riS0uaGuzx/j79ViaJEe4AIi//Hx7O+P9+6Xt272tpaPDdlCOI82c6W0tfRjhAiD+Bg60E+eRiFRR4e2lsdpae4vlrCxp6lTmW+KEcAEQf44jhUJ2SfJbb0ktLd7UYYy0ebPd1FlYKA0f7k0dFwHCBUBiTJokXXaZPeL+yy+9qcEYacMG+3HWLLs7H3FBuABIjMGDpaIi27V4tWqsvt4e+XLJJdL06VwSiyPCBUBiOI70k5/Yj+XliT8l2Ri7t+XAAXsK8qhRiX3/iwzhAiAxHMcu/R0yxK4Yq6lJ7PsbI61bJ0Wj0h132A2UiBvCBUDiDB1qA6axMfGrxo4ckd55x65cu+MOLonFGeECIHECAenOO+0v9tdeS9ylMWOkqiq7DHnsWPsHcUW4AEgcx7GnJA8ZIm3dao+8TwRjpH/8w+6z+fGPpWAwMe97ESNcACTW8OE2YJqapPXrE3NprK7O7q9JS5N++lMuiSUA4QIgsRxHuvtue4msrCz+GypjZ4nt32/32nBQZUIQLgASy3Hs3SlHjpQ+/9xeHotn99LRIa1da9/j7rul/v3j9144iXABkHhZWXbPS1ubtGZNfMNl92575Et2tvSjH3FJLEEIFwCJ5zjSPffYifUNG6SDB+PzPsZIf/+7dOyYvafMFVfE531wGsIFgDcmTpRuvFH65pv4TeyHw9Lf/mYPzLzvPm5nnED8TQPwRv/+0ty59vOXXnJ/Yt8Y6fXXpa++ksaPt5s3uSSWMIQLAG84jt0pn5srbdsmbdrkbvfS2ir99a/2e86daw+rRMIQLgC8k50t3Xuv3an//PN2ZZcbjLGnH1dVSSNGSHfdRdeSYIQLAO84jg2XIUPsJscPPnCne+nokJYvt93LL38pDRvW+++JC0K4APDWyJHSnDl2zqWkpPfdS6xreeMNG1r330/X4gHCBYC3AgHp17+WcnKkjRvtnpTedC9tbdLTT9uw+tWvpKuucq1UnD/CBYD38vKkefOk48elJ5747ivHjLE3InvzTenyy6Xf/Iblxx7hbx2A9wIB6be/tSGzadN337V/+LD0+OP20tqiRdKVV7pfK84L4QLAH4YPl5YssZ//8Y9SdfWFBUwkIi1bZs8ru/56af585lo8RLgA8AfHkX7xC7v3pbZW+t3vpObm83utMdK//iU995zdz/LnP0uDBsW3XvSIcAHgHwMHSk89ZVeQlZdLf/rTue9WaYz02WfSwoV2rubhh+39YuhaPEW4APAPx5EKCqTSUtuBlJTYlV9tbWd+vjHSjh12r8yePfbU48WLpZSUhJaN0xEuAPzFcaTiYtvBpKTYCfqHH7YHXMbmYIyxgbN+vfSzn9l5lqIie1ksPd3T8mH187oAADhNIGCXJvfrJ/3+9/aMsNdft0EycaI97biiQnr3XRsys2ZJL7xgd+JzOcwXCBcA/pSSYg+cvOYaaelSu+u+pKT7c4YOtUuYFy6UMjIIFh8hXADElenNbnvHkaZOtTcUq6qymyP37LHzMZMnS6GQ3csSC5V43tESF8Qxvfo/DwCnO1ZXpy0LFmjQhAnuf/PYryyXu5Tjhw7p0uuv14T773f1+16sCBcArotGImqsq5OJRr0u5YIMHDRIA1kQ4ArCBQDgOpYiAwBcx4Q+gOTR9UILK8N8jc4FQPLYts0uUd62zetKcA6ECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYRLEjDGqL6+Xl9//bXq6+t7d09yIEkZY3TkyBEZyX5kHPga4eJj4XBYpaWlKigoUE5OjnJzc5WTk6OCggKVlpYqHA57XSIQd13HwfQZM2SM0fQZMxgHPsdtjn2qoqJCs2fPVnNzsyR1+1eac+ImSWlpaSorK1MoFPKkRiDeTh0H1xmjjyRNlrSdceBrdC4+VFFRoeLiYrW0tMgYc1r7H/taS0uLiouLVVFR4VGlQPwwDpIbnYvPhMNhjRgxQi0tLYpGo+d8fiAQUDAYVG1trTIzM+NfIJAAZxsHk6STnUvXe1EyDvyHzsVnVq5cqebm5vMKFkmKRqNqbm7WqlWr4lwZkDiMg+RH5+IjxhgVFBSopqbmglbCOI6jvLw87d69++R8DJCsehoHZ+tcJMaB39C5+Mjhw4dVXV19wUssjTGqrq5WQ0NDnCoDEodx0DcQLj7S1NTUq9c3Nja6VAngHcZB30C4+Eh6enqvXp+RkeFSJYB3ehoHu2Qvie3q4fWMA38gXHwkOztb+fn5F3y92HEc5efnKysrK06VAYnT0zhokZ1raTnD6xgH/kK4+IjjOFq4cOF3eu2iRYuYxESfwDjoG1gt5jPscwEYB30BnYvPZGZmqqysTI7jKBDo+X9PIBCQ4zh69dVXGVDoUxgHyY9w8aFQKKTy8nIFg0E5jnNamx/7WjAY1MaNGzVz5kyPKgXih3GQ3AgXnwqFQqqtrVVJSYny8vK6PZaXl6eSkhLt27ePAYU+jXGQvJhzSQLGGDU0NKixsVEZGRnKyspi0hIXHcZBciFcAACu47IYAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1/w+t/uNoMzLGkwAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 500x200 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.get_act(x)\n",
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 97,
|
|
"id": "74a1fb02",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.2\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUf0lEQVR4nO3de3TU5Z3H8c8zucDkQm4MhJAEyNWigKQiKqmeFiQqbrcunp61rqsutF23hf1n3T+gx+26a/fU+kdSWkVPz+kqbe1uF46ughutWBGCeEFURAQSIDMhF3KZXEhCMjPP/hGYEkBE+M0l4f36x0MyM79v5Dz58H2e5/f8jLXWCgAAB7liXQAAYPwhXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjkuMdQHAWGCtVUdHh/r6+pSWlqacnBwZY2JdFhC36FyAC/D7/aqpqVFpaak8Ho9mzZolj8ej0tJS1dTUyO/3x7pEIC4ZHhYGnF9tba2WL1+u/v5+SSPdy2mnu5aUlBRt3LhRVVVVMakRiFeEC3AetbW1WrZsmay1CoVCn/s6l8slY4w2b95MwABnIFyAs/j9fuXn52tgYOCCwXKay+WS2+2Wz+dTZmZm5AsExgDWXICzPPvss+rv77+oYJGkUCik/v5+PffccxGuDBg76FyAM1hrVVpaqoaGBn2ZoWGMUVFRkQ4ePMguMkCECzBKe3u7PB7PZb0/JyfHwYqAsYlpMeAMfX19l/X+3t5ehyoBxjbCBThDWlraZb0/PT3doUqAsY1wAc6Qk5Oj4uLiL71uYoxRcXGxsrOzI1QZMLYQLsAZjDFatWrVJb139erVLOYDp7CgD5yF+1yAy0fnApwlMzNTGzdulDFGLteFh8jpO/Q3bdpEsABnIFyA86iqqtLmzZvldrtljDlnuuv019xut7Zs2aKlS5fGqFIgPhEuwOeoqqqSz+dTdXW1ioqKRn2vqKhI1dXVampqIliA82DNBbgI1lq98cYbWrx4sV5//XV9/etfZ/EeuAA6F+AiGGPCayqZmZkEC/AFCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIF+ALDA0NqaGhQe+//74kae/evTp+/Lh4WgXw+XieC/A5BgcH9eKLL2r9+vX68MMPNTw8LGOMrLXKzMzU0qVLtXr1as2dO5cj+IGz0LkA59HW1qYHHnhAK1asUEpKitatW6edO3dqz5492rZtm9auXatPP/1UVVVVWr9+vQKBQKxLBuJKYqwLAOKN3+/X/fffr927d+vpp5/W3XffrUAgoDVr1qizs1NlZWVas2aN7r33Xj355JP60Y9+pEAgoB/+8Id0MMApTIsBZ7DWas2aNXrqqaf0m9/8RsuWLZMxRp2dnZo3b558Pp8qKyu1detWJSUlKRgM6he/+IUee+wxvfTSS1q4cGGsfwQgLjAtBpzh0KFD+tWvfqWVK1fq9ttv/8JOJCEhQd///vd1ww036PHHH1cwGIxSpUB8I1yAM7z88ssKBAL67ne/K5fLJWvt5+4KO/29CRMm6KGHHtL27dvV2NgY5YqB+MSaC3CKtVY7d+5UWVmZgsGgHn74YYVCIUkjO8f8fr8kqaGhQQ8//LBcrpF/m+Xm5uree+9VUlKSPvnkE82aNStWPwIQNwgX4JRQKKTW1lbl5eWppaVF1dXV553mOnbsmGpqasJ/vvrqq7Vy5UplZWWppaUlmiUDcYtwAU4xxigpKUknT56UMUbJycnhcLHWanh4eNTrTjv9uuHh4VFfB65khAtwisvlUnFxsbZt26by8nK99dZb4fWWnp4e3XPPPWpra9O8efO0fv16JSQkSJLcbrf8fr/amtv06ROf6rV9r6ngpgIVLipUyuSUWP5IQMwQLsAZbr31Vm3YsEH79u3T4sWLw7vFOjs7lZycLElKS0tTRUVFuEux1uqXv/ylUpNTVV5Srr2/26u6n9VJkiaXT1bBogIVLBoJm+zSbO6FwRWBcAHOsGTJEpWUlOinP/2pFi5cqPT09Au+3lqrpqYmrVu3Tg8+9KAefPRBWWvV3ditxu2N8tZ55d3h1Z5f75G1Vqme1JGwuWkkcPK+mqeE5IQo/XRA9HATJXCWF154Qffdd59WrFihn/zkJ3K73erq6jrnJsrExER1dHRo5cqVOnr0qGprazVlypTzfuagf1C+t33y1nnVuL1RTbuaNDwwrMSJiZq+YHq4uym4qUDuLHeUf2LAeYQLcJZgMKjq6mr9+Mc/1m233aZHHnlEs2bN0oYNG9Tb26v8/Hzdddddevfdd7V27Vq1tbXp+eefV0VFxcVfYziolj0tatzeKF+dT407GtXX0idJ8sz2qLCycGTdprJQmbMymUrDmEO4AOcRDAb1hz/8QY888oja29t14403qqKiQpmZmWptbdWuXbv00UcfqbKyUk888YTKy8sv63rWWnU1dMm7wxvubo7vOy5JSstNU+GiwnB3k3ttrhKSmEpDfCNcgAtobW3Viy++qC1btqi+vl4nT55UVlaW5s+fr+XLl+vmm2/WhAkTInLtgc4BeXeOrNl4d3jV9E6TAicDSkpJUv7C/PA0Wv6N+ZqYMTEiNQCXinABLkIoFNLQ0JCCwaCSkpLCO8eiKXAyoObdzaO6m/72fhljNGXOlPCOtIJFBcoozGAqDTFFuABjlLVWnQc71bijMdzdtH/WLkmaNH2SCisLlX9TvgoXFWrq3KlyJXKUIKKHcAHGkRPHT4xsf67zyrvdq2PvHVNwOKjktGTl35Af7m6mL5yuCemRmc4DJMIFGNcCgwEde+/Yn++5qfNqoHNAxmWUe21ueN2msLJQk6ZPinW5GEcIF+AKYkNW7fvb1bjj1Bbo7Y3qrO+UJGXOyBx1moDnao9cCUyl4dIQLsAVrq+lb2SDwKm1m+bdzQoFQpqYMVH5N+aHTxOYfv10JadGfyMDxibCBcAow/3DanqnKdzdeOu8GuwelCvRpWnzp6mgsiB8MGdablqsy0WcIlwAXJANWbV90hbekda4o1H+I35JUnZx9qh1m8lXTZZxsQUahAuAS9DT1BO+38a7w6uWPS0KBUNyZ7nD02gFiwo0fcF0JU7kfNwrEeEC4LIN9Q3Jt8sX7m68O70a6htSQlKC8q7LG3UwZ6onNdblIgoIFwCOCwVCav24ddRpAj2+HklSTllO+GDOgkUFyinL4TSBcYhwARAV3Y3do04TaP2oVdZapUxOUeGiU6cJVBZqWsU0JU5gKm2sI1wAxMRg95+fcePd4ZXvbZ+G+4eVOCFReQvy/tzd3FQgdzbPuBlrCBcAcSE4HFTrh62jupve5l5JI8+4Ob1mU3JbidKmsgU63hEuAOJS+FeTHf1n4zKs0YwBTGwCiEvhAAn/x4h/C48dhAsAxw0ODOrtl95WYCgQ61IunpVKrivRzK/MjHUl4wLhAsBxJ7pP6KNXP9JNf3WT/Ef9ypyRqeS0+DyXrONAh5InJau7u1t739xLuDiEcAEQERlTM3T0d0d14H8PaPnzy1V2c1msSzpH4GRAG9ZuUFdDlyr+pUInzIlYlzRucJ42gIgwLqOMggwNnxjW0W1H43K9pK+5T8c/Pa5QMMTzbBxGuACIDCPNuHmGjMvIu8Or0HAo1hWdo/mDZg36B+WZ7VGKJyXW5YwrhAuAiJk6d6rc2W61728P37MSL6y1OvKnI5KVCr9WKFcivw6dxP9NABGTOiVVnms8GvQPqvn95liXM0rwZFC+Op9MgtGMm2fEupxxh3ABEDGuRNfIL24rHd56OK7WXbobu9VxsEOpU1I1dc7UWJcz7hAuACJq5i0z5Up0ybvDq8BgfNz3Yq2Vb5dPQ31Dyr02V+4czi5zGuECIKKmXDNFqbmp6qzvlP+wP9blhDX8sUGy0sxvzOTpmRFAuACIqIlZE5V3XZ6GTwyrcUdjXEyNDfUOqentJiVMSNDMr83krLIIIFwARJRxGRUtKZIkNbzWED6IMpaOf3pc3Y3dypyZqZyrcmJdzrhEuACIKGOMZnxthpJSktS0q0kDnQMxrcdaq8NbDys4FFTBooK4PZZmrCNcAERcVnGWcspy1Nvcq5Y9LTGtJRQIjXRQRipZWhLTWsYzwgVAxCVOTNTMb8yUDVodqj0U03WXHl+PWj9qlTvbrek3TGe9JUIIFwARZ4xRSVWJTKLR4dcPKzAQmy3J1lo1bm/UoH9Q0yqmKX1aekzquBIQLgCiInd+riblTVLHgQ51HOiITRFWOvDygZFnt9xeIpNA1xIphAuAqHBnuVVQWaDAQED1r9XHZGqsv71fvjqfklKTVLS4iCmxCCJcAESHkcq/WS4Z6eDmg1E/JdlaK9/bPvW19Mkz26PskuyoXv9KQ7gAiApjjAoXFSp1Sqpa9rSoq6ErugVYaf+L+2VDViV3lChhQkJ0r3+FIVwARE3q1FQVLCrQUO9Q1HeNDXQN6MgbR5Q4MVFld5QxJRZhhAuAqDEuo6vuukoy0mcvfBa1qTFrrXw7ferx9cgz2yPPbE9UrnslI1wARI0xRjNvmanUKalq3t2sjoNR2jVmpX3/s082aFX2F2VKdCdG57pXMMIFQFSlT0vXjFtmaKhvSJ+99FlUpsZOtJ3Q4dcPKyklSeV/Wc6UWBQQLgCiy0hXf/tqGZfR/o37I35DpbVWDVsb1Nvcq9z5uZp81eSIXg8jCBcAUWXMyGOFM2ZkqG1vm5p3N0f0pORQIKS9z++VrDT727OVkMwusWggXABEnTvbrbJvlik4FNTHv/s4olNjnQc75d3ulTvHrfI7mRKLFsIFQNQZYzTnnjlKdCfqwMsH1NfaF5HrWGu197/36mTPSRUtKdKkgkkRuQ7ORbgAiImpc6dq+vXT1XusN2IL+4P+Qe37r31yJbo07/55PM44iggXADGRkJygeQ/MkyR9+OsPHV/Yt9aq/tV6dR7q1JRrpqhwUSFTYlFEuACICWOMyu4oU9asLDV/0Kyj24462r0ETwb1/jPvy1qreQ/MU1JqkmOfjS9GuACIGXeOW3Pum6PQcEjvPvWuQgFn7ti31spb55Vvp0+T8idp9t2z6VqijHABEDPGGF1737VKnZKqw68fVtM7TY50L6FASO+se0fBk0HN+Zs5SstNc6BafBmEC4CYypiRoWvuvUaBgYB2Ve+67O7ldNdS/1q9UqekquLvKuhaYoBwARBTxmW04O8XKMWTooNbDqpxe+NldS/BoaDqflanwEBAc/92rjJnZjpXLC4a4QIg5rKKsjR/xXwFBgN667G3LnnnmLVWBzcfVMMfG5Q+PV0L/mEB249jhHABEHPGZXT9D65XVlGWjm47esl37Q90DOjNR99UKBDSwtULlVGYEYFqcTEIFwBxIW1amirXVEqStv3bNnXVd32pgAkFQ9rxxA617W1T3lfzVLGStZZYIlwAxAVjjK7562tUekepenw9evWfXtVw//BFvddaq0P/d0jvPfmeklOTtfg/FmtCxoQIV4wLIVwAxI3EiYm69fFblTEjQwc3H9S2f9+m4HDwgu+x1qrt4za9suoVBQYCWviPCzXjlhl0LTFGuACIG8YYZZdm67aa25SUmqRd1btU97M6BYfOHzDWWh3fd1yb7tuk7qPdKruzTIseXiRXAr/aYo2/AQBxxRijsmVlWvL4EpkEozcffVOv/OMr6j3WG16DsdYqOBTUgZcO6Pff+r2O7z2uwspC3fHkHUpK45iXeMCDpAHEHeMyqlhRoYTEBL32z69p9zO71fBqg8q/Va6pc6dq0D+o+tp6HfnTEQWHgiq5vUR3Pn2n0nLTmA6LE4QLgLjkSnDp2geu1eSvTNbWtVvlrfNqV/WuUa9JnZqqBT9YoIWrFio5PZlgiSOEC4CIuqyzwoyUf0O+vvPyd+Td6VXDHxvUfbRbSalJyrsuTyVVJSP3shgHrgVHES4AHGdcRr5PfPrtv/7W2Q+eINlSKyOjnuYe7f/P/Y59dM/xHhV9tcixz7vSGUvUA3BYMBhUV1uXbGhs/XpJzUhVSlpKrMsYFwgXAIDj2IoMAHAcay4AxowzJ1rYGRbf6FwAjBktH7To0YRH1fJBS6xLwRcgXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcxgBrrdrb23XkyBG1t7fznHBckay18nf5JUn+Lj/jIM4RLnHM7/erpqZGpaWl8ng8mjVrljwej0pLS1VTUyO/3x/rEoGIO3McLF6yWNZaLV6ymHEQ53jMcZyqra3V8uXL1d/fL+n8D0lKSUnRxo0bVVVVFZMagUg7exzk2lx9T9/TM3pGLWbkmS6Mg/hE5xKHamtrtWzZMg0MDMhae077f/prAwMDWrZsmWpra2NUKRA5jIOxjc4lzvj9fuXn52tgYEChUOgLX+9yueR2u+Xz+ZSZmRn5AoEo+LxxME3Twp1Ls5rDX2ccxB86lzjz7LPPqr+//6KCRZJCoZD6+/v13HPPRbgyIHoYB2Mf4RJHrLVat27dJb335z//ObtnMC4wDsYHwiWOdHR0qL6+/ksPDmut6uvr1dnZGaHKgOhhHIwPhEsc6evru6z39/b2OlQJEDuMg/GBcIkjaWlpl/X+9PR0hyoBYudC46Bd7XpGz6hd7Z/7GsZBfCBc4khOTo6Ki4vD97FcLGOMiouLlZ2dHaHKgOi50DgY1rCa1axhDZ/zPcZBfCFc4ogxRqtWrbqk965evfpLhxIQjxgH4wP3ucQZ7nMBGAfjAZ1LnMnMzNTGjRtljJHLdeG/HpfLJWOMNm3axIDCuMI4GPsIlzhUVVWlzZs3y+12yxhzTpt/+mtut1tbtmzR0qVLY1QpEDmMg7GNcIlTVVVV8vl8qq6uVlFR0ajvFRUVqbq6Wk1NTQwojGuMg7GLNZcxwFqrzs5O9fb2Kj09XdnZ2Sxa4orDOBhbCBcAgOOYFgMAOI5wAQA4jnABADiOcAEAOI5wAQA4jnABADiOcAEAOI5wAQA4jnABADiOcAEAOI5wAQA4jnABADiOcAEAOO7/AX8dLuaR/xUgAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 500x200 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.perturb(mode='best', mag=0.1)\n",
|
|
"model.get_act(x)\n",
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 98,
|
|
"id": "fd0d2987",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| loss: 5.03e-05 |: 100%|███████████████████████████████████████████| 20/20 [02:59<00:00, 8.99s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"from kan.utils import batch_jacobian, create_dataset_from_data\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"torch.use_deterministic_algorithms(True)\n",
|
|
"\n",
|
|
"def closure():\n",
|
|
" \n",
|
|
" global loss\n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" jacobian = batch_jacobian(model, x, create_graph=True)\n",
|
|
" hessian = batch_hessian(model, x, create_graph=True)\n",
|
|
" Lqdqd = hessian[:,d:,d:]\n",
|
|
" Lq = jacobian[:,:d]\n",
|
|
" Lqqd = hessian[:,d:,:d]\n",
|
|
"\n",
|
|
" Lqqd_qd_prod = torch.einsum('ijk,ik->ij', Lqqd, qd)\n",
|
|
"\n",
|
|
" qdd_pred = torch.einsum('ijk,ik->ij', torch.linalg.inv(Lqdqd), Lq - Lqqd_qd_prod)\n",
|
|
" loss = torch.mean((qdd - qdd_pred)**2)\n",
|
|
"\n",
|
|
" loss.backward()\n",
|
|
" return loss\n",
|
|
"\n",
|
|
"steps = 20\n",
|
|
"log = 1\n",
|
|
"optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
|
|
"#optimizer = torch.optim.Adam(params, lr=1e-2)\n",
|
|
"pbar = tqdm(range(steps), desc='description', ncols=100)\n",
|
|
"\n",
|
|
"\n",
|
|
"for _ in pbar:\n",
|
|
" \n",
|
|
" # update grid\n",
|
|
" if _ < 5 and _ % 20 == 0:\n",
|
|
" model.update_grid(x)\n",
|
|
" \n",
|
|
" optimizer.step(closure)\n",
|
|
" \n",
|
|
" if _ % log == 0:\n",
|
|
" pbar.set_description(\"| loss: %.2e |\" % loss.cpu().detach().numpy())\n",
|
|
" \n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 99,
|
|
"id": "782f818f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYiUlEQVR4nO3dWXBc1Z0G8O+01K1e1Yska5ctybIheAGxGLCzjRkrGdfUVEp+GPCQpUJCUWP8xkNM4kpRlVQlxYOEk+CkkgeWKkil5CmK2BNlhjABLzhsxhgTCPLSi6213ft+75mHttpqLC/g2923W9+vylVGUkt/2xx9+p9z7jlCSilBRESkIUOlCyAiotrDcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs3VV7oAomogpcTc3BxisRjsdjuampoghKh0WUS6xc6F6CpCoRBGR0cxMDCAlpYW9Pb2oqWlBQMDAxgdHUUoFKp0iUS6JHhZGNHixsfHMTw8jEQiASDfvcyb71qsVivGxsYwNDRUkRqJ9IrhQrSI8fFxbN26FVJKqKp6xY8zGAwQQmD//v0MGKIFGC5EnxIKhdDV1YVkMnnVYJlnMBhgsVjg9/vhcrlKXyBRFeCaC9GnPPPMM0gkEtcVLACgqioSiQSeffbZEldGVD3YuRAtIKXEwMAATp06hc8yNIQQ6Ovrwz/+8Q/uIiMCw4WoyOzsLFpaWm7o9U1NTRpWRFSdOC1GtEAsFruh10ejUY0qIapuDBeiBex2+w293uFwaFQJUXVjuBAt0NTUhP7+/s+8biKEQH9/PzweT4kqI6ouDBeiBYQQePTRRz/Xa3fu3MnFfKKLuKBP9Cl8zoXoxrFzIfoUl8uFsbExCCFgMFx9iMw/ob9v3z4GC9ECDBeiRQwNDWH//v2wWCwQQlw23TX/NovFggMHDmDLli0VqpRInxguRFcwNDQEv9+PkZER9PX1Fb2vr68PIyMjCAQCDBaiRXDNheg6SCnx6quvYvPmzXjllVfw1a9+lYv3RFfBzoXoOgghCmsqLpeLwUJ0DQwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOF6BoymQxOnTqFt99+GwBw4sQJzMzMgLdVEF0Z73MhuoJUKoWXXnoJe/fuxXvvvYdsNgshBKSUcLlc2LJlC3bu3Il169bxCH6iT2HnQrSI6elpfPvb38Z3v/tdWK1W7NmzB0eOHMGxY8fw2muv4fHHH8eHH36IoaEh7N27F7lcrtIlE+lKfaULINKbUCiEb33rW3jnnXfw61//Gtu2bUMul8OuXbsQDAaxatUq7Nq1C9u3b8evfvUr/PCHP0Qul8OOHTvYwRBdxGkxogWklNi1axeefvppPP/889i6dSuEEAgGg1i/fj38fj82bdqEv/zlLzAajVAUBb/4xS/wk5/8BC+//DI2bNhQ6T8CkS5wWoxogU8++QS//e1v8dBDD+HrX//6NTuRuro6PPzww7j77rvx85//HIqilKlSIn1juBAt8Mc//hG5XA7f+973YDAYIKW84q6w+fc1NDTgkUcewcGDB+H1estcMZE+cc2F6CIpJY4cOYJVq1ZBURQ89thjUFUVQH7nWCgUAgCcOnUKjz32GAyG/M9mbW1t2L59O4xGIz744AP09vZW6o9ApBsMF6KLVFXF1NQUOjo6MDk5iZGRkUWnuc6dO4fR0dHCf99yyy146KGH4Ha7MTk5Wc6SiXSL4UK0gFAE0uk0hBAwmUyFcJFSIpvN5j9GCBiNxsJr5j8um80WvZ1oKWO40JKl5lRMvT8F70Ev/If98B7yIuaLYXLVJFavXo3XX3+9sN4SiURw//33Y3p6GuvXr8fevXtRV1cHALBYLAiFQpidnkXugxwm/jwBZ48Tzh4njFaGDS1NDBdaMtLRNAJHA/Ae8sJ3yAf/G35kYhnUGevQcUcH1vz7GjwoHsQP9vwAJ0+exObNmwu7xYLBIEwmEwDAbrdjcHCw0KVIKfHLX/4SdpsdN998cz6wDuUX9q3N1kLQOHucsHgsfBaGlgSGC9WsSCAC70EvfId98B3yYfLYJKQqYXFb0L2xG1/c9UV0b+xGxx0dMFryQXFn6E787r9/h5/97GfYsGEDHA7HVb+GlBKBQAB79uzBN7/zTdz7nXshpUQ6nEbYGy78mnx3ElJKmGwmNHY3FsLG0eGAoY6bNqn2MFyoJqiKipkPZgpdie+QD6GzIQCAp9+D7o3duP3h29GzsQfNNzVDGBbvHlwuF5544gk8+OCD+NGPfoSf/vSnsFgsi36slBJzc3PYsWMHrFYrduzYASC/JmN2mWF2mdG6rhUAkEvlEPFHEPaGETobwplXz0DJKjDUG9DYeSlsGrsbC0FHVM34hD5VpWwii8DfAoXOxH/Ej1Q4BUO9Ae2D7eje2I3ue7vRs7EH9jb7Z/rciqJgZGQEP/7xj/G1r30Nu3fvRm9vL5577jlEo1F0dXXhG9/4Bt588008/vjjmJ6exgsvvIDBwcHr/hqqoiI2GSvqbjKxDADA1mKDc7kTzm4nnMudMLvMnEqjqsNwoaoQm4zlu5LDPvgO+nD+3fNQcyrMTjO67ulC98Z8kHTe1anJIrqiKPjDH/6A3bt3Y3Z2Fvfccw8GBwfhcrkwNTWFo0eP4vjx49i0aROefPJJrF69+oa+npQSqQupfND4wgifDSM+EwcAmOymonUbe5udU2mkewwX0h2pSsz+fbZovSQ4EQQAuJa70LOpB133dqFnYw9abmkp6TfaqakpvPTSSzhw4AAmJiaQTqfhdrtx2223YXh4GF/60pfQ0NBQkq+dTWYR8UUKnU0kEIGaU1FnrENj14KptK5G1Js5w036wnChisulcgi8GSislfgO+5C8kIQwCLTd2lboSro3dqOxs7EiNaqqikwmA0VRYDQaCzvHylpDTkX0fLRoKi2byN8xY1tmK+puGpwNnEqjimK4UNnFZ+KFjsR3yIdzb52DklVgspvQfU93oSvp3NCJBkdpuoJaIKVEMpgsCpvEbAIA0NDYUDyV1mq/4iYGolJguFBJSSkx9/FcoSPxHvRi7uM5AEBjV2OhI+ne2I3Wta0w1HMt4UZk4pmiqbTouShURUWdKT+V5lruym+B7nSgvoFTaVQ6DBfSVC6dw/l3zheeevcd9iE+E4cQAq3rWi/t4trUA2ePs9Ll1jw1pyJ6LorQ2VB+3cYXQTaZn0qzt9mLp9Ia2SWSdhgudEOSwWR+iuviNFfgbwHk0jkYrUZ0behC96Z8mHTd3QWz01zpcpc8KSUSs4miqbRkMAkAMLvMRWFja7FxKo0+N4YLXTcpJS6cugDfIV/hYcWZkzMAAEe7o6graV3fijpjXYUrpuuRiWUK25/D3jCi56OQqkS9ub54V1pnI+pM/Del68NwoStSsgom352E99DFgx0PehGbigEAlt2yrNCV9GzsgavXxd1JNULJKogGFuxK84WRS+UgDAKOdkdRd2Oyl3/XHFUHhgsVpMIp+I/4C11J4GgA2WQW9eZ6dN7VWdgS3HVPFyzuxY9EodojpUR8Ol40lZYKpQAAFo8lHzQXTxOwNlv5QwYBYLgsWVJKhL3hSw8qHvRh+sQ0pJSwtdgKO7h6NvagfbCd0yFUJB1JXzpNwBtGbDIGqUoYLcaigzkbOxu5A3CJYrgsEWpOxdTxqaKDHSOBCACg+abmovUSz0oPf/qkz0TJKIWDOedDR8koMNQZ4OhwFB3MabJxKm0pYLjUqHQ0Df8b/sIursLdJab83SU9m3rQfW8+UKzN1kqXSzVGqhKxqVhh+3PobAjpSBoAYG2yXjqYs8cJSxPvuKlFDJcaEfFHirqSyfcu3l3iseTXSe7tQs+mHnTc3sFzqKgiUuFU0bpNfCoOKSWMVmPRJgFHu4NTaTWA4VKFCneXXFwv8R70IuwNAwA8Kz2XupKN3WhefeW7S4gqaeEdN2FvGBF/pHDHjaPDUThNgHfcVCeGS5U489cz8L5+sTM54kM6koah3oCO2zsKi+/d93bD3vrZ7i4h0gtVURGfihcuVIv4IkhH81NptpZLB3N6Vnq4BboKMFyqhKqohd8LIQCx4PdENUhKiVwqh3Q4jVQ4hXQkjUw8g9Y1rVwnrAIMlyox/8/EMKGlTMkqEAbBy9KqAFd2NZZKpvDGy28gl8lVupTrJ4GVd6zEiptXVLoSqhHZbBb+j/xFHbfuScDT6YG7xV3pSmoCw0Vj8XAcx//nOL7y4FcqXco1Rc9HkZhJIGfJ4cRfTzBcSDPZVBaTE5PoXdcLqcr8phKdNt2pUArCIJDOpDF9eprhohGGSwm42l1Y+8W1up3CklIiGohi3+59CJ0JYXD3IKIiWumyqMZY7BYkTycR/CSIvvv64O7V3zdtJavg2J+PIRVKoWtLF9LZdKVLqhmcuFxipJSI+qMYe2AM3te9MDvNMDm484ZKIx1OIxKIFI711xs1pyIdTUPJKjziSGMMlyVkvmMZ+48x+A75sGzNMmz7/TZ4+j2VLo1qlKkx/4NLOpKGHvcOKRkFSiYfLLyZU1sMlyVCSomIP4Kx7WPwHfRh2dp8sDStbqp0aVTD5m+3nD/6RW+yiSzUrAqjxQiDkd8OtcS/zSVgPlj2bd93KVhezAeLXteFqDY0OC+Fix47l0wsA6lKmOwmbm/WGP82a1xRsBwq7lgYLFRqDY4GGOoMSEfTkIr+wiUdzodeQ2ODbnezVSuGSw27LFjWXQyWVQwWKg+jNT/dlE1koWSUSpdTREpZuPTM7DJXuJraw3CpUYt2LC8yWKi86s31MFqMUNIKsolspcu5TPJCfheb2cNw0RrDpQZdcSqMwUJlZqg3oKGxIb/lV2eL+lLJdy5CCFhcvLZbawyXGlMIlgc4FUaVJwwCZrcZUkok5/T1rIuSVfKnixsNhV1tpB2GSw2RUiLiuxgsh31oXdfKqTCqOFuzDQCQmEnoasdYJp5BLpmD0WpEvZXPuGiN4VIjiqbCLgbL8IvDDBaqOOuy/PH48dm4rsIlFUxBzakwu8yoM/LpfK0xXGpA0VTY4fxUGIOF9MLisaDOWIdkMKmbHWNSysI1y7YWG29rLQGGS5Wbnwobe2CMU2GkSya7CUarEdl4FploptLlFEQn84e12tt5e2spMFyqWGGNZfs++A/788HCxXvSmbqGOlibrVCyCuIz8UqXAwBQs/krlYVBwNZq43gpAYZLlSp0LNsXdCy/3wbPgIcDhXRFCAFHhwMAEPVHdbHukollkI6kYbQaYXFzG3IpMFyq0MJgme9Yhl8cZrCQLgkh4OjKh0skEIFUKx8usekYcukcbC021Ju5U6wUGC5VZuEai/+wH63rORVG+mdvtaO+oR6JmUTFn9SXUiJ8NgwAaOxq5GJ+iTBcqkhRsBy5GCwvciqM9M9kN8HSZEE2kUV8urLrLlKRiHgjEELAudzJsVMiDJcqwWChamaoN8DZ44SUEqHToYquu6QjacRn46i31MPeyp1ipcJwqQIMFqoF7j43hBAInQ5BVdSK1CClRNgfRi6Vg6PdAaPNWJE6lgKGi85JmW/hGSxUzeZ3jBltRsSmY4Wj7stOAsGPgwAAd7+b6y0lxHDRMXYsVEuMNiMauxqhZBRcmLhQkamxbDKL0NkQ6ox18PRzHJUSw0WnioLlDQYLVT8hBJpvagYAzP59tuxbkqWUCHvDyEQzsC2zwdLE51tKieGiQ/ODgB0L1RIhBNx9bhitRkQDUSSD5T2CX0qJmRMzkFKi+aZmGOr57a+U+LerM/PBsu+BfflguZVP3lPtMDlMcPe6kUvnMHtytqxTY5lIBsFTQdSZ6tB0E58LKzWGi44UBcvCqbCVDBaqDUIItK5rhRACUyemynZKspQSMx/OIJvIwtnjhLXJWpavu5QxXHTi08HSdmtbvmNhsFANEULAucIJS5MFiZlE2Z55UTIKJo9NQgiBttvauEusDBguOnBZx3LrxbPCGCxUg+ob6tF2axuklAi8GYBUShsuUkpcOHUB8ek4LE0W7hIrE4ZLhRUW7+8fu9SxcCqMapgQAq1rW2GymxA6E0LYHy5p96LmVATeCECqEu2D7TyoskwYLhW0cFdY4GgAbbe2sWOhJaHB2YC2W9ug5lT4D/lL9sS+lBLBiSBC3hDMTnNhvYdKj+FSIQunwgJvMFhoaRFCoOOODpjsJgQngiV7qFJJK/C+5oVUJTrv7ITJbtL8a9DiGC4VsNhUGIOFlhqzy4zOuzqhKirO/N8Z5FI5TT+/lBKTxyYRPReFtdmKtsE2jq8yYriU2cJgmZ8K4xoLLUVCCHTc2QFbqw3Rc9HCuogWpJRIzCZw9vWzgABWfGUFjFYeUllODJcymr+kqBAst+U7FvdKN4OFliSjxYi+zX0w1BvgO+JDyKvN1mQlo2DiTxPIxDJoubkFzTc3c4yVGcOlTArB8sClYGHHQkudEAKelR503NGBXDqHj1/+GKlQ6oYCRlVUeF/3IjgRhNllRt99fTDU8VtdufFvvAwW2xW27YVt+SO/GSy0xBnqDFjx5RVwLXchMZvARy99hGw8+7kCRlVVnH/7PHyHfTAYDRj4lwGY3WaOswpguJTYYlNh236/jVNhRAvUW+qx+t9Ww9pkxYXTF/Dhf32ITDTzmQJGVVScf+s8Jv48AQDo/adeNA3wDLFKYbiU0JWmwtixEBUTQsDiseAL274Ai9uC4CdBvP/C+4iei15zkV9KiWwyi9P/exqf/OkTSFVi+ZeXo/OuTh7zUkEMlxK6rGNhsBBdkRAC9nY71ty/Bo52B6Lnojj+3HGcefVMfh1GlUWdjJQyf7ryR7M4/txxeA97IeoE+v65Dz2berjOUmE8B6FECh3L3xgsRNdLCAHbMhvWbl+L06+cxtT7Uzjz2hmce/scnMudaOxqhMlugppVEZ+OI3QmhPhMHFKVsC2zYeXQSrj7eH2xHjBcSiCXymH/f+5nx0L0OQghYLKbsOpfV2HZmmXwHfEhfDaMmZMzmDk5c9nHmt1mtA+2o32wHUarkeNMJxguJVDfUI87H7kTakbF1qe3wt3vBoCK3BlOVGmf9/97YRBw97vhWuFCYi6B8Nkw4tNxZJNZGOoMMLvNcHY74eh0FB1GyXGmDwwXjQmDgP+kH0cNR5G7LYcDzx+odEnXFJmJoO/2vkqXQbVEAOHpMI6/elzbz2sFpFVCQCAjM4h4I4BXm0+dSqTg6fBo88kIQjLmNaUoCi5MX9DsGItysTltsNp5Ox9pQ1VVJONJoLqGAYwNRpgaeLilFhguRESkOe7VIyIizTFcqoSUsvCLaKlSFRXpaLpkl4uRdhguVWLy3Uk8UfcEJt+drHQpRBUTn47j8JOHEZ+OV7oUugaGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hUgWklAhdCAEAQhdCPHafliQpJYJzQYRDYQTnghwHOsdw0bFQKITR0VEMDAxg832bIaXE5vs2Y2BgAKOjowiFQpUukajkFo6DNWvXYPSpUaxZu4bjQOd4zbFOjY+PY3h4GIlEAgDQJtvwfXwfv8FvMCnyd7pYrVaMjY1haGiokqUSlcynx4FN2nAH7sBbeAtxkb/TheNAn9i56ND4+Di2bt2KZDK56O2T829LJpPYunUrxsfHK1QpUelwHFQ3hovOhEIhDA8PQ0oJVb36Va6qqkJKieHhYU4NUE3hOKh+DBedeeaZZ5BIJK45oOapqopEIoFnn322xJURlQ/HQfVjuOiIlBJ79uz5XK996qmnuHuGagLHQW1guOjI3NwcJiYmPvPgkFJiYmICwWCwRJURlQ/HQW1guOhILBa7oddHo1GNKiGqHI6D2sBw0RG73X7F981iFr/BbzCL2St+jMPhKEVZRGV1tXGQQAJv4S0kkLjix3Ac6APDRUeamprQ398PIcRl78sii/M4jyyyl71PCIH+/n54PJ5ylElUUlcbBypUxBCDissX+jkO9IXhoiNCCDz66KOf67U7d+5cdDASVRuOg9rAJ/R1JhQKoaurC8lk8rq2YRoMBlgsFvj9frhcrtIXSFQGHAfVj52LzrhcLoyNjUEIAYPh6v88BoMBQgjs27ePA4pqCsdB9WO46NDQ0BD2798Pi8UCIcRlbf782ywWCw4cOIAtW7ZUqFKi0uE4qG4MF50aGhqC3+/HyMgI+vr6it7X19eHkZERBAIBDiiqaRwH1YtrLlVASolgMIhoNAqHwwGPx8NFS1pyOA6qC8OFiIg0x2kxIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLS3P8Dh3IFDDpuFuQAAAAASUVORK5CYII=\n",
|
|
"text/plain": [
|
|
"<Figure size 500x200 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 100,
|
|
"id": "ad876b9d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_24271/2849209031.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mauto_symbolic\u001b[0;34m(self, a_range, b_range, lib, verbose)\u001b[0m\n\u001b[1;32m 1402\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'fixing ({l},{i},{j}) with 0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1403\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1404\u001b[0;31m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msuggest_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1405\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1406\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36msuggest_symbolic\u001b[0;34m(self, l, i, j, a_range, b_range, lib, topk, verbose, r2_loss_fun, c_loss_fun, weight_simple)\u001b[0m\n\u001b[1;32m 1332\u001b[0m \u001b[0;31m# getting r2 and complexities\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1333\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msymbolic_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1334\u001b[0;31m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1335\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1e8\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# zero function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1336\u001b[0m \u001b[0mr2s\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1e8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mfix_symbolic\u001b[0;34m(self, l, i, j, fun_name, fit_params_bool, a_range, b_range, verbose, random, log_history)\u001b[0m\n\u001b[1;32m 488\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspline_postacts\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0;31m#y = self.postacts[l][:, j, i]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 490\u001b[0;31m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msymbolic_fun\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 491\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1e8\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/Symbolic_KANLayer.py\u001b[0m in \u001b[0;36mfix_symbolic\u001b[0;34m(self, i, j, fun_name, x, y, random, a_range, b_range, verbose)\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[0;31m#initialize from x & y and fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 231\u001b[0;31m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 232\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuns_avoid_singularity\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun_avoid_singularity\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/utils.py\u001b[0m in \u001b[0;36mfit_params\u001b[0;34m(x, y, fun, a_range, b_range, grid_number, iteration, verbose, device)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mb_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrid_number\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0ma_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_grid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmeshgrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindexing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ij'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m \u001b[0mpost_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0mx_mean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpost_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0my_mean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.auto_symbolic()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"id": "428571e6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.5\n",
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 x 1.000000 -16.565706 1 1 -2.513141\n",
|
|
"1 cos 1.000000 -16.599499 2 2 -1.719900\n",
|
|
"2 sin 1.000000 -16.599499 2 2 -1.719900\n",
|
|
"3 exp 0.999997 -16.268112 2 2 -1.653622\n",
|
|
"4 x^0.5 0.999977 -14.896568 2 2 -1.379314\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('x',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 1,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.9999996907837526,\n",
|
|
" 1)"
|
|
]
|
|
},
|
|
"execution_count": 80,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.unfix_symbolic(0,0,0)\n",
|
|
"model.suggest_symbolic(0,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "0dea7189",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.4\n",
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 0 0.000000 0.000014 0 0 0.000003\n",
|
|
"1 cos 0.969503 -5.034727 2 2 0.593055\n",
|
|
"2 x^2 0.969092 -5.015413 2 2 0.596917\n",
|
|
"3 sin 0.965249 -4.846400 2 2 0.630720\n",
|
|
"4 x 0.000392 -0.000551 1 1 0.799890\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('0',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 0,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.0,\n",
|
|
" 0)"
|
|
]
|
|
},
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.unfix_symbolic(0,1,0)\n",
|
|
"model.suggest_symbolic(0,1,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 101,
|
|
"id": "ef60542b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARwAAAESCAYAAAAv/mqQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlkElEQVR4nO3de1xUdf4/8BcXAS8wpshNEchWtKjAYVVQ1DTH0HW77UNaW9HS/cZ+KyW6rGRfU3ID22KpLTFL7edX12y9tPWT0tlUMMFNCFoLM0sRVBBBncEbCHy+fxizDjPA+Zw5lzkz7+fjMY/HNvs557zPjPPi3D6fjwdjjIEQQhTgqXYBhBD3QYFDCFEMBQ4hRDEUOIQQxVDgEEIUQ4FDCFEMBQ4hRDHeahcgRHt7O86cOQN/f394eHioXQ4h5CaMMTQ1NSEsLAyent0fw2gicM6cOYPw8HC1yyCEdKOmpgZDhgzpto0mAsff3x/AjR0KCAhQuRpCyM3MZjPCw8Mtv9PuaCJwOk6jAgICKHAIcVJCLnfQRWNCiGIocAghiqHAIYQoRhPXcAghyrra0oZXCypR1XgFkQP74MXpt6O3j5fD63WZwJHrAyLE3fx+wyEYK+st/73/GPC/B6sx9fYgvJf6S4fW7aGFAbjMZjN0Oh1MJpPdu1SdP6AOk6MHYt1jY5UokRCX0NVvqYO90Onp93kzzV/D6e4D2nO0Eb9cYVS4IkK06WpLW7dhAwDGynpcbWkTvQ1NB46QD+jcpRY8/sFXClVEiHYt//Q7SdvZo+nAyfr/3wpqt+f7cw6lMiHuYPvXNYLaff5trehtaDpwDh4/L7it0HAixB21tLZD6N/ki1dbRW9H04Hj4yW8/KLv62SshBBte3ffT4psR9OB82DcYMFtT5vFpzIhru69L4UHTh8HHqbRdOA8Pv5WtUsgxCWYrwm/xmnMmCx6O5oOHB9vTZdPiCYNHtBb9LL0iyWEKIYChxCiGAocQohiKHAIIYrRfOB4ckziUHfxmnyFEKJRbe3K9d/mDpyioiLMnDkTYWFh8PDwwMcff9zjMoWFhdDr9fDz88Ott96K1atXi6nVrr4cQ1BMy/1Csu0S4ioKj54T3LaXg9viDpzLly/j7rvvxttvvy2o/YkTJzB9+nQkJSWhvLwcL774IhYuXIht27ZxF2tP8p2hgtuaWiTZJCEu5c+7vhfcdurtgQ5ti/uZweTkZCQnJwtuv3r1agwdOhR5eXkAgJEjR6K0tBSvv/46Hn74Yd7N21j+6xh8VHrK4fUQ4q5+rG8S3Pa3Yxx72Fb2azglJSUwGAxW702bNg2lpaW4fv263WWam5thNputXl3hGdWP5uwkxNb1duFtE3/h2BGO7IFTV1eH4OBgq/eCg4PR2tqKhoYGu8tkZ2dDp9NZXj3Nuin0MI0GHCXEMV48d2nsUOQuVecJsjpGNe1q4qzMzEyYTCbLq6am+3E6hHbLpO6bhKhL9kHUQ0JCUFdnPTREfX09vL29MXDgQLvL+Pr6wtfXV+7SCCEKk/0IJyEhAUaj9bjCu3fvRnx8PHr1cvQmG7+WVo4TVkKIpLgD59KlS6ioqEBFRQWAG7e9KyoqUF1dDeDG6VBqaqqlfVpaGk6ePImMjAwcOXIE69atw9q1a/Hcc89Jswec3vin+PFYCXE11Q1XFN0ed+CUlpYiLi4OcXFxAICMjAzExcVh6dKlAIDa2lpL+ABAVFQUCgoKsG/fPsTGxuKVV17BW2+9Jckt8Q48l7He3VfdcyNC3MR9bxYquj3uaziTJk1Cd1NZffDBBzbvTZw4EV9//TXvpgQLDfDFGXOzbOsnxFVd4bgn/vCoEIe3p/m+VADwj6eS1C6BEJe34oFYh9fhEoEzKIDuaBEiNymmznaJwCGEaAMFDiFEMW4ZOOfoAjMhqnCZwOnF0ccjOfefMlZCiDacqL+s+DZdJnCCOC4cN9DAf4Qo/gwO4EKB8+iYCLVLIERTmtuEDy0a2V+abkguEzgLkmgWTkLksv2pSZKsx2UCh2bhJEQ+A/r5SLIe+pUSQhTjtoGj5NQYhJAb3DZw9n13Vu0SCFHN+UvqTGHiUoHTi6Orx0v/qJCtDkKc3W/yD6iyXZcKnHG3Ch9RvvZSm4yVEOLcjjcqO/BWB5cKnLcf1atdAiEu58HYIMnW5VKB089P9jHhCXE7rz40SrJ1uVTgEEKkJ8U4OB3cOnDo1jghynLrwNlzuK7nRoS4mNPnr6q2bZcLHJ4dem6bfAO7E+Ks1Ogl3sHlAmf8bcJvjZvUefaJEFU1NQt/JGTRvdJ2ina5wFn1O7o1TohUnpwULen6XC5w6NY4IdKRehQGlwscQojzcvvAqbtI440S93HpWquq23f7wLn3z1+oXQIhivnvjWWqbt8lAyeUY0B16sNJ3MmXPzaoun2XDJxPaK5xQuxq52gb7i99PLhk4NBc44Q47h+Lpki+TpcMHF7Up4oQW1INnH4zChwAxcfUPa8lRAnOcEdWVOCsWrUKUVFR8PPzg16vx/79+7ttv2nTJtx9993o06cPQkND8dhjj6GxsVFUwXLYWHxM7RIIkV3yW0Vql8AfOFu2bEF6ejqWLFmC8vJyJCUlITk5GdXV1Xbbf/nll0hNTcX8+fPx3Xff4e9//zsOHTqEBQsWOFx8dwb2ET5T4K6jF2SshBDncOHKdcFtfTzkqYE7cHJzczF//nwsWLAAI0eORF5eHsLDw5Gfn2+3/cGDBxEZGYmFCxciKioK48ePxxNPPIHS0lKHi+/OzoUTZF0/Ia7szd/EyrJersBpaWlBWVkZDAaD1fsGgwHFxcV2l0lMTMSpU6dQUFAAxhjOnj2LrVu3YsaMGV1up7m5GWaz2erFK6S/H/cyhJAbDHFhsqyXK3AaGhrQ1taG4OBgq/eDg4NRV2d/MKvExERs2rQJKSkp8PHxQUhICPr374+//vWvXW4nOzsbOp3O8goPD+cpU5SWVp4nFAjRFt47sV6e8pxTibpo7OFhXQxjzOa9DpWVlVi4cCGWLl2KsrIyfP755zhx4gTS0tK6XH9mZiZMJpPlVVNTI6ZMLvmFP8i+DULUUqzyE8YduMZyCAwMhJeXl83RTH19vc1RT4fs7GyMGzcOzz//PADgrrvuQt++fZGUlIQVK1YgNDTUZhlfX1/4+ir78F6e8ScsmjJC0W0SopQtpfL/0RaC6wjHx8cHer0eRqPR6n2j0YjExES7y1y5cgWentab8fK6MQo8Y/I+cHd7aIDgtvToH3FlBzieNesr/AYvN+5TqoyMDLz//vtYt24djhw5gmeeeQbV1dWWU6TMzEykpqZa2s+cORPbt29Hfn4+jh8/jgMHDmDhwoUYPXo0wsLkuTDVYfPvx8q6fkK04sJV4bfEdz8zWbY6uIfHS0lJQWNjI7KyslBbW4uYmBgUFBQgIiICAFBbW2v1TM68efPQ1NSEt99+G88++yz69++PyZMnY+XKldLtRRd0HM/iADcuHEs9whkhWjN4QG/Z1u3B5D6vkYDZbIZOp4PJZEJAgPDTJACIXLxTcNs/3BOBP06L4S2PEKdW3XAFE17fK7h9VU7Xj6zYw/P7dPk/5zw7mL/3pGx1EKKWaSpOC9OZywcOz7QxhLiiq9eFP2PmK1OXhg4uHzg0bQwhwu15Xr4LxoAbBA7vtDFqDzJNiJSutvCNoSvnBWPADQKH1x/+X4naJRAimRf+/o3aJVhxi8Dx5jgv3X+Cv6MoIc7q08O1apdgxS0C547BfLfSCSHycIvA2fA4PXFMSE8OvCDvBWPATQKH94ljZxj7lRBHmThG+APkv2AMuEng8Lonh2bjJNo3a7X9QfHU5DaB08VwPXZdla8MQhRztP6S2iXYcJvA+fNDd6tdAiFuz20C50H9YK72NOQocSe7FJp0wG0Ch3eM1jf3HJGpEkLkV1F1kat9dJi/PIV04jaBw+udPVVql0CIaA+sPqB2CXa5VeBE3CL/bT9CtOYWGYcU7cytAmfHk+PVLoEQp7P7+XsV25ZbBc6Afj5c7U+fpxvkRHt4/90OClBuhhS3Chxek/+8R+0SCOHmTCP8deZ2gePFcbOq2elHeybE1qVmvjFwlOR2gbNr0US1SyDEaSj1/E0Htwuc20L6cbWnEQCJlhw908TVXqnnbzq4XeDwWvB+kdolECJY8lvO/e/VLQOH5zrOwVN0p4poh7N3yHHLwHn9N9SRk5CkKOVHwnTLwPl1HF9HzspTNM4xcX5lxy9wtc+fmyBTJV1zy8Dh7cg5/e39MlVCiHQeXsM34BbvFEpScMvAAQDOzCGESMBtA+ezp/meP+CdUIwQJfGO3/TEpKEyVdI9tw0c3ucPluwol6kSQhyXs5Nv/KZn771Dpkq657aBw2t7+Vm1SyCkS+tKqrja+3ir89N368Dxcuu9J0R5on5yq1atQlRUFPz8/KDX67F/f/d3cZqbm7FkyRJERETA19cXw4YNw7p160QVLKVdC/n6VdF1HOKMTtRf5mqvdP+pm3EHzpYtW5Ceno4lS5agvLwcSUlJSE5ORnV1dZfLzJo1C1988QXWrl2Lo0ePYvPmzRgxYoRDhUuBt1/Vs5tLZKqEEPEMefu42ivdf+pm3Dfic3NzMX/+fCxYsAAAkJeXh127diE/Px/Z2dk27T///HMUFhbi+PHjGDBgAAAgMjLSsapVUnDEpHYJhNi47uz9GW7CdYTT0tKCsrIyGAwGq/cNBgOKi+0/dPTJJ58gPj4er732GgYPHozhw4fjueeew9WrXfdRam5uhtlstnrJZeqIIK72NH0M0bJg5Qb3s4srcBoaGtDW1obg4GCr94ODg1FXV2d3mePHj+PLL7/Et99+ix07diAvLw9bt27Fk08+2eV2srOzodPpLK/w8HCeMrm8NXsUV/uVuw/LVAkh/H6s45td87Pnp8pUiTCiLhp7dJo3lzFm816H9vZ2eHh4YNOmTRg9ejSmT5+O3NxcfPDBB10e5WRmZsJkMlleNTU1YsoUpLePF1f7tUWnZKqEEH7T3uIbTpR3XG+pcV3DCQwMhJeXl83RTH19vc1RT4fQ0FAMHjwYOp3O8t7IkSPBGMOpU6fwi1/8wmYZX19f+Poqd+zXy1Nb58GEdGjT2L9briMcHx8f6PV6GI1Gq/eNRiMSExPtLjNu3DicOXMGly7959Dvhx9+gKenJ4YMGSKiZOntTp/E1Z73NiQhcuAdxUDN2+EduE+pMjIy8P7772PdunU4cuQInnnmGVRXVyMtLQ3AjdOh1NRUS/vZs2dj4MCBeOyxx1BZWYmioiI8//zzePzxx9G7t3NMTBcV1Jer/T25++QphBAOMzhHMVDzdngH7tviKSkpaGxsRFZWFmpraxETE4OCggJEREQAAGpra62eyenXrx+MRiOefvppxMfHY+DAgZg1axZWrFgh3V4Q4oa0OKmIB2PM6es2m83Q6XQwmUwICJBnlLL1RSewvKBScPtvlhqg66PgHKmE3OT8pRaMWmHsueHP5k8Ygv+ZLs9Ilzy/T+pN9LPU8ZFc7R/I3S1PIYQI8FD+Aa72fzTcKVMlfChwfsY7CuAJvscfCJFUVeMVrvZq9Q7vzDmqcBI6P75ncqob+L50QqRgunKdq/0tTjS6JQXOTXZx3h6f8PpeeQohpBsp7/KNXfzFEnWfLr4ZBc5NQvr7qV0CIT36/izf+bzaTxffjAKnk18OvYWrfd3FazJVQoit85da1C7BIRQ4nax/fDRX+6ScL2SqhBBb97/D97Dfx2njZKpEHAqcTnjn6uG7fEeIY2ou8B1Rx0b2l6cQkShw7OAd65j3rgEhYmj9dAqgwLGLd6zjX71GDwES+T34zpdc7Z2hs2ZnFDh28I51XEPXjYkCTl7oepRMe5yhs2ZnFDhd4D2tOlxN4x0T+Zw+zxc2w+TpcugwCpwu8J5WzVzFd7hLCA/DX/Zxtd+ebui5kQoocLrAe1pFiJwucw5J6awjGVDgdGOIju/JY+pbReRQdvwCV/voAdzDXCmGAqcbnzydxNWe+lYROTy8hq/v1EdPTZapEsdR4HRDTB8UmreKSEnM9NLOejoFUOD0aNpI+7NRdOWVzyrkKYS4pfTN5VztMw1RMlUiDQqcHuT9No6r/f8eqJWpEuKOdh05y9V+waSRMlUiDQqcHvBOlAcA58zNMlRC3I2Y6Yh4R65UGgWOALyPiE949Z8yVULciSGPb1bNT/97vEyVSIcCRwDeR8T5ngklxL7r7XwTqtw5VNdzI5VR4AjE+0Ed/KFRljqIe9hTUddzo5tMinDeO1M3o8AR6DPO06pH1h2UqRLiDh7/sIyrff78KTJVIi0KHIHE9LylOciJGD/W8c9BJObmhhoocDh8+PhYrvY0BzkR417Oi8XvzxolUyXSo8DhMHb4QO5laDRAwoN3GAoAuHdUqAyVyIMCh1Mvzk9s5p9pNEAiHO8wFFpDgcPpi4x7uNpX0z1ywoF3GIoDLzhvR017KHA4DQ3sw72MmIuAxP0UfVvPvczgAb1lqEQ+FDgiFDzFN2wF70VA4p5SNx7iaq+FJ4s7o8AR4fYh/APG0lEO6Y6YMbG18GRxZxQ4Ik0aFsjVno5ySHd4x8T+zWi+f3/OQlTgrFq1ClFRUfDz84Ner8f+/cKmHz1w4AC8vb0RGxsrZrNOJX9uPPcyl661ylAJ0bqvfjzPvcyrv/6lDJXIjztwtmzZgvT0dCxZsgTl5eVISkpCcnIyqquru13OZDIhNTUVU6Zo4xHsnvT28QLvs52jsnbJUgvRtlnvl3C17+cJ+Hhr8+SEu+rc3FzMnz8fCxYswMiRI5GXl4fw8HDk5+d3u9wTTzyB2bNnIyEhQXSxzubgi/dytW9ppwcBiTUx/x4OvOScU8AIwRU4LS0tKCsrg8FgvcMGgwHFxV0P9Lx+/Xr89NNPePnllwVtp7m5GWaz2erljAYF+HIvY6BpgclNErP5/z0485jFPeEKnIaGBrS1tSE42Hqc3+DgYNTV2e9Of+zYMSxevBibNm2Ct7ew6Suys7Oh0+ksr/DwcJ4yFcV7a/LsNaCNc5wT4pqutrThMucBzt6MSbLUohRRJ4IeHtbDGDLGbN4DgLa2NsyePRvLly/H8OHDBa8/MzMTJpPJ8qqpqRFTpiLE3JrcevCEDJUQrblz6efcy0QF9ZWhEuVwBU5gYCC8vLxsjmbq6+ttjnoAoKmpCaWlpXjqqafg7e0Nb29vZGVl4ZtvvoG3tzf27Nljdzu+vr4ICAiwejmzj9PGcbX/4ydHZKqEaIXpynXw3rP8YDb/nVFnwxU4Pj4+0Ov1MBqNVu8bjUYkJibatA8ICMDhw4dRUVFheaWlpSE6OhoVFRUYM2aMY9U7idjI/tzLfFl5TvpCiGaMeYX/2s2ku/imLHJG3HOCZmRkYM6cOYiPj0dCQgLWrFmD6upqpKWlAbhxOnT69Gls2LABnp6eiImJsVo+KCgIfn5+Nu9r3cbU0fjdhq8Et//dhq9QlTNDxoqIs7ra0oZrnJfxNqaOlqcYhXEHTkpKChobG5GVlYXa2lrExMSgoKAAERERAIDa2toen8lxReNvH8S9zMpPKvHHX98uQzXEmY1axn/tRsy/L2fkwRhz+lsmZrMZOp0OJpPJqa/nVFRdxAOrD3At89Or051+LiEinUvXWhGzjO8B0HWP6DE5NkSmihzH8/vU5uOKTkrMtZzEV409NyIugzdsADh12PCiwJHYRwv4nqQ+e+k69bFyE3UXr3EvwzsJo7OjwJHY6NsGcC9zp4i/ekR7xuZ8wb2MmNlCnBkFjgzWPaLnas8g7q8f0Y6jZ5q4l9nwO232CO8OBY4MxJxzi/nrR7Rj2ltF3MtMiAmSoRJ1UeDIREyfFzHjohDnV3mKv/Mx7xxoWkGBIxMxfV54x0Uh2jD9bWED1N1MzBxoWkCBIyMxg1yLGbmfOK/IxTu5l+Htm6clFDgyEtOTnHfkfuK8xMyiCYh7nksrKHBkdnAx/5CqwzL5/yoS5zPuNfujIXSHdwoiraHAkVlIfz/uZdqY+L+OxDmImfYFEDcFkZZQ4CjgEOfYx4C4v47EefBO+wIA3y6bJkMlzoUCRwFixj4GgClv7JO2EKIIMReKh/b3Qz8/7sEbNIcCRyFixr756dxl6melMefMzaKWKxJxrU+LKHAUJGYQJTG9i4l6fvnqP7mXcZXBtYSgwFGQ2EGUFm76WuJKiBxm/ZU/bADXGVxLCAochYnp8vDJ4Vq0tLZLXwyRzNWWNnx1mv906kjWfTJU47wocBQmdpqP4S99JnElREojRUz5EjdYh94+vBNGaxsFjgrEDp4eu5xm7XRGYu5KAcCOp/m7vmgdBY5KDrwwmXuZi1ev4/ylFhmqIWKJ/T7EPJvlCihwVDJ4QG9Ry41aQWMgOxMx30fvXp6in83SOgocFYk9tRJ7CE+kJfZ7OPJKssSVaAcFjsqKnrtH1HIVVRelLYRwmfSquLD5+qWpEleiLRQ4Khsa2EfUcrzzXxHpXLrWiir+Qfzg7+ONAf18pC9IQyhwnACdWmmL2Ke/D2e5fufMnlDgOAkxd60ACh2lif28v1lqkLgSbaLAcRJi71oBwPCXCiSshHRFbNgM1vlB16eXxNVoEwWOExF7atXSymheK5ltOPCj6GUPZLpHT3AhKHCczE+vThe1HM1rJZ+2doalnx4VtazYPyKuigLHyXh5euCl6beJWpau58hj2IviTlnFdNR1dRQ4TmjBhGjRy1LoSEvs5+kB8R11XRkFjpNy5FCcQkcajnyOJ+hUyi5RgbNq1SpERUXBz88Per0e+/d3PbPg9u3bMXXqVAwaNAgBAQFISEjArl00ip0QFDrqceTzo+s2XeMOnC1btiA9PR1LlixBeXk5kpKSkJycjOrqarvti4qKMHXqVBQUFKCsrAz33HMPZs6cifLycoeLdweODNBEoSMOhY18PBhjjGeBMWPGYNSoUcjPz7e8N3LkSDzwwAPIzs4WtI477rgDKSkpWLp0qaD2ZrMZOp0OJpMJAQGuPW+PPY+uLsSBqkuil6cfgXCOhE3Rc/eI7qqiZTy/T64jnJaWFpSVlcFgsH5q0mAwoLi4WNA62tvb0dTUhAEDBnTZprm5GWaz2erlzjalTXRoeTrSEcaRz8nLQ3y/OHfCFTgNDQ1oa2tDcHCw1fvBwcGoq6sTtI433ngDly9fxqxZs7psk52dDZ1OZ3mFh4fzlOmSHD1KodDpnqOfz0/ZdBQphKiLxh4eHlb/zRizec+ezZs3Y9myZdiyZQuCgoK6bJeZmQmTyWR51dTUiCnT5VDoyMPRz4VOWYXjCpzAwEB4eXnZHM3U19fbHPV0tmXLFsyfPx8fffQR7r23++EVfX19ERAQYPUiN1DoSIvCRllcgePj4wO9Xg+j0XpYRaPRiMTExC6X27x5M+bNm4e//e1vmDGDviBHSRE6V1vaJKpGm06fv0phowLuyYwzMjIwZ84cxMfHIyEhAWvWrEF1dTXS0tIA3DgdOn36NDZs2ADgRtikpqbizTffxNixYy1HR71794ZOp5NwV9xLVc4Mh34wI5d+Dn2IN7alu98YLcMyd6KN696sLQobcbiv4aSkpCAvLw9ZWVmIjY1FUVERCgoKEBERAQCora21eibn3XffRWtrK5588kmEhoZaXosWLZJuL9yUo//oy+pa3e4UK3IxhY2auJ/DUYO7P4fTEylCY2/GJJfu+3P+UoskM15Q2Nji+X1S4LgIqY5UXPEHdefSXWhqaXV4Pa742UhBtgf/iPOS6scQuXgnTp+/Ksm61PbVj+cRuXgnhY0ToSMcFyPlNRkt/8joc1AOHeG4MSl/HJGLd+KTg/Y75TqrjYXHKGycGB3huCip7z7l/voOPJQYKek6pbRh3w9Y+vkxydbX3xuoWEFhIwRdNCYAgNjlu3Hx6nVJ17ls2m2Yd4/4EQmllvfZN8grPCXpOr9ZaqBZFjhQ4BALqW4Hd+blAex9Vp3hGL768TxmvV8iy7rpFIofBQ6xIecDfneG9sFHf5iA3j5esm3DdOU64rJ2o122LVDYiEWBQ+y6LXMnWhX4tnNmjMAjScMcXs9L20qw8dB5CSrqGYWNeBQ4pEunz1/FuNf2qLLtj9PGITayv837u0vP4L+2qjPk7MHFUxDS30+VbbsKChzSI3frQ2UPHdVIg+f3yd1bnLiGqpwZqKi6iAdWH1C7FMX9bd4YJI4IVLsMt0QP/rmx2Mj+qMqZAa+eB2t0CRHeN4KWwkY9dIRD8FP2DFWv7SjhSNZ9st5FI8JQ4BAAwOABvVGVMwNlxy/g4TXCZuDQgl0LJyA6zF/tMsjPKHCIFf2tt6AqZwaOnmnCtLeK1C5HtEMv3otBAb5ql0E6ocAhdkWH+aMqZwaqG65gwut71S5HkCEBvtiZPpG6JTgxChzSraGBfSy3jz/YexTLdv2ockW2nL1jKfkPeg6HiPK3oh/xYsFR1bafPnEI0pPvVm375D/owT+iqD0VdXj8wzJZt6FmZ1HSPXrwjyhqcmwIqmJtn9oV22VhY+pojL99kBSlESdDgUNkY4gPQ1V8mNplECdCTxoTQhRDgUMIUYwmTqk6rmubzWaVKyGEdNbxuxRy/0kTgdPU1AQACA8PV7kSQkhXmpqaoNPpum2jidvi7e3tOHPmDPz9/eHh0XXXZrPZjPDwcNTU1LjM7XPaJ21w531ijKGpqQlhYWHw9Oz+Ko0mjnA8PT0xZMgQwe0DAgJc5kvvQPukDe66Tz0d2XSgi8aEEMVQ4BBCFONSgePr64uXX34Zvr6uMywB7ZM20D4Jo4mLxoQQ1+BSRziEEOdGgUMIUQwFDiFEMRQ4hBDFUOAQQhSj+cD505/+hMTERPTp0wf9+/cXtAxjDMuWLUNYWBh69+6NSZMm4bvvvpO3UA4XLlzAnDlzoNPpoNPpMGfOHFy8eLHbZebNmwcPDw+r19ixY5Up2I5Vq1YhKioKfn5+0Ov12L9/f7ftCwsLodfr4efnh1tvvRWrV69WqFLhePZp3759Nt+Hh4cHvv/+ewUr7lpRURFmzpyJsLAweHh44OOPP+5xGUm+I6ZxS5cuZbm5uSwjI4PpdDpBy+Tk5DB/f3+2bds2dvjwYZaSksJCQ0OZ2WyWt1iB7rvvPhYTE8OKi4tZcXExi4mJYb/61a+6XWbu3LnsvvvuY7W1tZZXY2OjQhVb+/DDD1mvXr3Ye++9xyorK9miRYtY37592cmTJ+22P378OOvTpw9btGgRq6ysZO+99x7r1asX27p1q8KVd413n/bu3csAsKNHj1p9J62trQpXbl9BQQFbsmQJ27ZtGwPAduzY0W17qb4jzQdOh/Xr1wsKnPb2dhYSEsJycnIs7127do3pdDq2evVqGSsUprKykgFgBw8etLxXUlLCALDvv/++y+Xmzp3L7r//fgUq7Nno0aNZWlqa1XsjRoxgixcvttv+hRdeYCNGjLB674knnmBjx46VrUZevPvUETgXLlxQoDrHCAkcqb4jzZ9S8Tpx4gTq6upgMBgs7/n6+mLixIkoLlZ/xsmSkhLodDqMGTPG8t7YsWOh0+l6rG/fvn0ICgrC8OHD8fvf/x719fVyl2ujpaUFZWVlVp8vABgMhi7rLykpsWk/bdo0lJaW4vr167LVKpSYfeoQFxeH0NBQTJkyBXv3amN+L3uk+o7cLnDq6uoAAMHBwVbvBwcHW/4/NdXV1SEoKMjm/aCgoG7rS05OxqZNm7Bnzx688cYbOHToECZPnozm5mY5y7XR0NCAtrY2rs+3rq7ObvvW1lY0NDTIVqtQYvYpNDQUa9aswbZt27B9+3ZER0djypQpKCrS5mymUn1HTjk8xbJly7B8+fJu2xw6dAjx8fGit9F5XB3GWLdj7ThK6D7Zqw3oub6UlBTL/46JiUF8fDwiIiKwc+dOPPTQQyKrFo/387XX3t77auLZp+joaERHR1v+OyEhATU1NXj99dcxYcIEWeuUixTfkVMGzlNPPYVHHnmk2zaRkZGi1h0SEgLgRmKHhoZa3q+vr7dJcCkJ3ad///vfOHv2rM3/d+7cOa76QkNDERERgWPHjnHX6ojAwEB4eXnZ/OXv7vMNCQmx297b2xsDBw6UrVahxOyTPWPHjsXGjRulLk8RUn1HThk4gYGBCAwMlGXdUVFRCAkJgdFoRFxcHIAb5+iFhYVYuXKlLNsEhO9TQkICTCYTvvrqK4wePRoA8K9//QsmkwmJiYmCt9fY2IiamhqrUFWCj48P9Ho9jEYjHnzwQcv7RqMR999/v91lEhIS8Omnn1q9t3v3bsTHx6NXL/XnCRezT/aUl5cr/n1IRbLviOsSsxM6efIkKy8vZ8uXL2f9+vVj5eXlrLy8nDU1NVnaREdHs+3bt1v+Oycnh+l0OrZ9+3Z2+PBh9tvf/tbpbovfddddrKSkhJWUlLA777zT5rb4zfvU1NTEnn32WVZcXMxOnDjB9u7dyxISEtjgwYNV2aeOW8hr165llZWVLD09nfXt25dVVVUxxhhbvHgxmzNnjqV9xy3XZ555hlVWVrK1a9c67W1xofv0l7/8he3YsYP98MMP7Ntvv2WLFy9mANi2bdvU2gUrTU1Nlt8KAJabm8vKy8stt/nl+o40Hzhz585lAGxee/futbQBwNavX2/57/b2dvbyyy+zkJAQ5uvryyZMmMAOHz6sfPFdaGxsZI8++ijz9/dn/v7+7NFHH7W5vXrzPl25coUZDAY2aNAg1qtXLzZ06FA2d+5cVl1drXzxP3vnnXdYREQE8/HxYaNGjWKFhYWW/2/u3Lls4sSJVu337dvH4uLimI+PD4uMjGT5+fkKV9wznn1auXIlGzZsGPPz82O33HILGz9+PNu5c6cKVdvXcdu+82vu3LmMMfm+IxoPhxCiGLe7LU4IUQ8FDiFEMRQ4hBDFUOAQQhRDgUMIUQwFDiFEMRQ4hBDFUOAQQhRDgUMIUQwFDiFEMRQ4hBDF/B+K/2VJMJWQCAAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 300x300 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"x, y = model.get_fun(0,1,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 102,
|
|
"id": "8f77c061",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n",
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.utils import create_dataset_from_data\n",
|
|
"\n",
|
|
"dataset2 = create_dataset_from_data(x[:,None], y[:,None])\n",
|
|
"model2 = KAN(width=[1,1,1])\n",
|
|
"model2.fix_symbolic(0,0,0,'x^2',fit_params_bool=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 103,
|
|
"id": "1c62302d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAf/ElEQVR4nO3de3DU9b3/8ddnc93cCIkRaxEkMaiAWOWmAioVEiQ9g0LH29GWDnUocqmMU88ZqUcOqBRHJUGOhxZrBWobe4zKIBzglKtcLBQELT8ugRytCaAI2TSXJQnZz++PL8kJEGKUb7Kb3edjhmFnv9ndd2ZIXrw/t6+x1loBAOAiT7ALAACEH8IFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4LroYBcAdAbWWp08eVJVVVVKSkpSenq6jDHBLgsIWXQuQCt8Pp8KCgqUnZ2tjIwM9erVSxkZGcrOzlZBQYF8Pl+wSwRCkuFOlEDL1qxZo/Hjx6umpkaS0700auxaEhISVFRUpNzc3KDUCIQqwgVowZo1a5SXlydrrQKBwEW/zuPxyBijlStXEjBAM4QLcB6fz6fu3bvL7/e3GiyNPB6PvF6vSktLlZqa2v4FAp0Acy7AeZYsWaKampo2BYskBQIB1dTUaOnSpe1cGdB50LkAzVhrlZ2drZKSEn2THw1jjDIzM1VcXMwqMkCEC3COr776ShkZGZf0+vT0dBcrAjonhsWAZqqqqi7p9ZWVlS5VAnRuhAvQTFJS0iW9Pjk52aVKgM6NcAGaSU9PV1ZW1jeeNzHGKCsrS2lpae1UGdC5EC5AM8YYTZs27Vu9dvr06UzmA2cxoQ+ch30uwKWjcwHOk5qaqqKiIhlj5PG0/iPSuEP/nXfeIViAZggXoAW5ublauXKlvF6vjDEXDHc1Puf1erVq1Srl5OQEqVIgNBEuwEXk5uaqtLRU+fn5yszMPOdaZmam8vPzVVZWRrAALWDOBWgDa602bNigu+66S+vWrdOIESOYvAdaQecCtIExpmlOJTU1lWABvgbhAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECfI36+nqVlZVp//79kqQjR47o1KlTCgQCQa4MCF3c5hi4CJ/Pp6KiIr355pvat2+fKisrVVdXp/j4eGVkZGj48OGaOHGihg4dqujo6GCXC4QUwgVowfbt2zVjxgx9/PHHGjRokPLy8tS/f38lJSXJ5/Np165dWrFihQ4fPqz7779fzz77rDIyMoJdNhAyCBfgPGvXrtWECROUlJSkuXPnasyYMaqrq1NhYaFqa2uVkpKiBx54QPX19SosLNSsWbPUt29fLVu2TN26dQt2+UBIIFyAZg4dOqTRo0crMTFRhYWF6tOnj4wxKikp0c0336yKigr16tVLu3btUteuXWWt1ZYtW/TQQw/pzjvv1Guvvaa4uLhgfxtA0DGhD5zV0NCg559/XuXl5Vq4cGFTsLTGGKNhw4bphRde0PLly7V69eoOqhYIbYQLcNbhw4e1YsUKjRs3TsOGDfvaYGlkjNE999yjW265RYsXL9aZM2fauVIg9LHEBThr27Ztqqqq0vjx4/Xpp5+qurq66VppaakaGhokSXV1ddq3b59SUlKarl955ZUaN26cZs2apePHj6t79+4dXj8QSggX4KwDBw4oISFBmZmZmjRpkrZu3dp0zVqr2tpaSdLRo0c1atSopmvGGL300ku64YYbVFNTo6NHjxIuiHiEC3CW3+9XdHS04uLiVFtbq9OnT7f4ddbaC66dOXNGXq/3nBACIhnhApx1+eWXy+/3y+fzaciQIUpMTGy65vf7tW3btqYQue2225o2Thpj1KNHD3355ZfyeDzq2rVrsL4FIGQQLsBZAwYMUH19vXbs2KF58+adc62kpESDBg1SRUWFunXrprfeekupqalN140xeuqpp3TFFVcwJAaI1WJAk8GDByszM1NLlixRdXW1oqKizvnTyBgjj8fT9LzH49GxY8f09ttvKy8vT126dAnidwGEBsIFOCs9PV1Tp07V7t27tWDBgjYvKa6trdWcOXPk9/s1adKkNi9hBsIZw2JAMxMmTNDmzZs1b948JSQkaPLkyYqPj5ckRUdHKzo6uqmLsdaqsrJSzz33nAoLCzV//nxde+21wSwfCBkc/wKc58SJE5oyZYref/995ebmasaMGbr++ut18OBBBQIBxcbG6pprrtGOHTv04osvas+ePZo9e7YmT558zvAZEMkIF6AF1dXVWrx4sRYsWKAvvvhCmZmZys7OVnJyssrLy3Xw4EEdPXpUAwYM0DPPPKM77rhDHg+jzEAjwgVoxfHjx7Vu3Tpt2rRJe/fu1Y4dOzR8+HANHTpUOTk5GjJkiBISEoJdJhByCBegjXbu3KnBgwdr586dGjhwYLDLAUIafTzQRlFRUU3LkAG0jp8SAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOu4nwvQRtZaBQIBeTweGWOCXQ4Q0uhcgG+Ae7kAbRMd7AIAt1hrVVxcrJMnTwa7lEvi8XjUr18/JSYmBrsU4FtjWAxhIxAIaMqUKbrqqquUlJR0Se/T0NCgqKiooHQqH3zwgZ5++mn179+/wz8bcAudC8JKXFycJk6cqG7dun2j1wUCARUXF+vdd9/V9u3bderUKaWlpWnIkCEaO3asrrvuug6Za7HWqqqqSvyfD50d4YKIZq1VRUWF8vPztWjRIp04ceKc6++//75efvllPfLII3ryySd1xRVXMJkPtAHhgohlrVVpaakee+wxrV69WpJ044036gc/+IF69Oihzz//XKtXr9aePXu0YMECbdmyRa+++qoGDhxIwABfg3BBRLLWqqysTD/60Y+0efNmdenSRf/6r/+qRx99VKmpqTLGyFqrJ554Qm+99ZZmz56t3bt367777tPrr7+uO++8k4ABWsG6SkScxqGwKVOmaPPmzcrIyNBvf/tbPfHEE+ratWtTaBhjlJKSop/+9Kd67733dOONN+rvf/+7JkyYoO3btzMvArSCcEHEOXPmjObOnatVq1YpJSVFCxcu1NixYxUVFdXi1xtjdPPNN+uPf/yjbrrpJpWWlurRRx/VoUOHCBjgIggXRBRrrVatWqVXX31VHo9HM2fO1L333vu1S46NMerdu7d+97vfKSsrSwcOHND06dNVXl7eQZUDnQvhgohhrdWxY8f0y1/+UjU1NRo7dqwee+yxi3Ys5zPGqF+/flq4cKG6du2qdevWae7cuTpz5kw7Vw50PoQLIkYgEND8+fO1f/9+XXXVVZozZ468Xu83eg9jjEaOHKmZM2cqKipKixYt0sqVKxkeA85DuCAiWGu1a9cuvf766/J4PHryySfVu3fvb7Xiy+PxaNKkSRo7dqxqamo0c+ZMlZWVETBAM4QLIkJdXZ1efPFF+Xw+3XrrrXr44YcvaSmx1+vVnDlz1LNnT+3fv1+/+tWvGB4DmiFcEPastdqyZYtWrVql+Ph4/eIXv1BycvIlvWfjBP/MmTMVExOjZcuWaePGjXQvwFmEC8JebW2tCgoK5Pf7NXLkSI0cOdKVDZDGGD344IPKyclRVVWVZs+erYqKChcqBjo/wgVhzVqrbdu2af369fJ6vZo+fbri4uJce3+v16unn35aaWlp+stf/qLf//73dC+ACBeEuTNnzmjRokXy+/0aMWKEhg4d6uqxLY0bLCdOnNi0Gu2zzz4jYBDxCBeELWut9u7dq7Vr1yo2NlY/+9nPXO1aGkVFRWnq1Km65ppr9Omnn2rhwoUKBAKufw7QmRAuCFvWWr3xxhuqrKzUwIED2/Wwye9+97t6/PHHFRUVpaVLl+qTTz6he0FEI1wQtj7//HMtX75cHo9HP/nJT5SQkNBun2WM0QMPPKABAwbo5MmTevnll1majIhGuCAsWWv13nvv6dixY+rVq5fy8vLa/Yj8Ll266IknnlBsbKyWL1+uDz/8kO4FEYtwQViqqqrSH/7wB1lr9cMf/lCXX355u3+mMUZjxozR7bffrqqqKs2fP191dXXt/rlAKCJcEJY+/PBDffLJJ0pJSdF9993XYTf28nq9mjFjhuLj47V27Vpt3ryZ7gURiXBB2GloaFBhYaFqa2s1bNgw9enTp8M+2xijESNGKCcnR36/X/n5+Tp9+nSHfT4QKggXhJ2ysjKtXbtWHo9HDz74oGJiYjr082NjY/Xzn/9cCQkJ2rBhgzZs2ED3gohDuCDs/O1vf1N1dbV69uyp73//+x1+r3tjjG677Tbdfffd5xw9A0QSwgVhZ+TIkdq4caMWLFjQIRP5LYmJidG0adOUmJioDz74QOvWraN7QUQhXBB2YmNj1b9/f919990d3rU0MsZoyJAhGj16tGpra7Vw4ULmXhBRCBeELWNM0MJFurB7Wb9+Pd0LIgbhArST87uXV155he4FEYNwAdpRTEyMpk6dqsTERG3evJmVY4gYhAvQjowxuuWWW+heEHEIF6CdxcTEaMqUKUpISNCmTZu4HTIiAuECtDNjjG699Vbl5OQ0rRyrra0NdllAuyJcgA7QuHIsISFBGzdupHtB2CNcgA7Q2L2MGjVKp0+fpntB2CNcgA4SGxurqVOnyuv1asOGDZyYjLBGuAAdxBijoUOHNnUvr7zyCt0LwhbhAnSg5t3L+vXrtWnTJroXhCXCBehAxhgNGzZMubm5On36tAoKCtj3grBEuAAdLDY2VtOnT29aOfbnP/+Z7gVhh3ABOljjyrHG+73k5+erpqYm2GUBriJcgCCIiYnRjBkzlJycrK1bt2rlypV0LwgrhAsQBMYYDRw4UPfee6/q6+v18ssvq6KiIthlAa4hXIAgiY6O1uOPP6709HTt3r1bf/rTn4JdEuAawgUIEmOM+vXrp0ceeUQNDQ2aP3++jh07FuyyAFcQLkAQeTweTZ06VT169FBxcbF+/etfM/eCsEC4AEFkjNHVV1+tqVOnKioqSlu3blVlZWWwywIuWXSwCwDcZK1VeXm5YmJigl3KNzJ27Fh5PB7l5eVp2bJlwS4HuGSEC8KGMUY9e/bUK6+8oqioqGCX860sXLhQfr9fXbp0CXYpwCUxlgFehAlrbdjMVxhjZIwJdhnAt0a4AABcx4Q+AMB1zLkAbdS8yWfICmgdnQvQRh999JGioqL00UcfBbsUIOQRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLkAbWGtVXl4uSSovLxd3BwdaR7gArfD5fCooKFB2drZGjhwpa61Gjhyp7OxsFRQUyOfzBbtEICQZy3/BgBatWbNG48ePV01NjaSWb3OckJCgoqIi5ebmBqVGIFQRLkAL1qxZo7y8PFlrFQgELvp1Ho9HxhitXLmSgAGaIVyA8/h8PnXv3l1+v7/VYGnk8Xjk9XpVWlqq1NTU9i8Q6ASYcwHOs2TJEtXU1LQpWCQpEAiopqZGS5cubefKgM6DzgVoxlqr7OxslZSUfKMVYcYYZWZmqri4uGk+BohkhAvQzFdffaWMjIxLen16erqLFQGdE8NiQDNVVVWX9PrKykqXKgE6N8IFaCYpKemSXp+cnOxSJUDnRrgAzaSnpysrK+sbz5sYY5SVlaW0tLR2qgzoXAgXoBljjKZNm/atXjt9+nQm84GzmNAHzsM+F+DS0bkA50lNTVVRUZGMMfJ4Wv8Radyh/8477xAsQDOEC9CC3NxcrVy5Ul6vV8aYC4a7Gp/zer1atWqVcnJyglQpEJoIF+AicnNzVVpaqvz8fGVmZp5zLTMzU/n5+SorKyNYgBYw5wK0gbVWGzZs0F133aV169ZpxIgRTN4DraBzAdrAGNM0p5KamkqwAF+DcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAb5GfX29ysrKtH//fknSkSNHdOrUKQUCgSBXBoQubnMMXITP51NRUZHefPNN7du3T5WVlaqrq1N8fLwyMjI0fPhwTZw4UUOHDlV0dHSwywVCCuECtGD79u2aMWOGPv74Yw0aNEh5eXnq37+/kpKS5PP5tGvXLq1YsUKHDx/W/fffr2effVYZGRnBLhsIGYQLcJ61a9dqwoQJSkpK0ty5czVmzBjV1dWpsLBQtbW1SklJ0QMPPKD6+noVFhZq1qxZ6tu3r5YtW6Zu3boFu3wgJBAuQDOHDh3S6NGjlZiYqMLCQvXp00fGGJWUlOjmm29WRUWFevXqpV27dqlr166y1mrLli166KGHdOedd+q1115TXFxcsL8NIOiY0AfOamho0PPPP6/y8nItXLiwKVhaY4zRsGHD9MILL2j58uVavXp1B1ULhDbCBTjr8OHDWrFihcaNG6dhw4Z9bbA0Msbonnvu0S233KLFixfrzJkz7VwpEPpY4gKctW3bNlVVVWn8+PH69NNPVV1d3XSttLRUDQ0NkqS6ujrt27dPKSkpTdevvPJKjRs3TrNmzdLx48fVvXv3Dq8fCCWEC3DWgQMHlJCQoMzMTE2aNElbt25tumatVW1trSTp6NGjGjVqVNM1Y4xeeukl3XDDDaqpqdHRo0cJF0Q8wgU4y+/3Kzo6WnFxcaqtrdXp06db/Dpr7QXXzpw5I6/Xe04IAZGMcAHOuvzyy+X3++Xz+TRkyBAlJiY2XfP7/dq2bVtTiNx2221NGyeNMerRo4e+/PJLeTwede3aNVjfAhAyCBfgrAEDBqi+vl47duzQvHnzzrlWUlKiQYMGqaKiQt26ddNbb72l1NTUpuvGGD311FO64oorGBIDxGoxoMngwYOVmZmpJUuWqLq6WlFRUef8aWSMkcfjaXre4/Ho2LFjevvtt5WXl6cuXboE8bsAQgPhApyVnp6uqVOnavfu3VqwYEGblxTX1tZqzpw58vv9mjRpUpuXMAPhjGExoJkJEyZo8+bNmjdvnhISEjR58mTFx8dLkqKjoxUdHd3UxVhrVVlZqeeee06FhYWaP3++rr322mCWD4QMjn8BznPixAlNmTJF77//vnJzczVjxgxdf/31OnjwoAKBgGJjY3XNNddox44devHFF7Vnzx7Nnj1bkydPPmf4DIhkhAvQgurqai1evFgLFizQF198oczMTGVnZys5OVnl5eU6ePCgjh49qgEDBuiZZ57RHXfcIY+HUWagEeECtOL48eNat26dNm3apJK9e3V6xw51HT5c/YYOVU5OjoYMGaKEhIRglwmEHMIFaKOGnTtlBw+WZ+dOeQYODHY5QEhjQh9oo6ioKMkYieEv4GvxUwIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHfdzAdrKWikQcI7cNybY1QAhjc4F+Ca4lwvQJtwsDGHDWqsTxcWqPXky2KVcEuPx6LJ+/RSfmBjsUoBvjXBB2LDW6v/Nn6+4q65SnNcr+f1SUlKwy2qb+nrnj9erE1u2qP/TT+s7/fsHuyrgWyNcEFZMXJz63HSTuixa5Pyy/tOfQj9grJWWL5dmz5YdPVo7+vZ1ngM6MQaQEX6sldavl7Ztkz79NNjVtM369dLevdKhQ1JUVLCrAS4Z4YLwc801Us+eUlWVtHNn6HcBfr/04YfO42HDCBeEBcIF4Sc5WRo40AmVTZtCP1w++0w6fFiKj5duvTXY1QCuIFwQfjwe6Y47nL0oO3dK//hHsCu6OGulv/7VqbFHD6l372BXBLiCcEF4GjzY6WA++0wqLg52NRdnrbRhg/P3wIFSly7BrghwBeGC8HT11c7ci98vbd0aukNjlZXSjh1OlzViBDv/ETYIF4SnhARp6FDn8caNUkNDUMu5qIMHnRVtycnSkCHBrgZwDeGC8DVihLPyavdu6csvg13NhayVPvjA6a5695Z69Qp2RYBrCBeEJ2OkAQOkjAzpiy+kjz4KvaGxM2ec/S2SswTZ6w1uPYCLCBeEr27dpO9979xf4qHk+HEn9KKjpbvuCnY1gKsIF4Sv5r+0N22SamqCW09z1joT+V99JV1xhXTTTUzmI6wQLghfxjj7XRITnWNVDh0KdkX/x1rpf/7HWWgweLAzfAeEEcIF4e3aa6XsbKm6OrR26//jH85kvjHSqFEc+YKwQ7ggvCUmSnfe6Txes8Y5KTkUfPKJ9L//62yaHD6cITGEHcIF4S83V4qJcZYkf/55sKv5vyGx2lrphhtYgoywRLggvBkj3Xyzc27XyZPSli3BHxrz+51wkaScHCkuLrj1AO2AcEH4S0uTbr/dCZWVK4O/W//AAWnfPucUgVGjGBJDWCJcEP6MkfLynEnzbdukY8eCV4u1ztxPdbXUp4/zBwhDhAvCnzHOfVK++11nt/4HHwRvaOz0aWnVKudxbq7TvQBhiHBBZMjIcIbGAgHpvfeCNzS2f7/08cfOUS9jxgSnBqADEC6IDB6PdM89ztDYli1SaWnH19A451NVJfXt66wUY74FYYpwQWQwxjmCv2dP54TkP/+544fGqqulFSucxz/4AUNiCGuECyLHZZc5q7OslYqKOnZDpbXOPpt9+5x7t/zTP9G1IKwRLogcxkjjxzv7Sv7yF+dGXR3FWuntt50J/UGDpOuu67jPBoKAcEHkMMb5xd6nj1RR4Uzsd9TQ2IkTzioxY6Qf/pCNkwh7hAsiS3Ky071IztBYRUX7f6a1zhzPZ59J3/mOdPfdDIkh7BEuiCzGOKvG0tKcnfIdcVJyXZ305pvOMujcXGe/DRDmCBdEnuxs6fvfdyb0ly1z7lTZXqyV9u6Vtm6V4uOlf/5nZ1k0EOb4V47IExUl/fjHUmysc/vjffvar3ux1gmwqipp4EBpyBCGxBARCBdEHmOce6jceKMz57J0afuFy2efOQsHPB4n0Lze9vkcIMQQLohMSUnST37i/NL/r/9yQsBt1kp/+INzUGZWlrNxkq4FEYJwQWRqnNjPypKOHnWGrtzuXo4fl5YscR7/6EfO+WZAhCBcELkuv9zpXoyR3njD3fPGGruWkhKpe3dnIp+uBRGEcEHkMkZ6+GHnNsN//7v02986y4XdcPSo9OtfOyHz4x9LV13lzvsCnQThgsh25ZXSpEnO49dek4qLL314LBCQfvMbp2u5+mpp4kSWHyPi8C8ekc0YZz6kb19n4v3FFy9t34u1ztLm3/zGee/HHqNrQUQiXICMDOlf/kWKiZHeeuvSjuM/fVqaM8c51v9735MmTGCuBRGJcAGMke691znzq7paevpp53bI3zRgrJUKC517tni9zvukpbVPzUCII1wAyTma5d//3TlYcs8e53FdXdtf33jMyzPPOK97+GFp9Gi6FkQswgWQnBC44QZp1ixneOyNN6RFi6SGhq9/rbXO6rApU6SyMmc47N/+zXkfIEIRLkAjY6RHHnFWd9XXOwGxdGnrAWOtM4Q2aZL04YdSt27SggVOB0TXgghGuADNxcZKzz7r7N6vrpYef1x66SXn4Mnz52AaV4Y99JD03/8tdekiFRRIt95KsCDiRQe7ACCkGOOExH/+p7M35d13pV/+0llB9thjzsnG8fHO0S7vvutslDx2TLrsMidYxo9nTwsgwgW4kDFOWLz2mtS7t/Qf/yGtWydt3Ois/oqLk3w+p5sxRrrpJmd/zO23EyzAWYQL0BJjpJQUZ9XYmDHSq686d6386itnB358vDNx/+CDzvEul13GUBjQDOGCsGKtVXV5uTxurtS67jpp/nxnY2RpqbNR8rLLnJ33iYlOqJSXu/Zx9adPu/ZeQLAQLggbxhgl9eypg6+8IhMVFexyvjXr9yu2S5dglwFcEmNte92CD+hY1lqFyz9nY4wMw2zoxAgXAIDrWNoCAHAdcy5AWzVv8hmyAlpF5wK01UcfSVFRzt8AWkW4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AG1grVV5ebms5PzN3cGBVhEuQCt8Pp8KCgqUnZ2tu0aOlLVWd40cqezsbBUUFMjn8wW7RCAkGct/wYAWrVmzRuPHj1dNTY0k6XvW6q+SBkrac/Y2xwkJCSoqKlJubm7wCgVCEJ0L0II1a9YoLy9Pfr9f1toLhsEan/P7/crLy9OaNWuCVCkQmuhcgPP4fD51795dfr9fgUCg6fmbpKbO5aNmX+/xeOT1elVaWqrU1NSOLRYIUXQuwHmWLFmimpqac4KlNYFAQDU1NVq6dGk7VwZ0HnQuQDPWWmVnZ6ukpOSCobCLdS6SZIxRZmamiouLZc7OxwCRjM4FaObkyZM6cuTIN15qbK3VkSNHdOrUqXaqDOhcCBegmaqqqkt6fWVlpUuVAJ0b4QI0k5SUdNFrB+QMiR1o5fXJyclulwR0SoQL0Ex6erqysrJanDfxy5lr8bfwOmOMsrKylJaW1t4lAp0C4QI0Y4zRtGnTvtVrp0+fzmQ+cBarxYDzXGyfy8WwzwW4EJ0LcJ7U1FQVFRXJGCOPp/UfEY/HI2OM3nnnHYIFaIZwAVqQm5urlStXyuv1yhhzwXBX43Ner1erVq1STk5OkCoFQhPhAlxEbm6uSktLlZ+fr8zMzHOuZWZmKj8/X2VlZQQL0ALmXIA2sNbq1KlTqqysVHJystLS0pi8B1pBuAAAXMewGADAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1/x+on2d72IR9zAAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 5 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.get_act(dataset2)\n",
|
|
"model2.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 104,
|
|
"id": "096134b0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 3.77e-04 | test_loss: 3.76e-04 | reg: 3.35e+00 | : 100%|█| 50/50 [00:46<00:00, 1.07it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.2\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.fit(dataset2, steps=50);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 105,
|
|
"id": "035e00f2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 3.73e-04 | test_loss: 3.72e-04 | reg: 3.35e+00 | : 100%|█| 50/50 [00:13<00:00, 3.81it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.3\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.fit(dataset2, steps=50, update_grid=False);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 106,
|
|
"id": "b65775e8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAex0lEQVR4nO3de3CU9b3H8c+zWUg2JBAIAUTEsjEqFzkqVwWqFCWUeKYIY/FoFRwckSJWxvZ4vIMXLCpjgvbM6QnaAuMxWGIPRaixpRYKqKFcOyiXELUmXETIYi6b6/7OH0+SEyDESJ7wbHbfrxmGTJaFb2YIb36/52YZY4wAAHCQx+0BAACRh7gAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA4r9sDAB2BMUYnTpxQWVmZEhISlJycLMuy3B4LCFusXIAWBAIBZWVlKS0tTSkpKRowYIBSUlKUlpamrKwsBQIBt0cEwpLFkyiB5uXl5WnatGmqqKiQZK9eGjSsWuLj45Wbm6v09HRXZgTCFXEBmpGXl6eMjAwZYxQKhc756zwejyzL0rp16wgM0ARxAc4QCATUr18/BYPBFsPSwOPxyOfzqaioSElJSe0/INABcMwFOMPy5ctVUVHRqrBIUigUUkVFhVasWNHOkwEdBysXoAljjNLS0lRYWKjv8q1hWZb8fr8OHjzIWWSAiAtwmq+//lopKSlten9ycrKDEwEdE9tiQBNlZWVten9paalDkwAdG3EBmkhISGjT+xMTEx2aBOjYiAvQRHJyslJTU7/zcRPLspSamqoePXq002RAx0JcgCYsy9K8efPO670PPvggB/OBehzQB87AdS5A27FyAc6QlJSk3NxcWZYlj6flb5GGK/TfeecdwgI0QVyAZqSnp2vdunXy+XyyLOus7a6Gz/l8Pq1fv14TJ050aVIgPBEX4BzS09NVVFSkzMxM+f3+017z+/3KzMxUcXExYQGawTEXoBWMMfrggw80YcIEbdiwQePHj+fgPdACVi5AK1iW1XhMJSkpibAA34K4AAAcR1wAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AAAcR1wAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AN+ipqZGxcXF+vTTTyVJhw4d0smTJxUKhVyeDAhfPOYYOIdAIKDc3Fy9+eab2rt3r0pLS1VdXa24uDilpKRo3LhxmjVrlsaMGSOv1+v2uEBYIS5AMz788EPNnz9fe/bs0YgRI5SRkaGhQ4cqISFBgUBA27dv19q1a1VQUKDp06frueeeU0pKittjA2GDuABneP/99zVz5kwlJCTohRde0OTJk1VdXa2cnBxVVVWpa9euuv3221VTU6OcnBwtWLBAgwcP1sqVK9W7d2+3xwfCAnEBmjhw4IAmTZqkLl26KCcnR4MGDZJlWSosLNS1116rU6dOacCAAdq+fbu6d+8uY4w2b96sO+64QzfeeKOWLVum2NhYt78MwHUc0Afq1dXVadGiRSopKdFrr73WGJaWWJalsWPH6sUXX9SaNWv03nvvXaBpgfBGXIB6BQUFWrt2raZOnaqxY8d+a1gaWJalKVOmaPTo0crOzlZtbW07TwqEP05xAept3bpVZWVlmjZtmj7//HOVl5c3vlZUVKS6ujpJUnV1tfbu3auuXbs2vt63b19NnTpVCxYs0NGjR9WvX78LPj8QTogLUG/fvn2Kj4+X3+/X7NmztWXLlsbXjDGqqqqSJB0+fFg333xz42uWZWnJkiW66qqrVFFRocOHDxMXRD3iAtQLBoPyer2KjY1VVVWVKisrm/11xpizXqutrZXP5zstQkA0Iy5AvV69eikYDCoQCGjUqFHq0qVL42vBYFBbt25tjMj111/feOGkZVnq37+/vvrqK3k8HnXv3t2tLwEIG8QFqDds2DDV1NQoPz9fixcvPu21wsJCjRgxQqdOnVLv3r21atUqJSUlNb5uWZYee+wx9enThy0xQJwtBjQaOXKk/H6/li9frvLycsXExJz2o4FlWfJ4PI2f93g8OnLkiFavXq2MjAx169bNxa8CCA/EBaiXnJysBx54QDt27NDSpUtbfUpxVVWVnn32WQWDQc2ePbvVpzADkYxtMaCJmTNnatOmTVq8eLHi4+M1Z84cxcXFSZK8Xq+8Xm/jKsYYo9LSUj3//PPKycnRK6+8oiuuuMLN8YGwwe1fgDMcP35cc+fO1bvvvqv09HTNnz9fAwcO1P79+xUKhdS5c2dddtllys/P18svv6xdu3bpmWee0Zw5c07bPgOiGXEBmlFeXq7s7GwtXbpUx44dk9/vV1pamhITE1VSUqL9+/fr8OHDGjZsmJ5++mndcMMN8njYZQYaEBegBUePHtWGDRu0ceNG7d69W/n5+Ro3bpzGjBmjiRMnatSoUYqPj3d7TCDsEBeglbZt26aRI0dq27ZtGj58uNvjAGGNdTzQSjExMY2nIQNoGd8lAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jue5AK1kjFEoFJLH45FlWW6PA4Q1Vi7Ad8CzXIDW8bo9AOAUY4wOHjyoEydOuD1Km3g8Hg0ZMkRdunRxexTgvLEthogRCoU0d+5cXXLJJUpISHB7nO+kpqZG5eXl6tatmzZv3qwnn3xSQ4cOdXss4LyxckFEiY2N1axZs9S7d2+3R2k1Y4z+9Kc/6eGHH9aPf/xjDRw4UPyfDx0dG8iAy+rq6vTb3/5Wn3zyiXbs2KGYmBi3RwLajLgALjLGaN++fXr//ffVqVMn3X333erUqZPbYwFtRlwAFxljtHLlSpWUlGjIkCH6wQ9+4PZIgCOIC+CiL7/8Ujk5ObIsS/fcc4+6du3q9kiAI4gL4JKGVUtxcbFSU1M1bdo0t0cCHENcAJcUFRXpN7/5jSTpnnvu6VBnuAHfhrgALgiFQlq2bJm++OILDRgwQHfddRe3lEFEIS7ABdZwJ4HXX39dknT//ferb9++Lk8FOIu4ABdYbW2tlixZoiNHjmjIkCG6++67WbUg4hAX4AIyxmjjxo1atWqVOnXqpF/84hfq2bOn22MBjiMuwAVUUlKiBQsWqKysTBMnTtStt97KqgURibgAF0hdXZ1effVVffzxx0pOTtZTTz0ln8/n9lhAuyAuwAVgjNGmTZu0dOlSSdJDDz2ka6+9llULIhZxAdqZMUZHjhzRz3/+cwUCAY0fP15z587lwWOIaPztBtpZMBjUf/zHf2jXrl26+OKL9dJLL3GbF0Q84gK0o4bjLKtWrZLP59OiRYs0dOhQtsMQ8YgL0E6MMcrNzdWiRYsan5I5ffp0woKoQFyAdtBwAP+hhx5SWVmZfvSjH+nxxx/nWS2IGsQFcJgxRnv27NF9992nY8eOafTo0crKylJiYqLbowEXDHEBHGSMUUFBge655x4VFBToyiuvVHZ2tvr27ct2GKIKcQEcYozR559/rhkzZmj37t3q37+/Xn/9dQ0cOJCwIOoQF8ABxhj985//1IwZM/Txxx+rT58+ys7O1qhRowgLohJxAdrIGKPPPvtMd911l7Zs2aJevXopOztbEyZMICyIWsQFaANjjA4cOKA777zztLBMmjSJsCCqERfgPBljtGvXLk2fPl35+fm66KKL9MYbb2jy5Mnc2gVRj+8A4DyEQiH95S9/0W233aZ//OMf+t73vqeVK1dq0qRJhAUQcQG+s9raWuXk5OjOO+/UZ599psGDBysnJ0c33ngjW2FAPeICtJIxRhUVFXrxxRd1//336/jx4xo7dqx+97vfafjw4YQFaMLr9gBAR2CM0VdffaVHHnlEb731lowxuu2225SZmanevXsTFuAMxAX4Fg0H7ufNm6ePPvpIsbGx+tnPfqZHH31UCQkJhAVoBnEBzsEYo9raWq1evVqPPPKIiouL1atXL/3yl7/UHXfcwU0ogRYQF6AZxhidPHlSixYt0q9//WsFg0Fdc801evXVVzVq1CjOCAO+BXEBzhAKhbRz5049/PDD2rx5s7xer+6880698MIL3IASaCXiAtQzxqiyslIrVqzQwoULdezYMSUnJ+vJJ5/Uvffeq7i4OMICtBJxAWSHpbCwUE888YR+//vfq7a2ViNHjtSSJUs0evRotsGA74jvGEQ1Y4yqq6v19ttv64c//KHefvtteb1ePfDAA1qzZo2uu+46wgKcB1YuiFrGGH3xxRdauHChVq1apaqqKl1++eVatGiRbrnlFnm9XrbBgPNEXBB1jDGqqqpSbm6uFi5cqEOHDik2NlYzZszQ008/rf79+xMVoI2IC6KKMUb79u3TggUL9Ic//EHV1dVKS0vTwoULNWXKFHXu3JmwAA4gLogKxhh98803euONN7RkyRIdOXJEPp9P9957rx5//HFdcsklRAVwEHFBRDPGqK6uTps2bdKCBQv04Ycfyhijq6++WgsXLlR6ejrHVoB2QFwQsRpOL37ppZf01ltvqaysTD169NCcOXM0b9489ezZk6gA7YS4IOI03Lpl+fLlysrK0pdffqlOnTrplltu0VNPPaVrrrmG04uBdkZcEHE++ugjLVq0SDt37pQxRoMGDdKjjz6qW2+9lavsgQuEuCDieL1eFRYWqmfPnrrvvvs0Z84cnrkCXGDEBRHFGCO/369nn31WQ4YM0cCBA+XxeFRSUuL2aK1WWVnp9ghAmxEXRAzLsnTppZfqV7/6lWJiYvTpp5+6PdJ5CQaD6tatm9tjAG1iGWOM20MATjDGKFL+OluWxTYeOjTiAgBwHOdjAgAcxzEXoJWaLvLZsgJaxsoFaKWdO3cqJiZGO3fudHsUIOwRFwCA44gLAMBxxAUA4DjiAgBwHHEBADiOuAAAHEdcAACOIy4AAMcRFwCA44gLAMBxxAUA4DjiAgBwHHEBADiOuAAAHEdcAACOIy5AKxhjVFJSIkkqKSkRTwcHWkZcgBYEAgFlZWUpLS1NN910k4wxuummm5SWlqasrCwFAgG3RwTCkmX4LxjQrLy8PE2bNk0VFRWSmn/McXx8vHJzc5Wenu7KjEC4Ii5AM/Ly8pSRkSFjjEKh0Dl/ncfjkWVZWrduHYEBmiAuwBkCgYD69eunYDDYYlgaeDwe+Xw+FRUVKSkpqf0HBDoAjrkAZ1i+fLkqKipaFRZJCoVCqqio0IoVK9p5MqDjYOUCNGGMUVpamgoLC7/TGWGWZcnv9+vgwYONx2OAaEZcgCa+/vprpaSktOn9ycnJDk4EdExsiwFNlJWVten9paWlDk0CdGzEBWgiISGhTe9PTEx0aBKgYyMuQBPJyclKTU39zsdNLMtSamqqevTo0U6TAR0LcQGasCxL8+bNO6/3PvjggxzMB+pxQB84A9e5AG3HygU4Q1JSknJzc2VZljyelr9FGq7Qf+eddwgL0ARxAZqRnp6udevWyefzybKss7a7Gj7n8/m0fv16TZw40aVJgfBEXIBzSE9PV1FRkTIzM+X3+097ze/3KzMzU8XFxYQFaAbHXIBWMMbogw8+0IQJE7RhwwaNHz+eg/dAC1i5AK1gWVbjMZWkpCTCAnwL4gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gJ8i5qaGhUXF+vTTz+VJB06dEgnT55UKBRyeTIgfPGYY+AcAoGAcnNz9eabb2rv3r0qLS1VdXW14uLilJKSonHjxmnWrFkaM2aMvF6v2+MCYYW4AM348MMPNX/+fO3Zs0cjRoxQRkaGhg4dqoSEBAUCAW3fvl1r165VQUGBpk+frueee04pKSlujw2EDeICnOH999/XzJkzlZCQoBdeeEGTJ09WdXW1cnJyVFVVpa5du+r2229XTU2NcnJytGDBAg0ePFgrV65U79693R4fCAvEBWjiwIEDmjRpkrp06aKcnBwNGjRIlmWpsLBQ1157rU6dOqUBAwZo+/bt6t69u4wx2rx5s+644w7deOONWrZsmWJjY93+MgDXcUAfqFdXV6dFixappKREr732WmNYWmJZlsaOHasXX3xRa9as0XvvvXeBpgXCG3EB6hUUFGjt2rWaOnWqxo4d+61haWBZlqZMmaLRo0crOztbtbW17TwpEP44xQWot3XrVpWVlWnatGn6/PPPVV5e3vhaUVGR6urqJEnV1dXau3evunbt2vh63759NXXqVC1YsEBHjx5Vv379Lvj8QDghLkC9ffv2KT4+Xn6/X7Nnz9aWLVsaXzPGqKqqSpJ0+PBh3XzzzY2vWZalJUuW6KqrrlJFRYUOHz5MXBD1iAtQLxgMyuv1KjY2VlVVVaqsrGz21xljznqttrZWPp/vtAgB0Yy4APV69eqlYDCoQCCgUaNGqUuXLo2vBYNBbd26tTEi119/feOFk5ZlqX///vrqq6/k8XjUvXt3t74EIGwQF6DesGHDVFNTo/z8fC1evPi01woLCzVixAidOnVKvXv31qpVq5SUlNT4umVZeuyxx9SnTx+2xABxthjQaOTIkfL7/Vq+fLnKy8sVExNz2o8GlmXJ4/E0ft7j8ejIkSNavXq1MjIy1K1bNxe/CiA8EBegXnJysh544AHt2LFDS5cubfUpxVVVVXr22WcVDAY1e/bsVp/CDEQytsWAJmbOnKlNmzZp8eLFio+P15w5cxQXFydJ8nq98nq9jasYY4xKS0v1/PPPKycnR6+88oquuOIKN8cHwga3fwHOcPz4cc2dO1fvvvuu0tPTNX/+fA0cOFD79+9XKBRS586dddlllyk/P18vv/yydu3apWeeeUZz5sw5bfsMiGbEBWhGeXm5srOztXTpUh07dkx+v19paWlKTExUSUmJ9u/fr8OHD2vYsGF6+umndcMNN8jjYZcZaEBcgBYcPXpUGzZs0MaNG1W4e7cq8/PVfdw4DRkzRhMnTtSoUaMUHx/v9phA2CEuQCvVbdsmM3KkPNu2yTN8uNvjAGGNA/pAK8XExEiWJbH9BXwrvksAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AAAcR1wAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAcz3MBWssYKRSyb7lvWW5PA4Q1Vi7Ad8GzXIBW4WFhiBjGGB0/eFBVJ064PUqbWB6Peg4ZorguXdweBThvxAURwxijT155RbGXXKLY+HipslLqKM+3r6mxf/h8Or55s4Y++aQuGjrU7amA80ZcEFGs2FgNGjZM3ZYtk6qqpLfeksJ9BWCMtGaN9MwzMpMmKX/wYPtzQAfGBjIiT22t9Mc/Sps2SZ9/7vY0rfOXv0i7d0sHDkgxMW5PA7QZcUHkuewy6dJLpbIy6e9/D/9VQDAoffyx/fHYscQFEYG4IPJ07SoNG2ZHZePG8I/LF19IBw9KcXHSdde5PQ3gCOKCyOPxSDfeaF+Lkp8vlZa6PdG5GSNt2yZ9843Uv790+eVuTwQ4grggMo0cKSUk/P+qIFwZI/31r/bPw4dL3bq5PRHgCOKCyDRggH3spaJC2ro1fLfGvvnGPt5iWdL48Vz5j4hBXBCZ4uOlMWPsjz/4QKqrc3eeczlwwF5dJSZKo0a5PQ3gGOKCyDV+vH3m1Y4d0vHjbk9zNmOkv/3NPlvs8svt1RYQIYgLIpNl2WeM9ewpHT0q7doVfltjtbX29S2SNG6c5PO5Ow/gIOKCyNWnj3T11fY/4hs2uD3N2Y4ckXbulLxee5UFRBDigsjl9UoTJtgfb9xoH9wPFw2nIH/9tR3Ba6/lYD4iCnFB5LIs6YYb7IP7Bw5IBQVuT/T/jJHef98+0WDECCklxe2JAEcRF0S2K66Q0tLsW8E0XE8SDr75Rtq82Q7gxInc8gURh7ggsiUk2FfrS/ZKobbW1XEa/eMf0mef2RdNjhvHlhgiDnFB5EtPlzp1sm9i+eWXbk9jr57+9Cf7kQBXXcUpyIhIxAWRzbLsg+WXXCKdOGFvRbm9NVZRYa+iJDt8sbHuzgO0A+KCyJecbG89GSO9+677V+vv3y998ol9osHNN7MlhohEXBD5LEu65Rb7oPnWrfb1JW4xRsrLk8rLpUGDpIED3ZsFaEfEBZHPsuznpPTtKx075u7WWGWltH69/XF6ur16ASIQcUF06NVL+v73pVBI+t//dW9r7NNPpT177Fu9TJ7szgzABUBcEB08HmnKFHtrbPNmqbj4ws9gjLRunX3NzeDB9pliHG9BhCIuiA6WZd+Cv39/e2tsw4YLvzVWXi6tXWt/fMstbIkhohEXRI+UFPvsLGOk1aulmpoL92cbY9/6f+9e+9kt//qvrFoQ0YgLoodlSdOmSZ07Sx99ZN9v7EJpCFplpX0vsSuvvHB/NuAC4oLoYVnSyJH2KcCnTtkH9i/U1tjx4/ZZYpYl3XYbF04i4hEXRJfERHv1Ikm5uXZk2psx0p//bD/O+KKLpEmT2BJDxCMuiC6WJd16q9Sjh31a8N/+1v6rl+pq6c037dOg09Oliy9u3z8PCAPEBdHnssvsJz/W1EgrV7bvnZKNkXbvlrZskeLipDvvtE+LBiIcf8sRfbxe6e677Tsl//nP9n2+2mv1YowdsLIyafhwadQotsQQFYgLok/DEyr/5V/sYy4rVrRfXL74Qvr97+3VyowZ9pX5QBQgLohOCQnSzJn2P/q/+530z386/2cYI/3P/0hHj0qpqfaFk6xaECWIC6KTZUlTp9r/6B8+bG9dOb16OXpUWr7c/vjuu+2LOIEoQVwQvXr1slcvkvTb30pFRc793g2rlsJCqV8/+0A+qxZEEeKC6GVZ0k9+In3ve/axkTfesE8XdsLhw9Kvf21HZsYM+0mYQBQhLohuF18szZ5tf7xsmVRQ0PbtsVDIDkthoR2uWbM4/RhRh7/xiG6WZa8sBg+2Vxsvv9y2616MsW9OmZ1t/94//SmrFkQl4gKkpEj//u/2dS85OW27HX9lpfTss9JXX0lXX20f0+FYC6IQcQEabgkzaZL9zJUnnrDj8F0DY4wdp7Vr7etZnnzSvs0MEIWICyDZMVi40L6x5K5d9sfV1a1/f8NtXp5+2n7fT37CDSoR1YgLINkRGDrUjkOnTtJvfmMflK+r+/b3GmMfr5k713588jXXSE89Zf8+QJQiLkADy7Ivdpw1y76p5VNP2RdXthQYY+yLJWfPth9A1ru3lJVlr4BYtSCKERegqc6d7QPyP/qRfbPJn/1MWrLE/vjMYzANZ4bdcYf0xz9K3brZYbnuOsKCqOd1ewAgrFiWlJQk/dd/STEx9k0nn3jCvnvyT39q39k4Lk46csR+7b//2/64Z087LNOmcU0LIOICnM2y7FgsWyalpUn/+Z/26cl//at99ldsrBQI2KsZy7KPsbz8svT97xMWoB5xAZpjWVLXrtIzz0gZGdKvfiVt3CidOGFfgR8XZ1/H8m//Zl+E2bMnW2FAE8QFEcUYo/KSEnmcPFPryiulzEz72peiIvtCyZ497Svvu3Sxo1JS4tgfV1NZ6djvBbiFuCBiWJalhEsv1f5XX5UVE+P2OOfNBIPq3K2b22MAbWIZ016P4AMuLGOMIuWvs2VZsthmQwdGXAAAjuPUFgCA4zjmArRW00U+W1ZAi1i5AK21c6d9YeXOnW5PAoQ94gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFaAVjjEpKSmQk+2eeDg60iLgALQgEAsrKylJaWpom3HSTjDGacNNNSktLU1ZWlgKBgNsjAmHJMvwXDGhWXl6epk2bpoqKCknS1cbo75KGS9pV/5jj+Ph45ebmKj093b1BgTDEygVoRl5enjIyMhQMBmWMOWsbrOFzwWBQGRkZysvLc2lSIDyxcgHOEAgE1K9fPwWDQYVCocbPXyM1rlx2Nvn1Ho9HPp9PRUVFSkpKurDDAmGKlQtwhuXLl6uiouK0sLQkFAqpoqJCK1asaOfJgI6DlQvQhDFGaWlpKiwsPGsr7FwrF0myLEt+v18HDx6UVX88BohmrFyAJk6cOKFDhw5951ONjTE6dOiQTp482U6TAR0LcQGaKCsra9P7S0tLHZoE6NiIC9BEQkLCOV/bJ3tLbF8L709MTHR6JKBDIi5AE8nJyUpNTW32uElQ9rGWYDPvsyxLqamp6tGjR3uPCHQIxAVowrIszZs377ze++CDD3IwH6jH2WLAGc51ncu5cJ0LcDZWLsAZkpKSlJubK8uy5PG0/C3i8XhkWZbeeecdwgI0QVyAZqSnp2vdunXy+XyyLOus7a6Gz/l8Pq1fv14TJ050aVIgPBEX4BzS09NVVFSkzMxM+f3+017z+/3KzMxUcXExYQGawTEXoBWMMTp58qRKS0uVmJioHj16cPAeaAFxAQA4jm0xAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AAAcR1wAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AAAc939v8OqzRcmF5gAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 5 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 107,
|
|
"id": "6cd26af5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 x^0.5 0.999957 -14.193489 2 2 -1.238698\n",
|
|
"1 sqrt 0.999957 -14.193489 2 2 -1.238698\n",
|
|
"2 log 0.999722 -11.763921 2 2 -0.752784\n",
|
|
"3 1/x^0.5 0.999485 -10.894391 2 2 -0.578878\n",
|
|
"4 1/sqrt(x) 0.999485 -10.894391 2 2 -0.578878\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('x^0.5',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 2,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.9999566254728288,\n",
|
|
" 2)"
|
|
]
|
|
},
|
|
"execution_count": 107,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.suggest_symbolic(1,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 72,
|
|
"id": "daabd91a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Best value at boundary.\n",
|
|
"r2 is 0.9989821969546337\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([-9.8000, 9.8868, -0.3482, 1.2049]), tensor(0.9990))"
|
|
]
|
|
},
|
|
"execution_count": 72,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.utils import fit_params\n",
|
|
"fit_params(x**2, y, lambda x: x**(1/2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"id": "8dc20f20",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"skipping (0,0,0) since already symbolic\n",
|
|
"fixing (1,0,0) with x^0.5, r2=0.9999494098870415, c=2\n",
|
|
"saving model version 0.4\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.auto_symbolic()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"id": "f27f3daa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle 1.19 - 1.08 \\sqrt{1 - 1.0 x_{1}^{2}}$"
|
|
],
|
|
"text/plain": [
|
|
"1.19 - 1.08*sqrt(1 - 1.0*x_1**2)"
|
|
]
|
|
},
|
|
"execution_count": 79,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.utils import ex_round\n",
|
|
"ex_round(model2.symbolic_formula()[0][0], 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a61540c1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|