GitHub_collection_pykan/tutorials/Example_9_singularity.ipynb

347 lines
75 KiB
Plaintext
Raw Normal View History

2024-04-29 12:35:18 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Example 9: Singularity"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"Let's construct a dataset which contains singularity $f(x,y)=sin(log(x)+log(y))\n",
" (x>0,y>0)$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.00e-03 | test loss: 3.94e-02 | reg: 2.72e+00 : 100%|██| 20/20 [00:11<00:00, 1.78it/s]\n"
]
}
],
"source": [
"from kan import KAN, create_dataset, SYMBOLIC_LIB, add_symbolic\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=20, k=3, seed=0)\n",
"f = lambda x: torch.sin(2*(torch.log(x[:,[0]])+torch.log(x[:,[1]])))\n",
"dataset = create_dataset(f, n_var=2, ranges=[0.2,5])\n",
"\n",
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvQklEQVR4nO3d6XNUV34+8Od2q1vqTUtrXxBakJCQABvM6g3PEMPYhInHs9jJULiSmVQqiV12/oG8ynuIp1yZSX7jMVMpMuPALA4eMF4BLxixGLSCFrQLqSW11OpF6u57fi+Ye6uvNiR0pd6eT5Wr5l51w7FHt58+53sWSQghQEREpCNDtBtARESJh+FCRES6Y7gQEZHuGC5ERKQ7hgsREemO4UJERLpjuBARke4YLkREpDuGCxER6Y7hQkREumO4EBGR7hguRESkO4YLERHpjuFCRES6Y7gQEZHuUqLdAKJ4IITA6OgopqamYLfbkZ2dDUmSot0sopjFngvRItxuN44fP46qqirk5uaivLwcubm5qKqqwvHjx+F2u6PdRKKYJPEkSqL5nTt3Di+++CJ8Ph+A+70XhdJrsVqtOHXqFA4cOBCVNhLFKoYL0TzOnTuH559/HkIIyLK84OsMBgMkScKZM2cYMEQRGC5Es7jdbpSUlMDv9y8aLAqDwQCLxYK+vj5kZmaufgOJ4gBrLkSzvPPOO/D5fEsKFgCQZRk+nw8nTpxY5ZYRxQ/2XIgiCCFQVVWFzs5OLOfRkCQJFRUVuHPnDmeREYHhQqThcrmQm5u7ovdnZ2fr2CKi+MRhMaIIU1NTK3q/x+PRqSVE8Y3hQhTBbrev6P0Oh0OnlhDFN4YLUYTs7GxUVlYuu24iSRIqKyvhdDpXqWVE8YXhQhRBkiS8+uqrD/Xe1157jcV8oj9jQZ9oFq5zIVo59lyIZsnMzMSpU6cgSRIMhsUfEWWF/unTpxksRBEYLkTzOHDgAM6cOQOLxQJJkuYMdyn3LBYL3n//fTz77LNRailRbGK4EC3gwIED6Ovrw7Fjx1BRUaH5WUVFBY4dO4b+/n4GC9E8WHMhWgIhBNra2nDy5Em8/PLL2LhxI4v3RItgz4VoCSRJgtPpRHZ2NpxOJ4OF6AEYLkREpDuGCxER6Y7hQkREumO4EBGR7hguRESkO4YLERHpjuFCRES6Y7gQEZHuGC5ERKQ7hgsREemO4UJERLpjuBARke4YLkREpDuGC9ESWSwWbNmyBRaLJdpNIYp5PM+FaInC4TB8Ph+sViuMRmO0m0MU0xguRESkOw6LERGR7lKi3QAihSzLGBoawszMTLSbEvfy8vJgtVqj3QxKYgwXihmyLOPChQsIhUIwm83Rbk5cEkJgYmIC+/fvR0VFRbSbQ0mM4UIxQwgBSZLw5JNPoqSkZN7XhMNhCCGQksJf3fmEw2H88Y9/jHYziBguFHsMBsOc2VhCCAwPD+PixYsIh8PYs2cP1q1bB0mSotTK2MX/JhQLWNCnuCCEwLVr1zA4OIjh4WFcvnwZ4XA42s0iogUwXCguhMNhuN1u9XpiYgJ+vz96DSKiRTFcKC4Eg0FNmExPT2NiYiKKLSKixTBcKC4EAgHNFGVZluFyucA1wESxieFCccHr9SIUCmnujYyMRKk1RPQgDBeKC1NTU3N6KW63G7IsR6lFRLQYhgvFPCEEPB7PnPs+nw/BYDAKLSKiB2G4UFyYnJycc29mZoZbxRDFKIYLxTxZljE1NTXnfjAYRCAQiEKLiOhBGC4U88LhMLxe75z7sizPe5+Ioo/hQjFvZmZm3h6KEAJer5fTkYliEPcWo5g3e41LpPkK/WtJCIGpqSl4PB6kpaUhIyMDBoOB+3tR0mO4UMzzer0L7iMWzXAJhUK4fv06bt68Cb/fD5PJhPLycuzduxc2m40BQ0mNw2IU05RpyAsNfXm93qisdZFlGQ0NDbh8+TJ8Ph+EEJiZmUFbWxs++ugjTE9Pr3mbiGIJw4ViXmTvxGAwaA4S8/v9ax4uQgj09/fjxo0b84Zeb28vvvnmG9aCKKkxXCjmRYaL2WyG0+lUrwOBwJovpAyFQmhoaFD/XkmSUFJSApPJBOB++Ny6dQvj4+Nr2i6iWMJwoZgmy7ImXCwWiyZcFppJtlqEEBgYGMDg4KB6Lz8/HwcPHsTmzZvVe36/H42Njey9UNJiuFBMC4fD8Pl86rXVatWESygUWtO1LkIINDc3qxMMDAYDtm3bhrS0NGzZsgUOh0N9bUdHx7yLP4mSAcOFYtr09LSmOO5wOJCRkaHOxFpo37HVMjk5ib6+PvU6JydHPW7Zbrejurpa/dnU1BS6u7vZe6GkxHChmOb3+zU1FYfDAbvdDqPRqN6bnJxckw9wIQS6u7s1w3DV1dVqrQUAqqqqNBMO2tvbuXMzJSWGC8W02Wtc0tPTYbFYNB/o821quRpkWUZnZ6d6nZqairKyMrUXJUkSnE4nCgoK1NcMDw/zxExKSgwXilmz17goQ0+pqalIS0tTXzc1NbUmvYPJyUnNAWX5+flIT0/XvMZgMKC8vFy9np6eRl9fH4fGKOkwXCimRdZTUlJSYLPZYDQaYbVa1fuLreDXixACg4ODmvpPWVkZDAbtI6RMS05NTVXv9fT0MFwo6TBcKKZFDnmZTCakpaXBYDBoZmUFAoFVXxEvhEBPT4+mLcXFxfNu8ZKeno7s7Gz1enh4WDPjjSgZMFwoZs0+x8VisajF8sjhqGAwCL/fv6ptmZ6exr1799Rrp9OJjIyMeV9rNBpRUlKiXvt8PoyMjLD3QkmF4UIxa3ZoKENikiRpwmWhw8T0NDY2pvk7ioqKkJIy/76vytCYMmSmbBdDlEwYLhSzpqenNdN+HQ6HOgxlt9vnrHVZrZ6BsipfmTQgSRKKi4sXfY/T6YTNZlOvBwcHV70uRBRLGC4Us/x+P0KhkHodWWex2WyansNqTkeWZRkDAwPqtcViQU5OzqJb6qempiIvL0+9Hh8f56mZlFQYLhSzZm+nn56ern6gp6WlaWZkreZCykAggLGxMfXa6XRqZqvNR5IkFBUVqdczMzNwuVysu1DSYLhQTBJCaALDYDDAbrerPzeZTLBYLOr1aq51GR8f18z2KigomDMFeTZJkpCfn6/uJKBMZSZKFgwXilmz17hE9haMRqMmbHw+36psvS+EwL179zT1lsLCwiW9NzMzU1N3ifxziBIdw4ViktJzUZjNZs2qfEmSNFOBA4HAqtQ0Zvc40tLSkJ2dvaQjjFNTUzXrXWb3gIgSGcOFYpIsy5qwmL2fGADNB3c4HIbb7da9pjE9Pa2pt2RkZGiG4xYjSZJmn7FAIMADxChpMFwoJgWDQc23fLvdrqlzSJKErKwsTU1jdHRU93ZMTk5qQi6yjvIgkiQhLy9PM2X63r17LOpTUmC4UEwKBAKYmZlRryPXuETei9zefnR0VNcPbiEEXC6XZjp0ZE9kKZxOp2Y4j+FCyYLhQjHJ5/MtuMZFkZaWprk/Pj6ueY8ehoaG1P9tMpmWXG9RWCwWZGZmqtdjY2Oa0CRKVAwXikmzpxZHrnFRGI1GTd1lampK16J+OByGy+VSrx0Ox7whtxiDwaBZTOn1etfs/BmiaGK4UEyK/ACePe04UuQHdzAY1BTfV8rn82na4XQ6F9xPbDGRbQyFQlxMSUmB4UIxKS8vD/X19Vi/fj3y8/M160UUkiQhJydHs0GknrsPj4+Pa7byz8/PX9aQWGQbI2e6Re6uvFI+nw8ulwvT09NcQ0MxZflfw4jWQGlpKUpLSyGEgCzLC87QSk9PR1pamjqzTAmX5YbAbEIIDA8Pa3YIyM3Nfag/1+FwwGazwe12AwBcLhfC4fBD9YJmt/H27dv48ssvYbPZkJ2djd27d6/ozyTSC3suFJMkSYIkSTAYDEhJSVnwQ91isWgWU46Pj+tSMFfCRZGamqopzC+HyWSC0+lUrycmJnQ7f+bevXsIhUKYmJhAf3//ikOVSC8MF4prBoMBOTk56rXX69XlbJfZ9ZvlLJ6cTdlnTKHXYspgMKhZ25Oenj7v8CFRNDBcKO7NLpiPjY2tuO7i8Xg
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ccb7ec43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.999988712412588\n",
"r2 is 0.9999928603717329\n",
"r2 is 0.9968394556850537\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.9968, grad_fn=<SelectBackward0>)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'log')\n",
"model.fix_symbolic(0,1,0,'log')\n",
"model.fix_symbolic(1,0,0,'sin')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0937db67",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.46e-15 | test loss: 6.78e-16 | reg: 2.72e+00 : 100%|██| 20/20 [00:02<00:00, 8.21it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e959cda3",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 \\sin{\\left(2.0 \\log{\\left(2.01 x_{1} \\right)} + 2.0 \\log{\\left(0.62 x_{2} \\right)} + 5.85 \\right)}$"
],
"text/plain": [
"1.0*sin(2.0*log(2.01*x_1) + 2.0*log(0.62*x_2) + 5.85)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.symbolic_formula()[0][0]"
]
},
{
"cell_type": "markdown",
"id": "16e4da06",
"metadata": {},
"source": [
"We were lucky -- singularity does not seem to be a problem in this case. But let's instead consider $f(x,y)=\\sqrt{x^2+y^2}$. $x=y=0$ is a singularity point."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1ce52cec",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.08e-01 | test loss: 2.18e-01 | reg: 2.75e+01 : 100%|██| 20/20 [00:38<00:00, 1.92s/it]\n"
]
}
],
"source": [
"from kan import KAN, create_dataset, SYMBOLIC_LIB, add_symbolic\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,5,1], grid=5, k=3, seed=1)\n",
"f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"\n",
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3a69ec41",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABIR0lEQVR4nO3de3QU53k/8O/salfalVaXXQmBEEjouhJISICxDQHLgM2tbhM7bdM2Td30NI3b2DltnDRt2ubXy8ml6TkxjpuLnbq1/0nS2OT0NBBszEUCDMExCAnQSkJCQgIh0N6099vM7w86m0Xosrua2Xln9XzO0T+wu/PuuzPzzHt7Xk4QBAGEEEKIhDRKF4AQQkj2oeBCCCFEchRcCCGESI6CCyGEEMlRcCGEECI5Ci6EEEIkR8GFEEKI5Ci4EEIIkRwFF0IIIZKj4EIIIURyFFwIIYRIjoILIYQQyVFwIYQQIjkKLoQQQiRHwYUQQojkcpQuACFqIAgC7HY7vF4vCgoKYLFYwHGc0sUihFnUciFkHi6XCwcOHEB9fT3KysqwZs0alJWVob6+HgcOHIDL5VK6iIQwiaOdKAmZ3TvvvINnnnkGfr8fwL3Wi0hstRiNRrz99tvYvXu3ImUkhFUUXAiZxTvvvIP9+/dDEATwPD/n6zQaDTiOw6FDhyjAEJKAggshM7hcLlRWViIQCMwbWEQajQYGgwHj4+MoLi6Wv4CEqACNuRAywxtvvAG/359UYAEAnufh9/vx5ptvylwyQtSDWi6EJBAEAfX19RgeHkYqlwbHcaipqcHg4CDNIiMEFFwIuc/U1BTKysoW9X6LxSJhiQhRJ+oWIySB1+td1Ps9Ho9EJSFE3Si4EJKgoKBgUe83mUwSlYQQdaPgQsj/EQQBZrMZNTU1Kb+X4zjU1tbCbDbLUDJC1IeCC1nSBEGI/wH3phW/8MILaX3OCy+8QIP5hPwfCi5kyZkZUDiOi/8JgoBPfepTMBqN0GiSvzzElgsh5B4KLmRJmBlQgF8HlcTXAEBJSQkOHjwIjuMWDDAajQYajQZbtmzBs88+i3/8x39EOByW50sQoiIUXEhWmyugzOy+SmzFAMDu3btx6NAhGAyGWV8v/pvBYMAvfvELdHV14atf/Sq+973vYd++fbh+/brM34wQtlFwIVlnvm6vuV4rvi7R7t27MT4+jpdeeumBQf6amhq89NJLuHnzJp588kloNBp87nOfw5EjR+B2u9HR0YGf/OQnKS3EJCSb0CJKkhVmnsbJDKzPFVTmeq3D4YDH44HJZILZbJ7zfV6vF1/+8pfx4x//GM888wz+7d/+jaYokyWHggtRrdlO3VSCSrKvT9fBgwfxhS98AWazGT/4wQ+wadMm2Y5FCGuoW4yoTrLjKHO9N/E9cnr66adx8uRJlJWVYf/+/fj2t7+NWCwm6zEJYQW1XIgqpNPtNddnZHotSiQSwbe+9S18+9vfxpYtW/D9738fK1asyGgZCMk0Ci6EWVIElJmfo+QixzNnzuCzn/0sgsEgXn75Zezdu1exshAiN+oWI0xJZaZXsp+X+DlK2rp1Kzo7O/Hoo4/iD//wD/HFL34RwWBQ0TIRIhdquRAmSNVKme3zlA4qMwmCgDfeeANf+cpXsGbNGrz66qtobm5WuliESIpaLkQxUrdSEj838fNYw3Ecnn32WRw7dgwajQa7du3Cf/zHf9CaGJJVqOVCMkrqFspcn89iUJlNKBTCV7/6Vfzwhz/E7t278fLLL9NmYyQrUHAhspM7oMw8hloCS6J33nkHzz//PHQ6Hb73ve9h+/btSheJkEWhbjEii2QSRUp5LDk/PxN2796Nrq4uNDY24plnnsE///M/IxKJKF0sQtJGLRciqUy0UmY7nlqDykw8z+OVV17B1772NbS2tuLVV19FdXW10sUiJGXUciGLJtfAfDLHFI+XLcTNyg4fPgyHw4GOjg789Kc/VbpYhKSMggtJixIBJfHYicfMRhs2bMDJkyexb98+PPfcc/jsZz8Lj8ejdLEISRp1i5GkpZsoUo7jZ2tQmc1bb72FF198ERaLBa+++io2btyodJEIWRC1XMiCFpMoUsoyJB57Kfn4xz+Ozs5OWCwW7N+/Hy+99BIlwCTMo5YLmVWmB+bnk41jK+mIRCL45je/iQMHDmDr1q343ve+RwkwCbMouJA4lgIKsHS7wRZy6tQpPPfccwiHw3j55ZexZ88epYtEyAOoW2yJU3JgfqFyJZaH/Nq2bdvQ1dWFhx9+GJ/85Cfx13/915QAkzCHWi5LFGutFBG1VpInCAL+8z//E3//93+PmpoavPrqq2hqalK6WIQAoJbLksJqK0VErZXUcByHT3/603jvvfcAALt27cLrr79OCTAJEyi4ZDnWA4qIBu3T19TUhKNHj+KTn/wkvvSlL+FTn/oUHA6H0sUiSxx1i2UhVru8ZkPdYNI6cuQIXnjhBej1enz/+9/HRz7yEaWLRJYoarlkkUwlipQKdYNJb8+ePejs7ERdXR0+9rGP4V/+5V8oASZRBLVcVE5NrZRE1A0mr1gshldeeQVf//rXsX79erz66quoqqpSulhkCaGWiwqpZRxlNtmacJI1Wq0Wn//853H48GFMTU3hsccew1tvvaV0scgSQsFFJdQcUETUDZZ5YgLMvXv34rOf/Sz+/M//HF6vV+likSWAusUYpnSiSKnQoD0bfvrTn+LFF19EWVkZXnvtNbS3tytdJJLFqOXCIBYSRUqFWivs+O3f/m10dnbCbDZj7969ePnll8HzvNLFIlmKWi6MUOvA/HxobIVNkUgE3/jGN/Dyyy9j27Zt+O53v4vly5crXSySZSi4KCgbAwpA3WBq0dXVheeeew6RSATf+c53sHv3bqWLRLIIdYtlWDYMzM+HusHUY/v27ejq6sJDDz2EP/iDP8CXv/xlhEIhpYtFsgS1XDIkW1spiagbTJ0EQcDrr7+Ov//7v0ddXR1effVVWK1WpYtFVI5aLjLK9laKiNauqBvHcfiTP/kTvPfee+B5Hjt37sR//dd/UQJMsigUXCS2VAKKiLrBskdzczPee+89/P7v/z5efPFF/NEf/RElwCRpo24xCSyFLq+ZaNA+ux0+fBif//znkZeXh+9///vYunWr0kUiKkMtl0VQW6JIqVBrJfvt27cPnZ2dqKmpwUc/+lF87WtfowSYJCXUcknRUmylJKKxlaUlFovh5Zdfxje+8Q20t7fjBz/4ASXAJEmhlksSlto4ymxo0H5p0mq1+Mu//EscOnQId+/eRUdHBw4ePKh0sYgKUHCZAwWUX6NuMLJp0yacOHECTz75JD7zmc/gc5/7HCXAJPOibrEE2ZIoUio0aE9mEgQB//3f/40vfelLWLZsGV577TW0tbUpXSzCIGq5ILsSRUqFWitkNhzH4Xd/93dx8uRJFBUVYc+ePXjllVcoASZ5wJJtuSz1gfn50NgKSUY4HMbXv/51fOc730FHRwf+/d//HeXl5UoXizBiSQUXCijzo24wko7Ozk4899xz8a2Vn3jiCaWLRBiQ9d1iNDCfHOoGI+l67LHHcOrUKWzYsAG/93u/h7/5m7+hBJgke1su1EpJHnWDESkIgoAf/vCH+OpXv4r6+nq89tpraGhoULpYRCFZ1XKhVkpqaO0KkRLHcfjTP/1THD16FJFIBDt27MCbb75JCTCXKNW3XKiFkh4KKkROgUAAf/d3f4c33ngDv/Ebv4Hvfve7MBqNSheLZFBWBBe6QaaO6o2kK5VbRjgcRigUQkFBQcrnG52f6pajdAEWi07A9FC9kXT967/+K9atWyfb5/v9fvh8Pjz77LOyHYPIT/XBhRCSWYODg3jxxRdl+ewTJ07gW9/6Fqqrqym4qFxWDegTQuTHcRy0Wq3kf36/H5/61Kfwla98RemvSCRAwYUQojhBEPCxj30MTz31FHbt2qV0cYgEqFuMEKK4t956C4ODgzhy5AiNB2YJCi6EEEW53W587nOfwy9+8Qvk5NAtKVtQtxghRDE8z+Opp57C7/zO76C9vV3p4hAJUXAhhCiC53l84QtfQDAYxEs
"text/plain": [
"<Figure size 500x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ad2e8d6f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmWUlEQVR4nO3dfWxc1Z3G8efOq2fssccvcZyQGuzUQKB0kaoVsC9dEAUjsarK8s+qS1sB21YUSNsVkDZQkrShbNvtkhQVSquiJStUlVUQK23opqJVS6Vtq7YSJUAbUlIS8uLEdjwztmfGM3Pv2T/SeztjO4kdX3vevh8Jqb7GcBru9TPnnN/5XcsYYwQAgI8C1R4AAKDxEC4AAN8RLgAA3xEuAADfES4AAN8RLgAA3xEuAADfES4AAN8RLgAA3xEuAADfES4AAN8RLgAA3xEuAADfES4AAN8RLgAA34WqPQCgHhhjND4+rqmpKbW1tam7u1uWZVV7WEDNYuYCnEUqldLOnTs1NDSkVatWaWBgQKtWrdLQ0JB27typVCpV7SECNcniTZTA/Pbu3atbb71V2WxW0unZi8udtcTjce3evVvDw8NVGSNQqwgXYB579+7VzTffLGOMHMc5498XCARkWZb27NlDwABlCBdgllQqpXXr1imXy501WFyBQECxWExHjhxRMplc/gECdYA9F2CWZ555RtlsdkHBIkmO4yibzWrXrl3LPDKgfjBzAcoYYzQ0NKSDBw9qMY+GZVkaHBzUgQMHqCIDRLgAFcbGxrRq1aol/Xx3d7ePIwLqE8tiQJmpqakl/fzk5KRPIwHqG+EClGlra1vSzycSCZ9GAtQ3wgUo093drfXr1y9638SyLK1fv15dXV3LNDKgvhAuQBnLsnTvvfee189u3LiRzXzgT9jQB2bhnAuwdMxcgFmSyaR2794ty7IUCJz9EXFP6D///PMEC1CGcAHmMTw8rD179igWi8myrDnLXe61WCymF198UTfeeGOVRgrUJsIFOIPh4WEdOXJEO3bs0ODgYMX3BgcHtWPHDh09epRgAebBnguwAMYY/exnP9OHPvQhvfDCC/rbv/1bNu+Bs2DmAiyAZVlKJpMKBAJKJpMEC3AOhAsAwHeECwDAd4QLAMB3hAsAwHeECwDAd4QLAMB3hAsAwHeECwDAd4QLAMB3hAsAwHeECwDAd4QLAMB3hAsAwHe03AcWqFQqKZ1Oq6OjQ6FQqNrDAWoa4QIskDFGtm0rGAzSch84B8IFAOA79lwAAL5j4Rg1w7Zt/epXv9Lk5GS1h1L3/uIv/kK9vb3VHgaaGMtiqBm5XE4333yz8vm82traqj2cumSM0cGDB/XNb35TN910U7WHgybGzAU1wxijQCCg7du366//+q+rPZy6YIzRD37wAw0MDGjDhg0qFov6x3/8R/GZEdVGuKDmhEIhRaPRag+j5hlj9Morr+jzn/+8gsGgNm3apFtuuYVKNtQEwgWoU+l0Wlu2bNHY2Jgk6XOf+5xs267yqIDTqBYD6pBt23riiSf0i1/8wrs2NDSk66+/voqjAv6McAHqjDFGL7/8sr797W/LcRxJUltbm7Zu3aq+vr4qjw44jXAB6ogxRidOnNC2bds0NTUlSQoEAvr4xz+u97///VUeHfBnhAtQR4rFor761a/qtdde865dffXV+tSnPqVgMFjFkQGVCBegThhj9D//8z967rnnvGs9PT3aunWrkslk9QYGzINwAeqAMUbHjh3TV77yFeXzeUmnS7Y/85nP6Morr6T8GDWHcAHqQKlU0mOPPaa33nrLu/aBD3xAH/nIRxQI8Bij9nBXAjXOGKOf/OQnFcthvb292rx5s+LxeBVHBpwZ4QLUuPHxcT366KPKZrOSpGAwqLvvvluXXnopy2GoWYQLUMNs29ZTTz2lffv2edeuueYaffSjH2U5DDWNuxOoUcYY/frXv9bTTz/tNaJMJpPavHkzXaNR8wgXoEZlMhk98sgjSqfTkiTLsnTnnXfqfe97H8thqHmEC1CDHMfRf/zHf1T0Dnvve9+rT3ziEyyHoS5wlwI1xhij3/72t3riiSe83mGtra168MEH1dXVxawFdYFwAWpMJpPRtm3bND4+Lun0cthtt92m97///QQL6gbhAtQQ27b1rW99S//3f//nXbvssst077330jsMdYVwAWqEMUY//elP9a1vfatiOeyhhx7S6tWrmbWgrhAuQA0wxujQoUN66KGHNDk5Ken0ctjtt9+u6667jmBB3SFcgCozxiiTyWjz5s168803vevXXHONNm7cqFCIt5Gj/hAuQJXl83lt375dL730kndtzZo12r59uzo7O6s4MuD8ES5AFeXzeX3961/Xf/7nf3r7LPF4XNu2bdMVV1zBchjqFvNtoAqMMcrlcvrqV7+qp556SqVSSdKf39HywQ9+kGBBXSNcgBVmjNH4+Li2bdum5557TrZtSzrd7fj222/XXXfdxT4L6h53MLCCjDHat2+fNm/erF/+8pdeQ8pgMKjbbrtNDz74oFpaWqo8SmDpCBdgBRhjlM1m9b3vfU///u//rpMnT3rfi0Qi+ud//mdt2rRJ8Xic5TA0BMIFWEbGGDmOo9/+9rf613/9V7388sve/ooktbe36/7779cdd9yhaDRaxZEC/iJcgGVijNHIyIiefPJJPfvss17rfNfQ0JAeeeQR/d3f/R2tXdBwCBfAZ+4S2AsvvKCdO3fq4MGDFd9vaWnRLbfcok2bNumCCy5gGQwNiXABfOQ4jl555RU9+uij+tnPflaxBGZZli699FI98MADGh4eVjgcJljQsAgXwAfGGE1NTek73/mOnnzySU1MTFR8v7u7W3fccYfuuOMO9fT0ECpoeIQLsETGGP3hD3/Qgw8+qJ/85CfeSXtJikajuummm/Qv//Iv2rBhA2+RRNMgXIAlsG1bP/7xj/W5z31Ohw4d8q5blqUNGzbogQce0A033KBIJMJsBU2FcAHOU6FQ0K5du/TlL39ZmUzGux6Px/WRj3xEn/70p7Vq1SpCBU2JcAEWya0G+/rXv66nnnpKMzMz3vf6+/v1xS9+UTfddJOCwSDBgqZFuACLYIxROp3Www8/rO9///teXzDLsnT11Vfra1/7mi655BJCBU2PcAEWyBij0dFR3X///XrxxRcr+oLdcsst+tKXvkQlGPAnhAuwAMYYHT16VJ/+9Kf18ssve8Hi9gV74IEH1NraSrAAf0K4AOdgjNFbb72lu+++W7/5zW+867FYTPfff78++clPUg0GzEK4AGfhtsi/++679bvf/c673tbWpq1bt+q2227j3SvAPHgqgDMwxujnP/+5Nm7cqLffftu73tXVpUcffVQf+tCHaDgJnAHhAszDcRz97//+r+67776Kd6/09fXpscce0/XXX89pe+AseDqAWQqFgp555hndc889FcFy0UUX6dvf/rY+8IEPECzAOTBzAf7EGKPJyUn927/9m7773e9WHI68/PLL9fjjj+uKK65g4x5YAMIF0J8rwr7whS/oRz/6UUXzyb/6q7/Sjh07NDAwQLAAC0S4oKkZY1QsFrVnzx5t3769ovlkMBjUBz/4QT3yyCP0CAMWiXBB0zLG6MSJE/rKV76i//qv/1I+n/e+F4/Hddddd+nee+/lcCRwHggXNKVSqaSXX35ZW7du1RtvvFHxvXXr1mnr1q36+7//e86wAOeJJwdNxRijVCqlxx9/XN/97nc1PT3tfS8YDOr666/X1q1bNTQ0xGwFWALCBU3DGKNXX31VDz30kH7xi194/cGk068h/sxnPqOPfvSjisfjBAuwRIQLmkKpVNILL7ygLVu26MSJE971QCCgq666Stu2bdOVV17J+RXAJ4QLGpoxRrlcTt/4xjf0zW9+U7lczvteW1ubPvGJT+hTn/qUOjo6mK0APiJc0LCMMcpkMtqyZYu+973veS/2kqShoSFt375d1157Lf3BgGVAuKAhuRv3DzzwgF544QVvfyUQCOiGG27Ql7/8ZfX39zNbAZYJ4YKG47ZxmR0s4XBYd955pzZt2qS2tjaCBVhGhAs
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = model.prune()\n",
"model(dataset['train_input'])\n",
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f2dc6ceb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 3.72e-03 | test loss: 3.45e-03 | reg: 3.42e+00 : 100%|██| 20/20 [00:03<00:00, 5.13it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2a302028",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqyElEQVR4nO3de4xc1X0H8O+Z587szuzszr5sr8H7sknsgBMrpWrU0qRgI5y0URxFgkaqWtGmCcUFxQkFlKCShBZDUpsoSpo0MbZAiaoYWopBVsIjqhSiCBEwj9RevH6t7d31zM5zZ3Ze9/QPcy/3zK7tfdzZe2fm+5EseWa99vHaZ7/3vH5HSCkliIiILOSyuwFERNR4GC5ERGQ5hgsREVmO4UJERJZjuBARkeUYLkREZDmGCxERWY7hQkRElmO4EBGR5RguRERkOYYLERFZjuFCRESWY7gQEZHlGC5ERGQ5hgsREVnOY3cDiOqBlBLxeBzZbBZtbW2IRqMQQtjdLCLH4siF6DKSyST27t2LkZERdHd3Y2BgAN3d3RgZGcHevXuRTCbtbiKRIwneREk0v8OHD2PHjh3I5XIALo5edPqoJRgM4uDBg9i2bZstbSRyKoYL0TwOHz6M7du3Q0oJTdMu+etcLheEEDh06BADhsiE4UJUJZlMor+/H/l8/rLBonO5XAgEAhgfH0ckEql9A4nqANdciKrs378fuVxuQcECAJqmIZfL4cCBAzVuGVH94MiFyERKiZGREYyNjWExXUMIgcHBQYyOjnIXGREYLkSKWCyG7u7uZX1+NBq1sEVE9YnTYkQm2Wx2WZ+fyWQsaglRfWO4EJm0tbUt6/NDoZBFLSGqbwwXIpNoNIqhoaFFr5sIITA0NITOzs4atYyovjBciEyEELjzzjuX9Lk7d+7kYj7Re7igT1SF51yIlo8jF6IqkUgEBw8ehBACLtflu4h+Qv+pp55isBCZMFyI5rFt2zYcOnQIgUAAQog50136e4FAAM899xy2bt1qU0uJnInhQnQJ27Ztw/j4OPbs2YPBwUHlY4ODg9izZw/Onj3LYCGaB9dciBZASomXXnoJf/Znf4YXXngBH//4x7l4T3QZHLkQLYAQwlhTiUQiDBaiK2C4EBGR5RguRERkOYYLERFZjuFCRESWY7gQEZHlGC5ERGQ5hgsREVmO4UJERJZjuBARkeUYLkREZDmGCxERWY7hQkRElmO4EBGR5Vhyn2iBpJSQUs57eRgRqThyIVoEhgrRwnjsbgBRvWCwEC0cRy5ERGQ5jlzIMSqVCl555RWk02m7m1L3tmzZgt7eXrubQU2MC/rkGPl8Hn/8x3+MfD6PUChkd3PqkpQSo6OjePzxx/Hnf/7ndjeHmhhHLuQYUkq4XC585zvfwZ/+6Z/a3Zy6cfr0abz00ku47bbbAAC33HIL+MxIdmO4kON4PB74/X67m+F4mqbhlVdewZe+9CWMjo7C6/Xitttu48YDcgSGC1GdkVKiWCziiSeewP3334+pqSkAwL333ovBwUGbW0d0EcOFqI5IKXH27Fk88MADePLJJ1EsFo2PCSEQj8dtbB3R+7gVmagOSClRKBTw85//HDfddBP27dunBMvmzZvx9NNPY/v27Ta2kuh9HLkQOZymaXjnnXfw0EMP4emnn0ahUDA+5vF4sGPHDuzevRv9/f0olUo2tpTofQwXIoeSUuL8+fP493//d/zgBz/AhQsXlI/39PTg/vvvx+23346WlhYu5JOjMFyIHEZKiVQqhSeffBJ79+7F8ePHla3FHo8HN910E771rW/h2muvhcvF2W1yHoYLkUNIKZHP5/Hss89i9+7deP3116FpmvJrhoaGcM899+DWW29FMBjkaIUci+FCZDMpJcrlMv73f/8X//qv/4pf/epXc9ZOOjs78Td/8zfYuXMn1qxZw1Ahx2O4ENlI0zS8/fbb2L17N55++mnkcjnl462trfj0pz+NXbt2YdOmTXC73Ta1lGhxGC5ENtAX67/3ve/hRz/6EWKxmPJxn8+HT3ziE7jnnnvwR3/0R/B4PBytUF1huBCtICklcrkc/vM//xMPP/wwRkdHlcV6l8uFLVu24Ktf/SpuueUW7gKjusVwIVoBUkqjFtiDDz6Il19+GeVyWfk1Q0NDuPvuu3Hbbbehvb2doUJ1jeFCVGN6yZZHH30Ujz/++Jz7aqLRKG6//Xb8wz/8A1avXs1QoYbAcCGqESklZmdn8fOf/xzf/OY3MTo6qnzc7/fjU5/6FO677z6eV6GGw3AhqgF9F9gDDzyAQ4cOKVuLhRDYvHkzvv71r+Pmm2+Gz+fjaIUaDsOFyEJSSmSzWfzoRz/CI488gsnJSeXjPT09uPPOO/H3f//36OzsZKhQw2K4EFlE0zS89tpruPfee/Hyyy+jUqkYH/N6vfjkJz+JBx54AJs2beIUGDU8hgvRMkkpkclk8P3vfx+PPvronDtVRkZG8LWvfQ2f/exn4ff7OVqhpsBwIVoGTdNw5MgR3HPPPXjhhReUWmCBQACf//zncf/992Pt2rUMFWoqDBeiJdB3gu3btw/f+MY35qytbNq0Cd/61rdw880383Q9NSWGC9EiSSlx4sQJ/NM//RP+67/+SzkMGQwG8dd//de477770NfXx1ChpsVwIVqEcrmMQ4cO4atf/eqccyvXXHMN/uVf/gW33HILRyvU9BguRAugX+C1e/dufPe738XMzIzxMZ/Ph1tvvRUPPvgg+vv7GSpEYLgQXZGUEr///e9x11134cUXX1QW7VevXo1vfOMbuO2223gYksiE4UJ0GeVyGf/93/+NXbt24dSpU8b7LpcLn/jEJ/Dtb38bmzZtYqgQVWG4EM1DP2n/yCOP4N/+7d+UabDW1lb84z/+I77yla8gHA4zWIjmwXAhqiKlxOnTp3H33XfjmWeeUabBBgYG8J3vfAfbt2+Hx8PuQ3Qp7B1EJlJK/PrXv8aXvvQlvPnmm8b7LpcLN910E/bs2YP169dztEJ0BSxwRPSecrmMJ554Ajt27FCCpaWlBV/+8pfxs5/9jMFCtEAcuVDTk1Iin89j9+7dePTRR5HL5YyP9fb24uGHH8att97KsytEi8BwoaYmpUQ8HseuXbvw5JNPKpWMr7vuOnz/+9/H9ddfz1AhWiSGCzUtKSVOnjyJL3zhC3jhhRcgpQRwcX1l+/bteOyxx3DVVVcxWIiWgGsu1JSklHj99dfxmc98Br/85S+NYPH5fLjzzjtx4MABBgvRMnDkQk1HSomXXnoJt99+O06ePGm8HwqF8M///M/44he/yNP2RMvEcKGmomkannnmGXzxi19UyuT39vbisccew2c+8xm43W4bW0jUGBgu1DQqlQp+9rOfYefOnUgkEsb7Q0ND+I//+A/8yZ/8CUcrRBbhmgs1hUqlgv379+OOO+5QgmXz5s04ePAgg4XIYhy5UMPTg+Xuu+9GJpMx3v/Yxz6Gxx9/HIODgwwWIotx5EINTdM0PPnkk3OC5cYbb8RPf/pTBgtRjTBcqGFpmoann34ad911lxIs27Ztw/79+7FmzRoGC1GNcFqMGpKUEi+++CLuuOMOJJNJ4/2tW7di37596O3tZbAQ1RBHLtRw9AOSf/d3f4epqSnj/RtuuAE//vGPGSxEK4DhQg1FSokzZ87gb//2b5UDklu2bMFPfvITrF69msFCtAIYLtQwpJRIp9PYuXMnXnvtNeP94eFh7Nu3D+vWrWOwEK0Qhgs1jHK5jIceegjPPvus8V5PTw9++MMfYuPGjQwWohXEcKGGIKXET3/6U3z3u981riUOBoN45JFHeECSyAYMF6p7Ukq8+uqruPfeezE7OwsAcLvd+PKXv4xbb70VLhf/mxOtNPY6qmtSSsRiMdx11104f/688f5f/MVf4Ctf+Qo8Hu62J7IDw4XqWqlUwoMPPojf/OY3xnubNm3Ct7/9bbS2ttrYMqLmxnChuiWlxMGDB/HjH//YuOwrEolgz549vOiLyGYMF6pLUkocO3YM999/v7LOcs899+CGG25gsBDZjOFCdSmXy+G+++5TDkp+8pOfxB133MH
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7cbec774",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fixing (0,0,0) with x^2, r2=0.9999921393183026\n",
"fixing (0,1,0) with x^2, r2=0.9999940727994734\n",
"fixing (1,0,0) with sqrt, r2=0.9998914314178492\n"
]
}
],
"source": [
"model.auto_symbolic()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e14000d8",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.01 \\sqrt{x_{1}^{2} + 1.0 x_{2}^{2}} - 0.01$"
],
"text/plain": [
"1.01*sqrt(x_1**2 + 1.0*x_2**2) - 0.e-2"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.symbolic_formula()[0][0]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "031fabd6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: nan | test loss: nan | reg: nan : 100%|█████████████████| 20/20 [00:03<00:00, 5.70it/s]\n"
]
}
],
"source": [
"# will give nan, it's a bug that should be resolved later. \n",
"# But happy to see the above already give a formula that is close enough to ground truth\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lr=1e-3, update_grid=False);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}