GitHub_collection_pykan/tutorials/Example_10_relativity-addition.ipynb
kindxiaoming bfe4e84ec1 clean
2024-04-29 12:35:18 -04:00

433 lines
65 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "5d904dee",
"metadata": {},
"source": [
"# Example 10: Use of lock for Relativity Addition"
]
},
{
"cell_type": "markdown",
"id": "6465ec94",
"metadata": {},
"source": [
"In this example, we will symbolically regress $f(u,v)=\\frac{u+v}{1+uv}$. In relavitity, we know the rapidity trick $f(u,v)={\\rm tanh}({\\rm arctanh}\\ u+{\\rm arctanh}\\ v)$. Can we rediscover rapidity trick with KAN?"
]
},
{
"cell_type": "markdown",
"id": "94056ef6",
"metadata": {},
"source": [
"Intialize model and create dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0a59179d",
"metadata": {},
"outputs": [],
"source": [
"from kan import KAN, create_dataset\n",
"\n",
"# initialize KAN with G=3\n",
"model = KAN(width=[2,1,1], grid=10, k=3)\n",
"\n",
"# create dataset\n",
"f = lambda x: (x[:,[0]]+x[:,[1]])/(1+x[:,[0]]*x[:,[1]])\n",
"dataset = create_dataset(f, n_var=2, ranges=[-0.9,0.9])"
]
},
{
"cell_type": "markdown",
"id": "cb1f817e",
"metadata": {},
"source": [
"Train KAN and plot"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a87b97b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.28e-04 | test loss: 6.37e-04 | reg: 2.73e+00 : 100%|██| 20/20 [00:03<00:00, 5.41it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f1cfc9d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "2795dfc8",
"metadata": {},
"source": [
"We notice that the two functions in the first layer look the same. Let's try to lock them!"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "17b6b983",
"metadata": {},
"outputs": [],
"source": [
"model.lock(0,[[0,0],[1,0]])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "eb976f5a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "8214259e",
"metadata": {},
"source": [
"Now there are lock symbols in their top left corners!"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0298d20a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.13e-04 | test loss: 6.00e-04 | reg: 2.73e+00 : 100%|██| 20/20 [00:03<00:00, 5.68it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "markdown",
"id": "5ca6421a",
"metadata": {},
"source": [
"Retrain the model, the loss remains similar, meaning that the locking does not degrade model behavior, justifying our hypothesis that these two activation functions are the same. Let's now determine what this function is using $\\texttt{suggest_symbolic}$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2ccb7048",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"function , r2\n",
"arctanh , 0.9999993678015309\n",
"tan , 0.9998485210873531\n",
"arcsin , 0.998865199664262\n",
"sqrt , 0.9830640000050016\n",
"x^2 , 0.9830517375289431\n"
]
},
{
"data": {
"text/plain": [
"('arctanh',\n",
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
" 0.9999993678015309)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(0,1,0)"
]
},
{
"cell_type": "markdown",
"id": "0092be41",
"metadata": {},
"source": [
"We can see that ${\\rm arctanh}$ is at the top of the suggestion list! So we can set both to arctanh, retrain the model, and plot it."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1bb96fe1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999992221865773\n",
"r2 is 0.9999993678015309\n"
]
},
{
"data": {
"text/plain": [
"tensor(1.0000)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'arctanh')\n",
"model.fix_symbolic(0,1,0,'arctanh')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "83b852a3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.39e-04 | test loss: 2.54e-03 | reg: 2.73e+00 : 100%|██| 20/20 [00:03<00:00, 6.33it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20, update_grid=False);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "9ccd0923",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "4b98a727",
"metadata": {},
"source": [
"We will see that ${\\rm tanh}$ is at the top of the suggestion list (${\\rm sigmoid}$ is equivalent to tanh given input/ouput affine transformations)! So we can set it to ${\\rm tanh}$, retrain the model to machine precision, plot it and finally get the symbolic formula."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "99ad38b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"function , r2\n",
"tanh , 0.9999837308133379\n",
"sigmoid , 0.9999837287987492\n",
"arctan , 0.9995498634842791\n",
"sin , 0.996256989539414\n",
"gaussian , 0.9938095927784649\n"
]
},
{
"data": {
"text/plain": [
"('tanh',\n",
" (<function kan.utils.<lambda>(x)>, <function kan.utils.<lambda>(x)>),\n",
" 0.9999837308133379)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(1,0,0)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "af24c80d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999837308133379\n"
]
},
{
"data": {
"text/plain": [
"tensor(1.0000, grad_fn=<SelectBackward0>)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(1,0,0,'tanh')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "01936f17",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.69e-11 | test loss: 5.76e-12 | reg: 2.69e+00 : 100%|██| 20/20 [00:00<00:00, 21.70it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "76bcc188",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "b62b0246",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 \\tanh{\\left(1.0 \\operatorname{atanh}{\\left(1.0 x_{1} \\right)} + 1.0 \\operatorname{atanh}{\\left(1.0 x_{2} \\right)} \\right)}$"
],
"text/plain": [
"1.0*tanh(1.0*atanh(1.0*x_1) + 1.0*atanh(1.0*x_2))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.symbolic_formula()[0][0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}