GitHub_collection_pykan/tutorials/Example_10_relativity-addition.ipynb
2024-08-11 17:13:55 -04:00

428 lines
48 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "5d904dee",
"metadata": {},
"source": [
"# Example 10: Relativitistic Velocity Addition"
]
},
{
"cell_type": "markdown",
"id": "6465ec94",
"metadata": {},
"source": [
"In this example, we will symbolically regress $f(u,v)=\\frac{u+v}{1+uv}$. In relavitity, we know the rapidity trick $f(u,v)={\\rm tanh}({\\rm arctanh}\\ u+{\\rm arctanh}\\ v)$. Can we rediscover rapidity trick with KAN?"
]
},
{
"cell_type": "markdown",
"id": "94056ef6",
"metadata": {},
"source": [
"Intialize model and create dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0a59179d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n",
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
}
],
"source": [
"from kan import *\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)\n",
"\n",
"# initialize KAN with G=3\n",
"model = KAN(width=[2,1,1], grid=10, k=3, device=device)\n",
"\n",
"# create dataset\n",
"f = lambda x: (x[:,[0]]+x[:,[1]])/(1+x[:,[0]]*x[:,[1]])\n",
"dataset = create_dataset(f, n_var=2, ranges=[-0.9,0.9], device=device)"
]
},
{
"cell_type": "markdown",
"id": "cb1f817e",
"metadata": {},
"source": [
"Train KAN and plot"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a87b97b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.28e-03 | test_loss: 2.31e-03 | reg: 6.50e+00 | : 100%|█| 20/20 [00:03<00:00, 5.88it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f1cfc9d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "5ca6421a",
"metadata": {},
"source": [
"Retrain the model, the loss remains similar, meaning that the locking does not degrade model behavior, justifying our hypothesis that these two activation functions are the same. Let's now determine what this function is using $\\texttt{suggest_symbolic}$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2ccb7048",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 arctanh 0.999992 -15.786788 4 4 -15.786788\n",
"1 tan 0.999825 -12.397871 3 3 -12.397871\n",
"2 arccos 0.998852 -9.753944 4 4 -9.753944\n",
"3 arcsin 0.998852 -9.753944 4 4 -9.753944\n",
"4 sqrt 0.982166 -5.808383 2 2 -5.808383\n"
]
},
{
"data": {
"text/plain": [
"('arctanh',\n",
" (<function kan.utils.<lambda>(x)>,\n",
" <function kan.utils.<lambda>(x)>,\n",
" 4,\n",
" <function kan.utils.<lambda>(x, y_th)>),\n",
" 0.999992311000824,\n",
" 4)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(0,1,0,weight_simple=0.0)"
]
},
{
"cell_type": "markdown",
"id": "0092be41",
"metadata": {},
"source": [
"We can see that ${\\rm arctanh}$ is at the top of the suggestion list! So we can set both to arctanh, retrain the model, and plot it."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1bb96fe1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999759197235107\n",
"saving model version 0.2\n",
"r2 is 0.999992311000824\n",
"saving model version 0.3\n"
]
},
{
"data": {
"text/plain": [
"tensor(1.0000, device='cuda:0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'arctanh')\n",
"model.fix_symbolic(0,1,0,'arctanh')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "83b852a3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 7.94e-04 | test_loss: 9.43e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:04<00:00, 4.34it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"LBFGS\", steps=20, update_grid=False);"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9ccd0923",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=10)"
]
},
{
"cell_type": "markdown",
"id": "4b98a727",
"metadata": {},
"source": [
"We will see that ${\\rm tanh}$ is at the top of the suggestion list! So we can set it to ${\\rm tanh}$, retrain the model to machine precision, plot it and finally get the symbolic formula."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "99ad38b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" function fitting r2 r2 loss complexity complexity loss total loss\n",
"0 tanh 0.999998 -16.336284 3 3 -16.336284\n",
"1 arctan 0.999435 -10.764618 4 4 -10.764618\n",
"2 cos 0.995899 -7.926177 2 2 -7.926177\n",
"3 sin 0.995899 -7.926177 2 2 -7.926177\n",
"4 gaussian 0.994457 -7.492519 3 3 -7.492519\n"
]
},
{
"data": {
"text/plain": [
"('tanh',\n",
" (<function kan.utils.<lambda>(x)>,\n",
" <function kan.utils.<lambda>(x)>,\n",
" 3,\n",
" <function kan.utils.<lambda>(x, y_th)>),\n",
" 0.9999979138374329,\n",
" 3)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.suggest_symbolic(1,0,0,weight_simple=0.)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "af24c80d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999979138374329\n",
"saving model version 0.5\n"
]
},
{
"data": {
"text/plain": [
"tensor(1.0000, device='cuda:0')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(1,0,0,'tanh')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "01936f17",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 1.97e-06 | test_loss: 2.06e-06 | reg: 0.00e+00 | : 100%|█| 2000/2000 [00:21<00:00, 93.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.6\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=2000, update_grid=False, singularity_avoiding=True);"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "76bcc188",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b62b0246",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\tanh{\\left(\\operatorname{atanh}{\\left(x_{1} \\right)} + \\operatorname{atanh}{\\left(x_{2} \\right)} \\right)}$"
],
"text/plain": [
"tanh(atanh(x_1) + atanh(x_2))"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"formula = model.symbolic_formula()[0][0]\n",
"nsimplify(ex_round(formula, 4))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}