632 lines
82 KiB
Plaintext
632 lines
82 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Physics 1: Lagrangian neural network"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 94,
|
|
"id": "66865edb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"from kan.utils import batch_jacobian, batch_hessian\n",
|
|
"\n",
|
|
"torch.use_deterministic_algorithms(True)\n",
|
|
"torch.set_default_dtype(torch.float64)\n",
|
|
"\n",
|
|
"seed = 0\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"\n",
|
|
"#example = 'harmonic_oscillator'\n",
|
|
"#example = 'single_pendulum'\n",
|
|
"example = 'relativistic_mass'\n",
|
|
"\n",
|
|
"# three examples: harmonic oscillator, single pendulum, double pendulum\n",
|
|
"\n",
|
|
"# dimension of q\n",
|
|
"# Lagrangian: (q, qd) -> (qdd)\n",
|
|
"\n",
|
|
"if example == 'harmonic_oscillator':\n",
|
|
" n_sample = 1000\n",
|
|
" # harmonic oscillator\n",
|
|
" d = 1\n",
|
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qd = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qdd = - q\n",
|
|
" x = torch.cat([q, qd], dim=1)\n",
|
|
" \n",
|
|
"if example == 'single_pendulum':\n",
|
|
" n_sample = 1000\n",
|
|
" # harmonic oscillator\n",
|
|
" d = 1\n",
|
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qd = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" qdd = - torch.sin(q)\n",
|
|
" x = torch.cat([q, qd], dim=1)\n",
|
|
" \n",
|
|
"if example == 'relativistic_mass':\n",
|
|
" n_sample = 10000\n",
|
|
" # harmonic oscillator\n",
|
|
" d = 1\n",
|
|
" q = torch.rand(size=(n_sample,1)) * 4 - 2\n",
|
|
" #qd = torch.rand(size=(n_sample,1)) * 1.998 - 0.999\n",
|
|
" #qd = 0.95 + torch.rand(size=(n_sample,1)) * 0.05\n",
|
|
" qd = torch.rand(size=(n_sample,1)) * 2 - 1\n",
|
|
" qdd = (1 - qd**2)**(3/2)\n",
|
|
" x = torch.cat([q, qd], dim=1)\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 95,
|
|
"id": "ec549451",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.compiler import kanpiler\n",
|
|
"from sympy import *\n",
|
|
"\n",
|
|
"input_variables = symbol_x, symbol_vx = symbols('x v_x')\n",
|
|
"expr = symbol_vx ** 2\n",
|
|
"model = kanpiler(input_variables, expr, grid=20)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 96,
|
|
"id": "f9930812",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASgklEQVR4nO3de3BU5f3H8c/ZEGBjUkJiuCiKSQgoNyWigKZ2LMiK6c3iOPVHrYzgWFrBf0r/AMc6ttqp8kcyeMGOMw7gwLQ2jhTCNN6LhNQb4A2omFSYIBBCWEhMyGX3+f3xsCThEsCc3XM2vF8zTGJ2N/sNzpMP3/NcjmOMMQIAwEUBrwsAAPQ9hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHX9vC4ASAbGGB0+fFhNTU1KT09Xdna2HMfxuizAt+hcgB6Ew2GVlpaqoKBAOTk5ys3NVU5OjgoKClRaWqpwOOx1iYAvOdwsDDiziooKzZ49W83NzZJs9xIT61rS0tJUVlamUCjkSY2AXxEuwBlUVFSouLhYxhhFo9GzPi8QCMhxHJWXlxMwQBeEC3CKcDisESNGqKWlpcdgiQkEAgoGg6qtrVVmZmb8CwSSAHMuwClWrlyp5ubm8woWSYpGo2pubtaqVaviXBmQPOhcgC6MMSooKFBNTY0uZGg4jqO8vDzt3r2bVWSACBegm/r6euXk5PTq9dnZ2S5WBCQnLosBXTQ1NfXq9Y2NjS5VAiQ3wgXoIj09vVevz8jIcKkSILkRLkAX2dnZys/Pv+B5E8dxlJ+fr6ysrDhVBiQXwgXownEcLVy48Du9dtGiRUzmAycwoQ+cgn0uQO/RuQCnyMzMVFlZmRzHUSDQ8xCJ7dB/9dVXCRagC8IFOINQKKTy8nIFg0E5jnPa5a7Y14LBoDZu3KiZM2d6VCngT4QLcBahUEi1tbUqKSlRXl5et8fy8vJUUlKiffv2ESzAGTDnApwHY4zeeecdTZ8+XW+99ZZuvfVWJu+BHtC5AOfBcZyTcyqZmZkEC3AOhAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWEC3AObW1tqqmp0ccffyxJ+vzzz3Xo0CFxtwrg7LifC3AWx48f17p167RixQp98sknam9vl+M4MsYoMzNTM2fO1KJFizRx4kSO4AdOQecCnEFdXZ3mzp2refPmKS0tTcuXL1dVVZW2b9+uTZs2aenSpdq5c6dCoZBWrFihjo4Or0sGfKWf1wUAfhMOh3Xfffdp69ateuGFF3TXXXepo6NDS5YsUUNDg0aPHq0lS5Zozpw5eu655/TII4+oo6NDDz30EB0McAKXxYAujDFasmSJnn/+eb388ssqLi6W4zhqaGjQtddeq9raWhUVFentt99WamqqIpGInnnmGT3xxBNav369pkyZ4vWPAPgCl8WALr766iu9+OKLmj9/vmbNmnXOTiQlJUUPPvigpk6dqqeeekqRSCRBlQL+RrgAXWzYsEEdHR164IEHFAgEZIw566qw2GMDBgzQggULtHnzZu3duzfBFQP+xJwLcIIxRlVVVRo9erQikYgWL16saDQqya4cC4fDkqSamhotXrxYgYD9t9mwYcM0Z84cpaam6osvvlBubq5XPwLgG4QLcEI0GtXBgwd12WWX6cCBAyopKTnjZa5vvvlGpaWlJ/973Lhxmj9/vgYPHqwDBw4ksmTAtwgX4ATHcZSamqrW1lY5jqP+/fufDBdjjNrb27s9Lyb2vPb29m5fBy5mhAtwQiAQUH5+vjZt2qQxY8bovffeOznfcuzYMd1zzz2qq6vTtddeqxUrViglJUWSFAwGFQ6HdXj/fuUtWybt2CHddJN0883SpZd6+SMBniFcgC5uu+02rV69Wjt27ND06dNPrhZraGhQ//79JUnp6ekqLCw82aUYY/Tss89qUP/+GjdqlLRmjfT00/YbjhljQyb2p6BAYi8MLgKEC9DFjBkzNGrUKP3lL3/RlClTlJGR0ePzjTHat2+fli9frv9bsEBZjz8uGSPt3Stt3ixt2SJVVkovvWS/npNjQybW2Vx/vXQitIC+hE2UwClee+013XvvvZo3b56efPJJBYNBHTly5LRNlP369dPhw4c1f/587dmzRxUVFRoyZMiZv2k4LP3nPzZsNm+W3n9fammRBg6Ubrihs7O56SZp8OCE/rxAPBAuwCkikYhKSkr02GOP6fbbb9ejjz6q3NxcrV69Wo2NjRoxYoTuvPNOffjhh1q6dKnq6uq0du1aFRYWnv+btLdL27d3725iK83GjpWKimzQFBVJublcSkPSIVyAM4hEInrllVf06KOPqr6+XtOmTVNhYaEyMzN18OBBvf/++/r0009VVFSkZcuWacyYMb17Q2OkmhobMrHuZscO+9iwYd3nba67TmJVGnyOcAF6cPDgQa1bt04bN25UdXW1WltbNXjwYE2aNEmzZ8/WLbfcogEDBsTnzRsapKoqGziVldIHH0itrVJamjRlSudltGnTpEGD4lMD8B0RLsB5iEajamtrUyQSUWpq6smVYwnV2ipt3dq9u6mvt5fMJkzo3t1ceSWX0uApwgVIVsZIu3d3djaVldJ//2sfu/zyznmbm2+WJk6U+rE4FIlDuAB9yaFDtquJdTYffWQXD6SnS1OndnY2U6ZI51hmDfQG4QL0ZceP24CJrUrbssXO5QQCdmFAbN6mqMh2O4BLCBfgYhKNSrt2dZ+3qa62j40c2X3eZtw46cQRN8CFIlyAi92BA517bSor7aKBjg67Am3atM55mxtvlC65xOtqkSQIFwDdNTfbZc+x7mbLFunoUbsgYNKk7gsFhg3zulr4FOECoGfRqPTFF91XpX39tX0sP7/7vM3VV9v5HFz0CBcAF27fvs7OprLSHmUTidhz0WJdzc0323PTBg70ulp4gHAB0HtNTfYwzlhnU1Vlv5aaKk2e3P1gzpwcr6tFAhAuANzX0SF99ln3VWm1tfax0aO7z9uMHs1pAn0Q4QIgMfbu7T5v8+mn9pSBSy/tPm9TWCjF67w2JAzhAsAbR4923uOmstJ+3txsg+WGGzq7m5tukrKyvK4WF4hwAeAP7e3SJ590727277ePjR3b2d3cfrs0dKi3teKcCBcA/hT71XTqx0CAOZokwDGpAPwpFiBdg4R/CycNwgWA61pbWlSzfr2ibW1el3LeHGOUPXmyhl5zjdel9AmECwDXHT96VPtff12jfv5zac8eeyhmerrXZZ3Zl19K3/uejhw9qgP//jfh4hLCBUBcDBw6VFesWSPnn/+U1q6VbrnF65JO19oqLV0q1dQo9Q9/UB1zOa7hECAA8REISFdcIX37rbRpkz/nS/bvl3butEfXcD8bVxEuAOLDcWy3EgjYZcXt7V5XdLpt26Rw2C515lgaVxEuAOJn4kS7AXLXrs49K35hjPTuu/bj979vbykA1xAuAOJnyBBp/HjbHXz8sdfVdNfaak8HSEnx53xQkiNcAMRPv372F7cx0ttv+2veZe9eafduG4ATJnhdTZ9DuACIrx/8wIZMZaV0/LjX1VjG2FsENDVJ110nZWd7XVGfQ7gAiK/x4+3tkKurpf/9z+tqOr35pg2ZH/6Qu2fGAX+jAOJr8GB7w7Bvv7Xdix8ujTU22lOYBwywk/nsb3Ed4QIgvgIBacYM+/kbb/gjXHbutHMuV10lXX2119X0SYQLgPhyHNsdpKXZeY6GBm/riS0uaGuzx/j79ViaJEe4AIi//Hx7O+P9+6Xt272tpaPDdlCOI82c6W0tfRjhAiD+Bg60E+eRiFRR4e2lsdpae4vlrCxp6lTmW+KEcAEQf44jhUJ2SfJbb0ktLd7UYYy0ebPd1FlYKA0f7k0dFwHCBUBiTJokXXaZPeL+yy+9qcEYacMG+3HWLLs7H3FBuABIjMGDpaIi27V4tWqsvt4e+XLJJdL06VwSiyPCBUBiOI70k5/Yj+XliT8l2Ri7t+XAAXsK8qhRiX3/iwzhAiAxHMcu/R0yxK4Yq6lJ7PsbI61bJ0Wj0h132A2UiBvCBUDiDB1qA6axMfGrxo4ckd55x65cu+MOLonFGeECIHECAenOO+0v9tdeS9ylMWOkqiq7DHnsWPsHcUW4AEgcx7GnJA8ZIm3dao+8TwRjpH/8w+6z+fGPpWAwMe97ESNcACTW8OE2YJqapPXrE3NprK7O7q9JS5N++lMuiSUA4QIgsRxHuvtue4msrCz+GypjZ4nt32/32nBQZUIQLgASy3Hs3SlHjpQ+/9xeHotn99LRIa1da9/j7rul/v3j9144iXABkHhZWXbPS1ubtGZNfMNl92575Et2tvSjH3FJLEEIFwCJ5zjSPffYifUNG6SDB+PzPsZIf/+7dOyYvafMFVfE531wGsIFgDcmTpRuvFH65pv4TeyHw9Lf/mYPzLzvPm5nnED8TQPwRv/+0ty59vOXXnJ/Yt8Y6fXXpa++ksaPt5s3uSSWMIQLAG84jt0pn5srbdsmbdrkbvfS2ir99a/2e86daw+rRMIQLgC8k50t3Xuv3an//PN2ZZcbjLGnH1dVSSNGSHfdRdeSYIQLAO84jg2XIUPsJscPPnCne+nokJYvt93LL38pDRvW+++JC0K4APDWyJHSnDl2zqWkpPfdS6xreeMNG1r330/X4gHCBYC3AgHp17+WcnKkjRvtnpTedC9tbdLTT9uw+tWvpKuucq1UnD/CBYD38vKkefOk48elJ5747ivHjLE3InvzTenyy6Xf/Iblxx7hbx2A9wIB6be/tSGzadN337V/+LD0+OP20tqiRdKVV7pfK84L4QLAH4YPl5YssZ//8Y9SdfWFBUwkIi1bZs8ru/56af585lo8RLgA8AfHkX7xC7v3pbZW+t3vpObm83utMdK//iU995zdz/LnP0uDBsW3XvSIcAHgHwMHSk89ZVeQlZdLf/rTue9WaYz02WfSwoV2rubhh+39YuhaPEW4APAPx5EKCqTSUtuBlJTYlV9tbWd+vjHSjh12r8yePfbU48WLpZSUhJaN0xEuAPzFcaTiYtvBpKTYCfqHH7YHXMbmYIyxgbN+vfSzn9l5lqIie1ksPd3T8mH187oAADhNIGCXJvfrJ/3+9/aMsNdft0EycaI97biiQnr3XRsys2ZJL7xgd+JzOcwXCBcA/pSSYg+cvOYaaelSu+u+pKT7c4YOtUuYFy6UMjIIFh8hXADElenNbnvHkaZOtTcUq6qymyP37LHzMZMnS6GQ3csSC5V43tESF8Qxvfo/DwCnO1ZXpy0LFmjQhAnuf/PYryyXu5Tjhw7p0uuv14T773f1+16sCBcArotGImqsq5OJRr0u5YIMHDRIA1kQ4ArCBQDgOpYiAwBcx4Q+gOTR9UILK8N8jc4FQPLYts0uUd62zetKcA6ECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYRLEjDGqL6+Xl9//bXq6+t7d09yIEkZY3TkyBEZyX5kHPga4eJj4XBYpaWlKigoUE5OjnJzc5WTk6OCggKVlpYqHA57XSIQd13HwfQZM2SM0fQZMxgHPsdtjn2qoqJCs2fPVnNzsyR1+1eac+ImSWlpaSorK1MoFPKkRiDeTh0H1xmjjyRNlrSdceBrdC4+VFFRoeLiYrW0tMgYc1r7H/taS0uLiouLVVFR4VGlQPwwDpIbnYvPhMNhjRgxQi0tLYpGo+d8fiAQUDAYVG1trTIzM+NfIJAAZxsHk6STnUvXe1EyDvyHzsVnVq5cqebm5vMKFkmKRqNqbm7WqlWr4lwZkDiMg+RH5+IjxhgVFBSopqbmglbCOI6jvLw87d69++R8DJCsehoHZ+tcJMaB39C5+Mjhw4dVXV19wUssjTGqrq5WQ0NDnCoDEodx0DcQLj7S1NTUq9c3Nja6VAngHcZB30C4+Eh6enqvXp+RkeFSJYB3ehoHu2Qvie3q4fWMA38gXHwkOztb+fn5F3y92HEc5efnKysrK06VAYnT0zhokZ1raTnD6xgH/kK4+IjjOFq4cOF3eu2iRYuYxESfwDjoG1gt5jPscwEYB30BnYvPZGZmqqysTI7jKBDo+X9PIBCQ4zh69dVXGVDoUxgHyY9w8aFQKKTy8nIFg0E5jnNamx/7WjAY1MaNGzVz5kyPKgXih3GQ3AgXnwqFQqqtrVVJSYny8vK6PZaXl6eSkhLt27ePAYU+jXGQvJhzSQLGGDU0NKixsVEZGRnKyspi0hIXHcZBciFcAACu47IYAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1/w+t/uNoMzLGkwAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 500x200 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.get_act(x)\n",
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 97,
|
|
"id": "74a1fb02",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.2\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAACuCAYAAAD6ZEDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUf0lEQVR4nO3de3TU5Z3H8c8zucDkQm4MhJAEyNWigKQiKqmeFiQqbrcunp61rqsutF23hf1n3T+gx+26a/fU+kdSWkVPz+kqbe1uF46ughutWBGCeEFURAQSIDMhF3KZXEhCMjPP/hGYEkBE+M0l4f36x0MyM79v5Dz58H2e5/f8jLXWCgAAB7liXQAAYPwhXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjkuMdQHAWGCtVUdHh/r6+pSWlqacnBwZY2JdFhC36FyAC/D7/aqpqVFpaak8Ho9mzZolj8ej0tJS1dTUyO/3x7pEIC4ZHhYGnF9tba2WL1+u/v5+SSPdy2mnu5aUlBRt3LhRVVVVMakRiFeEC3AetbW1WrZsmay1CoVCn/s6l8slY4w2b95MwABnIFyAs/j9fuXn52tgYOCCwXKay+WS2+2Wz+dTZmZm5AsExgDWXICzPPvss+rv77+oYJGkUCik/v5+PffccxGuDBg76FyAM1hrVVpaqoaGBn2ZoWGMUVFRkQ4ePMguMkCECzBKe3u7PB7PZb0/JyfHwYqAsYlpMeAMfX19l/X+3t5ehyoBxjbCBThDWlraZb0/PT3doUqAsY1wAc6Qk5Oj4uLiL71uYoxRcXGxsrOzI1QZMLYQLsAZjDFatWrVJb139erVLOYDp7CgD5yF+1yAy0fnApwlMzNTGzdulDFGLteFh8jpO/Q3bdpEsABnIFyA86iqqtLmzZvldrtljDlnuuv019xut7Zs2aKlS5fGqFIgPhEuwOeoqqqSz+dTdXW1ioqKRn2vqKhI1dXVampqIliA82DNBbgI1lq98cYbWrx4sV5//XV9/etfZ/EeuAA6F+AiGGPCayqZmZkEC/AFCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIFwCA4wgXAIDjCBcAgOMIF+ALDA0NqaGhQe+//74kae/evTp+/Lh4WgXw+XieC/A5BgcH9eKLL2r9+vX68MMPNTw8LGOMrLXKzMzU0qVLtXr1as2dO5cj+IGz0LkA59HW1qYHHnhAK1asUEpKitatW6edO3dqz5492rZtm9auXatPP/1UVVVVWr9+vQKBQKxLBuJKYqwLAOKN3+/X/fffr927d+vpp5/W3XffrUAgoDVr1qizs1NlZWVas2aN7r33Xj355JP60Y9+pEAgoB/+8Id0MMApTIsBZ7DWas2aNXrqqaf0m9/8RsuWLZMxRp2dnZo3b558Pp8qKyu1detWJSUlKRgM6he/+IUee+wxvfTSS1q4cGGsfwQgLjAtBpzh0KFD+tWvfqWVK1fq9ttv/8JOJCEhQd///vd1ww036PHHH1cwGIxSpUB8I1yAM7z88ssKBAL67ne/K5fLJWvt5+4KO/29CRMm6KGHHtL27dvV2NgY5YqB+MSaC3CKtVY7d+5UWVmZgsGgHn74YYVCIUkjO8f8fr8kqaGhQQ8//LBcrpF/m+Xm5uree+9VUlKSPvnkE82aNStWPwIQNwgX4JRQKKTW1lbl5eWppaVF1dXV553mOnbsmGpqasJ/vvrqq7Vy5UplZWWppaUlmiUDcYtwAU4xxigpKUknT56UMUbJycnhcLHWanh4eNTrTjv9uuHh4VFfB65khAtwisvlUnFxsbZt26by8nK99dZb4fWWnp4e3XPPPWpra9O8efO0fv16JSQkSJLcbrf8fr/amtv06ROf6rV9r6ngpgIVLipUyuSUWP5IQMwQLsAZbr31Vm3YsEH79u3T4sWLw7vFOjs7lZycLElKS0tTRUVFuEux1uqXv/ylUpNTVV5Srr2/26u6n9VJkiaXT1bBogIVLBoJm+zSbO6FwRWBcAHOsGTJEpWUlOinP/2pFi5cqPT09Au+3lqrpqYmrVu3Tg8+9KAefPRBWWvV3ditxu2N8tZ55d3h1Z5f75G1Vqme1JGwuWkkcPK+mqeE5IQo/XRA9HATJXCWF154Qffdd59WrFihn/zkJ3K73erq6jrnJsrExER1dHRo5cqVOnr0qGprazVlypTzfuagf1C+t33y1nnVuL1RTbuaNDwwrMSJiZq+YHq4uym4qUDuLHeUf2LAeYQLcJZgMKjq6mr9+Mc/1m233aZHHnlEs2bN0oYNG9Tb26v8/Hzdddddevfdd7V27Vq1tbXp+eefV0VFxcVfYziolj0tatzeKF+dT407GtXX0idJ8sz2qLCycGTdprJQmbMymUrDmEO4AOcRDAb1hz/8QY888oja29t14403qqKiQpmZmWptbdWuXbv00UcfqbKyUk888YTKy8sv63rWWnU1dMm7wxvubo7vOy5JSstNU+GiwnB3k3ttrhKSmEpDfCNcgAtobW3Viy++qC1btqi+vl4nT55UVlaW5s+fr+XLl+vmm2/WhAkTInLtgc4BeXeOrNl4d3jV9E6TAicDSkpJUv7C/PA0Wv6N+ZqYMTEiNQCXinABLkIoFNLQ0JCCwaCSkpLCO8eiKXAyoObdzaO6m/72fhljNGXOlPCOtIJFBcoozGAqDTFFuABjlLVWnQc71bijMdzdtH/WLkmaNH2SCisLlX9TvgoXFWrq3KlyJXKUIKKHcAHGkRPHT4xsf67zyrvdq2PvHVNwOKjktGTl35Af7m6mL5yuCemRmc4DJMIFGNcCgwEde+/Yn++5qfNqoHNAxmWUe21ueN2msLJQk6ZPinW5GEcIF+AKYkNW7fvb1bjj1Bbo7Y3qrO+UJGXOyBx1moDnao9cCUyl4dIQLsAVrq+lb2SDwKm1m+bdzQoFQpqYMVH5N+aHTxOYfv10JadGfyMDxibCBcAow/3DanqnKdzdeOu8GuwelCvRpWnzp6mgsiB8MGdablqsy0WcIlwAXJANWbV90hbekda4o1H+I35JUnZx9qh1m8lXTZZxsQUahAuAS9DT1BO+38a7w6uWPS0KBUNyZ7nD02gFiwo0fcF0JU7kfNwrEeEC4LIN9Q3Jt8sX7m68O70a6htSQlKC8q7LG3UwZ6onNdblIgoIFwCOCwVCav24ddRpAj2+HklSTllO+GDOgkUFyinL4TSBcYhwARAV3Y3do04TaP2oVdZapUxOUeGiU6cJVBZqWsU0JU5gKm2sI1wAxMRg95+fcePd4ZXvbZ+G+4eVOCFReQvy/tzd3FQgdzbPuBlrCBcAcSE4HFTrh62jupve5l5JI8+4Ob1mU3JbidKmsgU63hEuAOJS+FeTHf1n4zKs0YwBTGwCiEvhAAn/x4h/C48dhAsAxw0ODOrtl95WYCgQ61IunpVKrivRzK/MjHUl4wLhAsBxJ7pP6KNXP9JNf3WT/Ef9ypyRqeS0+DyXrONAh5InJau7u1t739xLuDiEcAEQERlTM3T0d0d14H8PaPnzy1V2c1msSzpH4GRAG9ZuUFdDlyr+pUInzIlYlzRucJ42gIgwLqOMggwNnxjW0W1H43K9pK+5T8c/Pa5QMMTzbBxGuACIDCPNuHmGjMvIu8Or0HAo1hWdo/mDZg36B+WZ7VGKJyXW5YwrhAuAiJk6d6rc2W61728P37MSL6y1OvKnI5KVCr9WKFcivw6dxP9NABGTOiVVnms8GvQPqvn95liXM0rwZFC+Op9MgtGMm2fEupxxh3ABEDGuRNfIL24rHd56OK7WXbobu9VxsEOpU1I1dc7UWJcz7hAuACJq5i0z5Up0ybvDq8BgfNz3Yq2Vb5dPQ31Dyr02V+4czi5zGuECIKKmXDNFqbmp6qzvlP+wP9blhDX8sUGy0sxvzOTpmRFAuACIqIlZE5V3XZ6GTwyrcUdjXEyNDfUOqentJiVMSNDMr83krLIIIFwARJRxGRUtKZIkNbzWED6IMpaOf3pc3Y3dypyZqZyrcmJdzrhEuACIKGOMZnxthpJSktS0q0kDnQMxrcdaq8NbDys4FFTBooK4PZZmrCNcAERcVnGWcspy1Nvcq5Y9LTGtJRQIjXRQRipZWhLTWsYzwgVAxCVOTNTMb8yUDVodqj0U03WXHl+PWj9qlTvbrek3TGe9JUIIFwARZ4xRSVWJTKLR4dcPKzAQmy3J1lo1bm/UoH9Q0yqmKX1aekzquBIQLgCiInd+riblTVLHgQ51HOiITRFWOvDygZFnt9xeIpNA1xIphAuAqHBnuVVQWaDAQED1r9XHZGqsv71fvjqfklKTVLS4iCmxCCJcAESHkcq/WS4Z6eDmg1E/JdlaK9/bPvW19Mkz26PskuyoXv9KQ7gAiApjjAoXFSp1Sqpa9rSoq6ErugVYaf+L+2VDViV3lChhQkJ0r3+FIVwARE3q1FQVLCrQUO9Q1HeNDXQN6MgbR5Q4MVFld5QxJRZhhAuAqDEuo6vuukoy0mcvfBa1qTFrrXw7ferx9cgz2yPPbE9UrnslI1wARI0xRjNvmanUKalq3t2sjoNR2jVmpX3/s082aFX2F2VKdCdG57pXMMIFQFSlT0vXjFtmaKhvSJ+99FlUpsZOtJ3Q4dcPKyklSeV/Wc6UWBQQLgCiy0hXf/tqGZfR/o37I35DpbVWDVsb1Nvcq9z5uZp81eSIXg8jCBcAUWXMyGOFM2ZkqG1vm5p3N0f0pORQIKS9z++VrDT727OVkMwusWggXABEnTvbrbJvlik4FNTHv/s4olNjnQc75d3ulTvHrfI7mRKLFsIFQNQZYzTnnjlKdCfqwMsH1NfaF5HrWGu197/36mTPSRUtKdKkgkkRuQ7ORbgAiImpc6dq+vXT1XusN2IL+4P+Qe37r31yJbo07/55PM44iggXADGRkJygeQ/MkyR9+OsPHV/Yt9aq/tV6dR7q1JRrpqhwUSFTYlFEuACICWOMyu4oU9asLDV/0Kyj24462r0ETwb1/jPvy1qreQ/MU1JqkmOfjS9GuACIGXeOW3Pum6PQcEjvPvWuQgFn7ti31spb55Vvp0+T8idp9t2z6VqijHABEDPGGF1737VKnZKqw68fVtM7TY50L6FASO+se0fBk0HN+Zs5SstNc6BafBmEC4CYypiRoWvuvUaBgYB2Ve+67O7ldNdS/1q9UqekquLvKuhaYoBwARBTxmW04O8XKMWTooNbDqpxe+NldS/BoaDqflanwEBAc/92rjJnZjpXLC4a4QIg5rKKsjR/xXwFBgN667G3LnnnmLVWBzcfVMMfG5Q+PV0L/mEB249jhHABEHPGZXT9D65XVlGWjm47esl37Q90DOjNR99UKBDSwtULlVGYEYFqcTEIFwBxIW1amirXVEqStv3bNnXVd32pgAkFQ9rxxA617W1T3lfzVLGStZZYIlwAxAVjjK7562tUekepenw9evWfXtVw//BFvddaq0P/d0jvPfmeklOTtfg/FmtCxoQIV4wLIVwAxI3EiYm69fFblTEjQwc3H9S2f9+m4HDwgu+x1qrt4za9suoVBQYCWviPCzXjlhl0LTFGuACIG8YYZZdm67aa25SUmqRd1btU97M6BYfOHzDWWh3fd1yb7tuk7qPdKruzTIseXiRXAr/aYo2/AQBxxRijsmVlWvL4EpkEozcffVOv/OMr6j3WG16DsdYqOBTUgZcO6Pff+r2O7z2uwspC3fHkHUpK45iXeMCDpAHEHeMyqlhRoYTEBL32z69p9zO71fBqg8q/Va6pc6dq0D+o+tp6HfnTEQWHgiq5vUR3Pn2n0nLTmA6LE4QLgLjkSnDp2geu1eSvTNbWtVvlrfNqV/WuUa9JnZqqBT9YoIWrFio5PZlgiSOEC4CIuqyzwoyUf0O+vvPyd+Td6VXDHxvUfbRbSalJyrsuTyVVJSP3shgHrgVHES4AHGdcRr5PfPrtv/7W2Q+eINlSKyOjnuYe7f/P/Y59dM/xHhV9tcixz7vSGUvUA3BYMBhUV1uXbGhs/XpJzUhVSlpKrMsYFwgXAIDj2IoMAHAcay4AxowzJ1rYGRbf6FwAjBktH7To0YRH1fJBS6xLwRcgXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcAACOI1wAAI4jXAAAjiNcxgBrrdrb23XkyBG1t7fznHBckay18nf5JUn+Lj/jIM4RLnHM7/erpqZGpaWl8ng8mjVrljwej0pLS1VTUyO/3x/rEoGIO3McLF6yWNZaLV6ymHEQ53jMcZyqra3V8uXL1d/fL+n8D0lKSUnRxo0bVVVVFZMagUg7exzk2lx9T9/TM3pGLWbkmS6Mg/hE5xKHamtrtWzZMg0MDMhae077f/prAwMDWrZsmWpra2NUKRA5jIOxjc4lzvj9fuXn52tgYEChUOgLX+9yueR2u+Xz+ZSZmRn5AoEo+LxxME3Twp1Ls5rDX2ccxB86lzjz7LPPqr+//6KCRZJCoZD6+/v13HPPRbgyIHoYB2Mf4RJHrLVat27dJb335z//ObtnMC4wDsYHwiWOdHR0qL6+/ksPDmut6uvr1dnZGaHKgOhhHIwPhEsc6evru6z39/b2OlQJEDuMg/GBcIkjaWlpl/X+9PR0hyoBYudC46Bd7XpGz6hd7Z/7GsZBfCBc4khOTo6Ki4vD97FcLGOMiouLlZ2dHaHKgOi50DgY1rCa1axhDZ/zPcZBfCFc4ogxRqtWrbqk965evfpLhxIQjxgH4wP3ucQZ7nMBGAfjAZ1LnMnMzNTGjRtljJHLdeG/HpfLJWOMNm3axIDCuMI4GPsIlzhUVVWlzZs3y+12yxhzTpt/+mtut1tbtmzR0qVLY1QpEDmMg7GNcIlTVVVV8vl8qq6uVlFR0ajvFRUVqbq6Wk1NTQwojGuMg7GLNZcxwFqrzs5O9fb2Kj09XdnZ2Sxa4orDOBhbCBcAgOOYFgMAOI5wAQA4jnABADiOcAEAOI5wAQA4jnABADiOcAEAOI5wAQA4jnABADiOcAEAOI5wAQA4jnABADiOcAEAOO7/AX8dLuaR/xUgAAAAAElFTkSuQmCC\n",
|
|
"text/plain": [
|
|
"<Figure size 500x200 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.perturb(mode='best', mag=0.1)\n",
|
|
"model.get_act(x)\n",
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 98,
|
|
"id": "fd0d2987",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| loss: 5.03e-05 |: 100%|███████████████████████████████████████████| 20/20 [02:59<00:00, 8.99s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"from kan.utils import batch_jacobian, create_dataset_from_data\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"torch.use_deterministic_algorithms(True)\n",
|
|
"\n",
|
|
"def closure():\n",
|
|
" \n",
|
|
" global loss\n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" jacobian = batch_jacobian(model, x, create_graph=True)\n",
|
|
" hessian = batch_hessian(model, x, create_graph=True)\n",
|
|
" Lqdqd = hessian[:,d:,d:]\n",
|
|
" Lq = jacobian[:,:d]\n",
|
|
" Lqqd = hessian[:,d:,:d]\n",
|
|
"\n",
|
|
" Lqqd_qd_prod = torch.einsum('ijk,ik->ij', Lqqd, qd)\n",
|
|
"\n",
|
|
" qdd_pred = torch.einsum('ijk,ik->ij', torch.linalg.inv(Lqdqd), Lq - Lqqd_qd_prod)\n",
|
|
" loss = torch.mean((qdd - qdd_pred)**2)\n",
|
|
"\n",
|
|
" loss.backward()\n",
|
|
" return loss\n",
|
|
"\n",
|
|
"steps = 20\n",
|
|
"log = 1\n",
|
|
"optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
|
|
"#optimizer = torch.optim.Adam(params, lr=1e-2)\n",
|
|
"pbar = tqdm(range(steps), desc='description', ncols=100)\n",
|
|
"\n",
|
|
"\n",
|
|
"for _ in pbar:\n",
|
|
" \n",
|
|
" # update grid\n",
|
|
" if _ < 5 and _ % 20 == 0:\n",
|
|
" model.update_grid(x)\n",
|
|
" \n",
|
|
" optimizer.step(closure)\n",
|
|
" \n",
|
|
" if _ % log == 0:\n",
|
|
" pbar.set_description(\"| loss: %.2e |\" % loss.cpu().detach().numpy())\n",
|
|
" \n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 99,
|
|
"id": "782f818f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 500x200 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 100,
|
|
"id": "ad876b9d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_24271/2849209031.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mauto_symbolic\u001b[0;34m(self, a_range, b_range, lib, verbose)\u001b[0m\n\u001b[1;32m 1402\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'fixing ({l},{i},{j}) with 0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1403\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1404\u001b[0;31m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msuggest_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlib\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1405\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1406\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36msuggest_symbolic\u001b[0;34m(self, l, i, j, a_range, b_range, lib, topk, verbose, r2_loss_fun, c_loss_fun, weight_simple)\u001b[0m\n\u001b[1;32m 1332\u001b[0m \u001b[0;31m# getting r2 and complexities\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1333\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msymbolic_lib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1334\u001b[0;31m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1335\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1e8\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# zero function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1336\u001b[0m \u001b[0mr2s\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1e8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mfix_symbolic\u001b[0;34m(self, l, i, j, fun_name, fit_params_bool, a_range, b_range, verbose, random, log_history)\u001b[0m\n\u001b[1;32m 488\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspline_postacts\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0;31m#y = self.postacts[l][:, j, i]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 490\u001b[0;31m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msymbolic_fun\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfix_symbolic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 491\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 492\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1e8\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/Symbolic_KANLayer.py\u001b[0m in \u001b[0;36mfix_symbolic\u001b[0;34m(self, i, j, fun_name, x, y, random, a_range, b_range, verbose)\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 230\u001b[0m \u001b[0;31m#initialize from x & y and fun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 231\u001b[0;31m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfit_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0ma_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 232\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuns_avoid_singularity\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun_avoid_singularity\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/utils.py\u001b[0m in \u001b[0;36mfit_params\u001b[0;34m(x, y, fun, a_range, b_range, grid_number, iteration, verbose, device)\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mb_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_range\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgrid_number\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[0ma_grid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_grid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmeshgrid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindexing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'ij'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m \u001b[0mpost_fun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 238\u001b[0m \u001b[0mx_mean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpost_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0my_mean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.auto_symbolic()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 80,
|
|
"id": "428571e6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.5\n",
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 x 1.000000 -16.565706 1 1 -2.513141\n",
|
|
"1 cos 1.000000 -16.599499 2 2 -1.719900\n",
|
|
"2 sin 1.000000 -16.599499 2 2 -1.719900\n",
|
|
"3 exp 0.999997 -16.268112 2 2 -1.653622\n",
|
|
"4 x^0.5 0.999977 -14.896568 2 2 -1.379314\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('x',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 1,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.9999996907837526,\n",
|
|
" 1)"
|
|
]
|
|
},
|
|
"execution_count": 80,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.unfix_symbolic(0,0,0)\n",
|
|
"model.suggest_symbolic(0,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "0dea7189",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.4\n",
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 0 0.000000 0.000014 0 0 0.000003\n",
|
|
"1 cos 0.969503 -5.034727 2 2 0.593055\n",
|
|
"2 x^2 0.969092 -5.015413 2 2 0.596917\n",
|
|
"3 sin 0.965249 -4.846400 2 2 0.630720\n",
|
|
"4 x 0.000392 -0.000551 1 1 0.799890\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('0',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 0,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.0,\n",
|
|
" 0)"
|
|
]
|
|
},
|
|
"execution_count": 47,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.unfix_symbolic(0,1,0)\n",
|
|
"model.suggest_symbolic(0,1,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 101,
|
|
"id": "ef60542b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 300x300 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"x, y = model.get_fun(0,1,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 102,
|
|
"id": "8f77c061",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n",
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.utils import create_dataset_from_data\n",
|
|
"\n",
|
|
"dataset2 = create_dataset_from_data(x[:,None], y[:,None])\n",
|
|
"model2 = KAN(width=[1,1,1])\n",
|
|
"model2.fix_symbolic(0,0,0,'x^2',fit_params_bool=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 103,
|
|
"id": "1c62302d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAf/ElEQVR4nO3de3DU9b3/8ddnc93cCIkRaxEkMaiAWOWmAioVEiQ9g0LH29GWDnUocqmMU88ZqUcOqBRHJUGOhxZrBWobe4zKIBzglKtcLBQELT8ugRytCaAI2TSXJQnZz++PL8kJEGKUb7Kb3edjhmFnv9ndd2ZIXrw/t6+x1loBAOAiT7ALAACEH8IFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4DrCBQDgOsIFAOA6wgUA4LroYBcAdAbWWp08eVJVVVVKSkpSenq6jDHBLgsIWXQuQCt8Pp8KCgqUnZ2tjIwM9erVSxkZGcrOzlZBQYF8Pl+wSwRCkuFOlEDL1qxZo/Hjx6umpkaS0700auxaEhISVFRUpNzc3KDUCIQqwgVowZo1a5SXlydrrQKBwEW/zuPxyBijlStXEjBAM4QLcB6fz6fu3bvL7/e3GiyNPB6PvF6vSktLlZqa2v4FAp0Acy7AeZYsWaKampo2BYskBQIB1dTUaOnSpe1cGdB50LkAzVhrlZ2drZKSEn2THw1jjDIzM1VcXMwqMkCEC3COr776ShkZGZf0+vT0dBcrAjonhsWAZqqqqi7p9ZWVlS5VAnRuhAvQTFJS0iW9Pjk52aVKgM6NcAGaSU9PV1ZW1jeeNzHGKCsrS2lpae1UGdC5EC5AM8YYTZs27Vu9dvr06UzmA2cxoQ+ch30uwKWjcwHOk5qaqqKiIhlj5PG0/iPSuEP/nXfeIViAZggXoAW5ublauXKlvF6vjDEXDHc1Puf1erVq1Srl5OQEqVIgNBEuwEXk5uaqtLRU+fn5yszMPOdaZmam8vPzVVZWRrAALWDOBWgDa602bNigu+66S+vWrdOIESOYvAdaQecCtIExpmlOJTU1lWABvgbhAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECfI36+nqVlZVp//79kqQjR47o1KlTCgQCQa4MCF3c5hi4CJ/Pp6KiIr355pvat2+fKisrVVdXp/j4eGVkZGj48OGaOHGihg4dqujo6GCXC4QUwgVowfbt2zVjxgx9/PHHGjRokPLy8tS/f38lJSXJ5/Np165dWrFihQ4fPqz7779fzz77rDIyMoJdNhAyCBfgPGvXrtWECROUlJSkuXPnasyYMaqrq1NhYaFqa2uVkpKiBx54QPX19SosLNSsWbPUt29fLVu2TN26dQt2+UBIIFyAZg4dOqTRo0crMTFRhYWF6tOnj4wxKikp0c0336yKigr16tVLu3btUteuXWWt1ZYtW/TQQw/pzjvv1Guvvaa4uLhgfxtA0DGhD5zV0NCg559/XuXl5Vq4cGFTsLTGGKNhw4bphRde0PLly7V69eoOqhYIbYQLcNbhw4e1YsUKjRs3TsOGDfvaYGlkjNE999yjW265RYsXL9aZM2fauVIg9LHEBThr27Ztqqqq0vjx4/Xpp5+qurq66VppaakaGhokSXV1ddq3b59SUlKarl955ZUaN26cZs2apePHj6t79+4dXj8QSggX4KwDBw4oISFBmZmZmjRpkrZu3dp0zVqr2tpaSdLRo0c1atSopmvGGL300ku64YYbVFNTo6NHjxIuiHiEC3CW3+9XdHS04uLiVFtbq9OnT7f4ddbaC66dOXNGXq/3nBACIhnhApx1+eWXy+/3y+fzaciQIUpMTGy65vf7tW3btqYQue2225o2Thpj1KNHD3355ZfyeDzq2rVrsL4FIGQQLsBZAwYMUH19vXbs2KF58+adc62kpESDBg1SRUWFunXrprfeekupqalN140xeuqpp3TFFVcwJAaI1WJAk8GDByszM1NLlixRdXW1oqKizvnTyBgjj8fT9LzH49GxY8f09ttvKy8vT126dAnidwGEBsIFOCs9PV1Tp07V7t27tWDBgjYvKa6trdWcOXPk9/s1adKkNi9hBsIZw2JAMxMmTNDmzZs1b948JSQkaPLkyYqPj5ckRUdHKzo6uqmLsdaqsrJSzz33nAoLCzV//nxde+21wSwfCBkc/wKc58SJE5oyZYref/995ebmasaMGbr++ut18OBBBQIBxcbG6pprrtGOHTv04osvas+ePZo9e7YmT558zvAZEMkIF6AF1dXVWrx4sRYsWKAvvvhCmZmZys7OVnJyssrLy3Xw4EEdPXpUAwYM0DPPPKM77rhDHg+jzEAjwgVoxfHjx7Vu3Tpt2rRJe/fu1Y4dOzR8+HANHTpUOTk5GjJkiBISEoJdJhByCBegjXbu3KnBgwdr586dGjhwYLDLAUIafTzQRlFRUU3LkAG0jp8SAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOsIFwCA6wgXAIDrCBcAgOu4nwvQRtZaBQIBeTweGWOCXQ4Q0uhcgG+Ae7kAbRMd7AIAt1hrVVxcrJMnTwa7lEvi8XjUr18/JSYmBrsU4FtjWAxhIxAIaMqUKbrqqquUlJR0Se/T0NCgqKiooHQqH3zwgZ5++mn179+/wz8bcAudC8JKXFycJk6cqG7dun2j1wUCARUXF+vdd9/V9u3bderUKaWlpWnIkCEaO3asrrvuug6Za7HWqqqqSvyfD50d4YKIZq1VRUWF8vPztWjRIp04ceKc6++//75efvllPfLII3ryySd1xRVXMJkPtAHhgohlrVVpaakee+wxrV69WpJ044036gc/+IF69Oihzz//XKtXr9aePXu0YMECbdmyRa+++qoGDhxIwABfg3BBRLLWqqysTD/60Y+0efNmdenSRf/6r/+qRx99VKmpqTLGyFqrJ554Qm+99ZZmz56t3bt367777tPrr7+uO++8k4ABWsG6SkScxqGwKVOmaPPmzcrIyNBvf/tbPfHEE+ratWtTaBhjlJKSop/+9Kd67733dOONN+rvf/+7JkyYoO3btzMvArSCcEHEOXPmjObOnatVq1YpJSVFCxcu1NixYxUVFdXi1xtjdPPNN+uPf/yjbrrpJpWWlurRRx/VoUOHCBjgIggXRBRrrVatWqVXX31VHo9HM2fO1L333vu1S46NMerdu7d+97vfKSsrSwcOHND06dNVXl7eQZUDnQvhgohhrdWxY8f0y1/+UjU1NRo7dqwee+yxi3Ys5zPGqF+/flq4cKG6du2qdevWae7cuTpz5kw7Vw50PoQLIkYgEND8+fO1f/9+XXXVVZozZ468Xu83eg9jjEaOHKmZM2cqKipKixYt0sqVKxkeA85DuCAiWGu1a9cuvf766/J4PHryySfVu3fvb7Xiy+PxaNKkSRo7dqxqamo0c+ZMlZWVETBAM4QLIkJdXZ1efPFF+Xw+3XrrrXr44YcvaSmx1+vVnDlz1LNnT+3fv1+/+tWvGB4DmiFcEPastdqyZYtWrVql+Ph4/eIXv1BycvIlvWfjBP/MmTMVExOjZcuWaePGjXQvwFmEC8JebW2tCgoK5Pf7NXLkSI0cOdKVDZDGGD344IPKyclRVVWVZs+erYqKChcqBjo/wgVhzVqrbdu2af369fJ6vZo+fbri4uJce3+v16unn35aaWlp+stf/qLf//73dC+ACBeEuTNnzmjRokXy+/0aMWKEhg4d6uqxLY0bLCdOnNi0Gu2zzz4jYBDxCBeELWut9u7dq7Vr1yo2NlY/+9nPXO1aGkVFRWnq1Km65ppr9Omnn2rhwoUKBAKufw7QmRAuCFvWWr3xxhuqrKzUwIED2/Wwye9+97t6/PHHFRUVpaVLl+qTTz6he0FEI1wQtj7//HMtX75cHo9HP/nJT5SQkNBun2WM0QMPPKABAwbo5MmTevnll1majIhGuCAsWWv13nvv6dixY+rVq5fy8vLa/Yj8Ll266IknnlBsbKyWL1+uDz/8kO4FEYtwQViqqqrSH/7wB1lr9cMf/lCXX355u3+mMUZjxozR7bffrqqqKs2fP191dXXt/rlAKCJcEJY+/PBDffLJJ0pJSdF9993XYTf28nq9mjFjhuLj47V27Vpt3ryZ7gURiXBB2GloaFBhYaFqa2s1bNgw9enTp8M+2xijESNGKCcnR36/X/n5+Tp9+nSHfT4QKggXhJ2ysjKtXbtWHo9HDz74oGJiYjr082NjY/Xzn/9cCQkJ2rBhgzZs2ED3gohDuCDs/O1vf1N1dbV69uyp73//+x1+r3tjjG677Tbdfffd5xw9A0QSwgVhZ+TIkdq4caMWLFjQIRP5LYmJidG0adOUmJioDz74QOvWraN7QUQhXBB2YmNj1b9/f919990d3rU0MsZoyJAhGj16tGpra7Vw4ULmXhBRCBeELWNM0MJFurB7Wb9+Pd0LIgbhArST87uXV155he4FEYNwAdpRTEyMpk6dqsTERG3evJmVY4gYhAvQjowxuuWWW+heEHEIF6CdxcTEaMqUKUpISNCmTZu4HTIiAuECtDNjjG699Vbl5OQ0rRyrra0NdllAuyJcgA7QuHIsISFBGzdupHtB2CNcgA7Q2L2MGjVKp0+fpntB2CNcgA4SGxurqVOnyuv1asOGDZyYjLBGuAAdxBijoUOHNnUvr7zyCt0LwhbhAnSg5t3L+vXrtWnTJroXhCXCBehAxhgNGzZMubm5On36tAoKCtj3grBEuAAdLDY2VtOnT29aOfbnP/+Z7gVhh3ABOljjyrHG+73k5+erpqYm2GUBriJcgCCIiYnRjBkzlJycrK1bt2rlypV0LwgrhAsQBMYYDRw4UPfee6/q6+v18ssvq6KiIthlAa4hXIAgiY6O1uOPP6709HTt3r1bf/rTn4JdEuAawgUIEmOM+vXrp0ceeUQNDQ2aP3++jh07FuyyAFcQLkAQeTweTZ06VT169FBxcbF+/etfM/eCsEC4AEFkjNHVV1+tqVOnKioqSlu3blVlZWWwywIuWXSwCwDcZK1VeXm5YmJigl3KNzJ27Fh5PB7l5eVp2bJlwS4HuGSEC8KGMUY9e/bUK6+8oqioqGCX860sXLhQfr9fXbp0CXYpwCUxlgFehAlrbdjMVxhjZIwJdhnAt0a4AABcx4Q+AMB1zLkAbdS8yWfICmgdnQvQRh999JGioqL00UcfBbsUIOQRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLgAA1xEuAADXES4AANcRLkAbWGtVXl4uSSovLxd3BwdaR7gArfD5fCooKFB2drZGjhwpa61Gjhyp7OxsFRQUyOfzBbtEICQZy3/BgBatWbNG48ePV01NjaSWb3OckJCgoqIi5ebmBqVGIFQRLkAL1qxZo7y8PFlrFQgELvp1Ho9HxhitXLmSgAGaIVyA8/h8PnXv3l1+v7/VYGnk8Xjk9XpVWlqq1NTU9i8Q6ASYcwHOs2TJEtXU1LQpWCQpEAiopqZGS5cubefKgM6DzgVoxlqr7OxslZSUfKMVYcYYZWZmqri4uGk+BohkhAvQzFdffaWMjIxLen16erqLFQGdE8NiQDNVVVWX9PrKykqXKgE6N8IFaCYpKemSXp+cnOxSJUDnRrgAzaSnpysrK+sbz5sYY5SVlaW0tLR2qgzoXAgXoBljjKZNm/atXjt9+nQm84GzmNAHzsM+F+DS0bkA50lNTVVRUZGMMfJ4Wv8Radyh/8477xAsQDOEC9CC3NxcrVy5Ul6vV8aYC4a7Gp/zer1atWqVcnJyglQpEJoIF+AicnNzVVpaqvz8fGVmZp5zLTMzU/n5+SorKyNYgBYw5wK0gbVWGzZs0F133aV169ZpxIgRTN4DraBzAdrAGNM0p5KamkqwAF+DcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAQC4jnABALiOcAEAuI5wAb5GfX29ysrKtH//fknSkSNHdOrUKQUCgSBXBoQubnMMXITP51NRUZHefPNN7du3T5WVlaqrq1N8fLwyMjI0fPhwTZw4UUOHDlV0dHSwywVCCuECtGD79u2aMWOGPv74Yw0aNEh5eXnq37+/kpKS5PP5tGvXLq1YsUKHDx/W/fffr2effVYZGRnBLhsIGYQLcJ61a9dqwoQJSkpK0ty5czVmzBjV1dWpsLBQtbW1SklJ0QMPPKD6+noVFhZq1qxZ6tu3r5YtW6Zu3boFu3wgJBAuQDOHDh3S6NGjlZiYqMLCQvXp00fGGJWUlOjmm29WRUWFevXqpV27dqlr166y1mrLli166KGHdOedd+q1115TXFxcsL8NIOiY0AfOamho0PPPP6/y8nItXLiwKVhaY4zRsGHD9MILL2j58uVavXp1B1ULhDbCBTjr8OHDWrFihcaNG6dhw4Z9bbA0Msbonnvu0S233KLFixfrzJkz7VwpEPpY4gKctW3bNlVVVWn8+PH69NNPVV1d3XSttLRUDQ0NkqS6ujrt27dPKSkpTdevvPJKjRs3TrNmzdLx48fVvXv3Dq8fCCWEC3DWgQMHlJCQoMzMTE2aNElbt25tumatVW1trSTp6NGjGjVqVNM1Y4xeeukl3XDDDaqpqdHRo0cJF0Q8wgU4y+/3Kzo6WnFxcaqtrdXp06db/Dpr7QXXzpw5I6/Xe04IAZGMcAHOuvzyy+X3++Xz+TRkyBAlJiY2XfP7/dq2bVtTiNx2221NGyeNMerRo4e+/PJLeTwede3aNVjfAhAyCBfgrAEDBqi+vl47duzQvHnzzrlWUlKiQYMGqaKiQt26ddNbb72l1NTUpuvGGD311FO64oorGBIDxGoxoMngwYOVmZmpJUuWqLq6WlFRUef8aWSMkcfjaXre4/Ho2LFjevvtt5WXl6cuXboE8bsAQgPhApyVnp6uqVOnavfu3VqwYEGblxTX1tZqzpw58vv9mjRpUpuXMAPhjGExoJkJEyZo8+bNmjdvnhISEjR58mTFx8dLkqKjoxUdHd3UxVhrVVlZqeeee06FhYWaP3++rr322mCWD4QMjn8BznPixAlNmTJF77//vnJzczVjxgxdf/31OnjwoAKBgGJjY3XNNddox44devHFF7Vnzx7Nnj1bkydPPmf4DIhkhAvQgurqai1evFgLFizQF198oczMTGVnZys5OVnl5eU6ePCgjh49qgEDBuiZZ57RHXfcIY+HUWagEeECtOL48eNat26dNm3apJK9e3V6xw51HT5c/YYOVU5OjoYMGaKEhIRglwmEHMIFaKOGnTtlBw+WZ+dOeQYODHY5QEhjQh9oo6ioKMkYieEv4GvxUwIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHeECAHAd4QIAcB3hAgBwHfdzAdrKWikQcI7cNybY1QAhjc4F+Ca4lwvQJtwsDGHDWqsTxcWqPXky2KVcEuPx6LJ+/RSfmBjsUoBvjXBB2LDW6v/Nn6+4q65SnNcr+f1SUlKwy2qb+nrnj9erE1u2qP/TT+s7/fsHuyrgWyNcEFZMXJz63HSTuixa5Pyy/tOfQj9grJWWL5dmz5YdPVo7+vZ1ngM6MQaQEX6sldavl7Ztkz79NNjVtM369dLevdKhQ1JUVLCrAS4Z4YLwc801Us+eUlWVtHNn6HcBfr/04YfO42HDCBeEBcIF4Sc5WRo40AmVTZtCP1w++0w6fFiKj5duvTXY1QCuIFwQfjwe6Y47nL0oO3dK//hHsCu6OGulv/7VqbFHD6l372BXBLiCcEF4GjzY6WA++0wqLg52NRdnrbRhg/P3wIFSly7BrghwBeGC8HT11c7ci98vbd0aukNjlZXSjh1OlzViBDv/ETYIF4SnhARp6FDn8caNUkNDUMu5qIMHnRVtycnSkCHBrgZwDeGC8DVihLPyavdu6csvg13NhayVPvjA6a5695Z69Qp2RYBrCBeEJ2OkAQOkjAzpiy+kjz4KvaGxM2ec/S2SswTZ6w1uPYCLCBeEr27dpO9979xf4qHk+HEn9KKjpbvuCnY1gKsIF4Sv5r+0N22SamqCW09z1joT+V99JV1xhXTTTUzmI6wQLghfxjj7XRITnWNVDh0KdkX/x1rpf/7HWWgweLAzfAeEEcIF4e3aa6XsbKm6OrR26//jH85kvjHSqFEc+YKwQ7ggvCUmSnfe6Txes8Y5KTkUfPKJ9L//62yaHD6cITGEHcIF4S83V4qJcZYkf/55sKv5vyGx2lrphhtYgoywRLggvBkj3Xyzc27XyZPSli3BHxrz+51wkaScHCkuLrj1AO2AcEH4S0uTbr/dCZWVK4O/W//AAWnfPucUgVGjGBJDWCJcEP6MkfLynEnzbdukY8eCV4u1ztxPdbXUp4/zBwhDhAvCnzHOfVK++11nt/4HHwRvaOz0aWnVKudxbq7TvQBhiHBBZMjIcIbGAgHpvfeCNzS2f7/08cfOUS9jxgSnBqADEC6IDB6PdM89ztDYli1SaWnH19A451NVJfXt66wUY74FYYpwQWQwxjmCv2dP54TkP/+544fGqqulFSucxz/4AUNiCGuECyLHZZc5q7OslYqKOnZDpbXOPpt9+5x7t/zTP9G1IKwRLogcxkjjxzv7Sv7yF+dGXR3FWuntt50J/UGDpOuu67jPBoKAcEHkMMb5xd6nj1RR4Uzsd9TQ2IkTzioxY6Qf/pCNkwh7hAsiS3Ky071IztBYRUX7f6a1zhzPZ59J3/mOdPfdDIkh7BEuiCzGOKvG0tKcnfIdcVJyXZ305pvOMujcXGe/DRDmCBdEnuxs6fvfdyb0ly1z7lTZXqyV9u6Vtm6V4uOlf/5nZ1k0EOb4V47IExUl/fjHUmysc/vjffvar3ux1gmwqipp4EBpyBCGxBARCBdEHmOce6jceKMz57J0afuFy2efOQsHPB4n0Lze9vkcIMQQLohMSUnST37i/NL/r/9yQsBt1kp/+INzUGZWlrNxkq4FEYJwQWRqnNjPypKOHnWGrtzuXo4fl5YscR7/6EfO+WZAhCBcELkuv9zpXoyR3njD3fPGGruWkhKpe3dnIp+uBRGEcEHkMkZ6+GHnNsN//7v02986y4XdcPSo9OtfOyHz4x9LV13lzvsCnQThgsh25ZXSpEnO49dek4qLL314LBCQfvMbp2u5+mpp4kSWHyPi8C8ekc0YZz6kb19n4v3FFy9t34u1ztLm3/zGee/HHqNrQUQiXICMDOlf/kWKiZHeeuvSjuM/fVqaM8c51v9735MmTGCuBRGJcAGMke691znzq7paevpp53bI3zRgrJUKC517tni9zvukpbVPzUCII1wAyTma5d//3TlYcs8e53FdXdtf33jMyzPPOK97+GFp9Gi6FkQswgWQnBC44QZp1ixneOyNN6RFi6SGhq9/rbXO6rApU6SyMmc47N/+zXkfIEIRLkAjY6RHHnFWd9XXOwGxdGnrAWOtM4Q2aZL04YdSt27SggVOB0TXgghGuADNxcZKzz7r7N6vrpYef1x66SXn4Mnz52AaV4Y99JD03/8tdekiFRRIt95KsCDiRQe7ACCkGOOExH/+p7M35d13pV/+0llB9thjzsnG8fHO0S7vvutslDx2TLrsMidYxo9nTwsgwgW4kDFOWLz2mtS7t/Qf/yGtWydt3Ois/oqLk3w+p5sxRrrpJmd/zO23EyzAWYQL0BJjpJQUZ9XYmDHSq686d6386itnB358vDNx/+CDzvEul13GUBjQDOGCsGKtVXV5uTxurtS67jpp/nxnY2RpqbNR8rLLnJ33iYlOqJSXu/Zx9adPu/ZeQLAQLggbxhgl9eypg6+8IhMVFexyvjXr9yu2S5dglwFcEmNte92CD+hY1lqFyz9nY4wMw2zoxAgXAIDrWNoCAHAdcy5AWzVv8hmyAlpF5wK01UcfSVFRzt8AWkW4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AABcR7gAAFxHuAAAXEe4AG1grVV5ebms5PzN3cGBVhEuQCt8Pp8KCgqUnZ2tu0aOlLVWd40cqezsbBUUFMjn8wW7RCAkGct/wYAWrVmzRuPHj1dNTY0k6XvW6q+SBkrac/Y2xwkJCSoqKlJubm7wCgVCEJ0L0II1a9YoLy9Pfr9f1toLhsEan/P7/crLy9OaNWuCVCkQmuhcgPP4fD51795dfr9fgUCg6fmbpKbO5aNmX+/xeOT1elVaWqrU1NSOLRYIUXQuwHmWLFmimpqac4KlNYFAQDU1NVq6dGk7VwZ0HnQuQDPWWmVnZ6ukpOSCobCLdS6SZIxRZmamiouLZc7OxwCRjM4FaObkyZM6cuTIN15qbK3VkSNHdOrUqXaqDOhcCBegmaqqqkt6fWVlpUuVAJ0b4QI0k5SUdNFrB+QMiR1o5fXJyclulwR0SoQL0Ex6erqysrJanDfxy5lr8bfwOmOMsrKylJaW1t4lAp0C4QI0Y4zRtGnTvtVrp0+fzmQ+cBarxYDzXGyfy8WwzwW4EJ0LcJ7U1FQVFRXJGCOPp/UfEY/HI2OM3nnnHYIFaIZwAVqQm5urlStXyuv1yhhzwXBX43Ner1erVq1STk5OkCoFQhPhAlxEbm6uSktLlZ+fr8zMzHOuZWZmKj8/X2VlZQQL0ALmXIA2sNbq1KlTqqysVHJystLS0pi8B1pBuAAAXMewGADAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1hAsAwHWECwDAdYQLAMB1/x+on2d72IR9zAAAAABJRU5ErkJggg==\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 5 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.get_act(dataset2)\n",
|
|
"model2.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 104,
|
|
"id": "096134b0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 3.77e-04 | test_loss: 3.76e-04 | reg: 3.35e+00 | : 100%|█| 50/50 [00:46<00:00, 1.07it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.2\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.fit(dataset2, steps=50);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 105,
|
|
"id": "035e00f2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 3.73e-04 | test_loss: 3.72e-04 | reg: 3.35e+00 | : 100%|█| 50/50 [00:13<00:00, 3.81it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.3\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.fit(dataset2, steps=50, update_grid=False);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 106,
|
|
"id": "b65775e8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 5 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 107,
|
|
"id": "6cd26af5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" function fitting r2 r2 loss complexity complexity loss total loss\n",
|
|
"0 x^0.5 0.999957 -14.193489 2 2 -1.238698\n",
|
|
"1 sqrt 0.999957 -14.193489 2 2 -1.238698\n",
|
|
"2 log 0.999722 -11.763921 2 2 -0.752784\n",
|
|
"3 1/x^0.5 0.999485 -10.894391 2 2 -0.578878\n",
|
|
"4 1/sqrt(x) 0.999485 -10.894391 2 2 -0.578878\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('x^0.5',\n",
|
|
" (<function kan.utils.<lambda>(x)>,\n",
|
|
" <function kan.utils.<lambda>(x)>,\n",
|
|
" 2,\n",
|
|
" <function kan.utils.<lambda>(x, y_th)>),\n",
|
|
" 0.9999566254728288,\n",
|
|
" 2)"
|
|
]
|
|
},
|
|
"execution_count": 107,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.suggest_symbolic(1,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 72,
|
|
"id": "daabd91a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Best value at boundary.\n",
|
|
"r2 is 0.9989821969546337\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(tensor([-9.8000, 9.8868, -0.3482, 1.2049]), tensor(0.9990))"
|
|
]
|
|
},
|
|
"execution_count": 72,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.utils import fit_params\n",
|
|
"fit_params(x**2, y, lambda x: x**(1/2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"id": "8dc20f20",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"skipping (0,0,0) since already symbolic\n",
|
|
"fixing (1,0,0) with x^0.5, r2=0.9999494098870415, c=2\n",
|
|
"saving model version 0.4\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model2.auto_symbolic()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 79,
|
|
"id": "f27f3daa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle 1.19 - 1.08 \\sqrt{1 - 1.0 x_{1}^{2}}$"
|
|
],
|
|
"text/plain": [
|
|
"1.19 - 1.08*sqrt(1 - 1.0*x_1**2)"
|
|
]
|
|
},
|
|
"execution_count": 79,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan.utils import ex_round\n",
|
|
"ex_round(model2.symbolic_formula()[0][0], 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a61540c1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|