GitHub_collection_pykan/tutorials/Example_9_singularity.ipynb
2024-08-11 13:02:16 -04:00

465 lines
40 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Example 9: Singularity"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"Let's construct a dataset which contains singularity $f(x,y)=sin(log(x)+log(y))\n",
" (x>0,y>0)$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.89e-02 | test_loss: 3.78e-02 | reg: 6.39e+00 | : 100%|█| 20/20 [00:02<00:00, 7.05it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=5, k=3, seed=2)\n",
"f = lambda x: torch.sin(2*(torch.log(x[:,[0]])+torch.log(x[:,[1]])))\n",
"dataset = create_dataset(f, n_var=2, ranges=[0.2,5])\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ccb7ec43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best value at boundary.\n",
"r2 is 0.999884843826294\n",
"saving model version 0.2\n",
"Best value at boundary.\n",
"r2 is 0.9998899102210999\n",
"saving model version 0.3\n",
"r2 is 0.9975605010986328\n",
"saving model version 0.4\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.9976)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'log')\n",
"model.fix_symbolic(0,1,0,'log')\n",
"model.fix_symbolic(1,0,0,'sin')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0937db67",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.95e-07 | test_loss: 2.91e-07 | reg: 0.00e+00 | : 100%|█| 20/20 [00:01<00:00, 15.68it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e959cda3",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 \\sin{\\left(2.0 \\log{\\left(9.993 x_{1} \\right)} + 2.0 \\log{\\left(10.0 x_{2} \\right)} - 9.209 \\right)}$"
],
"text/plain": [
"1.0*sin(2.0*log(9.993*x_1) + 2.0*log(10.0*x_2) - 9.209)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ex_round(model.symbolic_formula()[0][0], 3)"
]
},
{
"cell_type": "markdown",
"id": "16e4da06",
"metadata": {},
"source": [
"We were lucky -- singularity does not seem to be a problem in this case. But let's instead consider $f(x,y)=\\sqrt{x^2+y^2}$. $x=y=0$ is a singularity point."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "1ce52cec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 5.17e-03 | test_loss: 5.45e-03 | reg: 5.66e+00 | : 100%|█| 20/20 [00:02<00:00, 7.44it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=5, k=3, seed=0)\n",
"f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3a69ec41",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "abef7aa9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999783635139465\n",
"saving model version 0.2\n",
"r2 is 0.9999676942825317\n",
"saving model version 0.3\n",
"r2 is 0.9997884631156921\n",
"saving model version 0.4\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.9998)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'x^2')\n",
"model.fix_symbolic(0,1,0,'x^2')\n",
"model.fix_symbolic(1,0,0,'sqrt')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "aa71848c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rewind to model version 1.4, renamed as 2.4\n"
]
}
],
"source": [
"model = model.rewind('0.4')\n",
"model.get_act(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "e14000d8",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.01334547419162 \\sqrt{0.999861446076389 \\left(7.53297050423062 \\cdot 10^{-5} - x_{2}\\right)^{2} + \\left(0.000104069324734005 - x_{1}\\right)^{2} + 0.00834810636784406} - 0.0170296430587769$"
],
"text/plain": [
"1.01334547419162*sqrt(0.999861446076389*(7.53297050423062e-5 - x_2)**2 + (0.000104069324734005 - x_1)**2 + 0.00834810636784406) - 0.0170296430587769"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"formula = model.symbolic_formula()[0][0]\n",
"formula"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "c56ee3d5",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.01 \\sqrt{x_{1}^{2} + 1.0 x_{2}^{2} + 0.01} - 0.02$"
],
"text/plain": [
"1.01*sqrt(x_1**2 + 1.0*x_2**2 + 0.01) - 0.02"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ex_round(formula, 2)"
]
},
{
"cell_type": "markdown",
"id": "1fd57d41",
"metadata": {},
"source": [
"w/ singularity avoiding (LBFGS may still get nan because of line search, but Adam won't get nan)."
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "de708f21",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 1.09e-07 | test_loss: 1.48e-07 | reg: 0.00e+00 | : 100%|█| 1000/1000 [00:12<00:00, 83.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 1.5\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"Adam\", steps=1000, lr=1e-3, update_grid=False, singularity_avoiding=True);"
]
},
{
"cell_type": "markdown",
"id": "6fd34c4c",
"metadata": {},
"source": [
"w/o singularity avoiding, nan may appear"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "031fabd6",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: nan | test_loss: nan | reg: nan | : 100%|█████████| 1000/1000 [00:11<00:00, 84.83it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 1.6\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"Adam\", steps=1000, lr=1e-3, update_grid=False);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "124c9ca4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}