135 lines
33 KiB
Plaintext
135 lines
33 KiB
Plaintext
![]() |
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "f8ba3161",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Interprebility 6: Test symmetries of trained NN"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "87f1e596",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGaCAYAAACSWkBBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArDklEQVR4nO3deXhU5d3/8c9kIUhiQYUq9OEB9KmlWgUEWSw7LgSsZVVE9h23StGiPlIEfYo7XlVW2cIiioJ1YRE3CCiLBYJLVWwFW0VbUgF/CSZkOb8/vmUTWZLMzD0z9/t1XV7OCcnkmznzPfnkvs+5TygIgkAAAMBbSa4LAAAAbhEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8FyK6wIAuBMEgfbv3y9JqlKlikKhkOOKALjAyADgsf379ysjI0MZGRmHQgEA/xAGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAPBc8J//0jMypFBIuvdexxUBiDbCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADguRTXBQCooMJCaevWcn1p0nffHfvBL76QNmwoXy01akjnnVe+rwXgTCgIgsB1EQAqYOdOqV4911WY/v2luXNdVwGgjJgmAADAc4QBAAA8RxgAAMBzhAEg3tWtKwVBmf7L3R3o6s6BQso75un23zGuzM936D/OFwDiEmEA8Ex2ttSwobRsmVSpUpVj/n3aNGnt2ujXBcAdwgDgiZIS6b77pHbtpC+/lH72M2nTptAxn/ft/5PatpXuv9++BkDiIwwAHvjqK+nKK6Xf/14qLbUrAP/8Z6lBg2M/t8HF9jljx9rXfPVV9OsFEF2EASDBvfqq/dJ/800pPV3KyrKp/YyMH/78rl3tc9LT7WsaNLDnAJC4CANAgioqksaMkTp2lHbvli6+WNq8WerX7+Rf26+fjRxcfLF9bceO0p132nMCSDyEASAB7dwptW4tPfSQbd94o7Rxo50ncKrq17dViUeOtO0HH5TatJE+/zzs5QJwjDAAJJgXXpAaNbJf5FWrSs8/L02eLFWuXPbnOu00acoUe46qVaX16+1KhD/9KdxVA3CJMAAkiIIC6eabpW7dpL17pWbN7P5F3btX/Lm7d7fnatrUnrtrV+mWW+x7Aoh/hAEgAWzfLrVoYSMAknTHHbZWQDjvX1Svnj3n7bfb9pNP2vfcvj183wOAG4QBIM4tWCBdcomUkyNVry4tX27nCqSmhv97VaokPfywfY/q1e17Nm5sNQCIX4QBIE7l50sDB0p9+9rjtm2lbdukzMzIf+/MTAsCbdtKeXlWw6BBVgeA+EMYAOLQe+9JTZrYegFJSdL48dLrr0u1akWvhp/8xL7nvfdaDXPmSJdeKr3/fvRqABAehAEgjgSB3TugaVPp44/tl/+bb9rKgsnJ0a8nOVkaN0564w2pZk3po4+stunTrVYA8YEwAMSJffuk666z6/4LC6VOnWyovk0b15UdPUVRUCCNGCH16mU1A4h9oSAgvwOxbtMm++W6Y4eUkiI98IA0apQNz8eS0lLpsceku+6Sioulc8+VnnnGpg8AxC7CABDDSkulSZNsKeDiYqluXfvl2qyZ68pObONGCy87d9pVDQfDS+jYmyQCiAGEASBG5eba3QWXL7ftHj2kp56SqlVzWtYp27tXGjrUVi+UpM6d7YTH6tVdVgXgh8TYICMASVqzxu4WuHy5lJYmTZ0qLV4cP0FAsloXL7bljNPSpGXLbCnj7GzXlQH4PsIAEENKSqQJE6T27aVdu+xmQZs22Ql58TjEHgrZCY8Hb5L05ZdSu3bSfffZzwogNjBNAMSIXbukPn2kt96y7f79bcnfjAy3dYVLXp7dOyEry7bbt7eVC2vWdFsXAMIAEBNWrpT69ZN275bS021aoG9f11VFxrx5dkvl/HypRg1p/nzpqqtcVwX4jWkCwKGiImnMGLs+f/duO09g8+bEDQKShZ7Nm6WLL7afuWNHew2KilxXBviLkQHAkZ07peuvlzZssO2bbpIeeUSqXNlpWVFTUCCNHm0nGEpS8+bSokV2+SSA6CIMAA4sXSoNHmyX31WtKs2eLXXr5roqN5Yssddi3z67AmH2bKlrV9dVAX5hmgCIooICO4mue3cLAs2a2ZLCvgYByV6LrVvttdi7116Lm2+21wpAdBAGgCjZvt2GwidPtu3f/U5au5ZhcUmqV89eizvusO3Jk6UWLew1AxB5TBMAUTB/vl1vf/AM+nnz7MQ5HGvFCjvJMDfXrqyYNs0uuQQQOYwMABGUlycNGGC/3PLzbcGdnByCwIlkZtodENu2tdesb19p4EB7DCAyCANAhLz3ntSkiS2yk5QkjR8vvfaaVKuW68piX61a0uuv22uWlGT3NGjSxF5TAOHHNAEQZkEgTZ8u3XabVFhov9ieflpq08Z1ZfFpzRqpd29boTEtTXr8cWn48PhcnhmIVYQBIIy+f6e+Tp3sr9oaNVxWFf9277bploN3cOzZ0+7gWLWq07KAhME0ARAmmzZJjRpZEEhJkR59VHr5ZYJAONSoYa/lI4/Ya/vcc/Zab9rkujIgMTAyAFRQaak0aZJ0551ScbFdJvfMM1LTpq4rS0wbN0q9etkKjikp0gMPSKNG2bkFAMqHMABUAMPXbvzQdExWllS9utOygLhFlgbKac0aqWFDCwKVK9v18M8+SxCIhmrVpMWL7e6OaWm2Dxo0sH0CoOwIA0AZlZTYJW/t29sZ7vXr29w1Z7hHVygkjRhhr339+rYv2reXJkywfQTg1DFNAJTBrl3SDTdIq1fb9oAB0pNP2kp5cCcvz+5nkJVl2+3aSQsWsKYDcKoIA8ApWrnSVsNjmdzYxbLPQPkwTQCcRFGR3VQoM9OCQMOG0pYtBIFY1LevtHmznT+we7ftszFjbB8COD5GBoAT2LnTLmPbuNG2b75ZevhhO2EQsaugQLr99sN3iGzeXFq0iDtEAsdDGACOY8kSafBgad8+O3t91iypWzfXVaEsli6VBg1iHwInwzQB8D0FBdJNN0k9etgvkebNpa1b+SUSj7p1s7tENmtmaxN0726jOwUFrisDYgthADjCJ5/YL/8pU2x7zBgpO5vh5XhWt660dq2d9yHZ1EHz5tL27U7LAmIK0wTAf8ybJ914I2eiJ7KVK6V+/ezkwvR0W7Sob1/XVQHuMTIA7+XlSf3723/5+XaNek4OQSARdexo+7ZdO9vX/frZWhF5ea4rA9wiDMBr27ZJTZrYKEBSkq1e99prLFaTyGrVsn08frzt86wsew+8957rygB3mCaAl4LAFg0aNUoqLLRfEIsWSa1bu64M0bRmjdS7t60smZYmPf44y0rDT4QBeGfvXmnIELt0UJI6d5bmzuWOd776/p0ne/SwO09Wq+ayKiC6mCaAVzZulBo1siCQmio9+qj08ssEAZ/VqGHvgUcflVJS7LbIjRrZDZAAXzAyAC+UlkqPPSbddZdUXCzVq2e3G770UteVIZZs2mQrTu7YYcHggQdsKimJP5uQ4AgDSHi7d9uVAitW2Pa110ozZkhVq7qtC7Fp3z5p6FDpuedsu1Mnm0aqUcNpWUBEkXeR0FavthsLrVh
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from kan import *\n",
|
||
|
"from kan.hypothesis import plot_tree\n",
|
||
|
"\n",
|
||
|
"f = lambda x: (x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2\n",
|
||
|
"x = torch.rand(100,4) * 2 - 1\n",
|
||
|
"plot_tree(f, x)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "58c2ece4-a8dc-4b4e-83cc-49a3f04c1ec5",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda\n",
|
||
|
"checkpoint directory created: ./model\n",
|
||
|
"saving model version 0.0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"| train_loss: 1.58e-03 | test_loss: 4.79e-03 | reg: 2.38e+01 | : 100%|█| 100/100 [00:20<00:00, 4.93"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"saving model version 0.1\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||
|
"print(device)\n",
|
||
|
"\n",
|
||
|
"dataset = create_dataset(f, n_var=4, device=device)\n",
|
||
|
"model = KAN(width=[4,5,5,1], seed=0, device=device)\n",
|
||
|
"model.fit(dataset, steps=100);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "c02037c9-c903-4fc8-96bf-c78609ce0696",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGaCAYAAACSWkBBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAta0lEQVR4nO3dd3hUZfr/8c+EhCCJCyKo4NcvqCsirALSXTq6GHBdqiKC9GqDRRb1J6ugKzaEa6VLDSCCAl+lig0CShMIllXRBdxVLGQpXgkEUs7vj3tpIiXJzJyZed6v6/JyzpByZ87cJ588zznPCXie5wkAADgrzu8CAACAvwgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOC7e7wIA+MfzPB06dEiSVLJkSQUCAZ8rAuAHRgYAhx06dEjJyclKTk4+HgoAuIcwAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIA4DjZkjyJCUlJ0tPPOFzNQD8QBgAAMBxhAEAABxHGAAAwHGEASCW7dwZHV8TgK8IA0CsWrtWuuYaqW9fKSOj6F/vm2+kdu2kqlWlr78u+tcDEDEIA0AsysuT7rtPys+XXn5ZqlxZmjjRtgvqyBHpqaek666TFi+27QcfDH7NAHxDGABi0d69UpkyJ7b375cGDpTq1JHWrz//r7NsmVStmjR8uHT4sD0XCEhJSSe2AUQ9wgAQiy67THr/fWnFCqlmzRPPb90q/f73Us+e0k8/nfnzd+6Ubr9duu026Z//PPF8ixbSpk3SggXSBReErn4AYUUYAGLZrbdKW7ZI8+ZJv/2tPed50owZ0rXXKn7iRBU7+eOzs23hoWrVpCVLTjxfq5a0apX0zjtS7dph/AEAhEPA8zzP7yIAhEFurp0/MHKk9MMPx5/O10l/FQQCFhaOueYaO1+gY0f7NwAxiTAAuObQIWnsWOm556SDB3/9Y8qXl/76V6l3byk+PqzlAQg/wgDgqq+/Vn7Dhor78cdTn7/uOiktTSpb1p+6AIQd5wwArjl8WHr2WalOndODgCR9/rlUvbo0ZYpNLQCIeYwMAK7Iy5OmT7cTBPfsOf70Wc8ZqFxZ+tvfpA4dwlgogHBjZABwwaJFdoVA374ngkCpUjry3HOae/LHDR1q5wqUKGHbO3bYyYP16tmligBiEmEAiGWrV9sv8vbtpS+/tOcCAal7d2nHDuUOHKi8kz/+ggukESOkzz6zNQaO2bRJat7cLlVMTw9b+QDCgzAAxKIff5RSUqRmzewX+TE1a0offGDrDFxyyZk//6qrbJ2BpUulq68+8fxbb0k33ijdfTcrEAIxhDAAxKKLL5a+/fbE9kUXSePHSx99JDVocP5fp3VrGyUYOfLEioOeJ+3bxwqEQAwhDACxKD7efvnHxdlaATt22L0J4grR8omJdm+Czz+X2rSx7ZdeCnrJAPzDaiJArGrcWPrqKxvyD4aKFe2uhTt3Bu9rAogIjAwAsSwUv7QJAkDMIQwAAOA4wgAAAI4jDAAA4DjCAOAoz5Mee0zqISkg6c+DM+U9/oTPVQHwA/cmAByUkWGLEC5bliUp+b/PZuq225I0YwY3LARcQxgAHJOWJnXuLH33nVS8uKdnnjkkSXr44ZI6ejSgyy+X5s2TGjXyuVAAYUMYAByRlyc9/bTdtDA/X7r2Wmn+fLtbsSRt3y7dcYetTxQXZ7coeOQRqVgxX8sGEAaEAcAB338vdekivfeebXfrJo0bJyUnn/pxmZnSvfdKqam23by5NGeOVL58eOsFEF6EASDGvfWW1LWrtHevlJQkTZgg3XPP2T8nNdVWL87KksqVk2bPllq2DE+9AMKPqwmAGJWTIw0bZncd3rtXuuEGacuWcwcByT7mo4/sc/buta/x8MP2NQHEHkYGgBi0e7d0113Shg22PXCgNHq0VKJEwb7O4cPSkCHSxIm23aCBnVxYsWJQywXgM8IAEGMWL5Z69pQOHJBKlZKmTZPaty/a11y4UOrVSzp4UCpdWpoxw25gCCA2ME0AxIjsbOm++6R27SwI1KsnbdtW9CAg2dfYtk2qW9e+dtu20v332/cEEP0IA0AM2LHDhvDHj7ftoUOltWulK68M3ve48kr7mg89ZNvjxtn33LEjeN8DgD+YJgCi3Jw5Uv/+duZ/2bJ2JUBKSmi/54oVdpJhRoZdnjhxol26CCA6MTIARKmsLKlHD7tsMCtLatrUFg4KdRCQ7Hukp9v3zMy0Gnr2tDoARB/CABCFPv5Yql1bmjnzxGqB77wjVagQvhouv9y+5xNPWA0zZkh16kiffBK+GgAEB9MEQBTxPGnyZGnQIOnIEfvl/8orUpMm/ta1erXd7+D77+3yxbFjpb59pUDA37oAnB/CABAlDh6U+vSRXnvNtlu1spGBcuV8Leu4vXttmeMVK2z7jjukKVPs8kYAkY1pAiAKbNok1axpQSA+XnrhBWnJksgJApLVsnSp9PzzVuOCBdKNN0qbN/tdGYBzYWQAiGD5+dKYMbYUcG6uVKmS9OqrtoZAJNu4UerUyVZCTEiQnnlGGjyYaQMgUhEGgAiVkWHD7suX23aHDtLLL9sKgNHgwAGb1nj9ddtu3dqmNcqW9bMqAL+GaQIgAq1ZI1WvbkEgMdGu41+wIHqCgGS1Llhgd0lMTJSWLZNq1JDS0vyuDMAvEQaACJKXJ40cKTVvLu3ZI1WpYucL9O8fnUPsgYA0YIBNG1x7rfTdd1KzZtKTT9rPCiAyME0ARIg9e2wVv/fft+1u3WzJ3+Rkf+sKlsxMu3fCrFm23by5rZ5Yvry/dQEgDAARYeVKW953714pKcmmBbp29buq0EhNtVsqZ2XZFQizZ0stW/pdFeA2pgkAH+XkSMOG2fK+e/faeQJbtsRuEJAs9GzZIt1wg/3Mt95qr0FOjt+VAe5iZADwye7d0l13SRs22Pa999r6ASVK+FpW2GRnS0OG2AmGklS/vjRvnl0+CSC8CAOADxYtknr1ssvvSpWSpk+X2rXzuyp/LFxor8XBg3YFwvTpUtu2flcFuIVpAiCMsrPtJLr27S0I1Ktnd/9zNQhI9lps22avxYED9lrcd5+9VgDCgzAAhMmOHTYUPn68bf/lL9LatQyLS9KVV9prMXSobY8fLzVoYK8ZgNBjmgAIg9mz7Xr7Y2fQp6baiXM43YoVdpJhRoZdWTFpkl1yCSB0GBkAQigzU+re3X65ZWXZgjvp6QSBs0lJkbZvl5o2tdesa1epRw97DCA0CANAiHz8sVS7ti2yExcnjRghvf22VKGC35VFvgoVpHfesdcsLs7uaVC7tr2mAIKPaQIgyDxPmjxZGjRIOnLEfrG98orUpInflUWnNWukzp1thcbERGnsWKlfv+hcnhmIVIQBIIh+eae+Vq3sr9py5fysKvrt3WvTLcfu4Nixo93BsVQpX8sCYgbTBECQbNok1axpQSA+Xho9WlqyhCAQDOXK2Wv5wgv22r72mr3Wmzb5XRkQGxgZAIooP18aM0Z6+GEpN9cuk3v1ValuXb8ri00bN0qdOtkKjvHx0jPPSIMH27kFAAqHMAAUAcPX/vi16ZhZs6SyZX0tC4haZGmgkNaskWrUsCBQooRdDz9/PkEgHEqXlhYssLs7JibaPqhe3fYJgIIjDAAFlJdnl7w1b25nuFepYnPXnOEeXoGA1L+/vfZVqti+aN5cGjnS9hGA88c0AVAAe/ZId98trV5
|
||
|
"text/plain": [
|
||
|
"<Figure size 640x480 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.tree(sym_th=1e-2, sep_th=5e-1)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "2c2f31d6-be08-4bb8-a678-2c0d3f456722",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.16"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|