GitHub_collection_pykan/hellokan.ipynb

347 lines
105 KiB
Plaintext
Raw Normal View History

2024-04-28 20:28:25 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# Hello, KAN!"
]
},
{
"cell_type": "markdown",
"id": "2571d531",
"metadata": {},
"source": [
"### Initialize KAN"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2075ef56",
"metadata": {},
"outputs": [],
"source": [
"from kan import *\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[2,5,1], grid=5, k=3, seed=0)"
]
},
{
"cell_type": "markdown",
"id": "3d72e076",
"metadata": {},
"source": [
"### Create dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "46717e8b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1000, 2]), torch.Size([1000, 1]))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# create dataset f(x,y) = exp(sin(pi*x)+y^2)\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
"dataset = create_dataset(f, n_var=2)\n",
"dataset['train_input'].shape, dataset['train_label'].shape"
]
},
{
"cell_type": "markdown",
"id": "8c6add1d",
"metadata": {},
"source": [
"### Plot KAN at initialization"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ac76f858",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB3NklEQVR4nO3dd1QTWfsH8CcJIYTeLIAV7L13RcUKYsECih3r2rEvuig21oq9995FXRXFjl1E1y72XlBAesl8f3/4S15ZG2WSmeD9nON5z9mXzDy5yc0zc+fe50oAgBiGYRiGR1KhA2AYhmFyH5ZcGIZhGN6x5MIwDMPwjiUXhmEYhncsuTAMwzC8Y8mFYRiG4R1LLgzDMAzvWHJhGIZheMeSC8MwDMM7llwYhmEY3rHkwjAMw/COJReGYRiGdyy5MAzDMLxjyYVhGIbhHUsuDMMwDO8MhA6AYfQBAPr48SPFx8eTqakp2djYkEQiEToshhEtdufCMD8RExND8+fPp+LFi1OePHmoaNGilCdPHipevDjNnz+fYmJihA6RYURJwnaiZJjvCwkJofbt21NiYiIRfbl7UVPftRgbG9Pu3bupefPmgsTIMGLFkgvDfEdISAi5ubkRAOI47od/J5VKSSKR0D///MMSDMN8hSUXhvmPmJgYKlCgACUlJf00sahJpVJSKpX08uVLsrS01H6ADKMH2DMXhvmP9evXU2JiYqYSCxERx3GUmJhIGzZs0HJkDKM/2J0Lw3wFABUvXpweP35MWekaEomEHB0dKTIyks0iYxhiyYVhMoiKiqI8efLk6PU2NjY8RsQw+okNizHMV+Lj43P0+ri4OJ4iYRj9xpILw3zl48ePOXq9mZkZT5EwjH5jyYX57UVHR9OKFSvI2dmZatSoke1nJjKZjCZPnkyXL1/O0vMahsmNWHJhfkspKSm0b98+at++Pdnb29OgQYPI2NiYNm7cSF5eXtk6ZqNGjejAgQPUoEEDKleuHE2dOpUePXrEc+QMox9YcmF+GwAoLCyMBg4cSA4ODtS+fXt69uwZzZgxg168eEE7d+6k48eP044dO0gmk2XpDkYikdCgQYPo4cOHdPjwYapduzYFBQVR2bJlydnZmZYvX06fPn3S4rtjGHFhs8WYXO/+/fu0efNm2rJlCz158oQKFixI3t7e1KVLFypbtiwREV27do28vb3p9evXtHDhQsqbNy+1atUqUyv0iYgqVqxI169fp9GjR9OUKVPI0NCQEhMT6cCBA7R161Y6duwYSaVSat68OXXp0oVcXV3JyMhIJ++fYQQBhsmF3r17h4ULF6JmzZqQSqWwtLREnz59cOrUKahUKs3fqVQqzJ49G0ZGRqhevToePHig+f+OHDkCExMTSCQSSCQSEJHmn/q/mZiYICQkBCqVCjNnzoSBgQGqVauGyMjIb+JZtGgR6tSpA4VCgbx582LAgAE4c+ZMhngYJrdgyYXJNRISErBt2za0atUKcrkchoaGaNOmDXbu3InExMRv/v7169do3rw5ZDIZxo4di5SUlG/+Jjo6GvPnz4eTk1OG5OLk5IT58+cjJiYmw99fuXIFxYoVg6mpKdavXw+O47455r179+Dv74/ixYtDoVCgWLFimDhxIu7evctfYzCMwFhyYfRaeno6Tpw4gV69esHCwgJSqRR16tTBkiVL8OHDhx++7tChQ8ifPz8cHBxw7NixX56H4zicPn0aVlZWOH369HeThtrnz5/Ro0cPEBG6dOmC2NjY7/6dSqVCWFgY/vjjD+TLlw8KhQK1atXCggUL8Pbt21+/eYYRMZZcGL3077//YuzYsShYsCCkUimKFy+OSZMmfTMc9V9JSUkYMWIEZDIZ3N3d8f79+0yf88aNG7CxscGNGzcy9fdbtmyBubk5ihYtigsXLvwyrr1796JTp04wNTWFUqmEu7s7tmzZgvj4+EzHyDBiwZILozdevXqF2bNno3LlypBKpbC1tcWgQYNw/vz5n95JqN25cweVK1eGUqnEggULMvWar2U1uQDA48ePUatWLchkMkybNg3p6em/fM3Hjx+xcuVKNGzYEAqFAtbW1ujVqxdCQ0Mz9XqGEQOWXBhR+/z5MzZs2ICmTZtCJpNBqVSiY8eO2L9//3efkXwPx3FYuXIlTE1NUa5cuSwlh69lJ7kAQGpqKvz8/CCRSNCwYUO8fPky0699/Pgxpk2bhrJly0KhUKBw4cIYO3Ysrl+/nuXkyDC6xJILIzppaWk4fPgwvL29YWJiAqlUikaNGmHVqlWIjo7O0rE+ffqEjh07QiaTYcCAAUhISMh2XNlNLmonT56Eg4MDrK2tsW/fviy9luM4XL58GSNGjIC9vT0UCgWqVKmCWbNmZSlZMYyusOTCiALHcbhy5QqGDx+O/PnzQyqVomzZspgxYwaePn2arWOePXsWRYoUga2tLfbs2ZPjGHOaXAAgKioKbdu2BRFh4MCB353F9iupqak4dOgQunbtCgsLCxgZGaFZs2ZYv379DycPMIyuseTCCOrJkyeYNm0aypQpA6lUCjs7O4wYMQLh4eHZHvZJS0uDv78/5HI5GjZsiOfPn/MSKx/JBfiSSJcuXQojIyOULVsW//77b7aPFRsbi/Xr16N58+YwMjKChYUFvL298c8//yA1NTVHcTJMTrDkwujcp0+fsHLlSjg7O0MqlcLU1BRdu3bFkSNHkJaWlqNjP3nyBPXq1YOhoSGmTJnC6wNwvpKL2q1bt1C+fHkoFAosWrQox89QXrx4gdmzZ6NKlSpQKBSwt7fH8OHDcfnyZfZ8htE5llwYnUhJScG+ffvQoUMHGBkZwcDAAM2aNcPGjRsRFxfHyzm2b98Oa2trODo64vz587wc82t8JxfgyxTkwYMHg4jg7u7+07U5mcVxHG7cuIGxY8eicOHCUCgUKFu2LKZOnYrHjx/zEDXD/BpLLozWcByHc+fOYeDAgbC1tYVUKkWVKlUwZ84cvHr1irfzxMXFwcfHBzKZDJ07d87yQ//M0kZyUdu/fz9sbGxgZ2eH0NBQ3o6bnp6O0NBQ9O7dG9bW1lAoFGjYsCFWrFiBjx8/8nYehvkvllwY3j148AD+/v4oVqwYpFIpChUqhHHjxuHmzZu8n+vatWsoXbo0zM3NsW7dOq0O/2gzuQBf1vG4uLhAIpFg3LhxvD8ziY+Px9atW+Hu7g6lUglTU1N07NgRe/fuRXJyMq/nYhiWXBhevH//HosWLULt2rUhlUphYWGB3r174+TJk1opzKhSqTB37lwYGRmhWrVquH//Pu/n+C9tJxfgy/sKDAyEgYEBatSogYcPH2rlPG/fvsWCBQtQu3ZtKBQK5MuXD3/88QfCwsJYIU2GFyy5MNmWmJiIHTt2wN3dHYaGhjA0NIS7uzu2b9+erSm2mfXmzRu0bNkSMpkMo0ePzvRiypzSRXJRu3TpEpycnGBmZoaNGzdq9Vx3797FX3/9hWLFikGhUKBEiRLw9/fHvXv3tHpeJndjyYXJEpVKhZMnT8LHxweWlpaQSqWoXbs2Fi1alKU6Xdl1+PBh2NnZwd7eHiEhIVo/39d0mVyAL9UJunXrBiKCt7e31tewqFQqnD17FgMGDEDevHmhUChQp04dLFy4UCefLZO7sOTCZMqtW7cwbtw4FCpUCFKpFMWKFYO/v79OhqMAIDk5Gb6+vpDJZHBzc8O7d+90ct6v6Tq5qG3atAlmZmZwdHTExYsXdXLOpKQk7NmzBx06dICpqSmMjY3RunVrbNu2LUdVDpjfB0suzA+9fv0ac+fORZUqVSCVSmFjY4OBAwfi3LlzOl03cffuXVSpUgVKpRLz588XbM2GUMkFAB49eoQaNWrAwMAA06dP12kBy48fP2L58uVo0KABFAoFbGxs4OPjg+PHj7NCmswPseTCZBAXF4eNGzeiefPmMDAwgJGRETp06IB9+/bpfEYRx3FYtWoVzMzMUKZMGVy/fl2n5/8vIZML8KXsy/jx4yGRSNCoUSNBaoo9evQIU6dORZkyZaBQKFCkSBGMGzcuR1UGmNyJJRcGaWlpCAkJQdeuXWFqagqpVApnZ2e
"text/plain": [
"<Figure size 500x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot KAN at initialization\n",
"model(dataset['train_input']);\n",
"model.plot(beta=100)"
]
},
{
"cell_type": "markdown",
"id": "ddf67e30",
"metadata": {},
"source": [
"### Train KAN with sparsity regularization"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97111d75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.55e-01 | test loss: 1.30e-01 | reg: 2.05e+01 : 100%|██| 20/20 [00:16<00:00, 1.21it/s]\n"
]
}
],
"source": [
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
]
},
{
"cell_type": "markdown",
"id": "2f30c3ab",
"metadata": {},
"source": [
"### Plot trained KAN"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3f95fcdd",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJNElEQVR4nO3deXhURdY/8O/tNZ0FIiEgsgkhyAyuCDMjqMgmIKsSUCAhQXBkZwRxAZF3AMURUPbxJzAhYR0ggCKbw6IiOCpGWQbEgAIGZAmQkKT37vr9wdyepukkvVT3re6cz/Pked+RdHd15d577j1VdUpijDEQQgghHKmUbgAhhJDoQ8GFEEIIdxRcCCGEcEfBhRBCCHcUXAghhHBHwYUQQgh3FFwIIYRwR8GFEEIIdxRcCCGEcEfBhRBCCHcUXAghhHBHwYUQQgh3FFwIIYRwR8GFEEIIdxRcCCGEcKdRugGERALGGK5evYqysjLEx8cjKSkJkiQp3SxChEVPLoRUori4GPPnz0dqaiqSk5PRpEkTJCcnIzU1FfPnz0dxcbHSTSRESBLtREmId7t27UK/fv1gNBoB3Hx6kclPLbGxscjLy0PXrl0VaSMhoqLgQogXu3btQo8ePcAYg9PprPD3VCoVJEnCtm3bKMAQ4oaCCyEeiouL0aBBA5hMpkoDi0ylUsFgMKCwsBCJiYmhbyAhEYDGXAjxkJOTA6PR6FNgAQCn0wmj0Yjc3NwQt4yQyEFPLoS4YYwhNTUVP//8M/w5NSRJQtOmTVFQUECzyAgBBRdCblFUVITk5OSgXp+UlMSxRYREJkqLEeKmrKwsqNeXlpZyagkhkY2CCyFu4uPjg3p9QkICp5YQEtkouBDiJikpCSkpKX6Pm0iShJSUFNSqVStELSMkslBwIcSNJEkYO3as369jjGHcuHE0mE/If1FwIcRDZmYmYmNjoVL5fnpotVq0bds2hK0iJLJQcCHEQ2JiIvLy8iBJUpUBRqVSQaVS4cUXX8Tq1auxdetWOByOMLWUEHHRVGRCKuBrbbFNmzahS5cu2LdvHz755BPUr18fmZmZqF27tiLtJkQEFFwIqURxcTFyc3OxYMECnD592vXfU1JSMG7cOGRmZqJmzZqu//7rr78iJycHpaWlSEtLQ+vWrWkchlRLFFwI8QFjDNeuXUNpaSkSEhJQq1atCoOGxWJBXl4evvnmGzz88MPo378/YmJiwtxiQpRFwYWQEMnPz8f69esRFxeHjIwM3H333Uo3iZCwoeBCSAhdvXoVK1euxLlz59C9e3d06tTJr1lohEQqCi6EhJjD4cDOnTuxe/dupKSkICMj45ZxGkKiEQUXQsLk1KlTWLlyJWw2GwYOHIj77rtP6SYREjIUXAgJo/Lycqxbtw5Hjx5Fu3bt0LdvX2i1WqWbRQh3FFwICTPGGA4ePIgtW7YgKSkJmZmZqFevntLNIoQrCi6EKOTixYvIycnBlStX0LdvX7Rr147WxJCoQcGFEAXZbDZ8/PHH2L9/P1q2bIlBgwYhLi5O6WYREjQKLoQI4D//+Q/WrFkDtVqN9PR0NG/eXOkmERIUCi6ECKKkpASrV69GQUEBOnXqhO7du0OtVivdLEICQsGFEIEwxrB3715s27YNDRo0wJAhQ6gAJolIFFwIEdC5c+eQm5uL0tJS9O/fH61bt1a6SYT4hYILIYKyWCzYuHEjvv32WyqASSIOBRdCBPfdd99h/fr1iI+Px5AhQ9C4cWOlm0RIlSi4EBIBrl69ipycHBQWFuKpp55Cx44dqQAmERoFF0IihMPhwI4dO7Bnzx40a9YM6enpVACTCIuCCyERpqCgAKtWrYLdbsfAgQNx7733Kt0kQm5DwYWQCOReAPPRRx9Fnz59qAAmEQoFF0IiFGMMBw4cwEcffYTatWtjyJAhVACTCIOCCyER7rfffkNubi6KiorQp08fKoBJhEDBhZAoYLPZ8NFHH+HLL7/Efffdh+eee44KYBJFUXAhJIocO3YMa9euhUajQXp6OlJTU5VuEqmmKLgQEmVKSkqwatUqnDp1igpgEsVQcCEkCjmdTuzduxfbt29Hw4YNMWTIECQlJSndLFKNUHAhJIqdO3cOOTk5KCsrw4ABA/Dwww8r3SRSTVBwISTKmc1mbNy4EYcOHUKbNm2QlpYGvV6vdLNIlKPgQkg1cejQIWzYsAEJCQkYMmQIGjVqpHSTSBSj4EJINVJUVITc3FwUFhaiR48e6NixI62JISFBwYWQasa9AGZqaioGDx5MBTAJdxRcCKmmfvrpJ6xatQoOhwODBg1Cy5YtlW4SiSIUXAipxsrLy7FmzRr85z//wWOPPYbevXtTAUzCBQUXQqo5uQDmli1bkJycjMzMTNx5551KN4tEOAouhBAANwtg5uTk4OrVq+jbty/atm1Lg/0kYBRcCCEuNpsNW7ZswYEDB6gAJgkKBRdCyG2OHj2KtWvXQqvVIiMjA82aNVO6SSTCUHAhhHhVXFyMVatW4fTp0+jcuTO6detGBTCJzyi4EEIq5HQ6sWfPHuzYsQONGjVCRkYGFcAkPqHgQgip0pkzZ7By5UqUl5djwIABaNWqldJNIoKj4EII8YnZbMaGDRvw3Xff4Q9/+AP69etHBTBJhSi4EEJ8xhjDoUOHsHHjRiQkJCAzMxMNGzZUullEQBRcCCF+KyoqQk5ODs6fP4+ePXuiQ4cOtCaG3IKCCyEkIA6HA9u3b8eePXtwzz33YPDgwahRo4bSzSKCoOBCCAmKewHMwYMH4/e//73STSICoOBCCAlaWVkZ1qxZg+PHj+Pxxx9Hr169qABmNUfBhRDCBWMM+/fvx8cff4w6deogMzMTdevWVbpZRCEUXAghXF24cAE5OTm4du0ann76aTzyyCM02F8NUXAhhHBntVqxZcsWHDx4EPfffz/S09Oh0+mUbhYJIwouhBC/+HPJKCoqwqVLlwLa5ZKediIbBRdCiF927tyJu+66K2Tvb7PZYDab0a5du5B9Bgk9jdINIIRElkuXLuHJJ58MyXufOXMGeXl5qF27NgWXCKdSugGEkMijUqm4/wDAhx9+iE6dOin87QgPFFwIIULYtWsXYmNj8dBDDyndFMIBBRdCiOLMZjM+/fRTjB07lgbyowQFF0KIohhj+PDDD/Hggw/ijjvuULo5hBMKLoQQRV28eBFnzpzB4MGDlW4K4YiCCyFEMYwxLF68GP369YNGQ5NXowkFF0KIYr755hs4nU60bdtW6aYQzii4EEIU4XA48M9//hOjRo2iQfwoRMGFEKKIf/7zn2jQoAHq16+vdFNICFBwIYSE3Y0bN/Dtt99i5MiR9NQSpSi4EELCijGGBQsWoEOHDjAYDEo3h4QIBRdCSFgdOnQI5eXl6Nmzp9JNISFEwYUQEjYmkwlr167FyJEjXfXESHSivy4hJCycTicWLVqENm3aoGHDhko3h4QYBRdCSMgxxvDRRx/BbDbj2WefpUH8aoCCCyEkpBhj+OKLL/D111/jL3/5C6XDqgmqt0AICRnGGD7//HPs3LkT48ePR0JCgtJNImFCwYUQEhIOhwMff/wx8vPzMW7cONSrV0/pJpEwouBCCOHuxo0bWLVqFcrLyzFhwgQqpV8NUXAhhHDjdDpx/PhxbNiwAc2bN8ewYcOg1+uVbhZRAAUXQojfTp48iYSEBMTGxkKj0cBms+G3337DwYMHUVhYiJ49e+Lhhx+mwftqjIILIcRvn3zyCSwWCxhjkCQJTqcTer0e9957LwYMGICEhASablzNUXAhhPhFo9Ggc+fOsFqtMJvNsNvt0Gq1iI+Ph06nQ2FhYdCfUatWLQ4tJUqSGGNM6UYQQiKH1WoN+WeoVCramTLCUXAhhISMw+GA0WhEbGws1Gq10s0hYUSjbYSQkPntt98wdepU/Pbbb0o3hYQZBRdCCCHcUXAhhBDCHQUXQggh3FFwIYQQwh0FF0IIIdxRcCGEEMIdBRdCCCHcUXAhhBDCHQUXQggh3FFwIYQQwh0FF0IIIdxRcCGEEMIdBRdCCCHcUXAhhBDCHQUXQggh3FFwIYQQwh0FF0I
"text/plain": [
"<Figure size 500x400 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "576856cf",
"metadata": {},
"source": [
"### Prune KAN and plot again"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7fe6fb12",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAx+0lEQVR4nO3deXBVVYIG8O++vJeXvOw7BBJCQlbDqiFhE0YgYRHcWy1xsJ0ZLataxrampsYau3Rqpmu6q6wRxpmxptpy13ZQREDAIAiidAIEEAhZyEICCRHIRl7ykrztzB/4jrksEshN7lu+X5V/3AN5OcR7892zK0IIASIiIg0Z9K4AERH5H4YLERFpjuFCRESaY7gQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaM+pdASJfIIRAR0cHent7ER4ejri4OCiKone1iLwWWy5Ev6C7uxsbNmxAZmYmEhISMHnyZCQkJCAzMxMbNmxAd3e33lUk8koKT6Ikur7S0lI89NBDsNlsAK60Xjw8rRaLxYJNmzahpKRElzoSeSuGC9F1lJaWYuXKlRBCwO123/DvGQwGKIqC7du3M2CIhmC4EF2lu7sbEydORH9//y8Gi4fBYEBoaChaWloQHR09+hUk8gEccyG6ynvvvQebzTasYAEAt9sNm82G999/f5RrRuQ72HIhGkIIgczMTDQ2NuJWHg1FUZCeno66ujrOIiMCw4VIpb29HQkJCSP6+ri4OA1rROSb2C1GNERvb++Ivt5qtWpUEyLfxnAhGiI8PHxEXx8REaFRTYh8G8OFaIi4uDhkZGTc8riJoijIyMhAbGzsKNWMyLcwXIiGUBQFzz///G197bp16ziYT/QTDugTXYXrXIhGji0XoqtER0dj06ZNUBQFBsMvPyKeFfqff/45g4VoCIYL0XWUlJRg+/btCA0NhaIo13R3ecpCQ0OxY8cOFBcX61RTIu/EcCG6gZKSErS0tGD9+vVIT09X/Vl6ejrWr1+P1tZWBgvRdXDMhWgYhBCorKzEf/zHf+DFF19Efn4+B++JfgFbLkTDoCgKYmJiEBkZiZiYGAYL0U0wXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMF6JhioqKwgMPPICoqCi9q0Lk9XieC9EwuVwu2Gw2WCwWBAUF6V0dIq/GcCEiIs2xW4yIiDRn1LsCRB5utxtNTU0YGBjQuyo+LyUlBREREXpXgwIYw4W8htPpxOeffw6HwwGz2ax3dXzWpUuX8MQTTyA/P1/vqlAAY7iQV1EUBffffz8yMzP1ropPcLvdOHHiBCZOnIj4+Hi4XC787//+LziUSnpjuJDXCQoKgtHIW/NmhBA4d+4cNm3aBLPZjMWLF2PWrFl6V4sIAMOFyGfZ7XaUlpaiv78f/f39+OKLLxjK5DU4W4zIBwkhcPToUZw+fVqWjRs3Dnl5eTrWiuhnDBciHyOEQFdXF/bs2QO32w0AMBqNWLZsGcLCwnSuHdEVDBciH+N2u7F37160t7fLsunTpyM3N1fHWhGpMVyIfIgQAk1NTTh8+LAsi4yMxNKlS7klDXkVhguRD/EM4nsWmiqKgoULFyIxMRGKouhcO6KfMVyIfIRnEL++vl6WpaSkoKioCAYDH2XyLrwjiXyAEAKXL19WDeKbTCYUFxfDYrHoXDuiazFciHyAEALffvutahB/2rRpyMnJYXcYeSWGC5GX86zELy8vl2Xh4eFYsmQJB/HJazFciLyc0+nE119/jf7+flk2f/58JCUlsdVCXovhQuTFhBA4deoUqqurZdn48eMxb948DuKTV+PdSeTF+vr68PXXX8PlcgG4sqnn0qVLER4ernPNiH4Zw4XIS7ndbpSXl+P8+fOyLDs7G/n5+ewOI6/HcCHyQkIIdHR04LvvvpNns4SEhGDp0qUwmUw6147o5hguRF7Is3/Y5cuXZVlBQQFSU1PZaiGfwHAh8jJCCJw9exZHjx6VZdHR0Vi4cCEH8cln8E4l8jIOhwO7d+9W7R929913Iy4ujq0W8hkMFyIvIoRAVVUVampqZFlycjIKCwsZLORTGC5EXsRms2H37t2qqceLFy/m/mHkcxguRF5CCIFDhw6htbVVlnHqMfkqhguRFxBCoLOzE/v371dNPV6yZAmnHpNPYrgQeQHPrsddXV2ybNasWZg0aRJbLeSTGC5EOhNCoKWl5ZqjixctWsSpx+SzeOcS6ex6ux4vWLAA8fHxbLWQz2K4EOlICIGampprdj3m0cXk63j3Eumov78fu3fvhtPpBPDz1GPueky+juFCpBMhBCoqKnD27FlZNmXKFEydOpXdYeTzGC5EOhBCoLu7G/v27ZNTj81mM5YsWYLg4GCda0c0cgwXIh0IIbB//350dnbKsunTpyM9PZ2tFvILDBeiMSaEQGtrKw4ePCjLIiIicM8993AQn/wG72SiMeZyubB7927YbDZZNnfuXCQmJrLVQn6D4UI0hoQQqK2txalTp2RZUlIS5s2bx1YL+RXezURjqL+/H7t27ZJTjw0GAxYvXoyIiAida0akLYYL0RjxTD0+d+6cLMvIyMD06dPZHUZ+h+FCNAY8ux7v3bsXbrcbwJWpx8XFxZx6TH6J4UI0BtxuN/bu3ava9XjmzJmcekx+i+FCNMqEEGhqalLtehwVFcWpx+TXeGcTjbLBwUGUlpZicHAQAKAoChYuXMhdj8mvMVyIRpEQAkeOHEF9fb0sS0lJQWFhIVst5Nd4dxONEs8g/p49e+QgvslkQklJCSwWi861IxpdDBeiUeJ2u/H1119fs39YdnY2u8PI7zFciEaB5xCwI0eOyLLo6GgUFxcjKChIx5oRjQ2GC5HGhBDo7e3Fjh074HA4AFxZiX/PPfcgISGBrRYKCAwXIo0JIbB37160trbKsoyMDMyePZvBQgGD4UKkISEE6uvr8Ze//EWWhYaGYsWKFTCbzTrWjGhsMVyINOLpDtu2bRsGBgYA/LymZdKkSWy1UEBhuBBpxO12Y9euXaqNKSdNmoQFCxZwTQsFHN7xRBoQQuDEiRMoLy+XZRaLBatXr+aaFgpIDBeiERJC4Mcff8S2bduumR2WlpbG7jAKSAwXohEQQsBms2Hz5s2qxZK5ubmYP38+u8MoYPHOJxoBl8uFnTt34vTp07IsPj4e9913H2eHUUBjuBDdJrfbje+//x5lZWWyzGw24/777+diSQp4DBei2+AZwN+5cydcLheAK+MsS5cuRV5eHoOFAh7DhegWCSFQV1eHzz77TJ7RAgB33XUX7r77bo6zEIHhQnRLhBBobm7Gn//8Z/T29sryrKwsrF69GiaTScfaEXkPhgvRMAkhcO7cOXz44Yfo6uqS5RMmTMCjjz6KsLAwdocR/cSodwWIfIEQAmfPnsUHH3yA9vZ2WZ6YmIgnnngCsbGxDBaiIRguRDchhEBDQwM+/vhj1VqWuLg4rFmzBuPHj2ewEF2F4UL0C9xuN6qqqrBx40b09PTI8tjYWKxZswYpKSkMFqLrYLgQ3YDL5cKhQ4ewdetW9Pf3y/KEhASsWbMGqampDBaiG2C4EF1FCAG73Y49e/Zg7969cr8wABg/fjyeeOIJTJgwgcFC9AsYLkRDCCFgtVqxZcsWHDt2DG63W/5Zeno6Hn/8ccTHxzNYiG6C4UL0EyEE2trasHHjRjQ1NclyRVEwbdo0PPjgg4iMjGSwEA0Dw4UIVwbuKysrsXnzZtUalqCgIMyfPx/Lly+H2WxmsBANE8OFAppnfOXbb7/Fnj17VNu5hIaGYuXKlSgqKoLRyEeF6FbwiaGAJYRAZ2cntmzZgsrKStX4Snx8PB5++GFkZWVxrzCi28BwoYAjhIDb7UZ1dTW2bNmCS5cuqf48KysLDz/8MLfNJxoBhgsFFCEEenp6sHv3bpSXl6umGRuNRsydOxclJSWwWCwMFqIRYLhQQBBCwOF
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = model.prune()\n",
"model(dataset['train_input'])\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "bd08ad99",
"metadata": {},
"source": [
"### Continue training and replot"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "18a2db11",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 4.71e-03 | test loss: 4.74e-03 | reg: 3.08e+00 : 100%|██| 50/50 [00:07<00:00, 6.51it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "af27aba7",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAyzUlEQVR4nO3deVSV54EG8Oe7XED2fVFQwxpRUBZlE2MSFVTSJMamsZ2kaZM5zZzOMcm0neks58xMz2lzJq0njZ1kZk6TSRMnW9NqEo0LcYtaFYGwCSgiCgIBBFnlAhfu984f5r7hGjUoH3x3eX7n5I/3Vbiv5H48990VIYQAERGRhgx6N4CIiJwPw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0Z9W4AkSMQQuDKlSu4evUqfH19ERISAkVR9G4Wkd1iz4XoFvr6+rB161YkJCQgLCwMMTExCAsLQ0JCArZu3Yq+vj69m0hklxTeREl0Y0VFRdi4cSNMJhOAa70XK2uvxdvbG9u3b0dBQYEubSSyVwwXohsoKipCYWEhhBBQVfWmf89gMEBRFOzevZsBQzQBw4XoOn19fYiOjsbw8PAtg8XKYDDAy8sLra2tCAwMnP4GEjkAzrkQXeett96CyWSaVLAAgKqqMJlM2LZt2zS3jMhxsOdCNIEQAgkJCbhw4QJu59FQFAWxsbFoaGjgKjIiMFyIbHR3dyMsLGxKXx8SEqJhi4gcE4fFiCa4evXqlL5+cHBQo5YQOTaGC9EEvr6+U/p6Pz8/jVpC5NgYLkQThISEIC4u7rbnTRRFQVxcHIKDg6epZUSOheFCNIGiKNi8efMdfe2zzz7LyXyiL3FCn+g63OdCNHXsuRBdJzAwENu3b4eiKDAYbv2IWHfo79ixg8FCNAHDhegGCgoKsHv3bnh5eUFRlK8Nd1nrvLy8sGfPHuTn5+vUUiL7xHAhuomCggK0trbi5ZdfRmxsrM2fxcbG4uWXX0ZbWxuDhegGOOdCNAlCCNTU1ODll1/G888/j+TkZE7eE90Cey5Ek6AoCoKCguDv74+goCAGC9E3YLgQEZHmGC5ERKQ5hgsREWmO4UJERJpjuBARkeYYLkREpDmGCxERaY7hQkREmmO4EBGR5hguRESkOYYLERFpjuFCRESaY7gQEZHmGC5EkxQQEICHHnoIAQEBejeFyO7xPheiSbJYLDCZTPD29oabm5vezSGyawwXIiLSHIfFiIhIc0a9G0BkpaoqmpqaMDIyondTHN7cuXPh5+endzPIhTFcyG6Mj4/jo48+wtjYGDw9PfVujkMSQqC7uxvf/e53kZycrHdzyIUxXMiuKIqChx56CPHx8Xo3xSEMDw+jvb0ds2fPhpeXFywWC37/+9/r3SwihgvZH4PBAKORb81vIoRAY2Mj3n33XQQFBeHuu+/G8uXLoSiK3k0jYrgQOSohBGpqajA+Po6uri709fUhNTVV72YRAeBqMSKHNTAwgMbGRlmOjIxEZGSkji0i+grDhcgBCSFw8eJFDAwMyLqkpCR4eHjo2CqirzBciByQdUjMugfa3d0dSUlJOreK6CsMFyIH1N/fzyExsmsMFyIHI4TA+fPnMTg4KOsWLlzIITGyKwwXIgejqiqqq6vlkJinpycWLVrEJchkVxguRA6mp6cHFy9elOXo6GhERETo2CKir2O4EDkQIQTOnDkDk8kk65KTk7nplOwOw4XIgYyNjaG6ulqWvb29kZSUxCExsjsMFyIHIYRAe3s7WltbZV1MTAxCQkJ0bBXRjTFciBxIVVUVzGYzgGuHfKampsJg4GNM9ofvSiIHMTQ0hNraWlkOCgpCQkICh8TILjFciByAdW9Ld3e3rEtKSoKvr6+OrSK6OYYLkQNQVRXl5eU2x72kpaWx10J2i+FCZOeEEOjs7LQ57iU6OhrR0dEMF7JbDBciB1BZWYnh4WFZTktLg7u7u44tIro1hguRnRsaGkJVVZUsBwQE8LgXsnsMFyI7JoTA2bNnbSbyFy1ahICAAB1bRfTNGC5EdmxsbAwlJSVyIt/DwwNLly5lr4XsHsOFyE4JIXDp0iU0NzfLupiYGERFRTFcyO4xXIjslKqqOHXqFMbGxgAABoMBWVlZPKSSHALDhcgOWZcf19XVybqIiAgkJiay10IOgeFCZIeEECgtLbVZfpyZmQkvLy8dW0U0eQwXIjvU29uLiooKWQ4KCsKSJUvYayGHwXAhsjPWXsvAwICsy8jIgL+/v46tIro9DBciO9PX14fS0lJZ9vPzw7Jly3i0PjkUvluJ7IgQAmVlZejt7ZV1qampvBCMHA7DhciO9PX14dSpU7Ls4+ODnJwczrWQw2G4ENkJ676Wnp4eWZeamorw8HCGCzkchguRnejp6flar2X58uUMFnJIDBciO6CqKo4fP47+/n5Zt3TpUvZayGExXIh0JoRAe3u7zQqxgIAA9lrIoTFciHRmsVhw6NAhmEwmWZednY3g4GCGCzkshguRjoQQOHfuHGpqamRdeHg4cnJyuK+FHBrfvUQ6Gh4exv79+21OPr733nvh5+enc8uIpobhQqQTVVVRXFyMS5cuybrY2FikpaVxOIwcHsOFSAdCCFy+fBlHjx6Vt0x6enpizZo18PDw0Ll1RFPHcCHSwdjYGPbt22dzOOXSpUsRGxvLXgs5BYYL0QwTQqCiogK1tbWyLjQ0FPfddx8n8clp8J1MNIOsw2FFRUWwWCwAADc3N+Tn5yMoKIi9FnIaDBeiGWQ2m7F792709fXJusWLF/MiMHI6DBeiGaKqKk6ePIm6ujpZFxISgrVr18JoNOrYMiLtMVyIZoAQAhcuXMCBAwegqioAwGg0Yv369QgNDWWvhZwOw4Vomgkh0NfXhw8//NDmiJfMzEykpKQwWMgpMVyIppnZbMbOnTvR3t4u6+bNm4eCggK4ubnp2DKi6cNwIZpG1kMpq6urZZ2fnx82bNgAX19f9lrIaTFciKaJqqooLy/HZ599JnfhG41GFBYWYu7cuQwWcmoMF6JpIIRAQ0MDdu7cKQ+lVBQFeXl5SE9P52ZJcnp8hxNpTAiB1tZWfPDBBxgaGpL1ycnJyM/P5zwLuQSGC5GGrDvw3333XfT29sr6efPmYcOGDfD09ORwGLkEhguRRoQQ6OrqwjvvvIPOzk5ZHxYWhsceewwBAQEMFnIZDBciDUwMltbWVlkfEBCATZs2ITIyksFCLoVnThBNkRACHR0deOedd/DFF1/Iej8/P2zatAl33XUXg4VcDsOFaAqEEGhubsZ7772Hrq4uWe/r64vHHnsMiYmJDBZySQwXojukqirq6uqwfft29Pf3y3o/Pz889thjWLBgAYOFXBbDheg2CSFgsVhw4sQJ7Nu3DyMjI/LPAgMDsWnTJiQkJDBYyKUxXIhugxACJpMJe/fuxalTp+SFXwAQERGBTZs2Yd68eQwWcnkMF6JJEkLgiy++wIcffogLFy7Y/FlsbCy+853vICwsjMFCBIYL0TcSQmBsbAyff/45ioqKMDAwIP/MYDAgPT0d3/rWt3gQJdEEDBeiW7DuX9m7dy9qampshsFmzZqFVatWYcWKFXB3d2ewEE3AcCG6ASEERkdHUVpaikOHDtmsBgOu7bp/+OGHcffdd/MQSqIbYLgQTWBdCXb+/Hns378fTU1N8rh8AHBzc8OSJUuwfv16BAUFsbdCdBMMFyJcCxVVVdHa2orDhw/jzJkz8qh8q6CgIBQUFCAtLQ1Go5HBQnQLDBdyaRND5dixY6itrcXo6KjN33F3d0dqairy8/MRHBzMUCGaBIYLuSTrCrCLFy/ixIkTqK+vh9lstvk7iqJg/vz5WLNmDRITE2EwGBgsRJPEcCGXIYSAEAIDAwOoq6tDaWkpWltbbVaAWUVERGDlypVITU3lHSxEd4DhQk7NOhk/OjqKlpYWVFRU4MyZM19b/QVc66mEhoZi+fLlSE9Ph4+PD0OF6A4xXMgpCSFgNpvR3t6Ouro61Nb
"text/plain": [
"<Figure size 500x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "cf35d505",
"metadata": {},
"source": [
"### Automatically or manually set activation functions to be symbolic"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b3c0642b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fixing (0,0,0) with sin, r2=0.9999870517513758\n",
"fixing (0,1,0) with x^2, r2=0.9999995406740803\n",
"fixing (1,0,0) with exp, r2=0.9999989605235987\n"
]
}
],
"source": [
"mode = \"auto\" # \"manual\"\n",
"\n",
"if mode == \"manual\":\n",
" # manual mode\n",
" model.fix_symbolic(0,0,0,'sin');\n",
" model.fix_symbolic(0,1,0,'x^2');\n",
" model.fix_symbolic(1,0,0,'exp');\n",
"elif mode == \"auto\":\n",
" # automatic mode\n",
" lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']\n",
" model.auto_symbolic(lib=lib)"
]
},
{
"cell_type": "markdown",
"id": "821ba616",
"metadata": {},
"source": [
"### Continue training to almost machine precision"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c0800415",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.57e-11 | test loss: 1.66e-13 | reg: 3.08e+00 : 100%|██| 50/50 [00:02<00:00, 20.36it/s]\n"
]
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
]
},
{
"cell_type": "markdown",
"id": "e39da499",
"metadata": {},
"source": [
"### Obtain the symbolic formula"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "bf44f7e0",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \\sin{\\left(3.14 x_{1} \\right)}}$"
],
"text/plain": [
"1.0*exp(1.0*x_2**2 + 1.0*sin(3.14*x_1))"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.symbolic_formula()[0][0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}