547 lines
76 KiB
Plaintext
547 lines
76 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Example 5: Special functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2571d531",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's construct a dataset which contains special functions $f(x,y)={\\rm exp}(J_0(20x)+y^2)$, where $J_0(x)$ is the Bessel function."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "2075ef56",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train loss: 1.40e-01 | test loss: 1.38e-01 | reg: 2.88e+01 : 100%|██| 20/20 [00:30<00:00, 1.50s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import KAN, create_dataset, SYMBOLIC_LIB, add_symbolic\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
|
|
"model = KAN(width=[2,5,1], grid=20, k=3, seed=0)\n",
|
|
"f = lambda x: torch.exp(torch.special.bessel_j0(20*x[:,[0]]) + x[:,[1]]**2)\n",
|
|
"dataset = create_dataset(f, n_var=2)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f30c3ab",
|
|
"metadata": {},
|
|
"source": [
|
|
"Plot trained KAN, the bessel function shows up in the bettom left"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3f95fcdd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model = model.prune()\n",
|
|
"model(dataset['train_input'])\n",
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "733a2a41",
|
|
"metadata": {},
|
|
"source": [
|
|
"suggest_symbolic does not return anything that matches with it, since Bessel function isn't included in the default SYMBOLIC_LIB. We want to add Bessel to it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "031db28f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"function , r2\n",
|
|
"gaussian , 0.7090268761989152\n",
|
|
"1/x^2 , 0.21051195154680438\n",
|
|
"sin , 0.1822506022370818\n",
|
|
"abs , 0.12418544555819415\n",
|
|
"tan , 0.10407480103502795\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('gaussian',\n",
|
|
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
|
|
" 0.7090268761989152)"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.suggest_symbolic(0,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "4b8549a2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"dict_keys(['x', 'x^2', 'x^3', 'x^4', '1/x', '1/x^2', '1/x^3', '1/x^4', 'sqrt', '1/sqrt(x)', 'exp', 'log', 'abs', 'sin', 'tan', 'tanh', 'sigmoid', 'sgn', 'arcsin', 'arctan', 'arctanh', '0', 'gaussian', 'cosh'])"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"SYMBOLIC_LIB.keys()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "cbde1924",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# add bessel function J0 to the symbolic library\n",
|
|
"# we should include a name and a pytorch implementation\n",
|
|
"add_symbolic('J0', torch.special.bessel_j0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bda24c6d",
|
|
"metadata": {},
|
|
"source": [
|
|
"After adding Bessel, we check suggest_symbolic again"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "83e5cfdd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"function , r2\n",
|
|
"gaussian , 0.7090268761989152\n",
|
|
"J0 , 0.2681378679614782\n",
|
|
"1/x^2 , 0.21051195154680438\n",
|
|
"sin , 0.1822506022370818\n",
|
|
"abs , 0.12418544555819415\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('gaussian',\n",
|
|
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
|
|
" 0.7090268761989152)"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# J0 shows up but not top 1, why?\n",
|
|
"\n",
|
|
"model.suggest_symbolic(0,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "e78f4674",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"function , r2\n",
|
|
"J0 , 0.9717763100936939\n",
|
|
"gaussian , 0.7494106253678943\n",
|
|
"sin , 0.49679878395526067\n",
|
|
"1/x^2 , 0.21051195158162733\n",
|
|
"abs , 0.12435207425739554\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('J0',\n",
|
|
" (<function torch._C._special.special_bessel_j0>, J0),\n",
|
|
" 0.9717763100936939)"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# This is because the ground truth is J0(20x) which involves 20 which is too large.\n",
|
|
"# our default search is in (-10,10)\n",
|
|
"# so we need to set the search range bigger in order to include 20\n",
|
|
"# now J0 appears at the top of the list\n",
|
|
"\n",
|
|
"model.suggest_symbolic(0,0,0,a_range=(-40,40))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "47fb0d09",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train loss: 1.67e-02 | test loss: 1.80e-02 | reg: 2.87e+00 : 100%|██| 20/20 [00:08<00:00, 2.25it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.train(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "4773e989",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "104199f4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"function , r2\n",
|
|
"J0 , 0.9985560043309399\n",
|
|
"gaussian , 0.6101756259771707\n",
|
|
"sin , 0.5737221152646913\n",
|
|
"tan , 0.08366297315238909\n",
|
|
"1/x , 0.08315973336762218\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('J0',\n",
|
|
" (<function torch._C._special.special_bessel_j0>, J0),\n",
|
|
" 0.9985560043309399)"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.suggest_symbolic(0,0,0,a_range=(-40,40))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fe1f857d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Finish the rest of symbolic regression"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "eb6c0f43",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"r2 is 0.9985560043309399\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.9986)"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fix_symbolic(0,0,0,'J0',a_range=(-40,40))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "11a27268",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"skipping (0,0,0) since already symbolic\n",
|
|
"fixing (0,1,0) with x^2, r2=0.9999802186534139\n",
|
|
"fixing (1,0,0) with sigmoid, r2=0.9999663092809886\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.auto_symbolic()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "5076005f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "79816b25",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train loss: 1.12e-03 | test loss: 1.17e-03 | reg: 4.76e+01 : 100%|██| 20/20 [00:08<00:00, 2.38it/s]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.train(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "ba171cc4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 4 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "e26b771f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"function , r2\n",
|
|
"exp , 0.9999988610586863\n",
|
|
"cosh , 0.9999699077016541\n",
|
|
"sigmoid , 0.9999693609882967\n",
|
|
"arctan , 0.9999174139339265\n",
|
|
"gaussian , 0.9999096961395885\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"('exp',\n",
|
|
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
|
|
" 0.9999988610586863)"
|
|
]
|
|
},
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.suggest_symbolic(1,0,0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "b939a769",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"r2 is 0.9999988610586863\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(1.0000, grad_fn=<SelectBackward0>)"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fix_symbolic(1,0,0,'exp')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "a0e2813a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train loss: 8.09e-04 | test loss: 8.51e-04 | reg: 4.68e+01 : 100%|██| 20/20 [00:05<00:00, 3.96it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 J_{0}{\\left(- 20.0 x_{1} \\right)}}$"
|
|
],
|
|
"text/plain": [
|
|
"1.0*exp(1.0*x_2**2 + 1.0*J0(-20.0*x_1))"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# why can't we reach machine precision (because LBFGS early stops?)? The symbolic formula is correct though.\n",
|
|
"model.train(dataset, opt=\"LBFGS\", steps=20);\n",
|
|
"model.symbolic_formula()[0][0]"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|