GitHub_collection_pykan/tutorials/Example_5_special_functions.ipynb
2024-08-11 13:02:16 -04:00

361 lines
42 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Example 5: Special functions"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"Let's construct a dataset which contains special functions $f(x,y)={\\rm exp}(J_0(20x)+y^2)$, where $J_0(x)$ is the Bessel function."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 5.15e-01 | test_loss: 5.84e-01 | reg: 5.91e+00 | : 100%|█| 20/20 [00:02<00:00, 7.39it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=3, k=3, seed=2)\n",
"f = lambda x: torch.exp(torch.special.bessel_j0(20*x[:,[0]]) + x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "markdown",
"id": "2f30c3ab",
"metadata": {},
"source": [
"Plot trained KAN, the bessel function shows up in the bettom left"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "187d19f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.70e-02 | test_loss: 9.15e-02 | reg: 7.69e+00 | : 100%|█| 20/20 [00:05<00:00, 3.49it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model = model.refine(20)\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8d50bcef",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "733a2a41",
"metadata": {},
"source": [
"suggest_symbolic does not return anything that matches with it, since Bessel function isn't included in the default SYMBOLIC_LIB. We want to add Bessel to it."
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "031db28f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 0 0.000000 0.000014 0 0 0.000003\n",
"1 x 0.002039 -0.002930 1 1 0.799414\n",
"2 J0 0.200055 -0.322009 2 2 1.535598\n",
"3 cos 0.168072 -0.265453 2 2 1.546909\n",
"4 sin 0.168072 -0.265453 2 2 1.546909\n"
]
},
{
"data": {
"text/plain": [
"('0',\n",
" (<function kan.utils.<lambda>(x)>,\n",
" <function kan.utils.<lambda>(x)>,\n",
" 0,\n",
" <function kan.utils.<lambda>(x, y_th)>),\n",
" 0.0,\n",
" 0)"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(0,0,0)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "4b8549a2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['x', 'x^2', 'x^3', 'x^4', 'x^5', '1/x', '1/x^2', '1/x^3', '1/x^4', '1/x^5', 'sqrt', 'x^0.5', 'x^1.5', '1/sqrt(x)', '1/x^0.5', 'exp', 'log', 'abs', 'sin', 'cos', 'tan', 'tanh', 'sgn', 'arcsin', 'arccos', 'arctan', 'arctanh', '0', 'gaussian', 'J0'])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"SYMBOLIC_LIB.keys()"
]
},
{
"cell_type": "markdown",
"id": "5db9e7cf",
"metadata": {},
"source": [
"add bessel function J0 to the symbolic library. we should include a name and a pytorch implementation. c is the complexity assigned to J0."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "cbde1924",
"metadata": {},
"outputs": [],
"source": [
"add_symbolic('J0', torch.special.bessel_j0, c=1)"
]
},
{
"cell_type": "markdown",
"id": "bda24c6d",
"metadata": {},
"source": [
"After adding Bessel, we check suggest_symbolic again"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "83e5cfdd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 0 0.000000 0.000014 0 0 0.000003\n",
"1 J0 0.200055 -0.322009 1 1 0.735598\n",
"2 x 0.002039 -0.002930 1 1 0.799414\n",
"3 cos 0.168072 -0.265453 2 2 1.546909\n",
"4 sin 0.168072 -0.265453 2 2 1.546909\n"
]
},
{
"data": {
"text/plain": [
"('0',\n",
" (<function kan.utils.<lambda>(x)>,\n",
" <function kan.utils.<lambda>(x)>,\n",
" 0,\n",
" <function kan.utils.<lambda>(x, y_th)>),\n",
" 0.0,\n",
" 0)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# J0 fitting is not very good\n",
"model.suggest_symbolic(0,0,0)"
]
},
{
"cell_type": "markdown",
"id": "4180de14",
"metadata": {},
"source": [
"The fitting r2 is still not high, this is because the ground truth is J0(20x) which involves 20 which is too large. our default search is in (-10,10). so we need to set the search range bigger in order to include 20. now J0 appears at the top of the list\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e78f4674",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 J0 0.998490 -9.361690 1 1 -1.072338\n",
"1 0 0.000000 0.000014 0 0 0.000003\n",
"2 x 0.002039 -0.002930 1 1 0.799414\n",
"3 cos 0.580127 -1.251939 2 2 1.349612\n",
"4 sin 0.580127 -1.251939 2 2 1.349612\n"
]
},
{
"data": {
"text/plain": [
"('J0',\n",
" (<function torch._C._special.special_bessel_j0>,\n",
" J0,\n",
" 1,\n",
" <function torch._C._special.special_bessel_j0>),\n",
" 0.9984899759292603,\n",
" 1)"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(0,0,0,a_range=(-40,40))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}