624 lines
82 KiB
Plaintext
624 lines
82 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 94,
|
||
|
"id": "66865edb",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from kan import *\n",
|
||
|
"from kan.utils import batch_jacobian, batch_hessian\n",
|
||
|
"\n",
|
||
|
"torch.use_deterministic_algorithms(True)\n",
|
||
|
"torch.set_default_dtype(torch.float64)\n",
|
||
|
"\n",
|
||
|
"seed = 0\n",
|
||
|
"torch.manual_seed(seed)\n",
|
||
|
"\n",
|
||
|
"#example = 'harmonic_oscillator'\n",
|
||
|
"#example = 'single_pendulum'\n",
|
||
|
"example = 'relativistic_mass'\n",
|
||
|
"\n",
|
||
|
"# three examples: harmonic oscillator, single pendulum, double pendulum\n",
|
||
|
"\n",
|
||
|
"# dimension of q\n",
|
||
|
"# Lagrangian: (q, qd) -> (qdd)\n",
|
||
|
"\n",
|
||
|
"if example == 'harmonic_oscillator':\n",
|
||
|
" n_sample = 1000\n",
|
||
|
" # harmonic oscillator\n",
|
||
|
" d = 1\n",
|
||
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
||
|
" qd = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
||
|
" qdd = - q\n",
|
||
|
" x = torch.cat([q, qd], dim=1)\n",
|
||
|
" \n",
|
||
|
"if example == 'single_pendulum':\n",
|
||
|
" n_sample = 1000\n",
|
||
|
" # harmonic oscillator\n",
|
||
|
" d = 1\n",
|
||
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
||
|
" qd = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
||
|
" qdd = - torch.sin(q)\n",
|
||
|
" x = torch.cat([q, qd], dim=1)\n",
|
||
|
" \n",
|
||
|
"if example == 'relativistic_mass':\n",
|
||
|
" n_sample = 10000\n",
|
||
|
" # harmonic oscillator\n",
|
||
|
" d = 1\n",
|
||
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
||
|
" #qd = torch.rand(size=(n_sample,1)) * 1.998 - 0.999\n",
|
||
|
" #qd = 0.95 + torch.rand(size=(n_sample,1)) * 0.05\n",
|
||
|
" qd = torch.rand(size=(n_sample,1)) * 2 - 1\n",
|
||
|
" qdd = (1 - qd**2)**(3/2)\n",
|
||
|
" x = torch.cat([q, qd], dim=1)\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 95,
|
||
|
"id": "ec549451",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.1\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from kan.compiler import kanpiler\n",
|
||
|
"from sympy import *\n",
|
||
|
"\n",
|
||
|
"input_variables = symbol_x, symbol_vx = symbols('x v_x')\n",
|
||
|
"expr = symbol_vx ** 2\n",
|
||
|
"model = kanpiler(input_variables, expr, grid=20)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 96,
|
||
|
"id": "f9930812",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASgklEQVR4nO3de3BU5f3H8c/ZEGBjUkJiuCiKSQgoNyWigKZ2LMiK6c3iOPVHrYzgWFrBf0r/AMc6ttqp8kcyeMGOMw7gwLQ2jhTCNN6LhNQb4A2omFSYIBBCWEhMyGX3+f3xsCThEsCc3XM2vF8zTGJ2N/sNzpMP3/NcjmOMMQIAwEUBrwsAAPQ9hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHX9vC4ASAbGGB0+fFhNTU1KT09Xdna2HMfxuizAt+hcgB6Ew2GVlpaqoKBAOTk5ys3NVU5OjgoKClRaWqpwOOx1iYAvOdwsDDiziooKzZ49W83NzZJs9xIT61rS0tJUVlamUCjkSY2AXxEuwBlUVFSouLhYxhhFo9GzPi8QCMhxHJWXlxMwQBeEC3CKcDisESNGqKWlpcdgiQkEAgoGg6qtrVVmZmb8CwSSAHMuwClWrlyp5ubm8woWSYpGo2pubtaqVaviXBmQPOhcgC6MMSooKFBNTY0uZGg4jqO8vDzt3r2bVWSACBegm/r6euXk5PTq9dnZ2S5WBCQnLosBXTQ1NfXq9Y2NjS5VAiQ3wgXoIj09vVevz8jIcKkSILkRLkAX2dnZys/Pv+B5E8dxlJ+fr6ysrDhVBiQXwgXownEcLVy48Du9dtGiRUzmAycwoQ+cgn0uQO/RuQCnyMzMVFlZmRzHUSDQ8xCJ7dB/9dVXCRagC8IFOINQKKTy8nIFg0E5jnPa5a7Y14LBoDZu3KiZM2d6VCngT4QLcBahUEi1tbUqKSlRXl5et8fy8vJUUlKiffv2ESzAGTDnApwHY4zeeecdTZ8+XW+99ZZuvfVWJu+BHtC5AOfBcZyTcyqZmZkEC3AOhAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWEC3AObW1tqqmp0ccffyxJ+vzzz3Xo0CFxtwrg7LifC3AWx48f17p167RixQp98sknam9vl+M4MsYoMzNTM2fO1KJFizRx4kSO4AdOQecCnEFdXZ3mzp2refPmKS0tTcuXL1dVVZW2b9+uTZs2aenSpdq5c6dCoZBWrFihjo4Or0sGfKWf1wUAfhMOh3Xfffdp69ateuGFF3TXXXepo6NDS5YsUUNDg0aPHq0lS5Zozpw5eu655/TII4+oo6NDDz30EB0McAKXxYAujDFasmSJnn/+eb388ssqLi6W4zhqaGjQtddeq9raWhUVFentt99WamqqIpGInnnmGT3xxBNav369pkyZ4vWPAPgCl8WALr766iu9+OKLmj9/vmbNmnXOTiQlJUUPPvigpk6dqqeeekqRSCRBlQL+RrgAXWzYsEEdHR164IEHFAgEZIw566qw2GMDBgzQggULtHnzZu3duzfBFQP+xJwLcIIxRlVVVRo9erQikYgWL16saDQqya4cC4fDkqSamhotXrxYgYD9t9mwYcM0Z84cpaam6osvvlBubq5XPwLgG4QLcEI0GtXBgwd12WWX6cCBAyopKTnjZa5vvvlGpaWlJ/973Lhxmj9/vgYPHqwDBw4ksmTAtwgX4ATHcZSamqrW1lY5jqP+/fufDBdjjNrb27s9Lyb2vPb29m5fBy5mhAtwQiAQUH5+vjZt2qQxY8bovffeOznfcuzYMd1zzz2qq6vTtddeqxUrViglJUWSFAwGFQ6HdXj/fuUtWybt2CHddJN0883SpZd6+SMBniFcgC5uu+02rV69Wjt27ND06dNPrhZraGhQ//79JUnp6ekqLCw82aUYY/Tss89qUP/+GjdqlLRmjfT00/YbjhljQyb2p6BAYi8MLgKEC9DFjBkzNGrUKP3lL3/RlClTlJGR0ePzjTHat2+fli9frv9bsEBZjz8uGSPt3Stt3ixt2SJVVkovvWS/npNjQybW2Vx/vXQitIC+hE2UwClee+013XvvvZo3b56efPJJBYNBHTly5LRNlP369dPhw4c1f/587dmzRxUVFRoyZMiZv2k4LP3nPzZsNm+W3n9fammRBg6Ubrihs7O56SZp8OCE/rxAPBAuwCkikYhKSkr02GOP6fbbb9ejjz6q3NxcrV69Wo2NjRoxYoTuvPNOffjhh1q6dKnq6uq0du1aFRYWnv+btLdL27d3725iK83GjpWKimzQFBVJublcSkPSIVyAM4hEInrllVf06KOPqr6+XtOmTVNhYaEyMzN18OBBvf/++/r0009VVFSkZcuWacyYMb17Q2OkmhobMrHuZscO+9iwYd3nba67TmJVGnyOcAF6cPDgQa1bt04bN25UdXW1WltbNXjwYE2aNEmzZ8/WLbfcogEDBsTnzRsapKoqGziVldIHH0itrVJamjRlSudltGnTpEGD4lMD8B0RLsB5iEajamtrUyQSUWpq6smVYwnV2ipt3dq9u6mvt5fMJkzo3t1ceSWX0uApwgVIVsZIu3d3djaVldJ//2sfu/zyznmbm2+WJk6U+rE4FIlDuAB9yaFDtquJdTYffWQXD6SnS1OndnY2U6ZI51hmDfQG4QL0ZceP24CJrUrbssXO5QQCdmFAbN6mqMh2O4BLCBfgYhKNSrt2dZ+3qa62j40c2X3eZtw46cQRN8CFIlyAi92BA517bSor7aKBjg67Am3atM55mxtvlC65xOtqkSQIFwDdNTfbZc+x7mbLFunoUbsgYNKk7gsFhg3zulr4FOECoGfRqPTFF91XpX39tX0sP7/7vM3VV9v5HFz0CBcAF27fvs7OprLSHmUTidhz0WJdzc0323PTBg70ulp4gHAB0HtNTfYwzlhnU1Vlv5aaKk2e3P1gzpwcr6tFAhAuANzX0SF99ln3VWm1tfax0aO7z9uMHs1pAn0Q4QIgMfbu7T5v8+mn9pSBSy/tPm9TWCjF67w2JAzhAsAbR4923uOmstJ+3txsg+WGGzq7m5tukrKyvK4WF4hwAeAP7e3SJ590727277ePjR3b2d3cfrs0dKi3teKcCBcA/hT71XTqx0CAOZokwDGpAPwpFiBdg4R/CycNwgWA61pbWlSzfr2ibW1el3LeHGOUPXmyhl5zjdel9AmECwDXHT96VPtff12jfv5zac8eeyhmerrXZZ3Zl19K3/uejhw9qgP//jfh4hLCBUBcDBw6VFesWSPnn/+U1q6VbrnF65JO19oqLV0q1dQo9Q9/UB1zOa7hECAA8REISFdcIX37rbRpkz/nS/bvl3butEfXcD8bVxEuAOLDcWy3EgjYZcXt7V5XdLpt26Rw2C515lgaVxEuAOJn4kS7AXLXrs49K35hjPTuu/bj979vbykA1xAuAOJnyBBp/HjbHXz8sdfVdNfaak8HSEnx53xQkiNcAMRPv372F7cx0ttv+2veZe9eafduG4ATJnhdTZ9DuACIrx/8wIZMZaV0/LjX1VjG2FsENDVJ110nZWd7XVGfQ7gAiK/x4+3tkKurpf/9z+tqOr35pg2ZH/6Qu2fGAX+jAOJr8GB7w7Bvv7Xdix8ujTU22lOYBwywk/nsb3Ed4QIgvgIBacYM+/kbb/gjXHbutHMuV10lXX2119X0SYQLgPhyHNsdpKXZeY6GBm/riS0uaGuzx/j79ViaJEe4AIi//Hx7O+P9+6Xt272tpaPDdlCOI82c6W0tfRjhAiD+Bg60E+eRiFRR4e2lsdpae4vlrCxp6lTmW+KEcAEQf44jhUJ2SfJbb0ktLd7UYYy0ebPd1FlYKA0f7k0dFwHCBUBiTJokXXaZPeL+yy+9qcEYacMG+3HWLLs7H3FBuABIjMGDpaIi27V4tWqsvt4e+XLJJdL06VwSiyPCBUBiOI70k5/Yj+XliT8
|
||
|
"text/plain": [
|
||
|
"<Figure size 500x200 with 4 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.get_act(x)\n",
|
||
|
"model.plot()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 97,
|
||
|
"id": "74a1fb02",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.2\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUf0lEQVR4nO3de3TU5Z3H8c8zucDkQm4MhJAEyNWigKQiKqmeFiQqbrcunp61rqsutF23hf1n3T+gx+26a/fU+kdSWkVPz+kqbe1uF46ughutWBGCeEFURAQSIDMhF3KZXEhCMjPP/hGYEkBE+M0l4f36x0MyM79v5Dz58H2e5/f8jLXWCgAAB7liXQAAYPwhXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjkuMdQHAWGCtVUdHh/r6+pSWlqacnBwZY2JdFhC36FyAC/D7/aqpqVFpaak8Ho9mzZolj8ej0tJS1dTUyO/3x7pEIC4ZHhYGnF9tba2WL1+u/v5+SSPdy2mnu5aUlBRt3LhRVVVVMakRiFeEC3AetbW1WrZsmay1CoVCn/s6l8slY4w2b95MwABnIFyAs/j9fuXn52tgYOCCwXKay+WS2+2Wz+dTZmZm5AsExgDWXICzPPvss+rv77+oYJGkUCik/v5+PffccxGuDBg76FyAM1hrVVpaqoaGBn2ZoWGMUVFRkQ4ePMguMkCECzBKe3u7PB7PZb0/JyfHwYqAsYlpMeAMfX19l/X+3t5ehyoBxjbCBThDWlraZb0/PT3doUqAsY1wAc6Qk5Oj4uLiL71uYoxRcXGxsrOzI1QZMLYQLsAZjDFatWrVJb139erVLOYDp7CgD5yF+1yAy0fnApwlMzNTGzdulDFGLteFh8jpO/Q3bdpEsABnIFyA86iqqtLmzZvldrtljDlnuuv019xut7Zs2aKlS5fGqFIgPhEuwOeoqqqSz+dTdXW1ioqKRn2vqKhI1dXVampqIliA82DNBbgI1lq98cYbWrx4sV5//XV9/etfZ/EeuAA6F+AiGGPCayqZmZkEC/AFCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIF+ALDA0NqaGhQe+//74kae/evTp+/Lh4WgXw+XieC/A5BgcH9eKLL2r9+vX68MMPNTw8LGOMrLXKzMzU0qVLtXr1as2dO5cj+IGz0LkA59HW1qYHHnhAK1asUEpKitatW6edO3dqz5492rZtm9auXatPP/1UVVVVWr9+vQKBQKxLBuJKYqwLAOKN3+/X/fffr927d+vpp5/W3XffrUAgoDVr1qizs1NlZWVas2aN7r33Xj355JP60Y9+pEAgoB/+8Id0MMApTIsBZ7DWas2aNXrqqaf0m9/8RsuWLZMxRp2dnZo3b558Pp8qKyu1detWJSUlKRgM6he/+IUee+wxvfTSS1q4cGGsfwQgLjAtBpzh0KFD+tWvfqWVK1fq9ttv/8JOJCEhQd///vd1ww036PHHH1cwGIxSpUB8I1yAM7z88ssKBAL67ne/K5fLJWvt5+4KO/29CRMm6KGHHtL27dvV2NgY5YqB+MSaC3CKtVY7d+5UWVmZgsGgHn74YYVCIUkjO8f8fr8kqaGhQQ8//LBcrpF/m+Xm5uree+9VUlKSPvnkE82aNStWPwIQNwgX4JRQKKTW1lbl5eWppaVF1dXV553mOnbsmGpqasJ/vvrqq7Vy5UplZWWppaUlmiUDcYtwAU4xxigpKUknT56UMUbJycnhcLHWanh4eNTrTjv9uuHh4VFfB65khAtwisvlUnFxsbZt26by8nK99dZb4fWWnp4e3XPPPWpra9O8efO0fv16JSQkSJLcbrf8fr/amtv06ROf6rV9r6ngpgIVLipUyuSUWP5IQMwQLsAZbr31Vm3YsEH79u3T4sWLw7vFOjs7lZycLElKS0tTRUVFuEux1uqXv/ylUpNTVV5Srr2/26u6n9VJkiaXT1bBogIVLBoJm+zSbO6FwRWBcAHOsGTJEpWUlOinP/2pFi5cqPT09Au+3lqrpqYmrVu3Tg8+9KAefPRBWWvV3ditxu2N8tZ55d3h1Z5f75G1Vqme1JGwuWkkcPK+mqeE5IQo/XRA9HATJXCWF154Qffdd59WrFihn/zkJ3K73erq6jrnJsrExER1dHRo5cqVOnr0qGprazVlypTzfuagf1C+t33y1nnVuL1RTbuaNDwwrMSJiZq+YHq4uym4qUDuLHeUf2LAeYQLcJZgMKjq6mr9+Mc/1m233aZHHnlEs2bN0oYNG9Tb26v8/Hzdddddevfdd7V27Vq1tbXp+eefV0VFxcVfYziolj0tatzeKF+dT407GtXX0idJ8sz2qLCycGTdprJQmbMymUrDmEO4AOcRDAb1hz/8QY888oja29t14403qqKiQpmZmWptbdWuXbv00UcfqbKyUk888YTKy8sv63rWWnU1dMm7wxvubo7vOy5JSstNU+GiwnB3k3ttrhKSmEpDfCNcgAtobW3Viy++qC1btqi+vl4nT55UVlaW5s+fr+XLl+vmm2/WhAkTInLtgc4BeXeOrNl4d3jV9E6TAicDSkpJUv7C/PA0Wv6N+ZqYMTEiNQCXinABLkIoFNLQ0JCCwaCSkpLCO8eiKXAyoObdzaO6m/72fhljNGXOlPCOtIJFBcoozGAqDTFFuABjlLVWnQc71bijMdzdtH/WLkmaNH2SCisLlX9TvgoXFWrq3KlyJXKUIKKHcAHGkRPHT4xsf67zyrvdq2PvHVNwOKjktGTl35Af7m6mL5yuCemRmc4DJMIFGNcCgwEde+/Yn++5qfNqoHNAxmWUe21ueN2msLJQk6ZPinW5GEcIF+AKYkNW7fvb1bjj1Bbo7Y3qrO+UJGXOyBx1moDnao9cCUyl4dIQLsAVrq+lb2SDwKm1m+bdzQoFQpqYMVH5N+aHTxOYfv10JadGfyMDxibCBcAow/3DanqnKdzdeOu8GuwelCvRpWnzp6mgsiB8MGdablqsy0WcIlwAXJANWbV90hbekda4o1H+I35JUnZx9qh1m8lXTZZxsQUahAuAS9DT1BO+38a7w6uWPS0KBUNyZ7nD02gFiwo0fcF0JU7kfNwrEeEC4LIN9Q3Jt8sX7m68O70a6htSQlKC8q7LG3UwZ6onNdblIgoIFwCOCwVCav24ddRpAj2+HklSTllO+GDOgkUFyinL4TSBcYhwARAV3Y3do04TaP2oVdZapUxOUeGiU6cJVBZqWsU0JU5gKm2sI1wAxMRg95+fcePd4ZXvbZ+G+4eVOCFReQvy/tzd3FQgdzbPuBlrCBcAcSE4HFTrh62jupve5l5JI8+4Ob1mU3JbidKmsgU63hEuAOJS+FeTHf1n4zKs0YwBTGwCiEvhAAn/x4h/C48dhAsAxw0ODOrtl95WYCgQ61IunpVKrivRzK/MjHUl4wLhAsBxJ7pP6KNXP9JNf3WT/Ef9ypyRqeS0+DyXrONAh5InJau7u1t739xLuDiEcAEQERlTM3T0d0d14H8PaPnzy1V2c1msSzpH4GRAG9ZuUFdDlyr+pUInzIlYlzRucJ42gIgwLqOMggwNnxjW0W1H43K9pK+5T8c/Pa5QMMTzbBxGuACIDCPNuHmGjMvIu8Or0HAo1hWdo/mDZg36B+WZ7VGKJyXW5YwrhAuAiJk6d6rc2W61728P37MSL6y1OvKnI5KVCr9WKFcivw6dxP9NABGTOiVVnms8GvQPqvn95liXM0rwZFC+Op9MgtGMm2fEupxxh3ABEDGuRNfIL24rHd56OK7WXbobu9VxsEOpU1I1dc7UWJcz7hAuACJq5i0z5Up0ybvDq8BgfNz3Yq2Vb5dPQ31Dyr02V+4czi5zGuECIKKmXDNFqbmp6qzvlP+wP9blhDX8sUGy0sxvzOTpmRFAuACIqIlZE5V3XZ6GTwyrcUdjXEyNDfUOqentJiVMSNDMr83krLIIIFwARJRxGRU
|
||
|
"text/plain": [
|
||
|
"<Figure size 500x200 with 4 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.perturb(mode='best', mag=0.1)\n",
|
||
|
"model.get_act(x)\n",
|
||
|
"model.plot()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 98,
|
||
|
"id": "fd0d2987",
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"| loss: 5.03e-05 |: 100%|███████████████████████████████████████████| 20/20 [02:59<00:00, 8.99s/it]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from kan import *\n",
|
||
|
"from kan.utils import batch_jacobian, create_dataset_from_data\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"torch.use_deterministic_algorithms(True)\n",
|
||
|
"\n",
|
||
|
"def closure():\n",
|
||
|
" \n",
|
||
|
" global loss\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" \n",
|
||
|
" jacobian = batch_jacobian(model, x, create_graph=True)\n",
|
||
|
" hessian = batch_hessian(model, x, create_graph=True)\n",
|
||
|
" Lqdqd = hessian[:,d:,d:]\n",
|
||
|
" Lq = jacobian[:,:d]\n",
|
||
|
" Lqqd = hessian[:,d:,:d]\n",
|
||
|
"\n",
|
||
|
" Lqqd_qd_prod = torch.einsum('ijk,ik->ij', Lqqd, qd)\n",
|
||
|
"\n",
|
||
|
" qdd_pred = torch.einsum('ijk,ik->ij', torch.linalg.inv(Lqdqd), Lq - Lqqd_qd_prod)\n",
|
||
|
" loss = torch.mean((qdd - qdd_pred)**2)\n",
|
||
|
"\n",
|
||
|
" loss.backward()\n",
|
||
|
" return loss\n",
|
||
|
"\n",
|
||
|
"steps = 20\n",
|
||
|
"log = 1\n",
|
||
|
"optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
|
||
|
"#optimizer = torch.optim.Adam(params, lr=1e-2)\n",
|
||
|
"pbar = tqdm(range(steps), desc='description', ncols=100)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"for _ in pbar:\n",
|
||
|
" \n",
|
||
|
" # update grid\n",
|
||
|
" if _ < 5 and _ % 20 == 0:\n",
|
||
|
" model.update_grid(x)\n",
|
||
|
" \n",
|
||
|
" optimizer.step(closure)\n",
|
||
|
" \n",
|
||
|
" if _ % log == 0:\n",
|
||
|
" pbar.set_description(\"| loss: %.2e |\" % loss.cpu().detach().numpy())\n",
|
||
|
" \n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 99,
|
||
|
"id": "782f818f",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYiUlEQVR4nO3dWXBc1Z0G8O+01K1e1Yska5ctybIheAGxGLCzjRkrGdfUVEp+GPCQpUJCUWP8xkNM4kpRlVQlxYOEk+CkkgeWKkil5CmK2BNlhjABLzhsxhgTCPLSi6213ft+75mHttpqLC/g2923W9+vylVGUkt/2xx9+p9z7jlCSilBRESkIUOlCyAiotrDcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs3VV7oAomogpcTc3BxisRjsdjuampoghKh0WUS6xc6F6CpCoRBGR0cxMDCAlpYW9Pb2oqWlBQMDAxgdHUUoFKp0iUS6JHhZGNHixsfHMTw8jEQiASDfvcyb71qsVivGxsYwNDRUkRqJ9IrhQrSI8fFxbN26FVJKqKp6xY8zGAwQQmD//v0MGKIFGC5EnxIKhdDV1YVkMnnVYJlnMBhgsVjg9/vhcrlKXyBRFeCaC9GnPPPMM0gkEtcVLACgqioSiQSeffbZEldGVD3YuRAtIKXEwMAATp06hc8yNIQQ6Ovrwz/+8Q/uIiMCw4WoyOzsLFpaWm7o9U1NTRpWRFSdOC1GtEAsFruh10ejUY0qIapuDBeiBex2+w293uFwaFQJUXVjuBAt0NTUhP7+/s+8biKEQH9/PzweT4kqI6ouDBeiBYQQePTRRz/Xa3fu3MnFfKKLuKBP9Cl8zoXoxrFzIfoUl8uFsbExCCFgMFx9iMw/ob9v3z4GC9ECDBeiRQwNDWH//v2wWCwQQlw23TX/NovFggMHDmDLli0VqpRInxguRFcwNDQEv9+PkZER9PX1Fb2vr68PIyMjCAQCDBaiRXDNheg6SCnx6quvYvPmzXjllVfw1a9+lYv3RFfBzoXoOgghCmsqLpeLwUJ0DQwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOF6BoymQxOnTqFt99+GwBw4sQJzMzMgLdVEF0Z73MhuoJUKoWXXnoJe/fuxXvvvYdsNgshBKSUcLlc2LJlC3bu3Il169bxCH6iT2HnQrSI6elpfPvb38Z3v/tdWK1W7NmzB0eOHMGxY8fw2muv4fHHH8eHH36IoaEh7N27F7lcrtIlE+lKfaULINKbUCiEb33rW3jnnXfw61//Gtu2bUMul8OuXbsQDAaxatUq7Nq1C9u3b8evfvUr/PCHP0Qul8OOHTvYwRBdxGkxogWklNi1axeefvppPP/889i6dSuEEAgGg1i/fj38fj82bdqEv/zlLzAajVAUBb/4xS/wk5/8BC+//DI2bNhQ6T8CkS5wWoxogU8++QS//e1v8dBDD+HrX//6NTuRuro6PPzww7j77rvx85//HIqilKlSIn1juBAt8Mc//hG5XA7f+973YDAYIKW84q6w+fc1NDTgkUcewcGDB+H1estcMZE+cc2F6CIpJY4cOYJVq1ZBURQ89thjUFUVQH7nWCgUAgCcOnUKjz32GAyG/M9mbW1t2L59O4xGIz744AP09vZW6o9ApBsMF6KLVFXF1NQUOjo6MDk5iZGRkUWnuc6dO4fR0dHCf99yyy146KGH4Ha7MTk5Wc6SiXSL4UK0gFAE0uk0hBAwmUyFcJFSIpvN5j9GCBiNxsJr5j8um80WvZ1oKWO40JKl5lRMvT8F70Ev/If98B7yIuaLYXLVJFavXo3XX3+9sN4SiURw//33Y3p6GuvXr8fevXtRV1cHALBYLAiFQpidnkXugxwm/jwBZ48Tzh4njFaGDS1NDBdaMtLRNAJHA/Ae8sJ3yAf/G35kYhnUGevQcUcH1vz7GjwoHsQP9vwAJ0+exObNmwu7xYLBIEwmEwDAbrdjcHCw0KVIKfHLX/4SdpsdN998cz6wDuUX9q3N1kLQOHucsHgsfBaGlgSGC9WsSCAC70EvfId98B3yYfLYJKQqYXFb0L2xG1/c9UV0b+xGxx0dMFryQXFn6E787r9/h5/97GfYsGEDHA7HVb+GlBKBQAB79uzBN7/zTdz7nXshpUQ6nEbYGy78mnx3ElJKmGwmNHY3FsLG0eGAoY6bNqn2MFyoJqiKipkPZgpdie+QD6GzIQCAp9+D7o3duP3h29GzsQfNNzVDGBbvHlwuF5544gk8+OCD+NGPfoSf/vSnsFgsi36slBJzc3PYsWMHrFYrduzYASC/JmN2mWF2mdG6rhUAkEvlEPFHEPaGETobwplXz0DJKjDUG9DYeSlsGrsbC0FHVM34hD5VpWwii8DfAoXOxH/Ej1Q4BUO9Ae2D7eje2I3ue7vRs7EH9jb7Z/rciqJgZGQEP/7xj/G1r30Nu3fvRm9vL5577jlEo1F0dXXhG9/4Bt588008/vjjmJ6exgsvvIDBwcHr/hqqoiI2GSvqbjKxDADA1mKDc7kTzm4nnMudMLvMnEqjqsNwoaoQm4zlu5LDPvgO+nD+3fNQcyrMTjO67ulC98Z8kHTe1anJIrqiKPjDH/6A3bt3Y3Z2Fvfccw8GBwfhcrkwNTWFo0eP4vjx49i0aROefPJJrF69+oa+npQSqQupfND4wgifDSM+EwcAmOymonUbe5udU2mkewwX0h2pSsz+fbZovSQ4EQQAuJa70LOpB133dqFnYw9abmkp6TfaqakpvPTSSzhw4AAmJiaQTqfhdrtx2223YXh4GF/60pfQ0NBQkq+dTWYR8UUKnU0kEIGaU1FnrENj14KptK5G1Js5w036wnChisulcgi8GSislfgO+5C8kIQwCLTd2lboSro3dqOxs7EiNaqqikwmA0VRYDQaCzvHylpDTkX0fLRoKi2byN8xY1tmK+puGpwNnEqjimK4UNnFZ+KFjsR3yIdzb52DklVgspvQfU93oSvp3NCJBkdpuoJaIKVEMpgsCpvEbAIA0NDYUDyV1mq/4iYGolJguFBJSSkx9/FcoSPxHvRi7uM5AEBjV2OhI+ne2I3Wta0w1HMt4UZk4pmiqbTouShURUWdKT+V5lruym+B7nSgvoFTaVQ6DBfSVC6dw/l3zheeevcd9iE+E4cQAq3rWi/t4trUA2ePs9Ll1jw1pyJ6LorQ2VB+3cYXQTaZn0qzt9mLp9Ia2SWSdhgudEOSwWR+iuviNFfgbwHk0jkYrUZ0behC96Z8mHTd3QWz01zpcpc8KSUSs4miqbRkMAkAMLvMRWFja7FxKo0+N4YLXTcpJS6cugDfIV/hYcWZkzMAAEe7o6graV3fijpjXYUrpuuRiWUK25/D3jCi56OQqkS9ub54V1pnI+pM/Del68NwoStSsgom352E99DFgx0PehGbigEAlt2yrNCV9GzsgavXxd1JNULJKogGFuxK84WRS+UgDAKOdkdRd2Oyl3/XHFUHhgsVpMIp+I/4C11J4GgA2WQW9eZ6dN7VWdgS3HVPFyzuxY9EodojpUR8Ol40lZYKpQAAFo8lHzQXTxOwNlv5QwYBYLgsWVJKhL3hSw8qHvRh+sQ0pJSwtdgKO7h6NvagfbCd0yFUJB1JXzpNwBtGbDIGqUoYLcaigzkbOxu5A3CJYrgsEWpOxdTxqaKDHSOBCACg+abmovUSz0oPf/qkz0TJKIWDOedDR8koMNQZ4OhwFB3MabJxKm0pYLjUqHQ0Df8b/sIursLdJab83SU9m3rQfW8+UKzN1kqXSzVGqhKxqVhh+3PobAjpSBoAYG2yXjqYs8cJSxPvuKlFDJcaEfFHirqSyfcu3l3iseTXSe7tQs+mHnTc3sFzqKgiUuFU0bpNfCoOKSWMVmPRJgFHu4NTaTWA4VKFCneXXFwv8R70IuwNAwA8Kz2XupKN3Wh
|
||
|
"text/plain": [
|
||
|
"<Figure size 500x200 with 4 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.plot()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 100,
|
||
|
"id": "ad876b9d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"ename": "KeyboardInterrupt",
|
||
|
"evalue": "",
|
||
|
"output_type": "error",
|
||
|
"traceback": [
|
||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||
|
"\u001b[0;32m/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_24271/2849209031.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mauto_symbolic\u001b[0;34m(self, a_range, b_range, lib, verbose)\u001b[0m\n\u001b[1;32m 1402\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'fixing ({l},{i},{j}) with 0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1403\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1404\u001b[0;31m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msuggest_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1405\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1406\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36msuggest_symbolic\u001b[0;34m(self, l, i, j, a_range, b_range, lib, topk, verbose, r2_loss_fun, c_loss_fun, weight_simple)\u001b[0m\n\u001b[1;32m 1332\u001b[0m \u001b[0;31m# getting r2 and complexities\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1333\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msymbolic_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1334\u001b[0;31m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1335\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1e8\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# zero function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1336\u001b[0m \u001b[0mr2s\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1e8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mfix_symbolic\u001b[0;34m(self, l, i, j, fun_name, fit_params_bool, a_range, b_range, verbose, random, log_history)\u001b[0m\n\u001b[1;32m 488\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspline_postacts\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0;31m#y = self.postacts[l][:, j, i]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 490\u001b[0;31m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msymbolic_fun\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 491\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1e8\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/Symbolic_KANLayer.py\u001b[0m in \u001b[0;36mfix_symbolic\u001b[0;34m(self, i, j, fun_name, x, y, random, a_range, b_range, verbose)\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[0;31m#initialize from x & y and fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 231\u001b[0;31m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 232\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuns_avoid_singularity\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun_avoid_singularity\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/utils.py\u001b[0m in \u001b[0;36mfit_params\u001b[0;34m(x, y, fun, a_range, b_range, grid_number, iteration, verbose, device)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mb_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrid_number\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0ma_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_grid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmeshgrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindexing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ij'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m \u001b[0mpost_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0mx_mean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpost_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0my_mean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.auto_symbolic()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 80,
|
||
|
"id": "428571e6",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.5\n",
|
||
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
||
|
"0 x 1.000000 -16.565706 1 1 -2.513141\n",
|
||
|
"1 cos 1.000000 -16.599499 2 2 -1.719900\n",
|
||
|
"2 sin 1.000000 -16.599499 2 2 -1.719900\n",
|
||
|
"3 exp 0.999997 -16.268112 2 2 -1.653622\n",
|
||
|
"4 x^0.5 0.999977 -14.896568 2 2 -1.379314\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"('x',\n",
|
||
|
" (<function kan.utils.<lambda>(x)>,\n",
|
||
|
" <function kan.utils.<lambda>(x)>,\n",
|
||
|
" 1,\n",
|
||
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
||
|
" 0.9999996907837526,\n",
|
||
|
" 1)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 80,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.unfix_symbolic(0,0,0)\n",
|
||
|
"model.suggest_symbolic(0,0,0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 47,
|
||
|
"id": "0dea7189",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.4\n",
|
||
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
||
|
"0 0 0.000000 0.000014 0 0 0.000003\n",
|
||
|
"1 cos 0.969503 -5.034727 2 2 0.593055\n",
|
||
|
"2 x^2 0.969092 -5.015413 2 2 0.596917\n",
|
||
|
"3 sin 0.965249 -4.846400 2 2 0.630720\n",
|
||
|
"4 x 0.000392 -0.000551 1 1 0.799890\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"('0',\n",
|
||
|
" (<function kan.utils.<lambda>(x)>,\n",
|
||
|
" <function kan.utils.<lambda>(x)>,\n",
|
||
|
" 0,\n",
|
||
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
||
|
" 0.0,\n",
|
||
|
" 0)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 47,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.unfix_symbolic(0,1,0)\n",
|
||
|
"model.suggest_symbolic(0,1,0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 101,
|
||
|
"id": "ef60542b",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARwAAAESCAYAAAAv/mqQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlkElEQVR4nO3de1xUdf4/8BcXAS8wpshNEchWtKjAYVVQ1DTH0HW77UNaW9HS/cZ+KyW6rGRfU3ID22KpLTFL7edX12y9tPWT0tlUMMFNCFoLM0sRVBBBncEbCHy+fxizDjPA+Zw5lzkz7+fjMY/HNvs557zPjPPi3D6fjwdjjIEQQhTgqXYBhBD3QYFDCFEMBQ4hRDEUOIQQxVDgEEIUQ4FDCFEMBQ4hRDHeahcgRHt7O86cOQN/f394eHioXQ4h5CaMMTQ1NSEsLAyent0fw2gicM6cOYPw8HC1yyCEdKOmpgZDhgzpto0mAsff3x/AjR0KCAhQuRpCyM3MZjPCw8Mtv9PuaCJwOk6jAgICKHAIcVJCLnfQRWNCiGIocAghiqHAIYQoRhPXcAghyrra0oZXCypR1XgFkQP74MXpt6O3j5fD63WZwJHrAyLE3fx+wyEYK+st/73/GPC/B6sx9fYgvJf6S4fW7aGFAbjMZjN0Oh1MJpPdu1SdP6AOk6MHYt1jY5UokRCX0NVvqYO90Onp93kzzV/D6e4D2nO0Eb9cYVS4IkK06WpLW7dhAwDGynpcbWkTvQ1NB46QD+jcpRY8/sFXClVEiHYt//Q7SdvZo+nAyfr/3wpqt+f7cw6lMiHuYPvXNYLaff5trehtaDpwDh4/L7it0HAixB21tLZD6N/ki1dbRW9H04Hj4yW8/KLv62SshBBte3ffT4psR9OB82DcYMFtT5vFpzIhru69L4UHTh8HHqbRdOA8Pv5WtUsgxCWYrwm/xmnMmCx6O5oOHB9vTZdPiCYNHtBb9LL0iyWEKIYChxCiGAocQohiKHAIIYrRfOB4ckziUHfxmnyFEKJRbe3K9d/mDpyioiLMnDkTYWFh8PDwwMcff9zjMoWFhdDr9fDz88Ott96K1atXi6nVrr4cQ1BMy/1Csu0S4ioKj54T3LaXg9viDpzLly/j7rvvxttvvy2o/YkTJzB9+nQkJSWhvLwcL774IhYuXIht27ZxF2tP8p2hgtuaWiTZJCEu5c+7vhfcdurtgQ5ti/uZweTkZCQnJwtuv3r1agwdOhR5eXkAgJEjR6K0tBSvv/46Hn74Yd7N21j+6xh8VHrK4fUQ4q5+rG8S3Pa3Yxx72Fb2azglJSUwGAxW702bNg2lpaW4fv263WWam5thNputXl3hGdWP5uwkxNb1duFtE3/h2BGO7IFTV1eH4OBgq/eCg4PR2tqKhoYGu8tkZ2dDp9NZXj3Nuin0MI0GHCXEMV48d2nsUOQuVecJsjpGNe1q4qzMzEyYTCbLq6am+3E6hHbLpO6bhKhL9kHUQ0JCUFdnPTREfX09vL29MXDgQLvL+Pr6wtfXV+7SCCEKk/0IJyEhAUaj9bjCu3fvRnx8PHr1cvQmG7+WVo4TVkKIpLgD59KlS6ioqEBFRQWAG7e9KyoqUF1dDeDG6VBqaqqlfVpaGk6ePImMjAwcOXIE69atw9q1a/Hcc89Jswec3vin+PFYCXE11Q1XFN0ed+CUlpYiLi4OcXFxAICMjAzExcVh6dKlAIDa2lpL+ABAVFQUCgoKsG/fPsTGxuKVV17BW2+9Jckt8Q48l7He3VfdcyNC3MR9bxYquj3uaziTJk1Cd1NZffDBBzbvTZw4EV9//TXvpgQLDfDFGXOzbOsnxFVd4bgn/vCoEIe3p/m+VADwj6eS1C6BEJe34oFYh9fhEoEzKIDuaBEiNymmznaJwCGEaAMFDiFEMW4ZOOfoAjMhqnCZwOnF0ccjOfefMlZCiDacqL+s+DZdJnCCOC4cN9DAf4Qo/gwO4EKB8+iYCLVLIERTmtuEDy0a2V+abkguEzgLkmgWTkLksv2pSZKsx2UCh2bhJEQ+A/r5SLIe+pUSQhTjtoGj5NQYhJAb3DZw9n13Vu0SCFHN+UvqTGHiUoHTi6Orx0v/qJCtDkKc3W/yD6iyXZcKnHG3Ch9RvvZSm4yVEOLcjjcqO/BWB5cKnLcf1atdAiEu58HYIMnW5VKB089P9jHhCXE7rz40SrJ1uVTgEEKkJ8U4OB3cOnDo1jghynLrwNlzuK7nRoS4mNPnr6q2bZcLHJ4dem6bfAO7E+Ks1Ogl3sHlAmf8bcJvjZvUefaJEFU1NQt/JGTRvdJ2ina5wFn1O7o1TohUnpwULen6XC5w6NY4IdKRehQGlwscQojzcvvAqbtI440S93HpWquq23f7wLn3z1+oXQIhivnvjWWqbt8lAyeUY0B16sNJ3MmXPzaoun2XDJxPaK5xQuxq52gb7i99PLhk4NBc44Q47h+Lpki+TpcMHF7Up4oQW1INnH4zChwAxcfUPa8lRAnOcEdWVOCsWrUKUVFR8PPzg16vx/79+7ttv2nTJtx9993o06cPQkND8dhjj6GxsVFUwXLYWHxM7RIIkV3yW0Vql8AfOFu2bEF6ejqWLFmC8vJyJCUlITk5GdXV1Xbbf/nll0hNTcX8+fPx3Xff4e9//zsOHTqEBQsWOFx8dwb2ET5T4K6jF2SshBDncOHKdcFtfTzkqYE7cHJzczF//nwsWLAAI0eORF5eHsLDw5Gfn2+3/cGDBxEZGYmFCxciKioK48ePxxNPPIHS0lKHi+/OzoUTZF0/Ia7szd/EyrJersBpaWlBWVkZDAaD1fsGgwHFxcV2l0lMTMSpU6dQUFAAxhjOnj2LrVu3YsaMGV1up7m5GWaz2erFK6S/H/cyhJAbDHFhsqyXK3AaGhrQ1taG4OBgq/eDg4NRV2d/MKvExERs2rQJKSkp8PHxQUhICPr374+//vWvXW4nOzsbOp3O8goPD+cpU5SWVp4nFAjRFt47sV6e8pxTibpo7OFhXQxjzOa9DpWVlVi4cCGWLl2KsrIyfP755zhx4gTS0tK6XH9mZiZMJpPlVVNTI6ZMLvmFP8i+DULUUqzyE8YduMZyCAwMhJeXl83RTH19vc1RT4fs7GyMGzcOzz//PADgrrvuQt++fZGUlIQVK1YgNDTUZhlfX1/4+ir78F6e8ScsmjJC0W0SopQtpfL/0RaC6wjHx8cHer0eRqPR6n2j0YjExES7y1y5cgWentab8fK6MQo8Y/I+cHd7aIDgtvToH3FlBzieNesr/AYvN+5TqoyMDLz//vtYt24djhw5gmeeeQbV1dWWU6TMzEykpqZa2s+cORPbt29Hfn4+jh8/jgMHDmDhwoUYPXo0wsLkuTDVYfPvx8q6fkK04sJV4bfEdz8zWbY6uIfHS0lJQWNjI7KyslBbW4uYmBgUFBQgIiICAFBbW2v1TM68efPQ1NSEt99+G88++yz69++PyZMnY+XKldLtRRd0HM/iADcuHEs9whkhWjN4QG/Z1u3B5D6vkYDZbIZOp4PJZEJAgPDTJACIXLxTcNs/3BOBP06L4S2PEKdW3XAFE17fK7h9VU7Xj6zYw/P7dPk/5zw7mL/3pGx1EKKWaSpOC9OZywcOz7QxhLiiq9eFP2PmK1OXhg4uHzg0bQwhwu15Xr4LxoAbBA7vtDFqDzJNiJSutvCNoSvnBWPADQKH1x/+X4naJRAimRf+/o3aJVhxi8Dx5jgv3X+Cv6MoIc7q08O1apdgxS0C547BfLfSCSHycIvA2fA4PXFMSE8OvCDvBWPATQKH94ljZxj7lRBHmThG+APkv2AMuEng8Lonh2bjJNo3a7X9QfHU5DaB08VwPXZdla8MQhRztP6S2iXYcJvA+fNDd6tdAiFuz20C50H9YK72NOQocSe7FJp0wG0Ch3eM1jf3HJGpEkLkV1F1kat9dJi/PIV04jaBw+udPVVql0CIaA+sPqB2CXa5VeBE3CL/bT9CtOYWGYcU7cytAmfHk+PVLoEQp7P7+XsV25ZbBc6Afj5c7U+fpxvkRHt4/90OClBuhhS3Chxek/+8R+0SCOH
|
||
|
"text/plain": [
|
||
|
"<Figure size 300x300 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"x, y = model.get_fun(0,1,0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 102,
|
||
|
"id": "8f77c061",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"checkpoint directory created: ./model\n",
|
||
|
"saving model version 0.0\n",
|
||
|
"saving model version 0.1\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from kan.utils import create_dataset_from_data\n",
|
||
|
"\n",
|
||
|
"dataset2 = create_dataset_from_data(x[:,None], y[:,None])\n",
|
||
|
"model2 = KAN(width=[1,1,1])\n",
|
||
|
"model2.fix_symbolic(0,0,0,'x^2',fit_params_bool=False)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 103,
|
||
|
"id": "1c62302d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAf/ElEQVR4nO3de3DU9b3/8ddnc93cCIkRaxEkMaiAWOWmAioVEiQ9g0LH29GWDnUocqmMU88ZqUcOqBRHJUGOhxZrBWobe4zKIBzglKtcLBQELT8ugRytCaAI2TSXJQnZz++PL8kJEGKUb7Kb3edjhmFnv9ndd2ZIXrw/t6+x1loBAOAiT7ALAACEH8IFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4LroYBcAdAbWWp08eVJVVVVKSkpSenq6jDHBLgsIWXQuQCt8Pp8KCgqUnZ2tjIwM9erVSxkZGcrOzlZBQYF8Pl+wSwRCkuFOlEDL1qxZo/Hjx6umpkaS0700auxaEhISVFRUpNzc3KDUCIQqwgVowZo1a5SXlydrrQKBwEW/zuPxyBijlStXEjBAM4QLcB6fz6fu3bvL7/e3GiyNPB6PvF6vSktLlZqa2v4FAp0Acy7AeZYsWaKampo2BYskBQIB1dTUaOnSpe1cGdB50LkAzVhrlZ2drZKSEn2THw1jjDIzM1VcXMwqMkCEC3COr776ShkZGZf0+vT0dBcrAjonhsWAZqqqqi7p9ZWVlS5VAnRuhAvQTFJS0iW9Pjk52aVKgM6NcAGaSU9PV1ZW1jeeNzHGKCsrS2lpae1UGdC5EC5AM8YYTZs27Vu9dvr06UzmA2cxoQ+ch30uwKWjcwHOk5qaqqKiIhlj5PG0/iPSuEP/nXfeIViAZggXoAW5ublauXKlvF6vjDEXDHc1Puf1erVq1Srl5OQEqVIgNBEuwEXk5uaqtLRU+fn5yszMPOdaZmam8vPzVVZWRrAALWDOBWgDa602bNigu+66S+vWrdOIESOYvAdaQecCtIExpmlOJTU1lWABvgbhAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECfI36+nqVlZVp//79kqQjR47o1KlTCgQCQa4MCF3c5hi4CJ/Pp6KiIr355pvat2+fKisrVVdXp/j4eGVkZGj48OGaOHGihg4dqujo6GCXC4QUwgVowfbt2zVjxgx9/PHHGjRokPLy8tS/f38lJSXJ5/Np165dWrFihQ4fPqz7779fzz77rDIyMoJdNhAyCBfgPGvXrtWECROUlJSkuXPnasyYMaqrq1NhYaFqa2uVkpKiBx54QPX19SosLNSsWbPUt29fLVu2TN26dQt2+UBIIFyAZg4dOqTRo0crMTFRhYWF6tOnj4wxKikp0c0336yKigr16tVLu3btUteuXWWt1ZYtW/TQQw/pzjvv1Guvvaa4uLhgfxtA0DGhD5zV0NCg559/XuXl5Vq4cGFTsLTGGKNhw4bphRde0PLly7V69eoOqhYIbYQLcNbhw4e1YsUKjRs3TsOGDfvaYGlkjNE999yjW265RYsXL9aZM2fauVIg9LHEBThr27Ztqqqq0vjx4/Xpp5+qurq66VppaakaGhokSXV1ddq3b59SUlKarl955ZUaN26cZs2apePHj6t79+4dXj8QSggX4KwDBw4oISFBmZmZmjRpkrZu3dp0zVqr2tpaSdLRo0c1atSopmvGGL300ku64YYbVFNTo6NHjxIuiHiEC3CW3+9XdHS04uLiVFtbq9OnT7f4ddbaC66dOXNGXq/3nBACIhnhApx1+eWXy+/3y+fzaciQIUpMTGy65vf7tW3btqYQue2225o2Thpj1KNHD3355ZfyeDzq2rVrsL4FIGQQLsBZAwYMUH19vXbs2KF58+adc62kpESDBg1SRUWFunXrprfeekupqalN140xeuqpp3TFFVcwJAaI1WJAk8GDByszM1NLlixRdXW1oqKizvnTyBgjj8fT9LzH49GxY8f09ttvKy8vT126dAnidwGEBsIFOCs9PV1Tp07V7t27tWDBgjYvKa6trdWcOXPk9/s1adKkNi9hBsIZw2JAMxMmTNDmzZs1b948JSQkaPLkyYqPj5ckRUdHKzo6uqmLsdaqsrJSzz33nAoLCzV//nxde+21wSwfCBkc/wKc58SJE5oyZYref/995ebmasaMGbr++ut18OBBBQIBxcbG6pprrtGOHTv04osvas+ePZo9e7YmT558zvAZEMkIF6AF1dXVWrx4sRYsWKAvvvhCmZmZys7OVnJyssrLy3Xw4EEdPXpUAwYM0DPPPKM77rhDHg+jzEAjwgVoxfHjx7Vu3Tpt2rRJe/fu1Y4dOzR8+HANHTpUOTk5GjJkiBISEoJdJhByCBegjXbu3KnBgwdr586dGjhwYLDLAUIafTzQRlFRUU3LkAG0jp8SAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOu4nwvQRtZaBQIBeTweGWOCXQ4Q0uhcgG+Ae7kAbRMd7AIAt1hrVVxcrJMnTwa7lEvi8XjUr18/JSYmBrsU4FtjWAxhIxAIaMqUKbrqqquUlJR0Se/T0NCgqKiooHQqH3zwgZ5++mn179+/wz8bcAudC8JKXFycJk6cqG7dun2j1wUCARUXF+vdd9/V9u3bderUKaWlpWnIkCEaO3asrrvuug6Za7HWqqqqSvyfD50d4YKIZq1VRUWF8vPztWjRIp04ceKc6++//75efvllPfLII3ryySd1xRVXMJkPtAHhgohlrVVpaakee+wxrV69WpJ044036gc/+IF69Oihzz//XKtXr9aePXu0YMECbdmyRa+++qoGDhxIwABfg3BBRLLWqqysTD/60Y+0efNmdenSRf/6r/+qRx99VKmpqTLGyFqrJ554Qm+99ZZmz56t3bt367777tPrr7+uO++8k4ABWsG6SkScxqGwKVOmaPPmzcrIyNBvf/tbPfHEE+ratWtTaBhjlJKSop/+9Kd67733dOONN+rvf/+7JkyYoO3btzMvArSCcEHEOXPmjObOnatVq1YpJSVFCxcu1NixYxUVFdXi1xtjdPPNN+uPf/yjbrrpJpWWlurRRx/VoUOHCBjgIggXRBRrrVatWqVXX31VHo9HM2fO1L333vu1S46NMerdu7d+97vfKSsrSwcOHND06dNVXl7eQZUDnQvhgohhrdWxY8f0y1/+UjU1NRo7dqwee+yxi3Ys5zPGqF+/flq4cKG6du2qdevWae7cuTpz5kw7Vw50PoQLIkYgEND8+fO1f/9+XXXVVZozZ468Xu83eg9jjEaOHKmZM2cqKipKixYt0sqVKxkeA85DuCAiWGu1a9cuvf766/J4PHryySfVu3fvb7Xiy+PxaNKkSRo7dqxqamo0c+ZMlZWVETBAM4QLIkJdXZ1efPFF+Xw+3XrrrXr44YcvaSmx1+vVnDlz1LNnT+3fv1+/+tWvGB4DmiFcEPastdqyZYtWrVql+Ph4/eIXv1BycvIlvWfjBP/MmTMVExOjZcuWaePGjXQvwFmEC8JebW2tCgoK5Pf7NXLkSI0cOdKVDZDGGD344IPKyclRVVWVZs+erYqKChcqBjo/wgVhzVqrbdu2af369fJ6vZo+fbri4uJce3+v16unn35aaWlp+stf/qLf//73dC+ACBeEuTNnzmjRokXy+/0aMWKEhg4d6uqxLY0bLCdOnNi0Gu2zzz4jYBDxCBeELWut9u7dq7Vr1yo2NlY/+9nPXO1aGkVFRWnq1Km65ppr9Omnn2rhwoUKBAKufw7QmRAuCFvWWr3xxhuqrKzUwIED2/Wwye9+97t6/PHHFRUVpaVLl+qTTz6he0FEI1wQtj7//HMtX75cHo9HP/nJT5SQkNBun2W
|
||
|
"text/plain": [
|
||
|
"<Figure size 500x400 with 5 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model2.get_act(dataset2)\n",
|
||
|
"model2.plot()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 104,
|
||
|
"id": "096134b0",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"| train_loss: 3.77e-04 | test_loss: 3.76e-04 | reg: 3.35e+00 | : 100%|█| 50/50 [00:46<00:00, 1.07it"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.2\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model2.fit(dataset2, steps=50);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 105,
|
||
|
"id": "035e00f2",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"| train_loss: 3.73e-04 | test_loss: 3.72e-04 | reg: 3.35e+00 | : 100%|█| 50/50 [00:13<00:00, 3.81it"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.3\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model2.fit(dataset2, steps=50, update_grid=False);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 106,
|
||
|
"id": "b65775e8",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAex0lEQVR4nO3de3CU9b3H8c+zWUg2JBAIAUTEsjEqFzkqVwWqFCWUeKYIY/FoFRwckSJWxvZ4vIMXLCpjgvbM6QnaAuMxWGIPRaixpRYKqKFcOyiXELUmXETIYi6b6/7OH0+SEyDESJ7wbHbfrxmGTJaFb2YIb36/52YZY4wAAHCQx+0BAACRh7gAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA4r9sDAB2BMUYnTpxQWVmZEhISlJycLMuy3B4LCFusXIAWBAIBZWVlKS0tTSkpKRowYIBSUlKUlpamrKwsBQIBt0cEwpLFkyiB5uXl5WnatGmqqKiQZK9eGjSsWuLj45Wbm6v09HRXZgTCFXEBmpGXl6eMjAwZYxQKhc756zwejyzL0rp16wgM0ARxAc4QCATUr18/BYPBFsPSwOPxyOfzqaioSElJSe0/INABcMwFOMPy5ctVUVHRqrBIUigUUkVFhVasWNHOkwEdBysXoAljjNLS0lRYWKjv8q1hWZb8fr8OHjzIWWSAiAtwmq+//lopKSlten9ycrKDEwEdE9tiQBNlZWVten9paalDkwAdG3EBmkhISGjT+xMTEx2aBOjYiAvQRHJyslJTU7/zcRPLspSamqoePXq002RAx0JcgCYsy9K8efPO670PPvggB/OBehzQB87AdS5A27FyAc6QlJSk3NxcWZYlj6flb5GGK/TfeecdwgI0QVyAZqSnp2vdunXy+XyyLOus7a6Gz/l8Pq1fv14TJ050aVIgPBEX4BzS09NVVFSkzMxM+f3+017z+/3KzMxUcXExYQGawTEXoBWMMfrggw80YcIEbdiwQePHj+fgPdACVi5AK1iW1XhMJSkpibAA34K4AAAcR1wAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AAAcR1wAAI4jLgAAxxEXAIDjiAsAwHHEBQDgOOICAHAccQEAOI64AN+ipqZGxcXF+vTTTyVJhw4d0smTJxUKhVyeDAhfPOYYOIdAIKDc3Fy9+eab2rt3r0pLS1VdXa24uDilpKRo3LhxmjVrlsaMGSOv1+v2uEBYIS5AMz788EPNnz9fe/bs0YgRI5SRkaGhQ4cqISFBgUBA27dv19q1a1VQUKDp06frueeeU0pKittjA2GDuABneP/99zVz5kwlJCTohRde0OTJk1VdXa2cnBxVVVWpa9euuv3221VTU6OcnBwtWLBAgwcP1sqVK9W7d2+3xwfCAnEBmjhw4IAmTZqkLl26KCcnR4MGDZJlWSosLNS1116rU6dOacCAAdq+fbu6d+8uY4w2b96sO+64QzfeeKOWLVum2NhYt78MwHUc0Afq1dXVadGiRSopKdFrr73WGJaWWJalsWPH6sUXX9SaNWv03nvvXaBpgfBGXIB6BQUFWrt2raZOnaqxY8d+a1gaWJalKVOmaPTo0crOzlZtbW07TwqEP05xAept3bpVZWVlmjZtmj7//HOVl5c3vlZUVKS6ujpJUnV1tfbu3auuXbs2vt63b19NnTpVCxYs0NGjR9WvX78LPj8QTogLUG/fvn2Kj4+X3+/X7NmztWXLlsbXjDGqqqqSJB0+fFg333xz42uWZWnJkiW66qqrVFFRocOHDxMXRD3iAtQLBoPyer2KjY1VVVWVKisrm/11xpizXqutrZXP5zstQkA0Iy5AvV69eikYDCoQCGjUqFHq0qVL42vBYFBbt25tjMj111/feOGkZVnq37+/vvrqK3k8HnXv3t2tLwEIG8QFqDds2DDV1NQoPz9fixcvPu21wsJCjRgxQqdOnVLv3r21atUqJSUlNb5uWZYee+wx9enThy0xQJwtBjQaOXKk/H6/li9frvLycsXExJz2o4FlWfJ4PI2f93g8OnLkiFavXq2MjAx169bNxa8CCA/EBaiXnJysBx54QDt27NDSpUtbfUpxVVWVnn32WQWDQc2ePbvVpzADkYxtMaCJmTNnatOmTVq8eLHi4+M1Z84cxcXFSZK8Xq+8Xm/jKsYYo9LSUj3//PPKycnRK6+8oiuuuMLN8YGwwe1fgDMcP35cc+fO1bvvvqv09HTNnz9fAwcO1P79+xUKhdS5c2dddtllys/P18svv6xdu3bpmWee0Zw5c07bPgOiGXEBmlFeXq7s7GwtXbpUx44dk9/vV1pamhITE1VSUqL9+/fr8OHDGjZsmJ5++mndcMMN8njYZQYaEBegBUePHtWGDRu0ceNG7d69W/n5+Ro3bpzGjBmjiRMnatSoUYqPj3d7TCDsEBeglbZt26aRI0dq27ZtGj58uNvjAGGNdTzQSjExMY2nIQNoGd8lAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jrgAABxHXAAAjiMuAADHERcAgOOICwDAccQFAOA44gIAcBxxAQA4jue5AK1kjFEoFJLH45FlWW6PA4Q1Vi7Ad8CzXIDW8bo9AOAUY4wOHjyoEydOuD1Km3g8Hg0ZMkRdunRxexTgvLEthogRCoU0d+5cXXLJJUpISHB7nO+kpqZG5eXl6tatmzZv3qwnn3xSQ4cOdXss4LyxckFEiY2N1axZs9S7d2+3R2k1Y4z+9Kc/6eGHH9aPf/xjDRw4UPyfDx0dG8iAy+rq6vTb3/5Wn3zyiXbs2KGYmBi3RwLajLgALjLGaN++fXr//ffVqVMn3X333erUqZPbYwFtRlwAFxljtHLlSpWUlGjIkCH6wQ9+4PZIgCOIC+CiL7/8Ujk5ObIsS/fcc4+6du3q9kiAI4gL4JKGVUtxcbFSU1M1bdo0t0cCHENcAJcUFRXpN7/5jSTpnnvu6VBnuAHfhrgALgiFQlq2bJm++OILDRgwQHfddRe3lEFEIS7ABdZwJ4HXX39dknT//ferb9++Lk8FOIu4ABdYbW2tlixZoiNHjmjIkCG6++67WbUg4hAX4AIyxmjjxo1atWqVOnXqpF/84hfq2bOn22MBjiMuwAVUUlKiBQsWqKysTBMnTtStt97KqgURibgAF0hdXZ1effVVffzxx0pOTtZTTz0ln8/n9lhAuyAuwAVgjNGmTZu0dOlSSdJDDz2ka6+9llULIhZxAdqZMUZHjhzRz3/+cwUCAY0fP15z587lwWOIaPztBtpZMBjUf/zHf2jXrl26+OKL9dJLL3GbF0Q84gK0o4bjLKtWrZLP59OiRYs0dOhQtsMQ8YgL0E6MMcrNzdWiRYsan5I5ffp0woKoQFyAdtBwAP+hhx5SWVmZfvSjH+nxxx/nWS2IGsQFcJgxRnv27NF9992nY8eOafTo0crKylJiYqLbowEXDHEBHGSMUUFBge655x4VFBToyiuvVHZ2tvr27ct2GKIKcQEcYozR559/rhkzZmj37t3q37+/Xn/9dQ0cOJCwIOoQF8ABxhj985//1IwZM/Txxx+rT58+ys7O1qhRowgLohJxAdrIGKPPPvtMd911l7Zs2aJevXopOztbEyZMICyIWsQFaANjjA4cOKA777zztLBMmjSJsCCqERfgPBljtGvXLk2fPl35+fm66KKL9MYbb2jy5Mnc2gVRj+8A4DyEQiH95S9/0W233aZ//OMf+t73vqeVK1dq0qRJhAUQcQG+s9raWuXk5OjOO+/UZ599psGDBysnJ0c33ngjW2FAPeICtJIxRhUVFXrxxRd1//336/jx4xo7dqx+97vfafjw4YQFaMLr9gBAR2CM0VdffaVHHnlEb731lowxuu2225SZmanevXsTFuAMxAX4Fg0H7ufNm6ePPvpIsbGx+tnPfqZHH31UCQkJhAVoBnEBzsEYo9raWq1evVqPPPKIiouL1atXL/3yl7/UHXfcwU0ogRYQF6A
|
||
|
"text/plain": [
|
||
|
"<Figure size 500x400 with 5 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model2.plot()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 107,
|
||
|
"id": "6cd26af5",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
||
|
"0 x^0.5 0.999957 -14.193489 2 2 -1.238698\n",
|
||
|
"1 sqrt 0.999957 -14.193489 2 2 -1.238698\n",
|
||
|
"2 log 0.999722 -11.763921 2 2 -0.752784\n",
|
||
|
"3 1/x^0.5 0.999485 -10.894391 2 2 -0.578878\n",
|
||
|
"4 1/sqrt(x) 0.999485 -10.894391 2 2 -0.578878\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"('x^0.5',\n",
|
||
|
" (<function kan.utils.<lambda>(x)>,\n",
|
||
|
" <function kan.utils.<lambda>(x)>,\n",
|
||
|
" 2,\n",
|
||
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
||
|
" 0.9999566254728288,\n",
|
||
|
" 2)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 107,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model2.suggest_symbolic(1,0,0)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 72,
|
||
|
"id": "daabd91a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Best value at boundary.\n",
|
||
|
"r2 is 0.9989821969546337\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(tensor([-9.8000, 9.8868, -0.3482, 1.2049]), tensor(0.9990))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 72,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from kan.utils import fit_params\n",
|
||
|
"fit_params(x**2, y, lambda x: x**(1/2))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 74,
|
||
|
"id": "8dc20f20",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"skipping (0,0,0) since already symbolic\n",
|
||
|
"fixing (1,0,0) with x^0.5, r2=0.9999494098870415, c=2\n",
|
||
|
"saving model version 0.4\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model2.auto_symbolic()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 79,
|
||
|
"id": "f27f3daa",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/latex": [
|
||
|
"$\\displaystyle 1.19 - 1.08 \\sqrt{1 - 1.0 x_{1}^{2}}$"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
"1.19 - 1.08*sqrt(1 - 1.0*x_1**2)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 79,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from kan.utils import ex_round\n",
|
||
|
"ex_round(model2.symbolic_formula()[0][0], 2)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "a61540c1",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.7"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|