GitHub_collection_pykan/tutorials/Example_9_singularity.ipynb

388 lines
47 KiB
Plaintext
Raw Normal View History

2024-04-29 12:35:18 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Example 9: Singularity"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"Let's construct a dataset which contains singularity $f(x,y)=sin(log(x)+log(y))\n",
" (x>0,y>0)$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 5.01e-03 | test loss: 5.28e-03 | reg: 7.54e+00 : 100%|██| 20/20 [00:06<00:00, 3.16it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"from kan import *\n",
2024-04-29 12:35:18 -04:00
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,1,1], grid=20, k=3, seed=0)\n",
"f = lambda x: torch.sin(2*(torch.log(x[:,[0]])+torch.log(x[:,[1]])))\n",
"dataset = create_dataset(f, n_var=2, ranges=[0.2,5])\n",
"\n",
"# train the model\n",
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAtiklEQVR4nO3deVBUZ74+8OftbvYdRJBFpaGDuC9BXEDUqLhEb6Le0sp1bhnNxGtpMuNMJuY6cXQck2jGuUETU5PSqRtNcsMkQccYvOINyaACccFdWQVUQESBZmtolj6/P2L3D1xRDvT2fKpSU8U5ffg249tPv8t5j5AkSQIREZGMFOYugIiIbA/DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2anMXQCRNZAkCVVVVWhoaIC7uzv8/PwghDB3WUQWiz0XokfQarXYvn07NBoN/P39ERYWBn9/f2g0Gmzfvh1ardbcJRJZJMEnURI9WGpqKhYsWACdTgfg596LkbHX4urqiuTkZCQkJJilRiJLxXAheoDU1FTMmTMHkiTBYDA89DyFQgEhBFJSUhgwRB0wXIjuodVqERISgqampkcGi5FCoYCLiwtKS0vh7e3d8wUSWQHOuRDdY8+ePdDpdF0KFgAwGAzQ6XTYu3dvD1dGZD3YcyHqQJIkaDQaFBUV4UmahhACarUaBQUFXEVGBIYLUSd37tyBv79/t17v5+cnY0VE1onDYkQdNDQ0dOv19fX1MlVCZN0YLkQduLu7d+v1Hh4eMlVCZN0YLkQd+Pn5ITw8/InnTYQQCA8Ph6+vbw9VRmRdGC5EHQgh8Nprrz3Va19//XVO5hPdxQl9onvwPhei7mPPhege3t7eSE5OhhACCsWjm4jxDv19+/YxWIg6YLgQPUBCQgJSUlLg4uICIcR9w13Gn7m4uODQoUOYMWOGmSolskwMF6KHSEhIQGlpKRITE6FWqzsdU6vVSExMRFlZGYOF6AE450LUBZIk4ccff8Rzzz2HtLQ0TJkyhZP3RI/AngtRFwghTHMq3t7eDBaix2C4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQPUZrayvKysqQk5MDALh69Sqqq6thMBjMXBmR5eJjjokeQqvVIjk5GV988QUuX76M+vp6tLS0wNnZGf7+/oiLi8Py5csxceJEqFQqc5dLZFEYLkQPkJWVhTVr1uDChQuIjo7GnDlzMHz4cLi7u0Or1SI7OxsHDx5EYWEhFi1ahM2bN8Pf39/cZRNZDIYL0T2OHDmCpUuXwt3dHe+99x5mz56NlpYWJCUlQa/Xw9PTE4sXL0ZrayuSkpKwceNGDBkyBJ999hkCAgLMXT6RRWC4EHWQn5+PmTNnws3NDUlJSRg8eDCEECgqKsLo0aNRW1uLsLAwZGdnw8fHB5Ik4fjx43jppZcwefJk7N69G05OTuZ+G0Rmxwl9orva29vx7rvvoqamBh999JEpWB5FCIHY2Fi8//77OHDgAA4fPtxL1RJZNoYL0V2FhYU4ePAg5s+fj9jY2McGi5EQAi+88ALGjRuHXbt2oa2trYcrJbJ8XOJCdFdmZiYaGhqwYMEClJSUoLGx0XSstLQU7e3tAICWlhZcvnwZnp6epuNBQUGYP38+Nm7ciIqKCoSEhPR6/USWhOFCdFdubi5cXV2hVquxYsUKZGRkmI5JkgS9Xg8AKC8vx/Tp003HhBD4y1/+gmHDhkGn06G8vJzhQnaP4UJ0V1NTE1QqFZycnKDX69Hc3PzA8yRJuu9YW1sbXFxcOoUQkT1juBDd1bdvXzQ1NUGr1SImJgZubm6mY01NTcjMzDSFyIQJE0w3Tgoh0L9/f1RWVkKhUMDHx8dcb4HIYjBciO4aM2YMWltbcfLkSWzdurXTsaKiIkRHR6O2thYBAQH4+9//Dm9vb9NxIQTWrVuHwMBADokRgavFiEzGjh0LtVqNPXv2oLGxEUqlstN/RkIIKBQK088VCgVu3ryJb775BnPmzIGXl5cZ3wWRZWC4EN3l5+eH1atX48yZM9ixY0eXlxTr9Xr86U9/QlNTE1asWNHlJcxEtozDYkQdLF26FEePHsXWrVvh6uqKlStXwtnZGQCgUqmgUqlMvRhJklBfX4933nkHSUlJ+OCDDxAZGWnO8oksBrd/IbrH7du3sWrVKnz33XdISEjAmjVrEBUVhby8PBgMBjg6OiIiIgInT57Etm3bcO7cOWzatAkrV67sNHxGZM8YLkQP0NjYiF27dmHHjh24desW1Go1NBoNPDw8UFNTg7y8PJSXl2PMmDHYsGED4uPjoVBwlJnIiOFC9AgVFRVIS0tDeno6zp8/j5MnTyIuLg4TJ07EjBkzEBMTA1dXV3OXSWRxGC5EXXTq1CmMHTsWp06dwrPPPmvucogsGvvxRF2kVCpNy5CJ6NHYSoiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh2f50LURZIkwWAwQKFQQAhh7nKILBp7LkRPgM9yIeoalbkLIJKLJEkoKChAVVWVuUvpFoVCgaFDh8LNzc3cpRA9NQ6Lkc0wGAxYtWoVQkND4e7ubu5yHqu9vR1KpfK+nx87dgzr16/H8OHDzVAVkTzYcyGb4uTkhOXLlyMgIKBb15EkCefPn0dqairGjBmDKVOmPDAInobBYMBXX32Fzz//HH/4wx8QHR1tmsORJAkNDQ3gdz6ydgwXontIkoQzZ85g/vz5KC0thbu7O3bv3o2FCxfKMpFfWlqKtWvXorS0FA4ODkhKSoKTk5MMlRNZDs5OEt2jtbUVW7ZsQWlpKZydndHQ0IDNmzejurq629eWJAnfffcdysrKAABpaWk4ffo0eypkcxguRPfIycnB999/D0dHR7zzzjvo27ev6WfdDYG2tjYcPnwYkiRBqVSisbERn3zyCdrb22WqnsgyMFyIOpAkCYcOHUJdXR2GDh2Kl19+GTNmzEB7ezuSk5O7HQKVlZU4c+YMlEolfvGLX0ClUiElJQWXLl1i74VsCsOFqIOWlhYcOXIEADBr1ix4enpi3rx5UCqV+Omnn1BZWfnU15YkCZcuXcLt27fh7++PN954A6NGjYJWq8Xf/vY3hgvZFIYLUQelpaW4fPkyHB0d8dxzz0EIgbFjx8Lf3x+3bt3CuXPnunX948ePo62tDYMHD0Z4eDheffVVKBQKJCcn4+rVq/K8CSILwHAhukuSJJw7dw41NTUICgrC4MGDAQCBgYEYNmwY2tracPz48afuYbS0tCAjIwMAEBsbCwcHB8ybNw+DBg3CrVu3sHfvXvZeyGYwXIg6yMrKgsFgwLBhw+Dj4wMAUKlUGD9+PADg5MmTaG1tfapr37x5E1euXIGDgwNiY2MhhICfnx+WLVsGIQS+/PLLbg27EVkShgvRXa2trThz5gwAICYmxnTTpBACMTExUKlUyMvLe6rtZSRJwtmzZ1FdXY3AwEAMGTLEdO2FCxciKCgI165dQ1pamnxviMiMGC5Ed1VVVaGgoAAqlQqjR4/udMPkoEGD4O3tjTt37jz13MixY8fQ3t6OESNGoE+fPqafBwUFYfr06TAYDPj666/R1tbW7fdCZG4MF6K7iouLUVVVBS8vL0RGRnY6FhAQgIEDB6KlpQUXLlx44rmR5uZmZGVlAQDi4uI6bSWjUCiwaNEijBgxAtHR0bznhWwCw4UI/3+ZsF6vR2hoKPr27dvpuJOTk2mC/+zZs098/dLSUuTn58PJyQkTJkzo1CsSQmDKlClIT0/HunXruBUM2QSGC9Fd58+fBwBERUXB2dm50zEhBEaNGgUAuHLlCvR6fZeva5xvqaurQ3BwMJ555pn7znFwcIC
2024-04-29 12:35:18 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-29 12:35:18 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ccb7ec43",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"r2 is 0.9999974370002747\n",
"r2 is 0.9999890923500061\n",
"r2 is 0.999965488910675\n"
2024-04-29 12:35:18 -04:00
]
},
{
"data": {
"text/plain": [
2024-07-13 22:17:48 -04:00
"tensor(1.0000)"
2024-04-29 12:35:18 -04:00
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fix_symbolic(0,0,0,'log')\n",
"model.fix_symbolic(0,1,0,'log')\n",
"model.fix_symbolic(1,0,0,'sin')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0937db67",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 2.85e-07 | test loss: 2.82e-07 | reg: 7.54e+00 : 100%|██| 20/20 [00:02<00:00, 9.03it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e959cda3",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
2024-07-13 22:17:48 -04:00
"$\\displaystyle - 1.0 \\sin{\\left(2.0 \\log{\\left(2.205 x_{1} \\right)} + 2.0 \\log{\\left(2.018 x_{2} \\right)} + 0.156 \\right)}$"
2024-04-29 12:35:18 -04:00
],
"text/plain": [
2024-07-13 22:17:48 -04:00
"-1.0*sin(2.0*log(2.205*x_1) + 2.0*log(2.018*x_2) + 0.156)"
2024-04-29 12:35:18 -04:00
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-07-13 22:17:48 -04:00
"ex_round(model.symbolic_formula()[0][0], 3)"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "markdown",
"id": "16e4da06",
"metadata": {},
"source": [
"We were lucky -- singularity does not seem to be a problem in this case. But let's instead consider $f(x,y)=\\sqrt{x^2+y^2}$. $x=y=0$ is a singularity point."
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 13,
2024-04-29 12:35:18 -04:00
"id": "1ce52cec",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 4.94e-03 | test loss: 5.23e-03 | reg: 5.98e+00 : 100%|██| 20/20 [00:03<00:00, 5.04it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"from kan import *\n",
2024-04-29 12:35:18 -04:00
"import torch\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
2024-07-13 22:17:48 -04:00
"model = KAN(width=[2,1,1], grid=5, k=3, seed=0)\n",
2024-04-29 12:35:18 -04:00
"f = lambda x: torch.sqrt(x[:,[0]]**2+x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"\n",
"# train the model\n",
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 14,
2024-04-29 12:35:18 -04:00
"id": "3a69ec41",
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAshUlEQVR4nO3deVRUV54H8O+tKpZiLUVEDdEGJLgSN0RlNSIYTNJGTexxMt1O7E7iMWZccrqzTVqTNplkNKLGyURjMhqT0J24i4rRCIgbrtGgIrhEAXEBiq2KYqk3f7RVB1xRXtUriu/nHE/O8Vm8H4Rffeve+959QpIkCURERDJSKV0AERE5H4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESy0yhdAFFbIEkSSktLUV1dDS8vL/j5+UEIoXRZRA6LIxeie9Dr9Vi8eDFCQ0Ph7++PoKAg+Pv7IzQ0FIsXL4Zer1e6RCKHJPgkSqI7S09Px4QJE2AwGAD8c/RiYRm1eHh4YO3atUhKSlKkRiJHxXAhuoP09HSMHTsWkiTBbDbf9d+pVCoIIZCWlsaAIWqC4UJ0C71ej8DAQBiNxnsGi4VKpYJWq0VhYSF0Op3tCyRqA7jmQnSLVatWwWAwtChYAMBsNsNgMGD16tU2royo7eDIhagJSZIQGhqK8+fP40FaQwiB4OBg5Ofn8yoyIjBciJq5ceMG/P39W/V6Pz8/GSsiaps4LUbURHV1dateX1VVJVMlRG0bw4WoCS8vr1a93tvbW6ZKiNo2hgtRE35+fggJCXngdRMhBEJCQtCxY0cbVUbUtjBciJoQQmDGjBkP9drXXnuNi/lEN3FBn+gWvM+FqPU4ciG6hU6nw9q1ayGEgEp17xax3KG/bt06BgtREwwXojtISkpCWloatFothBC3TXdZ/k6r1WLr1q1ITExUqFIix8RwIbqLpKQkFBYWIiUlBcHBwc2OBQcHIyUlBUVFRQwWojvgmgtRC0iShN27d2PUqFHYtWsXRo4cycV7onvgyIWoBYQQ1jUVnU7HYCG6D4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgvRfdTX16OoqAinT58GAJw7dw5lZWUwm80KV0bkuPiYY6K70Ov1WLt2Lb755hvk5uaiqqoKdXV1cHd3h7+/P2JiYjB16lRERUVBo9EoXS6RQ2G4EN3B/v37MWvWLJw4cQIREREYO3YswsPD4eXlBb1ejyNHjmDz5s0oKCjApEmT8Le//Q3+/v5Kl03kMBguRLfYsWMHpkyZAi8vL3z44YdITk5GXV0dUlNTYTKZ4OPjg9/97neor69Hamoq5s6di759++Lrr79GQECA0uUTOQSGC1ETZ8+exZgxY+Dp6YnU1FT06dMHQgicP38egwYNQkVFBYKCgnDkyBF06NABkiQhOzsbkydPRnx8PL744gu4ubkp/W0QKY4L+kQ3NTY24oMPPkB5eTk+/fRTa7DcixAC0dHR+Pjjj7Fx40Zs377dTtUSOTaGC9FNBQUF2Lx5M8aPH4/o6Oj7BouFEALjxo3DsGHDsGLFCjQ0NNi4UiLHx0tciG7at28fqqurMWHCBFy8eBE1NTXWY4WFhWhsbAQA1NXVITc3Fz4+Ptbj3bp1w/jx4zF37lyUlJQgMDDQ7vUTORKGC9FNZ86cgYeHB4KDg/Hyyy9j79691mOSJMFkMgEAiouLMXr0aOsxIQQWLlyI/v37w2AwoLi4mOFC7R7Dhegmo9EIjUYDNzc3mEwm1NbW3vHfSZJ027GGhgZotdpmIUTUnjFciG7q3LkzjEYj9Ho9IiMj4enpaT1mNBqxb98+a4iMGDHCeuOkEALdu3fHtWvXoFKp0KFDB6W+BSKHwXAhumnw4MGor69HTk4OPvroo2bHzp8/j4iICFRUVCAgIAB///vfodPprMeFEHjrrbfQpUsXTokRgVeLEVkNHToUwcHBWLVqFWpqaqBWq5v9sRBCQKVSWf9epVLhypUr+OGHHzB27Fj4+voq+F0QOQaGC9FNfn5+ePXVV3H06FEsWbKkxZcUm0wmvP/++zAajXj55ZdbfAkzkTPjtBhRE1OmTEFWVhY++ugjeHh4YNq0aXB3dwcAaDQaaDQa6yhGkiRUVVVh/vz5SE1NxaJFixAWFqZk+UQOg9u/EN3i+vXrmD59OrZs2YKkpCTMmjULvXv3Rl5eHsxmM1xdXdGzZ0/k5ORgwYIFOH78ON577z1Mmzat2fQZUXvGcCG6g5qaGqxYsQJLlizB1atXERwcjNDQUHh7e6O8vBx5eXkoLi7G4MGD8de//hVxcXFQqTjLTGTBcCG6h5KSEuzatQuZmZn4+eefkZOTg5iYGERFRSExMRGRkZHw8PBQukwih8NwIWqhQ4cOYejQoTh06BCGDBmidDlEDo3jeKIWUqvV1suQieje2CVERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmOz3MhaiFJkmA2m6FSqSCEULocIofGkQvRA+CzXIhaRqN0AURykSQJ+fn5KC0tVbqUVlGpVOjXrx88PT2VLoXooXFajJyG2WzG9OnT8eijj8LT0xONjY3Wp0e2JXv27MF//ud/Ijw8XOlSiB4aRy7kVNzc3DB06FB8+eWXKCkpwcCBA/Hss89i4MCB8PDwcPigkSQJ1dXV4Gc+aus4gUxOx2QyYdOmTcjMzERKSgoSExORlJSE9PR0NDQ08I2byA4YLuR0QkJC8MYbb+CDDz5AQkIC3NzcsH//fkyaNAlvvfUWSktLGTBENsZwIafToUMHvP322/jzn/+MzZs3Y9euXXjmmWdQW1uLTz75BE899RR27tzJUQyRDTFcyCkJISCEgKurKwYOHIg1a9ZgwYIF6Ny5Mw4dOoQJEyZgxowZuHDhAgOGyAYYLuT0hBDw9PTE9OnTsW3bNjzzzDOoq6vD8uXLMWrUKCxZsgR6vZ4hQyQjhgu1GyqVCuHh4fj222+xcuVK9OrVC5cvX8brr7+O5ORkbN26FXV1dQwZIhkwXKhdEUJAq9Vi8uTJ2LVrF95880107NgRBw8exPPPP4+pU6fi7NmzDBiiVmK4ULskhECXLl0wb9487NixA8899xwkScK3336LxMRE/O///i9qamoYMkQPieFC7ZplqmzVqlX45ptvEB4ejqKiIsycOROTJ09Gbm4uA4boITBcqN0TQsDNzQ3jxo3D9u3bMWvWLLi7u2PLli1ITk7Gl19+idraWoYM0QNguBDdJIRA586d8eGHH+L777/HgAEDUFxcjFdffRUvv/wyLl++zIAhaiGGC1ETQghoNBqMHj0aaWlpeOWVV6BSqbBmzRqMHTsW6enpaGxsVLpMIofHcCG6AyEEAgICsGjRInz55Zf4zW9+g1OnTuFf/uVf8MEHH6CqqoqjGKJ7YLgQ3YUQAi4uLnj++eeRlpaG5ORk1NTU4P3338cLL7yAgoICBgzRXTBciO5DCIGwsDB8++23ePfdd+Hp6YktW7bgqaeewtatWzlNRnQHDBeiFhBCwNvbG2+++Sa+/fZbPPbYYygoKMALL7yAlJQUGI1GjmKImmC4ED0AtVqNMWPGYPPmzUhOTkZ1dTXefvttzJw5E2VlZQwYopsYLkQPSAiBkJAQrFmzBjNnzoRKpcLKlSvxb//2b7h06RIDhggMF6KHIoSAr68v5s+fj08++QS+vr5IT0/HpEmTcPr0aQYMtXsMF6JWcHV1xUsvvYSvvvoK3bp1w6FDhzBp0iQcO3aMAUPtGsOFqJVUKhWefvppfPfddwgODsapU6cwefJkHD58mAFD7RbDhUgGQghERUXhu+++Q1hYGPLz8/H73/8ex48
2024-04-29 12:35:18 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"<Figure size 500x400 with 6 Axes>"
2024-04-29 12:35:18 -04:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 15,
"id": "abef7aa9",
2024-04-29 12:35:18 -04:00
"metadata": {},
"outputs": [
2024-07-13 22:17:48 -04:00
{
"name": "stdout",
"output_type": "stream",
"text": [
"r2 is 0.9999871253967285\n",
"r2 is 0.9999728798866272\n",
"r2 is 0.9998090863227844\n"
]
},
2024-04-29 12:35:18 -04:00
{
"data": {
"text/plain": [
2024-07-13 22:17:48 -04:00
"tensor(0.9998)"
2024-04-29 12:35:18 -04:00
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 15,
2024-04-29 12:35:18 -04:00
"metadata": {},
2024-07-13 22:17:48 -04:00
"output_type": "execute_result"
2024-04-29 12:35:18 -04:00
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fix_symbolic(0,0,0,'x^2')\n",
"model.fix_symbolic(0,1,0,'x^2')\n",
"model.fix_symbolic(1,0,0,'sqrt')"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 16,
"id": "e14000d8",
2024-04-29 12:35:18 -04:00
"metadata": {},
"outputs": [
{
2024-07-13 22:17:48 -04:00
"data": {
"text/latex": [
"$\\displaystyle 1.01262303277627 \\sqrt{\\left(8.59418125232242 \\cdot 10^{-5} - x_{2}\\right)^{2} + 0.999965395886852 \\left(- x_{1} - 2.26198704007758 \\cdot 10^{-5}\\right)^{2} + 0.00768977463773129} - 0.0159889459609985$"
],
"text/plain": [
"1.01262303277627*sqrt((8.59418125232242e-5 - x_2)**2 + 0.999965395886852*(-x_1 - 2.26198704007758e-5)**2 + 0.00768977463773129) - 0.0159889459609985"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
2024-04-29 12:35:18 -04:00
}
],
"source": [
2024-07-13 22:17:48 -04:00
"formula = model.symbolic_formula()[0][0]\n",
"formula"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 17,
"id": "c56ee3d5",
2024-04-29 12:35:18 -04:00
"metadata": {},
"outputs": [
{
"data": {
2024-07-13 22:17:48 -04:00
"text/latex": [
"$\\displaystyle 1.01 \\sqrt{1.0 x_{1}^{2} + x_{2}^{2} + 0.01} - 0.02$"
],
2024-04-29 12:35:18 -04:00
"text/plain": [
2024-07-13 22:17:48 -04:00
"1.01*sqrt(1.0*x_1**2 + x_2**2 + 0.01) - 0.02"
2024-04-29 12:35:18 -04:00
]
},
2024-07-13 22:17:48 -04:00
"execution_count": 17,
2024-04-29 12:35:18 -04:00
"metadata": {},
2024-07-13 22:17:48 -04:00
"output_type": "execute_result"
2024-04-29 12:35:18 -04:00
}
],
"source": [
2024-07-13 22:17:48 -04:00
"ex_round(formula, 2)"
]
},
{
"cell_type": "markdown",
"id": "1fd57d41",
"metadata": {},
"source": [
"w/ singularity avoiding"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
"execution_count": 11,
2024-07-13 22:17:48 -04:00
"id": "de708f21",
2024-04-29 12:35:18 -04:00
"metadata": {},
"outputs": [
{
2024-07-13 22:17:48 -04:00
"name": "stderr",
2024-04-29 12:35:18 -04:00
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 4.85e-08 | test loss: 4.84e-08 | reg: 5.95e+00 : 100%|██| 20/20 [00:01<00:00, 14.88it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20, update_grid=False, singularity_avoiding=True);"
2024-04-29 12:35:18 -04:00
]
},
{
2024-07-13 22:17:48 -04:00
"cell_type": "markdown",
"id": "6fd34c4c",
2024-04-29 12:35:18 -04:00
"metadata": {},
"source": [
2024-07-13 22:17:48 -04:00
"w/o singularity avoiding, nan may appear"
2024-04-29 12:35:18 -04:00
]
},
{
"cell_type": "code",
2024-07-13 22:17:48 -04:00
"execution_count": 18,
2024-04-29 12:35:18 -04:00
"id": "031fabd6",
2024-07-13 22:17:48 -04:00
"metadata": {
"scrolled": true
},
2024-04-29 12:35:18 -04:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: nan | test loss: nan | reg: nan : 25%|████▌ | 5/20 [00:01<00:03, 3.90it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Intel MKL ERROR: Parameter 6 was incorrect on entry to SGELSY.\n",
"\n",
"Intel MKL ERROR: Parameter 6 was incorrect on entry to SGELSY.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"ename": "RuntimeError",
"evalue": "false INTERNAL ASSERT FAILED at \"/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp\":1540, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_33275/1949812002.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"LBFGS\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msteps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, dataset, opt, steps, log, lamb, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff, update_grid, grid_update_num, loss_fn, lr, start_grid_update_step, stop_grid_update_step, batch, small_mag_threshold, small_reg_factor, metrics, save_fig, in_vars, out_vars, beta, save_fig_freq, img_folder, device, singularity_avoiding, y_th, reg_metric)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mgrid_update_freq\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mstop_grid_update_step\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mupdate_grid\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mstart_grid_update_step\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 806\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_grid_from_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train_input'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain_id\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 807\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 808\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"LBFGS\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/MultKAN.py\u001b[0m in \u001b[0;36mupdate_grid_from_samples\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdepth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 288\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact_fun\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_grid_from_samples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macts\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 289\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minitialize_grid_from_another_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/KANLayer.py\u001b[0m in \u001b[0;36mupdate_grid_from_samples\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0mgrid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrid_eps\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mgrid_uniform\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrid_eps\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mgrid_adaptive\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrid\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextend_grid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk_extend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcurve2coef\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_eval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrid\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minitialize_grid_from_parent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/2022/research/code/pykan/kan/spline.py\u001b[0m in \u001b[0;36mcurve2coef\u001b[0;34m(x_eval, y_eval, grid, k, device)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;31m# coef shape: (in_dim, outdim, G+k)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0my_eval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_eval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# y_eval: (in_dim, out_dim, batch, 1)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m coef = torch.linalg.lstsq(mat.to(device), y_eval.to(device),\n\u001b[0m\u001b[1;32m 167\u001b[0m driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]\n\u001b[1;32m 168\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcoef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: false INTERNAL ASSERT FAILED at \"/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp\":1540, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 6 has illegal value. Most certainly there is a bug in the implementation calling the backend library."
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
2024-07-13 22:17:48 -04:00
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
2024-04-29 12:35:18 -04:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}