GitHub_collection_pykan/tutorials/physics_2A_conservation_law.ipynb
2024-08-11 13:15:16 -04:00

167 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Physics 2A: Conservation Laws"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fd0d2987",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"from kan.utils import batch_jacobian, create_dataset_from_data\n",
"import numpy as np\n",
"\n",
"model = KAN(width=[2,1], seed=42)\n",
"\n",
"# the model learns the Hamiltonian H = 1/2 * (x**2 + p**2)\n",
"x = torch.rand(1000,2) * 2 - 1\n",
"flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1)\n",
"\n",
"def pred_fn(model, x):\n",
" grad = batch_jacobian(model, x, create_graph=True)\n",
" grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True)\n",
" return grad_normalized\n",
"\n",
"loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2)\n",
"\n",
"\n",
"dataset = create_dataset_from_data(x, flow)\n",
"model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn);"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "60c88d7f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 500x200 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e8cb9a2f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2\n",
"fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2\n",
"saving model version 0.2\n"
]
}
],
"source": [
"model.auto_symbolic()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1b143bf8",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle - 1.191 x_{1}^{2} - 1.191 x_{2}^{2} + 2.329$"
],
"text/plain": [
"-1.191*x_1**2 - 1.191*x_2**2 + 2.329"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from kan.utils import ex_round\n",
"ex_round(model.symbolic_formula()[0][0], 3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "782f818f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}