460 lines
40 KiB
Plaintext
460 lines
40 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Example 9: Singularity"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2571d531",
|
|
"metadata": {},
|
|
"source": [
|
|
"Let's construct a dataset which contains singularity $f(x,y)=sin(log(x)+log(y))\n",
|
|
" (x>0,y>0)$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "2075ef56",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda\n",
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 1.14e-01 | test_loss: 1.29e-01 | reg: 6.34e+00 | : 100%|█| 20/20 [00:03<00:00, 5.03it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"print(device)\n",
|
|
"\n",
|
|
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
|
|
"model = KAN(width=[2,1,1], grid=5, k=3, seed=2, device=device)\n",
|
|
"f = lambda x: torch.sin(2*(torch.log(x[:,[0]])+torch.log(x[:,[1]])))\n",
|
|
"dataset = create_dataset(f, n_var=2, ranges=[0.2,5], device=device)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3f95fcdd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtEElEQVR4nO3deVBUV74H8O9pmn0RRFwQdWjouGMGF1AERlEwMTqOZqLJxHrGOJNRg5Xt6cSYp/FpzExMiVuciXmT0SQvasRoFNxGfYC7QXGLoogoS0DAbqRplqb7vj8iXWBcUC709v1UpaaK27fvD8bT3z7n3HuOkCRJAhERkYwUli6AiIjsD8OFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZKS1dAJEtkCQJ5eXl0Ol08PLygr+/P4QQli6LyGqx50L0EFqtFitXroRarUZAQACCg4MREBAAtVqNlStXQqvVWrpEIqskuBMl0f3t3bsXkyZNgl6vB/Bz76VBQ6/Fw8MDycnJSEhIsEiNRNaK4UJ0H3v37sXYsWMhSRJMJtMDX6dQKCCEQEpKCgOGqBGGC9E9tFotgoKCUF1d/dBgaaBQKODu7o6CggL4+vq2foFENoBzLkT32LBhA/R6fbOCBQBMJhP0ej02btzYypUR2Q72XIgakSQJarUaubm5eJymIYSASqXC1atXeRcZERguRE2UlZUhICCgRef7+/vLWBGRbeKwGFEjOp2uRedXVlbKVAmRbWO4EDXi5eXVovO9vb1lqoTItjFciBrx9/dHSEjIY8+bCCEQEhKC9u3bt1JlRLaF4ULUiBACiYmJT3TunDlzOJlPdBcn9InuwedciFqOPReie/j6+iI5ORlCCCgUD28iDU/ob9u2jcFC1AjDheg+EhISkJKSAnd3dwghfjHc1fAzd3d3pKamIj4+3kKVElknhgvRAyQkJKCgoABJSUlQqVRNjqlUKiQlJaGwsJDBQnQfnHMhagZJknDo0CHExcXhwIEDGDFiBCfviR6CPReiZhBCmOdUfH19GSxEj8BwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhegSDwYDCwkJcunQJAHDt2jXcvn0bJpPJwpURWS9uc0z0AFqtFsnJyfj6669x8eJFVFZWoq6uDm5ubggICEB0dDReffVVREVFQalUWrpcIqvCcCG6j2PHjuHNN9/EuXPnMHjwYIwdOxZhYWHw8vKCVqtFZmYmdu7ciZycHEyePBlLlixBQECApcsmshoMF6J77Nu3D9OmTYOXlxeWLVuGZ599FnV1ddi0aRNqa2vh4+ODKVOmwGAwYNOmTVi0aBH69u2LL7/8Ep06dbJ0+URWgeFC1MiVK1cwZswYeHp6YtOmTejTpw+EEMjNzUV4eDgqKioQHByMzMxM+Pn5QZIkHD58GC+99BJ+85vf4PPPP4erq6ulfw0ii+OEPtFdRqMRH374ITQaDdasWWMOlocRQmD48OH429/+hh07dmDPnj1tVC2RdWO4EN2Vk5ODnTt3YuLEiRg+fPgjg6WBEAITJkxAZGQk1q9fj/r6+laulMj68RYXoruOHj0KnU6HSZMmIS8vD1VVVeZjBQUFMBqNAIC6ujpcvHgRPj4+5uOBgYGYOHEiFi1ahOLiYgQFBbV5/UTWhOFCdNfly5fh4eEBlUqF1157DUeOHDEfkyQJtbW1AICioiKMHj3afEwIgU8++QT9+/eHXq9HUVERw4UcHsOF6K7q6moolUq4urqitrYWNTU1932dJEm/OFZfXw93d/cmIUTkyBguRHd17NgR1dXV0Gq1iIiIgKenp/lYdXU1jh49ag6RYcOGmR+cFEKge/fuuHXrFhQKBfz8/Cz1KxBZDYYL0V0DBw6EwWDAyZMn8de//rXJsdzcXAwePBgVFRXo1KkTNm/eDF9fX/NxIQTmz5+Pzp07c0iMCLxbjMhsyJAhUKlU2LBhA6qqquDk5NTkvwZCCCgUCvPPFQoFfvrpJ2zduhVjx45Fu3btLPhbEFkHhgvRXf7+/nj99ddx+vRprFq1qtm3FNfW1uK///u/UV1djddee63ZtzAT2TMOixE1Mm3aNKSnp+Ovf/0rPDw8MHPmTLi5uQEAlEollEqluRcjSRIqKyuxdOlSbNq0CStWrEDPnj0tWT6R1eDyL0T3KC0txezZs7Fr1y4kJCTgzTffRO/evZGdnQ2TyQQXFxeEhobi5MmTWL58ObKysrB48WLMnDmzyfAZkSNjuBDdR1VVFdavX49Vq1ahpKQEKpUKarUa3t7e0Gg0yM7ORlFREQYOHIiFCxciNjYWCgVHmYkaMFyIHqK4uBgHDhxAWloazp49i5MnTyI6OhpRUVGIj49HREQEPDw8LF0mkdVhuBA106lTpzBkyBCcOnUKgwYNsnQ5RFaN/XiiZnJycjLfhkxED8dWQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7LifC1EzSZIEk8kEhUIBIYSlyyGyauy5ED0G7uVC1DxKSxdAJAeDwYCbN2/CZDJZupQWE0Kge/fucHFxsXQpRE+M4UJ2oaCgALNmzcLAgQMtXcoTkSTJPNSWmZmJTz/9FCEhIRauiujJMVzILkiShLCwMCxdutTSpTyRdevW4cSJE4iPj0d9fT04FUq2juFCdsfWJttNJhMOHDiA7du3w9PTE15eXpYuiajFODtJZGF1dXW4efMmAEClUtlcOBLdD8OF6D4kSYJOp8PZs2dRXl7eqsNUOp0Ot27dghACwcHBrXYdorbEcCG6j9u3b2PKlCkYNmwYRo8ejcuXL7dawJSXl0Or1cLZ2Rndu3dvlWsQtTWGC9E9JEnCp59+ij179qCmpgZnz57FggULYDAYWuV6N2/eRHV1Nby9vdGlS5dWuQZRW2O4EN2jrKwMGzduhCRJiIyMhFKpxL59+3DmzBnZryVJEnJycmA0GtGpUyf4+fnJfg0iS2C4EDUiSRIOHz6MGzduwNfXF0lJSejbty/0ej2+/fbbVhkay87OBgD06NEDbm5usr8/kSUwXIgakSQJqampMBqNGDRoEJ5++mlMmDABALB3717odDpZr2cymZCTkwMACA0N5fIyZDf4L5moEZ1OhyNHjgAAEhISoFQqMWbMGLi5uSE3NxeXLl2S9Xo1NTXIy8sDAPTq1UvW9yayJIYLUSPXrl3DzZs34erqiuHDh0MIgV69eiE4OBg1NTXIyMiQdWhMo9Hg1q1bcHJy4nIvZFcYLkR3SZKEzMxMVFdXo2vXrlCr1QAAb29vREREAADS0tJkXRyzqKgIlZWVcHd3523IZFcYLkSNHDt2DADQv39/+Pj4mH8eGxsLIQTOnTsHjUYj2/Vyc3NhMBjQvn17BAQEyPa+RJbGcCG6q+GZFgAYMmSIeXJdCIGBAwfCw8MDJSUluHr1qizXkyQJV65cgSRJCAoK4ppiZFcYLkR33bp1Czdu3ICTkxPCw8ObrPHVvXt3dOvWDXV1dTh16pRs8y4NtyGrVCo4OzvL8p5E1oDhQnTX1atXUVFRAR8fHzz11FNNjnl6emLAgAEAgOPHj8sSLnV1dbh+/ToAoGfPni1+PyJrwnAhws9DVOfOnYPRaES3bt3QsWPHJseFEOZJ/fPnz0Ov17f4mjqdDoWFhRBCQK1WczVksisMF6K7srKyAPz8vMm9T8o3zLs4OzsjPz8fhYWFLb7erVu3oNFo4OLiwtWQye4wXIgA1NbW4vLlywCAAQMG3LcXERoaCn9/f+h0Oly8eLHF17xx4waqq6vh4+ODwMDAFr8fkTVhuBDh5yX28/PzoVAo0K9fv/uGS/v27REaGgpJklo8qS9JErKzs2EymdC5c2f4+vq2oHoi68NwIcLPy95rtVp4eHggNDT0vq9xdnZGeHg4ACAzMxNGo7FF1zx//jwAQK1Wc8FKsjsMF3J4Db2Iuro6dOzYEZ07d37ga4cMGQIhBLKzs6HVap/4mgaDwTwM17dvX07mk91huBABuHDhAgAgODj4gQ8zCiHQr18/uLu7o7S0FLm5uU98PY1Gg7y8PAgh0L9/f4YL2R2GCzk8k8lkXu24d+/ecHJyeuBru3fvjsDAQNTW1iIrK+uJ511u3LiB27dvw8PDg6shk11iuJDD0+v1uHbtGgCgX79+D32tl5cX+vbtCwA4ceLEE11PkiT8+OOPqK2tRadOndC1a9cneh8ia8ZwIYdXWlqKkpISKJVK9OzZ86FDVAqFwvww5ZkzZ1BdXf1E1zx9+jSAnyfzuaYY2SOGCzm8vLw86HQ6eHt7o0ePHg99rRACgwYNglKpRF5eHoqKih77evX19eY5nrCwsIcOwxHZKoYLOTRJknD58mUYjUZ06dIFHTp0eOQ5vXv3hr+/PyorK3Hu3LnHnnfRarXIycmBEOIXC2QS2QuGC9kdSZIe6wO/oRcREhLSrOdNAgIC0KtXL5hMJvOWyI8jJycHpaWl8PDwMM/fENkbhgvZFb1ej3/+85/QaDTNCpj6+nrz8yZ9+vQx7+HyMEqlEkOHDgXw8wrJdXV1za6vYbfLuro6BAUFcfdJslsMF7IbBoMBc+fOxaxZszB79mxUVVU98hydToe8vDwAzX+YUQiB4cOHw8nJCdnZ2Y+1iKUkSebdLgcMGABPT89mn0tkSxguZDecnJwQEhIChUKBrVu3Yu3atY/c7764uBhlZWVwcXH5xR4uDxMWFgZ/f39otVpkZmY2exiuqqrKvNvl0KFDOd9CdovhQnZDoVDg9ddfxx//+EeYTCYkJSU9ckviq1evorq6Gn5+fujWrVuzr9WxY0eEhYVBkiQcOHCg2efl5eXh5s2bcHFxwaBBgxguZLcYLmRXnJ2dMW/ePKhUKty6dQufffbZA3svDQ8zmkwmdOvW7bFWJlYqlRg5ciQA4MiRI9DpdI88R5IknDhxAlVVVejSpQt3nyS7xnAhuxMYGIhXX30VALB161b89NNP931dw+6TwM/bDLu6ujb7GkIIjBgxAm5ubsjNzTUvH/MwkiTh4MGDAIDw8HAus092jeFCdkcIgcmTJ6Njx44oLCxESkrKfedEampqmmwQ9rh69+4NlUqFmpoa7N+//5HzLhqNxrxkTFxcXLPuTCOyVfzXTXape/fuiI+PhyRJ2Lx5831vFy4rK0NBQQGcnJweuEHYw3h5eSEuLg4AkJqaipqamge+VpIknD17FoWFhfDy8kJMTAznW8iuMVzILikUCkyZMgVKpRI//PCDuYfSWE5ODioqKuDt7Q21Wv1E1xk3bhycnZ1x/vz5Rw6N7d69GwaDAb169YJKpXqi6xHZCoYL2SUhBIYOHYqQkBDodDrs2LGjybCVJEnIyspCfX09goKC0KlTpye6xqBBg6BWq1FVVYXvvvvugUNjd+7cwZ49ewAAzzzzDHeeJLvHcCG71a5dOzz33HMAgO3btzd5qFKSJJw8eRIAzBuAPQkfHx+MHz8eAJCcnAyNRvOL1zRM5GdnZ8PLywvjx4/nkBjZPYYL2S0hBCZOnAg3NzdcunSpycOOOp3OfKdYZGTkE3/YCyEwZcoU+Pj4ICcnB/v27ftF76W+vh5ffPEFjEYjYmJiHrlnDJE9YLiQXQsLC8OAAQNQV1eHzZs3mz/4r1+/jvz8fLi4uGDw4MEt6kn06tULcXFxMBqNWLduHfR6vfmYJEm4cOEC0tLS4OTkhFdeeQXOzs4t/r2IrB3Dheyau7s7fv/73wMAUlJSUFxcDEmScPToUej1enTt2vWxln25H6VSiVmzZsHd3R3Hjx9HcnKyOcQMBgNWrlwJnU6HXr16YeTIkRwSI4fAcCG7JoTAhAkTzM+8JCcnw2g0Yvfu3QCAiIiIFj/M2LCQ5bhx41BfX49Fixbhxx9/hCRJ2LFjB7Zs2QInJye8/vrraNeunQy/FZH1Y7iQ3evRowd+97vfQZIkfPrpp8jIyMDRo0ehUChkm1x3dnbGokWL0L17d9y4cQPPP/883njjDSQmJqK2thbx8fF4+eWX2Wshh8FwIbsnhMDs2bPRoUMHXLlyBX/4wx+g0WigUqlkG6YSQuCpp57CunXr0LlzZ1y5cgVr1qxBaWkp+vfvj08++eSJ70gjskUMF7J7Qgj07t0b7777LlxcXFBSUgIXFxe88847zdrW+HGuk5CQgNTUVMyYMQPx8fGYO3cudu7ciaeeeoq9FnIoSksXQCS3+z3IKITArFmz0LFjR+zfvx9xcXF44YUXHvj6lggLC8Pf//53mEymJuuHyX0dImsmJP6LJztw/fp1zJw5E5GRkZYupcWOHz+OtWvXIiQkxNKlED0xhgvZhbq6OuTm5sJoNFq6lBZTKBQICQmBi4uLpUshemIMFyIikh3nXIiaqfH3ME7OEz0c7xYjaqYzZ87AyckJZ86csXQpRFaP4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGC1EzSJIEjUYDANBoNOAGrkQPx3AhegitVouVK1dCrVZj1KhRkCQJo0aNglqtxsqVK6HVai1dIpFVEhK/ghHd1969ezFp0iTo9XoA99/m2MPDA8nJyUhISLBIjUTWiuFCdB979+7F2LFjIUkSTCbTA1+nUCgghEBKSgoDhqgRhgvRPbRaLYKCglBdXf3QYGmgUCjg7u6OgoIC+Pr6tn6BRDaAcy5E99iwYQP0en2zggUATCYT9Ho9Nm7c2MqVEdkO9lyIGpEkCWq1Grm5uY91R5gQAiqVClevXjXPxxA5MoYLUSNlZWUICAho0fn+/v4yVkRkmzgsRtSITqdr0fmVlZUyVUJk2xguRI14eXm16Hxvb2+ZKiGybQwXokb8/f0REhLy2PMmQgiEhISgffv2rVQZkW1huBA1IoRAYmLiE507Z84cTuYT3cUJfaJ78DkXopZjz4XoHr6+vkhOToYQAgrFw5tIwxP627ZtY7AQNcJwIbqPhIQEpKSkwN3dHUKIXwx3NfzM3d0dqampiI+Pt1ClRNaJ4UL0AAkJCSgoKEBSUhJUKlWTYyqVCklJSSgsLGSwEN0H51yImkGSJBw6dAhxcXE4cOAARowYwcl7oodgz4WoGYQQ5jkVX19fBgvRIzBciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIHsFgMKCwsBCXLl0CAFy7dg23b9+GyWSycGVE1ovbHBM9gFarRXJyMr7++mtcvHgRlZWVqKurg5ubGwICAhAdHY1XX30VUVFRUCqVli6XyKowXIju49ixY3jzzTdx7tw5DB48GGPHjkVYWBi8vLyg1WqRmZmJnTt3IicnB5MnT8aSJUsQEBBg6bKJrAbDhege+/btw7Rp0+Dl5YVly5bh2WefRV1dHTZt2oTa2lr4+PhgypQpMBgM2LRpExYtWoS+ffviyy+/RKdOnSxdPpFVYLgQNXLlyhWMGTMGnp6e2LRpE/r06QMhBHJzcxEeHo6KigoEBwcjMzMTfn5+kCQJhw8fxksvvYTf/OY3+Pzzz+Hq6mrpX4PI4jihT3SX0WjEhx9+CI1GgzVr1piD5WGEEBg+fDj+9re/YceOHdizZ08bVUtk3RguRHfl5ORg586dmDhxIoYPH/7IYGkghMCECRMQGRmJ9evXo76+vpUrJbJ+vMWF6K6jR49Cp9Nh0qRJyMvLQ1VVlflYQUEBjEYjAKCurg4XL16Ej4+P+XhgYCAmTpyIRYsWobi4GEFBQW1eP5E1YbgQ3XX58mV4eHhApVLhtddew5EjR8zHJElCbW0tAKCoqAijR482HxNC4JNPPkH//v2h1+tRVFTEcCGHx3Ahuqu6uhpKpRKurq6ora1FTU3NfV8nSdIvjtXX18Pd3b1JCBE5MoYLOby8vDykpaXhyJEj0Ov10Gq1iIiIgKenp/k11dXVOHr0qDlEhg0bZn5wUgiB7t2749atW6ivr0dOTg4GDx4MNzc3S/1KRBbHW5HJ4eTn5yM9PR1paWlIS0vDzZs3zQGRk5ODtWvXYsaMGU3Oyc3NxeDBg1FRUYFf/epX+OGHH+Dr62s+LoTA/PnzsXz5ciiVSri5uSEiIgIxMTGIjY3F4MGDeYsyORSGC9m9oqIipKWlmQPl+vXrAICwsDDExMQgJiYG0dHRMBqNGD58OPz8/LBnz54mE/YPes4F+HmYrKioCLGxsRg3bhymTZuGjIwMpKenIz09HRUVFXB3d0dkZKQ5bAYOHAgXFxeL/D2I2gLDhexOSUlJkzDJyckBAPTt29f84R4dHY327dv/4ty1a9fi7bffxoIFC/CXv/zFPPT1sHCpqanBG2+8gZ07d+LgwYPo2bOn+f2MRiPOnz9vruXw4cO4c+cOPDw8MHToUMTGxiI2Nha//vWv4ezs3AZ/HaK2wXAhm1daWmruJaSlpSE7OxsA0KtXryZh0py1v6qqqjB9+nSkpqbigw8+wMyZM+Hm5obr169jyJAh5mGxkydPwtfXF5WVlVi6dCn+8Y9/YMWKFXjllVce+v719fU4e/asOfyOHDkCnU4HLy8vDBs2zBw2AwYM4GKYZNMYLmRzysvLzcNOaWlp+PHHHwEAarXaHCYxMTFPvM5XaWkpZs+ejV27diEhIQFvvvkmevfujezsbJhMJri4uCA0NBQnT57E8uXLkZWVhcWLF2PmzJlwcnJ6rGsZDAacOXPGHDZHjx6FXq+Hj48PoqKizGHTv3//x35vIktiuJDV02q1yMjIMH8Anz9/HgCgUqmahElgYKBs16yqqsL69euxatUqlJSUQKVSQa1Ww9vbGxqNBtnZ2SgqKsLAgQOxcOFCxMbGQqFo+YIXdXV1yMzMNAfnsWPHUFNTg3bt2iE6Otr8u/br10+W6xG1FoYLWZ2KigocPnzYHCbnzp2DJEno0aOH+cM1Nja2TR5ULC4uxoEDB5CWlobc3FzU1NTAz88P/fr1Q3x8PCIiIuDh4dFq16+trcWpU6fMYXPixAnU1tbCz8/PHDaxsbHNWgeNqC0xXMjiKisrceTIEfMHaFZWFkwmE7p27Wr+8IyNjUWPHj0sWqfRaIQkSVAoFBbrNdTU1ODEiRPmv9XJkydhMBjQoUOHJmHTs2dPhg1ZFMOF2pxOp8OxY8fMPZPTp0/DaDSiS5cuTXomwcHB/IB8BL1ej+PHj5vD5ocffkB9fT06duxo/jvGxsYiNDSUf0tqUwwXanUNH4ANYdL4A7BxmPADsOV0Op35b52WltYkuBuHDYObWhvDhWTXMHTTECaNh24aHlrk0E3buHPnjrmX2HjIMSgoqEnYWHrIkewPw4VarGHSuSFMGk86Nw4TTjpbXsPNEg3DaI1vlmgcNlzVmVqK4UKPreF22YZvw8ePHzffLtsQJrxd1jZoNJomYdNwm3dwcHCTmym6dOli4UrJ1jBc6JEaP+jX8OxFw4N+w4cPN8+b8EE/21deXm6+DbzxA6qhoaHmoGnJA6rkOBgu9AuNlyhJS0sz79Do5eWFqKgo8/AJlyixf6WlpeYHWBsvrdOzZ88mYdOhQwcLV0rWhuFC5sUVGz5Ajhw5Yl5csWG9q5iYGC6uSCguLm4SNg2Lgvbp08ccNg9aFJQcC8PFAZlMJly4cME8zp6RkYGKigq4ubmZV+qNiYnhsvD0SEVFRU32xmnYzqB///7mf0fR0dFN9r4hx8BwcQCSJOHHH380fwBkZGRAo9HA1dUVERER5g8BbmhFLfWgjdjCwsLMPZuoqCi0a9fO0qVSK2O42CFJkpCdnd0kTMrKyuDs7IwhQ4aYwyQiIoJb8VKrysvLaxI2hYWFUCgUePrpp81hM2zYMHh7e1u6VJIZw8UOSJKEnJwccwNOT0/HrVu3oFQqMWjQIHOYREZGtuoii0QPI0kSrl+/bv53mpaWhuLiYjg5OSE8PNwcNkOHDoWnp6ely6UWYrjYoHsbaXp6On766adfNNLIyEh4eXlZulyi+2rOl6LY2NhWX3maWgfDxUbcuHGjSSMsKCj4xfDC0KFDm+z7TmRL7h3OTU9PR3l5OVxcXDB48GAO59oYhouNGDBgAK5evcqJUXIYJpMJly5d+sWNKF9++SWef/55S5dHj8BwsREmkwlCCK7NRQ5LkiRIksR2YCMYLkREJDuu3SEDg8GA/Px8mEwmS5fSYkIIdOvWjQ9P0mNhG6B7MVxkUFhYiDlz5iA8PNzcdbfV1YBPnz6NVatWQaVSWboUsiGFhYVITEy0mzawevVqtoEWYrjIQJIk9OvXD+Hh4di2bRtGjBiBadOmWbqsJ7JgwQJwpJQeV0MbGDBgAHbt2oWoqCjMmDHD0mU9kffee49tQAYMFxnt3bsXW7duRUVFBV5++WWbW+SRDYpaQgiB//u//8PmzZtRUVGB6dOn29wWDGwD8rHNfqsVEkIgLi4OQgicPXsWJSUlli6JqM0NGTIEAHDp0iVUVFRYuBqyJIaLjIYMGYL27dujrKwMp06d4rcgcjgDBgyAu7s7SkpKkJuba+lyyIIYLjLq0qULnn76aZhMJuzbt4/hQg6nR48eCAwMRE1NDU6fPs024MAYLjJSKpUYPXo0AJj3SCFyJN7e3ggLCwMAHDlyhOHiwBguMhsxYgQ8PT2Rn5+Ps2fPWrocojalUCgQFRUF4OdbeisrKy1cEVkKw0VmoaGh6N27NwwGA1JTU/nNjRxOZGQk3NzcUFBQYN4GmRwPw0Vmbm5uGDNmDADg3//+N+7cuWPhiojaVmhoKLp164aamhocPnyYX7AcFMNFZkIIjBkzBh4eHrh27RrOnDnDxkUOxdvbG5GRkQCAgwcPwmg0WrgisgSGSyvo3bs3+vXrB4PBgB07dli6HKI2JYTAqFGjIIRAVlYWiouLLV0SWQDDpRW4ublh/PjxAH5+ar+8vNzCFRG1HSEEIiMj0aFDB5SVleHYsWPsvTsghksrEELg2WefRbt27ZCfn4+MjAw2LnIogYGBGDRoEEwmE3bt2mUXqyXT42G4tJKQkBAMGzYMRqMRmzZt4rgzORQnJyc899xzEEIgIyODyyE5IIZLK1EqlZgyZQoUCgUyMjJ4SyY5lIa19vz9/VFcXIxDhw6x9+5gGC6tRAiBESNGIDg4GFqtFt9++y0bFzmUoKAgREdHQ5IkbN68GQaDwdIlURtiuLQif39/TJo0CQCwZcsWlJWVWbgiorajUCjw4osvQqlU4tixY7h48aKlS6I2xHBpRUIITJkyBX5+frh+/Tp27drF3gs5DCEEYmJi0LNnT+h0Onz55Zec2HcgDJdWplar8cwzz8BkMuHzzz+HTqezdElEbcbHxwdTp06FEALbtm1DXl6epUuiNsJwaWUKhQJ/+tOf4OnpiXPnzmH37t3svZDDEELg97//Pbp164aSkhKsX7+evRcHwXBpZUII/PrXv8bo0aNhNBqxevVqrhRLDqVLly6YPn06hBD46quvcOnSJX7BcgAMlzagVCoxZ84ceHp6IisrC1u3bmXjIochhMC0adOgVqtRVlaGDz/8EHV1dZYui1oZw6UNCCEwaNAgTJgwAUajEUlJSVxviRxKx44dMW/ePDg7O2Pnzp28Nd8BMFzaiJOTE9566y0EBAQgJycHq1at4lP75DCEEJg4cSKee+45GAwGLFy4EOfPn2fA2DGGSxsRQqBnz56YNWsWhBD45z//yTXHyKG4urpi6dKlCAkJQVFRERITE1FcXMw2YKcYLm2o4c6xwYMHo7KyEvPnz0dJSQkbFzkEIQR+9atfYcWKFWjXrh1OnjyJxMREaDQatgE7xHBpY76+vli2bBn8/PyQlZWFhQsXora21tJlEbWJhjXHli5dCjc3N6SmpuL1119nwNghhksbE0IgIiICf/nLX6BUKvHNN99gzZo1qK+vt3RpRG1CoVDgP/7jP/D+++/DxcUF27dvx/Tp01FYWMiAsSMMFwtQKBT44x//iJdffhlGoxHLli3DV199xYfLyGEolUokJibigw8+gJubG/bu3YtJkybhxIkTbAd2guFiIa6urliyZAni4+NRXV2NefPm4auvvuIdZOQwnJ2dMXv2bKxevRodOnTAuXPn8Pzzz2PNmjWoqqpiL8bGMVwsRAgBPz8/fPrpp4iOjkZlZSXefvttrFu3jg+YkcNQKpV46aWXsGXLFoSFhaG8vBzz58/H5MmTkZmZyV6MDWO4WJAQAp07d8a//vUvxMXFoaqqCgsWLMD8+fNRUVHBb27kEIQQiIyMxI4dO/CnP/0JLi4uOHDgAMaNG4eFCxfyjkobxXCxsIaA+eKLLzB58mQYjUasW7cOf/jDH3DlyhU2KnIIDe3gk08+wf/+7//i6aefRkVFBZYvX46EhAR8++23qKmpYXuwIQwXKyCEgL+/P9auXYt58+bB3d0dBw8exPjx4/Hdd9/BYDCwUZFDUCqVSEhIQEpKCv7rv/4LAQEByM7OxowZM/Dyyy8jKyuLQ2U2guFiJYQQ8PDwwLvvvosvvvgCPXr0QH5+PmbMmIH//M//5NAAOQwhBNq3b4+5c+di9+7d+N3vfgchBFJSUjB27FgsWLAABQUFbA9WjuFiZZycnPDcc8/h+++/xzPPPAODwYDPPvsM48aNQ2pqKnsx5DAUCgX69OmDf/3rX9i4cSP69+8PrVaLFStWYNSoUVi1ahVu3brF9mClGC5WSAiB0NBQbNy4EUuWLEGHDh1w8eJFTJ06FX/+85+RnZ3NoQFyCEIIuLi4YPz48di9ezc++OADBAYG4saNG3j33XcxcuRIfPzxx7h+/TqMRiODxoowXKyUEAKenp5ITEzE999/j9GjR6O+vh7ffPMN4uPj8f777+PatWswmUxsUGT3GuYl33nnHfz73/9GYmIi/P39ce3aNSxcuBAxMTGYPn06du/ezaVkrATDxcopFAoMGDDAvEyMWq1GeXk5VqxYgREjRmD27NlIS0vDnTt3IEkSGxXZNSEEgoOD8dFHH+HQoUN455130KNHD9y+fRtbtmzBCy+8gOjoaMybNw/Hjx+HXq9nm7AQhosNaJjsnzp1Kvbv34/FixcjODgY5eXl2LBhA377298iNjYWb7/9Nvbv34/S0lL2aMiuKRQKhIaGYvHixUhPT8f69esRHx8PT09PXLt2DatXr8aYMWMQHx+P1atXIz8/n0PJbYzhYkOEEOjYsSPeeustHDp0CKtXr0ZUVBRcXFyQnZ2Nv//975g0aRKioqIwffp0bN++HWVlZQwZslsNbeLFF1/E1q1bkZ6ejiVLlmDQoEEQQiAzMxPz5s1DTEwM5s2bh8uXL/OLVxthuNighgb1yiuvYNeuXTh48CCWLl2KmJgYeHt7o7CwEFu2bMHUqVMRExODhQsXmudniOyREALOzs7o2bMn3nrrLezbtw979uzBrFmz0LVrV5SUlGDNmjWIi4vD3LlzkZeXx4BpZQwXGyaEgKurK/r374833ngD33//PQ4fPox169Zh7Nix8PX1RV5eHpYvX464uDh88MEH+Omnn9ioyK4JIeDu7o6IiAh8/PHHSE9Px0cffYSnnnoKGo0Ga9euxahRo7By5UpotVq2h1bCcLETDbdsqlQqTJ06Fd988w3S0tKwcOFCBAcHo7S0FB9//DFGjRqF//mf/zHfAEBkzxQKBQIDA5GYmIgDBw7go48+Qo8ePVBUVIT33nsPv/3tb5GRkcHVyFsBw8UOCSGgVCoREhKCuXPn4uDBg3j//ffRqVMnXL9+HW+88QaeeeYZfP3115z8J4cghECHDh2QmJhovpXZ09MTJ0+exPPPP4/FixdDq9Vauky7wnCxc0IIdOrUCfPmzcP+/fvxyiuvwNPTE2fOnMGf//xnDB8+HC+++CLeeustbNu2zdLlErUqIQSCgoKwbNkybN++HZGRkdDpdPj444/xwgsv4MKFC/yiJROGi4MQQiAkJASrVq1CamoqXnzxRfj6+iI/Px87d+7EP/7xD6Snp1u6TKI24eTkhKFDh+K7777D3Llz4enpiYyMDLz33nvcT0kmSksXYG+s/VuPQqFAeHg4PvvsM9y4cQMZGRk4c+YMysvLER4ejitXrli6RLJx1t4GGmvXrh3ef/99DB06FB9++CHmz5+P77//3tJl2QWGiwyEELhw4QKWLl1q6VKeSIcOHeDv74+bN2/i3LlzEEJYuiSyMQ1tYMmSJZYu5YlFRkZi3759bAMyEZItfc2wUnV1deaF82ydQqGASqWCi4uLpUshG8I2QPdiuBARkew4oW8jJEniLcPk8NgObAfDxUacPXsWnp6eOHv2rKVLIbIYtgPbwXAhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw8UGSJIEjUbT5H+JHA3bgW1huFgxrVaLlStXQq1WY+TIkaitrcXIkSOhVquxcuVKaLVaS5dI1OrYDmyTkBj/Vmnv3r2YNGkS9Ho9ADT5liaEAAB4eHggOTkZCQkJFqmRqLWxHdguhosV2rt3L8aOHWveL/xBFAoFhBBISUlhwyK7w3Zg2xguVkar1SIoKAjV1dUPbVANFAoF3N3dUVBQAF9f39YvkKgNsB3YPs65WJkNGzZAr9c3q0EBgMlkgl6vx8aNG1u5MqK2w3Zg+9hzsSKSJEGtViM3N/ex7oQRQkClUuHq1avmcWgiW8V2YB8YLlakrKwMAQEBLTrf399fxoqI2h7bgX3gsJgV0el0LTq/srJSpkqILIftwD4wXKyIl5dXi8739vaWqRIiy2E7sA8MFyvi7++PkJCQxx4vFkIgJCQE7du3b6XKiNoO24F9YLhYESEEEhMTn+jcOXPmcBKT7ALbgX3ghL6V4f39RGwH9oA9Fyvj6+uL5ORkCCGgUDz8/56GJ5O3bdvGBkV2he3A9jFcrFBCQgJSUlLg7u4OIcQvuvkNP3N3d0dqairi4+MtVClR62E7sG0MFyuVkJCAgoICJCUlQaVSNTmmUqmQlJSEwsJCNiiya2wHtotzLjZAkiTcvn0blZWV8Pb2Rvv27TlpSQ6H7cC2MFyIiEh2HBYjIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2/w+lxIsMjvXaAAAAAABJRU5ErkJggg==",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 6 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ccb7ec43",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"r2 is 0.9974619150161743\n",
|
|
"saving model version 0.2\n",
|
|
"r2 is 0.997527003288269\n",
|
|
"saving model version 0.3\n",
|
|
"r2 is 0.9740613698959351\n",
|
|
"saving model version 0.4\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.9741, device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fix_symbolic(0,0,0,'log')\n",
|
|
"model.fix_symbolic(0,1,0,'log')\n",
|
|
"model.fix_symbolic(1,0,0,'sin')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "0937db67",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 2.66e-07 | test_loss: 2.75e-07 | reg: 0.00e+00 | : 100%|█| 20/20 [00:01<00:00, 15.69it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.5\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "e959cda3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle - 1.0 \\sin{\\left(2.0 \\log{\\left(5.017 x_{1} \\right)} + 2.0 \\log{\\left(1.512 x_{2} \\right)} - 7.194 \\right)}$"
|
|
],
|
|
"text/plain": [
|
|
"-1.0*sin(2.0*log(5.017*x_1) + 2.0*log(1.512*x_2) - 7.194)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ex_round(model.symbolic_formula()[0][0], 3)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "16e4da06",
|
|
"metadata": {},
|
|
"source": [
|
|
"We were lucky -- singularity does not seem to be a problem in this case. But let's instead consider $f(x,y)=\\sqrt{x^2+y^2}$. $x=y=0$ is a singularity point."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "1ce52cec",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 3.65e-03 | test_loss: 3.97e-03 | reg: 4.84e+00 | : 100%|█| 20/20 [00:03<00:00, 5.36it\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
|
|
"model = KAN(width=[2,1,1], grid=5, k=3, seed=0)\n",
|
|
"f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2)\n",
|
|
"dataset = create_dataset(f, n_var=2)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "3a69ec41",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 500x400 with 6 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "abef7aa9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"r2 is 0.9999973773956299\n",
|
|
"saving model version 0.2\n",
|
|
"r2 is 0.9999948740005493\n",
|
|
"saving model version 0.3\n",
|
|
"r2 is 0.9998846650123596\n",
|
|
"saving model version 0.4\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.9999)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fix_symbolic(0,0,0,'x^2')\n",
|
|
"model.fix_symbolic(0,1,0,'x^2')\n",
|
|
"model.fix_symbolic(1,0,0,'sqrt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "aa71848c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"rewind to model version 0.4, renamed as 1.4\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = model.rewind('0.4')\n",
|
|
"model.get_act(dataset)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "e14000d8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle 1.00775534257195 \\sqrt{0.999962771771901 \\left(6.10769914067904 \\cdot 10^{-5} - x_{1}\\right)^{2} + \\left(9.20887777110479 \\cdot 10^{-5} - x_{2}\\right)^{2} + 0.00441348508007971} - 0.00955450534820557$"
|
|
],
|
|
"text/plain": [
|
|
"1.00775534257195*sqrt(0.999962771771901*(6.10769914067904e-5 - x_1)**2 + (9.20887777110479e-5 - x_2)**2 + 0.00441348508007971) - 0.00955450534820557"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"formula = model.symbolic_formula()[0][0]\n",
|
|
"formula"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "c56ee3d5",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/latex": [
|
|
"$\\displaystyle 1.01 \\sqrt{1.0 x_{1}^{2} + x_{2}^{2}} - 0.01$"
|
|
],
|
|
"text/plain": [
|
|
"1.01*sqrt(1.0*x_1**2 + x_2**2) - 0.e-2"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"ex_round(formula, 2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1fd57d41",
|
|
"metadata": {},
|
|
"source": [
|
|
"w/ singularity avoiding (LBFGS may still get nan because of line search, but Adam won't get nan)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "de708f21",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 5.11e-04 | test_loss: 5.64e-04 | reg: 0.00e+00 | : 100%|█| 1000/1000 [00:14<00:00, 70.\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 1.5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fit(dataset, opt=\"Adam\", steps=1000, lr=1e-3, update_grid=False, singularity_avoiding=True);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "6fd34c4c",
|
|
"metadata": {},
|
|
"source": [
|
|
"w/o singularity avoiding, nan may appear"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "031fabd6",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: nan | test_loss: nan | reg: nan | : 100%|█████████| 1000/1000 [00:17<00:00, 57.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 1.6\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.fit(dataset, opt=\"Adam\", steps=1000, lr=1e-3, update_grid=False);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "124c9ca4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|