{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "fd0d2987", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "checkpoint directory created: ./model\n", "saving model version 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "| train_loss: 1.07e-04 | test_loss: 1.17e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:01<00:00, 16.52it" ] }, { "name": "stdout", "output_type": "stream", "text": [ "saving model version 0.1\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from kan import *\n", "from kan.utils import batch_jacobian, create_dataset_from_data\n", "import numpy as np\n", "\n", "model = KAN(width=[2,1], seed=42)\n", "\n", "# the model learns the Hamiltonian H = 1/2 * (x**2 + p**2)\n", "x = torch.rand(1000,2) * 2 - 1\n", "flow = torch.cat([x[:,[1]], -x[:,[0]]], dim=1)\n", "\n", "def pred_fn(model, x):\n", " grad = batch_jacobian(model, x, create_graph=True)\n", " grad_normalized = grad/torch.linalg.norm(grad, dim=1, keepdim=True)\n", " return grad_normalized\n", "\n", "loss_fn = lambda grad_normalized, flow: torch.mean(torch.sum(flow * grad_normalized, dim=1)**2)\n", "\n", "\n", "dataset = create_dataset_from_data(x, flow)\n", "model.fit(dataset, steps=20, pred_fn=pred_fn, loss_fn=loss_fn);" ] }, { "cell_type": "code", "execution_count": 3, "id": "60c88d7f", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model.plot()" ] }, { "cell_type": "code", "execution_count": 4, "id": "e8cb9a2f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fixing (0,0,0) with x^2, r2=1.0000003576278687, c=2\n", "fixing (0,1,0) with x^2, r2=1.0000004768371582, c=2\n", "saving model version 0.2\n" ] } ], "source": [ "model.auto_symbolic()" ] }, { "cell_type": "code", "execution_count": 5, "id": "1b143bf8", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle - 1.191 x_{1}^{2} - 1.191 x_{2}^{2} + 2.329$" ], "text/plain": [ "-1.191*x_1**2 - 1.191*x_2**2 + 2.329" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from kan.utils import ex_round\n", "ex_round(model.symbolic_formula()[0][0], 3)" ] }, { "cell_type": "code", "execution_count": null, "id": "782f818f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }