GitHub_collection_pykan/tutorials/Example_9_singularity.ipynb

465 lines
40 KiB
Plaintext
Raw Normal View History

2024-08-11 13:02:16 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Example 9: Singularity"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"Let's construct a dataset which contains singularity $f(x,y)=sin(log(x)+log(y))\n",
" (x>0,y>0)$"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.89e-02 | test_loss: 3.78e-02 | reg: 6.39e+00 | : 100%|█| 20/20 [00:02<00:00, 7.05it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=5, k=3, seed=2)\n",
"f = lambda x: torch.sin(2*(torch.log(x[:,[0]])+torch.log(x[:,[1]])))\n",
"dataset = create_dataset(f, n_var=2, ranges=[0.2,5])\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsk0lEQVR4nO3deVBUd7o+8OfbNEJDswmIUVxoQFCQIMoSUXGJYiSL0ZmYpbxjYm6ljFfvOFtuJePVm5iZScapUTNOZbuVQEZjzODEiOuNGlRcUMAFUARRIyCoQCMNTQPd5/eHdv/ABVEO9PZ8qlJT4XS373Hy9sN3OecISZIkEBERyUhh7QKIiMjxMFyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpKd0toFENkDSZJQW1sLnU4HtVoNf39/CCGsXRaRzeLIhagLWq0Wa9euRXh4OAIDAxESEoLAwECEh4dj7dq10Gq11i6RyCYJPomS6N52796NuXPnorm5GcCt0YuZedTi4eGBzMxMpKamWqVGIlvFcCG6h927dyMtLQ2SJMFkMt33dQqFAkIIbN++nQFD1AHDhegOWq0WwcHB0Ov1XQaLmUKhgEqlQkVFBXx9fXu/QCI7wDUXojukp6ejubm5W8ECACaTCc3NzcjIyOjlyojsB0cuRB1IkoTw8HCUl5fjYVpDCAGNRoPS0lLuIiMCw4Wokxs3biAwMLBH7/f395exIiL7xGkxog50Ol2P3t/Y2ChTJUT2jeFC1IFare7R+728vGSqhMi+MVyIOvD390doaOhDr5sIIRAaGor+/fv3UmVE9oXhQtSBEAJLlix5pPcuXbqUi/lEt3FBn+gOvM6FqOc4ciG6g6+vLzIzMyGEgELRdYuYr9DfsmULg4WoA4YL0T2kpqZi+/btUKlUEELcNd1l/plKpcKOHTswY8YMK1VKZJsYLkT3kZqaioqKCqxZswYajabTMY1GgzVr1qCyspLBQnQPXHMh6gZJkrB//35MmzYNe/fuxZQpU7h4T9QFjlyIukEIYVlT8fX1ZbAQPQDDhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOF6AHa2tpQWVmJs2fPAgAuXLiAuro6mEwmK1dGZLv4mGOi+9BqtcjMzMSGDRtQVFSExsZGtLa2wt3dHYGBgZg4cSIWLlyI5ORkKJVKa5dLZFMYLkT3cOTIESxbtgynT59GfHw80tLSEBMTA7VaDa1Wi7y8PGzbtg1lZWWYN28eVq1ahcDAQGuXTWQzGC5Ed9izZw8WLFgAtVqNP/7xj5g1axZaW1uxadMmGAwGeHt748UXX0RbWxs2bdqElStXIioqCl999RWCgoKsXT6RTWC4EHVw/vx5zJw5E56enti0aRNGjRoFIQTKy8sRFxeHhoYGhISEIC8vD35+fpAkCYcOHcLLL7+MyZMn4/PPP4ebm5u1T4PI6rigT3Sb0WjEH/7wB9TX1+Nvf/ubJVi6IoTAhAkT8OGHH2Lr1q3YtWtXH1VLZNsYLkS3lZWVYdu2bZgzZw4mTJjwwGAxE0Jg9uzZSEpKwmeffYb29vZerpTI9nGLC9Fthw8fhk6nw9y5c3Hp0iU0NTVZjlVUVMBoNAIAWltbUVRUBG9vb8vxQYMGYc6cOVi5ciWqq6sRHBzc5/UT2RKGC9Ft586dg4eHBzQaDd544w3k5ORYjkmSBIPBAACoqqrC9OnTLceEEPjLX/6C0aNHo7m5GVVVVQwXcnoMF6Lb9Ho9lEol3NzcYDAY0NLScs/XSZJ017H29naoVKpOIUTkzBguRLcNGDAAer0eWq0WiYmJ8PT0tBzT6/U4fPiwJUTGjx9vuXBSCIGhQ4fi2rVrUCgU8PPzs9YpENkMhgvRbWPHjkVbWxtyc3PxwQcfdDpWXl6O+Ph4NDQ0ICgoCN988w18fX0tx4UQePvttzFw4EBOiRGBu8WILBISEqDRaJCeno6mpia4uLh0+sdMCAGFQmH5uUKhwNWrV/HPf/4TaWlp8PHxseJZENkGhgvRbf7+/viP//gP5OfnY926dd3eUmwwGPDee+9Br9fjjTfe6PYWZiJHxmkxog4WLFiAAwcO4IMPPoCHhwcWLVoEd3d3AIBSqYRSqbSMYiRJQmNjI95//31s2rQJf/3rXxEREWHN8olsBm//QnSH69evY/HixcjKykJqaiqWLVuGkSNHoqSkBCaTCf369UNYWBhyc3OxevVqnDx5Eu+++y4WLVrUafqMyJkxXIjuoampCZ999hnWrVuHmpoaaDQahIeHw8vLC/X19SgpKUFVVRXGjh2LFStWICUlBQoFZ5mJzBguRF2orq7G3r17kZ2djVOnTiE3NxcTJ05EcnIyZsyYgcTERHh4eFi7TCKbw3Ah6qbjx48jISEBx48fx7hx46xdDpFN4zieqJtcXFws25CJqGvsEiIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMfnuRB1kyRJMJlMUCgUEEJYuxwim8aRC9FD4LNciLpHae0CiOQiSRJKS0tRW1tr7VJ6RKFQIDo6Gp6entYuheiRcVqMHIbJZMLixYsxZMgQqNVqa5fzQCaTCcDdo6GDBw9i+fLliImJsUZZRLLgyIUcipubGxYuXIigoCBrl9KlpqYmLF++HDqdDn/+85/h4+MD4NboS6fTgb/zkb3jBDJRH5MkCVlZWVi/fj0yMjKQl5dn7ZKIZMdwIepjbW1tyMjIQHt7O9ra2nD8+HGOVMjhMFyI+lhNTQ3y8/Mt/56bm2tZfyFyFAwXoj5WXFyMuro6y7+fOXMGDQ0NVqyISH4MF6I+JEkScnNz0d7ejuDgYHh4eKCqqgoXL160dmlEsmK4EPUho9GI3NxcAMAzzzwDjUYDvV6PgoICrruQQ2G4EN2DJEmoqKhARkYGcnJyYDQaZfnc+vp6FBUVQaFQYOrUqYiNjQUAHD16lOFCDoXhQnQHSZJw5coVPP/883jttdeQlpaGDRs2yPLlf/HiRdTU1MDb2xsxMTFISkoCABQUFECv1/f484lsBcOF6A6SJGHNmjXIz8+HUqlEY2MjVq5ciStXrvT4c0tKSmAwGDB48GAMHDgQcXFxcHd3x6VLl1BVVSXTGRBZH8OF6A4VFRX49ttvoVAo8NZbb0Gj0eCnn37CN9980+PRS3FxMSRJQmhoKFQqFUJDQzFw4EDcvHkTZ86ckekMiKyP4ULUgSRJ+L//+z9cvXoVw4YNw6JFi/Dyyy9DkiR8++23aGpqeuTPNplMOHfuHAAgMjISCoUCvr6+iIqKgslkwrFjx+Q6DSKrY7gQdWA0GpGVlQVJkjB9+nQMGDAAs2fPhlqtRnFxcY9GF3q9HuXl5RBCYOTIkRBCwMXFBQkJCQCAEydOoK2tTa5TIbIqhgtRBzU1NTh+/DhcXFzw1FNPQQiByMhIjB49Gnq9Hj/88MMjT43V1taiuroarq6uCAsLAwAIIZCQkAClUomSkhLcuHFDztMhshqGC9FtkiShoKAA165dQ1BQEMaOHQshBNzd3TF16lQAwP79+9Ha2vpIn19RUYGbN2/Cx8cHQ4YMsfx8xIgR8PHxQV1dHSoqKmQ5FyJrY7gQdXDo0CEYjUbExMRgwIABAG6NLiZPngxXV1cUFRXh6tWrD/25kiShvLwcbW1tCAoKgp+fn+WYt7c3vLy80N7e3um2MET2jOFCdFtrayuOHj0KAEhOToZS+f8fdxQVFYWBAweirq4OJ0+efKSpsZKSEkiShOHDh0OlUll+7ubmBi8vL5hMJtTX1/f8RIhsAMOF6LZr166hpKQErq6uSExMhBDCcszf3x8xMTEwmUw4fPjwQ3+2JEk4f/4
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "ccb7ec43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best value at boundary.\n",
"r2 is 0.999884843826294\n",
"saving model version 0.2\n",
"Best value at boundary.\n",
"r2 is 0.9998899102210999\n",
"saving model version 0.3\n",
"r2 is 0.9975605010986328\n",
"saving model version 0.4\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.9976)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'log')\n",
"model.fix_symbolic(0,1,0,'log')\n",
"model.fix_symbolic(1,0,0,'sin')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0937db67",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.95e-07 | test_loss: 2.91e-07 | reg: 0.00e+00 | : 100%|█| 20/20 [00:01<00:00, 15.68it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e959cda3",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 \\sin{\\left(2.0 \\log{\\left(9.993 x_{1} \\right)} + 2.0 \\log{\\left(10.0 x_{2} \\right)} - 9.209 \\right)}$"
],
"text/plain": [
"1.0*sin(2.0*log(9.993*x_1) + 2.0*log(10.0*x_2) - 9.209)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ex_round(model.symbolic_formula()[0][0], 3)"
]
},
{
"cell_type": "markdown",
"id": "16e4da06",
"metadata": {},
"source": [
"We were lucky -- singularity does not seem to be a problem in this case. But let's instead consider $f(x,y)=\\sqrt{x^2+y^2}$. $x=y=0$ is a singularity point."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "1ce52cec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 5.17e-03 | test_loss: 5.45e-03 | reg: 5.66e+00 | : 100%|█| 20/20 [00:02<00:00, 7.44it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=5, k=3, seed=0)\n",
"f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"\n",
"# train the model\n",
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3a69ec41",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsLElEQVR4nO3deVCUZ7o28OvpZmsWbcBd1NhKjIqOO8ruBiY4xtGJZjyVGhMn43ESPS5J6pwkk8WTjBVHjQtmjtFTM+pxhkkFlxFUghuIorhGJYoiSgK4QyPQTbP08/0R6Q/clbf7bZrrV2WlyqbpG+LdV7/P9goppQQREZGCNGoXQERErofhQkREimO4EBGR4hguRESkOIYLEREpjuFCRESKY7gQEZHiGC5ERKQ4hgsRESmO4UJERIpjuBARkeIYLkREpDiGCxERKY7hQkREimO4EBGR4tzULoCoOZBS4vbt26ioqICvry8CAwMhhFC7LCKnxSsXokcwGo1YsWIFgoOD0bZtW3Tv3h1t27ZFcHAwVqxYAaPRqHaJRE5J8E6URA+WmpqKyZMnw2QyAfj56qVe/VWLt7c3kpKSEBcXp0qNRM6K4UL0AKmpqYiPj4eUElar9aFfp9FoIIRASkoKA4aoAYYL0T2MRiOCgoJgNpsfGSz1NBoNdDodCgsLodfr7V8gUTPAOReie6xfvx4mk+mJggUArFYrTCYTNmzYYOfKiJoPXrkQNSClRHBwMPLz8/E0rSGEgMFgwMWLF7mKjAgMF6JGbt26hbZt2zbp+YGBgQpWRNQ8cViMqIGKioomPb+8vFyhSoiaN4YLUQO+vr5Ner6fn59ClRA1bwwXogYCAwPRo0ePp543EUKgR48eCAgIsFNlRM0Lw4WoASEEZs+e/UzPnTNnDifzie7ihD7RPbjPhajpeOVCdA+9Xo+kpCQIIaDRPLpF6nfob968mcFC1ADDhegB4uLikJKSAp1OByHEfcNd9X+n0+mwY8cOxMbGqlQpkXNiuBA9RFxcHAoLC7F8+XIYDIZGjxkMBixfvhxFRUUMFqIH4JwL0ROQUmLfvn0YPXo09uzZg5EjR3LynugReOVC9ASEELY5Fb1ez2AhegyGCxERKY7hQkREimO4EBGR4hguRESkOIYLEREpjuFCRESKY7gQEZHiGC5ERKQ4hgsRESmO4UJERIpjuBARkeIYLkREpDiGCxERKY7hQkREimO4EBGR4hguRESkOIYL0WPU1NSgqKgI586dAwBcunQJJSUlsFqtKldG5Lx4m2OihzAajUhKSsKmTZuQk5OD8vJyVFdXw8vLC23btkVkZCRmzJiB8PBwuLm5qV0ukVNhuBA9QFZWFubNm4fTp09j6NChiI+PR//+/eHr6wuj0Yjjx49j+/btyMvLw9SpU/HZZ5+hbdu2apdN5DQYLkT3+O677zB9+nT4+vpi0aJFeOmll1BdXY3ExERYLBa0atUKr776KmpqapCYmIhPPvkEffv2xcaNG9G+fXu1yydyCgwXogYuXLiAcePGwcfHB4mJiejTpw+EEMjPz8egQYNQVlaG7t274/jx4/D394eUEpmZmZg2bRpiYmKwbt06eHp6qv1jEKmOE/pEd9XV1eFPf/oTSktLkZCQYAuWRxFCICIiAosXL8a2bduwa9cuB1VL5NwYLkR35eXlYfv27Zg0aRIiIiIeGyz1hBCYOHEihg8fjrVr16K2ttbOlRI5Py5xIbrr0KFDqKiowOTJk3HlyhVUVlbaHissLERdXR0AoLq6Gjk5OWjVqpXt8U6dOmHSpEn45JNPcO3aNQQFBTm8fiJnwnAhuuv8+fPw9vaGwWDAzJkzcfDgQdtjUkpYLBYAQHFxMcaOHWt7TAiBpUuXol+/fjCZTCguLma4UIvHcCG6y2w2w83NDZ6enrBYLKiqqnrg10kp73ustrYWOp2uUQgRtWQMF6K72rVrB7PZDKPRiNDQUPj4+NgeM5vNOHTokC1EwsLCbBsnhRDo2rUrbty4AY1GA39/f7V+BCKnwXAhumvw4MGoqalBdnY2vvjii0aP5efnY+jQoSgrK0P79u3xz3/+E3q93va4EALvv/8+OnTowCExInC1GJHNsGHDYDAYsH79elRWVkKr1Tb6U08IAY1GY/t7jUaDq1ev4ttvv0V8fDxat26t4k9B5BwYLkR3BQYG4u2338aJEyewcuXKJ15SbLFY8N///d8wm82YOXPmEy9hJnJlHBYjamD69OnIyMjAF198AW9vb8yaNQteXl4AADc3N7i5udmuYqSUKC8vx+eff47ExER8+eWX6NWrl5rlEzkNHv9CdI+bN2/irbfeQnJyMuLi4jBv3jz07t0bubm5sFqt8PDwQM+ePZGdnY0lS5bg1KlTWLhwIWbNmtVo+IyoJWO4ED1AZWUl1q5di5UrV+L69eswGAwIDg6Gn58fSktLkZubi+LiYgwePBgff/wxoqOjodFwlJmoHsOF6BGuXbuGPXv2ID09Hd9//z2ys7MRGRmJ8PBwxMbGIjQ0FN7e3mqXSeR0GC5ET+jo0aMYNmwYjh49iiFDhqhdDpFT43U80RPSarW2ZchE9GjsEiIiUhzDhYiIFMdwISIixTFciIhIcQwXIiJSHMOFiIgUx3AhIiLFMVyIiEhxDBciIlIcw4WIiBTHcCEiIsUxXIiISHEMFyIiUhzDhYiIFMf7uRA9ISklrFYrNBoNhBBql0Pk1HjlQvQUeC8XoifjpnYBREqRUuLixYu4ffu22qU0iUajQUhICHx8fNQuheiZcViMXIbVasVbb72FLl26wNfXF3V1dc1yCOvAgQP44x//iP79+6tdCtEz45ULuRRPT09ER0dj69atOHXqFPr06YMXX3wRQ4cORUBAgNMHjZQSFRUV4Gc+au44gEwup7i4GKtXr8bevXuRkJCAl19+GZGRkdi0aRMsFgvfuIkcgOFCLmfAgAGYPn06PvroI0yYMAEBAQHIzc3Fm2++iTfffBOXL19mwBDZGYfFyOX4+fkhISEBQgjU1dWhoKAAixcvxv/93/9h06ZNOHToED744ANMnToVOp3O6YfKiJojXrmQyxFC2Cby3dzcYDAYsGrVKvztb39Dr169cPnyZfz7v/87fv3rXyMrKwt1dXW8kiFSGMOFXJ4QAh4eHvj1r3+N1NRUvPXWW9DpdEhNTcX48eOxYMEC/PjjjwwYIgUxXKjFEEIgKCgIy5Ytw7/+9S+MGjUKlZWVWLVqFUaPHo01a9agvLycIUOkAIYLtSj1Q2WRkZHYsmUL/ud//gfPP/88rly5gjlz5uBXv/oVDh06hLq6OrVLJWrWGC7UIgkh4Ovri+nTp2P37t2YP38+/Pz8sG/fPkyYMAELFy5ESUkJr2KInhHDhVo0IQQ6d+6MRYsWITk5GaNGjUJ5eTk+//xzTJgwAQcPHoTValW7TKJmh+FCBECr1WL48OFISkrCokWL0KZNG2RlZWHixIlYsmQJ52KInhLDheguIQRatWqFefPmITk5GTExMSgrK8OHH36If/u3f8P58+cZMERPiOFCdA+NRoPBgwcjKSkJH3zwAXx8fJCSkoL4+Hh88803qKmpUbtEIqfHcCF6ACEE9Ho9/vjHPyIxMRF9+/ZFQUEBZsyYgffeew+3bt3iVQzRIzBciB5Bq9UiNjYWKSkpmDZtGmpra7Fq1SpMnjwZp06dYsAQPQTDhegx6jdfrl27Fl9++SXatGmDzMxMTJgwAX//+985TEb0AAwXoicghICXlxdmzpyJzZs3Y/DgwSguLsbMmTPx8ccfczUZ0T0YLkRPQaPRYMSIEdi6dSumTZuGmpoa/PnPf8aMGTNw9epVBgzRXQwXoqckhEDHjh2xZs0afPrpp/D29kZSUhKmTp2K3NxcBgwRGC5Ez0QIAW9vb7zzzjv4+uuv0a5dOxw6dAhTpkzBiRMnGDDU4jFciJrAzc0Nr7zyCv7+97/DYDAgJycHv/nNb5CVlcWAoRaN4ULURBqNBjExMUhMTETv3r1x6dIlvPbaawwYatEYLkQKEEJg0KBB+Mc//oGQkBBcuXIFv/3tb3Hs2DE
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "abef7aa9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999783635139465\n",
"saving model version 0.2\n",
"r2 is 0.9999676942825317\n",
"saving model version 0.3\n",
"r2 is 0.9997884631156921\n",
"saving model version 0.4\n"
]
},
{
"data": {
"text/plain": [
"tensor(0.9998)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'x^2')\n",
"model.fix_symbolic(0,1,0,'x^2')\n",
"model.fix_symbolic(1,0,0,'sqrt')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "aa71848c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rewind to model version 1.4, renamed as 2.4\n"
]
}
],
"source": [
"model = model.rewind('0.4')\n",
"model.get_act(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "e14000d8",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.01334547419162 \\sqrt{0.999861446076389 \\left(7.53297050423062 \\cdot 10^{-5} - x_{2}\\right)^{2} + \\left(0.000104069324734005 - x_{1}\\right)^{2} + 0.00834810636784406} - 0.0170296430587769$"
],
"text/plain": [
"1.01334547419162*sqrt(0.999861446076389*(7.53297050423062e-5 - x_2)**2 + (0.000104069324734005 - x_1)**2 + 0.00834810636784406) - 0.0170296430587769"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"formula = model.symbolic_formula()[0][0]\n",
"formula"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "c56ee3d5",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.01 \\sqrt{x_{1}^{2} + 1.0 x_{2}^{2} + 0.01} - 0.02$"
],
"text/plain": [
"1.01*sqrt(x_1**2 + 1.0*x_2**2 + 0.01) - 0.02"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ex_round(formula, 2)"
]
},
{
"cell_type": "markdown",
"id": "1fd57d41",
"metadata": {},
"source": [
"w/ singularity avoiding (LBFGS may still get nan because of line search, but Adam won't get nan)."
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "de708f21",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 1.09e-07 | test_loss: 1.48e-07 | reg: 0.00e+00 | : 100%|█| 1000/1000 [00:12<00:00, 83.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 1.5\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"Adam\", steps=1000, lr=1e-3, update_grid=False, singularity_avoiding=True);"
]
},
{
"cell_type": "markdown",
"id": "6fd34c4c",
"metadata": {},
"source": [
"w/o singularity avoiding, nan may appear"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "031fabd6",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: nan | test_loss: nan | reg: nan | : 100%|█████████| 1000/1000 [00:11<00:00, 84.83it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 1.6\n"
]
}
],
"source": [
"model.fit(dataset, opt=\"Adam\", steps=1000, lr=1e-3, update_grid=False);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "124c9ca4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}