135 lines
33 KiB
Plaintext
135 lines
33 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f8ba3161",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Interprebility 6: Test symmetries of trained NN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "87f1e596",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGaCAYAAACSWkBBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArDklEQVR4nO3deXhU5d3/8c9kIUhiQYUq9OEB9KmlWgUEWSw7LgSsZVVE9h23StGiPlIEfYo7XlVW2cIiioJ1YRE3CCiLBYJLVWwFW0VbUgF/CSZkOb8/vmUTWZLMzD0z9/t1XV7OCcnkmznzPfnkvs+5TygIgkAAAMBbSa4LAAAAbhEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8FyK6wIAuBMEgfbv3y9JqlKlikKhkOOKALjAyADgsf379ysjI0MZGRmHQgEA/xAGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAAAAzxEGAADwHGEAAADPEQYAAPAcYQAAAM8RBgAA8BxhAPBc8J//0jMypFBIuvdexxUBiDbCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADguRTXBQCooMJCaevWcn1p0nffHfvBL76QNmwoXy01akjnnVe+rwXgTCgIgsB1EQAqYOdOqV4911WY/v2luXNdVwGgjJgmAADAc4QBAAA8RxgAAMBzhAEg3tWtKwVBmf7L3R3o6s6BQso75un23zGuzM936D/OFwDiEmEA8Ex2ttSwobRsmVSpUpVj/n3aNGnt2ujXBcAdwgDgiZIS6b77pHbtpC+/lH72M2nTptAxn/ft/5PatpXuv9++BkDiIwwAHvjqK+nKK6Xf/14qLbUrAP/8Z6lBg2M/t8HF9jljx9rXfPVV9OsFEF2EASDBvfqq/dJ/800pPV3KyrKp/YyMH/78rl3tc9LT7WsaNLDnAJC4CANAgioqksaMkTp2lHbvli6+WNq8WerX7+Rf26+fjRxcfLF9bceO0p132nMCSDyEASAB7dwptW4tPfSQbd94o7Rxo50ncKrq17dViUeOtO0HH5TatJE+/zzs5QJwjDAAJJgXXpAaNbJf5FWrSs8/L02eLFWuXPbnOu00acoUe46qVaX16+1KhD/9KdxVA3CJMAAkiIIC6eabpW7dpL17pWbN7P5F3btX/Lm7d7fnatrUnrtrV+mWW+x7Aoh/hAEgAWzfLrVoYSMAknTHHbZWQDjvX1Svnj3n7bfb9pNP2vfcvj183wOAG4QBIM4tWCBdcomUkyNVry4tX27nCqSmhv97VaokPfywfY/q1e17Nm5sNQCIX4QBIE7l50sDB0p9+9rjtm2lbdukzMzIf+/MTAsCbdtKeXlWw6BBVgeA+EMYAOLQe+9JTZrYegFJSdL48dLrr0u1akWvhp/8xL7nvfdaDXPmSJdeKr3/fvRqABAehAEgjgSB3TugaVPp44/tl/+bb9rKgsnJ0a8nOVkaN0564w2pZk3po4+stunTrVYA8YEwAMSJffuk666z6/4LC6VOnWyovk0b15UdPUVRUCCNGCH16mU1A4h9oSAgvwOxbtMm++W6Y4eUkiI98IA0apQNz8eS0lLpsceku+6Sioulc8+VnnnGpg8AxC7CABDDSkulSZNsKeDiYqluXfvl2qyZ68pObONGCy87d9pVDQfDS+jYmyQCiAGEASBG5eba3QWXL7ftHj2kp56SqlVzWtYp27tXGjrUVi+UpM6d7YTH6tVdVgXgh8TYICMASVqzxu4WuHy5lJYmTZ0qLV4cP0FAsloXL7bljNPSpGXLbCnj7GzXlQH4PsIAEENKSqQJE6T27aVdu+xmQZs22Ql58TjEHgrZCY8Hb5L05ZdSu3bSfffZzwogNjBNAMSIXbukPn2kt96y7f79bcnfjAy3dYVLXp7dOyEry7bbt7eVC2vWdFsXAMIAEBNWrpT69ZN275bS021aoG9f11VFxrx5dkvl/HypRg1p/nzpqqtcVwX4jWkCwKGiImnMGLs+f/duO09g8+bEDQKShZ7Nm6WLL7afuWNHew2KilxXBviLkQHAkZ07peuvlzZssO2bbpIeeUSqXNlpWVFTUCCNHm0nGEpS8+bSokV2+SSA6CIMAA4sXSoNHmyX31WtKs2eLXXr5roqN5Yssddi3z67AmH2bKlrV9dVAX5hmgCIooICO4mue3cLAs2a2ZLCvgYByV6LrVvttdi7116Lm2+21wpAdBAGgCjZvt2GwidPtu3f/U5au5ZhcUmqV89eizvusO3Jk6UWLew1AxB5TBMAUTB/vl1vf/AM+nnz7MQ5HGvFCjvJMDfXrqyYNs0uuQQQOYwMABGUlycNGGC/3PLzbcGdnByCwIlkZtodENu2tdesb19p4EB7DCAyCANAhLz3ntSkiS2yk5QkjR8vvfaaVKuW68piX61a0uuv22uWlGT3NGjSxF5TAOHHNAEQZkEgTZ8u3XabVFhov9ieflpq08Z1ZfFpzRqpd29boTEtTXr8cWn48PhcnhmIVYQBIIy+f6e+Tp3sr9oaNVxWFf9277bploN3cOzZ0+7gWLWq07KAhME0ARAmmzZJjRpZEEhJkR59VHr5ZYJAONSoYa/lI4/Ya/vcc/Zab9rkujIgMTAyAFRQaak0aZJ0551ScbFdJvfMM1LTpq4rS0wbN0q9etkKjikp0gMPSKNG2bkFAMqHMABUAMPXbvzQdExWllS9utOygLhFlgbKac0aqWFDCwKVK9v18M8+SxCIhmrVpMWL7e6OaWm2Dxo0sH0CoOwIA0AZlZTYJW/t29sZ7vXr29w1Z7hHVygkjRhhr339+rYv2reXJkywfQTg1DFNAJTBrl3SDTdIq1fb9oAB0pNP2kp5cCcvz+5nkJVl2+3aSQsWsKYDcKoIA8ApWrnSVsNjmdzYxbLPQPkwTQCcRFGR3VQoM9OCQMOG0pYtBIFY1LevtHmznT+we7ftszFjbB8COD5GBoAT2LnTLmPbuNG2b75ZevhhO2EQsaugQLr99sN3iGzeXFq0iDtEAsdDGACOY8kSafBgad8+O3t91iypWzfXVaEsli6VBg1iHwInwzQB8D0FBdJNN0k9etgvkebNpa1b+SUSj7p1s7tENmtmaxN0726jOwUFrisDYgthADjCJ5/YL/8pU2x7zBgpO5vh5XhWt660dq2d9yHZ1EHz5tL27U7LAmIK0wTAf8ybJ914I2eiJ7KVK6V+/ezkwvR0W7Sob1/XVQHuMTIA7+XlSf3723/5+XaNek4OQSARdexo+7ZdO9vX/frZWhF5ea4rA9wiDMBr27ZJTZrYKEBSkq1e99prLFaTyGrVsn08frzt86wsew+8957rygB3mCaAl4LAFg0aNUoqLLRfEIsWSa1bu64M0bRmjdS7t60smZYmPf44y0rDT4QBeGfvXmnIELt0UJI6d5bmzuWOd776/p0ne/SwO09Wq+ayKiC6mCaAVzZulBo1siCQmio9+qj08ssEAZ/VqGHvgUcflVJS7LbIjRrZDZAAXzAyAC+UlkqPPSbddZdUXCzVq2e3G770UteVIZZs2mQrTu7YYcHggQdsKimJP5uQ4AgDSHi7d9uVAitW2Pa110ozZkhVq7qtC7Fp3z5p6FDpuedsu1Mnm0aqUcNpWUBEkXeR0FavthsLrVhh9xOYPl165hmCAI6valUbNZo2zd4zy5fbe2jNGteVAZFDGEBCKimxS8c6dLAzxevXtyHgYcM4UxwnFwrZVQWbNtl7Z9cuqX17e0+VlLiuDgg/pgmQcHbtkm64wUYFJGngQOmJJ2zFOaCs8vPtfgZz59p227bSwoWsRYHEQhhAQlmxwlaVy82VMjJsudk+fVxXhUSwYIE0YoSFg+rVpfnzWaUSiYNpAiSEAwekO+6wk71yc22Od/NmggDCp08facsWe2/l5kqZmXbzo6Ii15UBFcfIAOLejh12OdjB68Jvvll6+GE7+QsIt4ICC55PPmnbzZrZSanc2RLxjDCAuLZkiTR4sF0OVq2aNHu21LWr66rgg6VL7b23d69dgTBrltS9u+uqgPJhmgBxqaDAbjfco4cFgebN7W50BAFES7du0tat9t7bt8/eizfdZO9NIN4QBhB3Pv7YhmanTrXtMWOk7GypTh23dcE/devae2/MGNueMsXCwSefOC0LKDOmCRBXsrLsr6/8fFsRbv586aqrXFcFSCtX2pUsu3fbZaxTptg2EA8IA4gLeXkWAubNs+327e1Sr5o13dYFHGnXLrvq4K23bLtfP2nyZLvMFYhlTBMg5m3bJjVpYkEgKUm67z5p1SqCAGJPrVrSa69JEybYe3XePHvvbtvmujLgxBgZQMwKAjsv4Le/lQoLpZ/8RHr6aal1a9eVASeXnS1df72NFqSlSZMm2aJFLIeNWEQYQEzau1caMsQuHZSkq6+W5syxld+AeJGbKw0YIC1bZtvdu0szZ9plsEAsYZoAMWfjRqlRIwsCqanSY49JL71EEED8qV5devll6dFH7b28ZIm9tzdudF0ZcDRGBhAzSkvtoHn33VJxsXTuubay26WXuq4MqLh335Wuu85WzExJkSZOtCmwJP4kQwwgDCAm7N4t9e9vNxqSpGuvlWbMsJXdgESxb5/dRnvxYtvOzLTLZWvUcFsXQCaFc6tXSw0aWBCoXFmaPt1GBAgCSDRVq9p7e/p0e6+vWGE3Pjp4u23AFcIAnCkpke69V+rQQfrqK+nnP7ebDQ0bxhnXSFyhkL3HN22S6te3qw06dJDGj7eeAFxgmgBOfPmldMMN0po1tj1okPTHP9rKbYAv8vOlW26xK2UkqW1baeFCW68AiCbCAKJu+XI7PyA311ZmmzbNggHgqwULpJEjbaXN6tVtsaLMTNdVwSdMEyBqDhyQbr9d6tzZgkCjRtKWLQQBoE8fafNmO38gN1fq1Em64w7rGSAaGBlAVOzYIfXqZfOkkg2NPvywrcwGwBQUWAh48knbbtrUTjisV89tXUh8hAFE3PPP22qC+/bZymtz5khduriuCohdL7xg59Hs3WtXIMyaZasXApHCNAEi5rvvbB60Z08LAi1aSDk5BAHgZLp2tV5p3tx6p0cP6cYbbeQAiATCACLi44/tQDZtmm3feaddOVCnjtu6gHhRp47d7GjMGNueOlVq1sx6Cwg3pgkQdllZ9lfM/v3Sj38szZ8vXXml66qA+PXqq1LfvrZSZ3q6NHmyXZEDhAthAGGTl2chYP582+7QwR7XrOm2LiARfPWVXXXw5pu23a+fhYKMDLd1ITEwTYCwyMmRGje2X/5JSdL999tfMwQBIDxq1pRWrZLuu896bN48qUkTads215UhETAygAoJAmnKFGn0aKmwUPrJT6RFi6RWrVxXBiSu7Gypd29byTMtzW7zPXIky3ij/AgDKLe9e6XBg6WlS2376quluXOls85yWRXgh9xcaeBA6ZVXbLt7d2nmTLt8FygrpglQLhs22GppS5dKqanSpEnSSy8RBIBoqV7deu6xx6wHlyyxVT03bnRdGeIRYQBlUloqPfSQTQN8/rl07rnSO+9It93GECUQbaGQNGqU9Pbb1os7d0otW9rqnqWlrqtDPGGaAKfsX/+yy5lWrrTt666z+7JXreq2LgC2ONGwYdLixbadmWmX+dao4bYuxAdGBnBK3nrLpgVWrpQqV5ZmzLATBQkCQGyoWtXuYzB9uvXoihVSgwbS6tWuK0M8IAzghEpKpHHjbM2Ar76SLrhAevddaehQpgWAWBMK2ejApk3Sz39uPduhg3TvvdbLwPEwTYDj+vJLu3wpO9u2Bw2S/vhHWwENQGzLz5duvVWaPdu227SRFi60y3+B7yMM4ChBEGj//v1auTJZw4en6d//Dikjw4Yee/d2XR2Aslq4UBoxwlYIrV490PTphbrqqhJVqVJFIYb38B+EARxlz558nXnmVEm3S7JLlZ59VvrpT93WBaD8Pv3UTvjduvXgRx7RN9+M1BlnMMwHwzkDOMqHHyZJuk2SNHJkkdavJwgA8e6nP5XWr5dGjCj6z0du+0+vA4aRARwlPz9fGRl3Sfq78vIWKp0TBICEYf19g6Tayst7gP7GIYQBHMUOFnYbtLy8PA4WQAKhv3E8jBMBAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAAAAniMMAADgOcIAAACeIwwAAOA5wgAAAJ4jDAAA4DnCAI5SUvLDjwHEP/obx0MYwFGmTTv8ePp0d3UACD/6G8cTCoIgcF0EYsPevdKFFwbatWu/JKlWrSr68MOQqlVzWhaAMKC/cSKMDOCQ3/xG2rUrpPPOS9f//E+6du0K6bbbXFcFIBzob5wIIwOQJL34otSli5SUJK1bJwWB1KqVVFpq/3bNNa4rBFBe9DdOhpEBKDdXGjbMHt9xh9SihXTZZdLtt9vHhg2T/v1vd/UBKD/6G6eCkQHouuukxYulCy+UNm+W0tLs4wUFUuPG0l/+Yp/zzDNu6wRQdvQ3TgUjA55bvNj+S06WsrIOHygkqXJlad48+7dnn5Wee85dnQDKjv7GqSIMeOzrr6Ubb7TH//u/9lfC9zVuLN19tz0eOVL65z+jVx+A8qO/URZME3gqCOyEopdekho1kjZskCpV+uHPPXBAatZMysmRfv1r6YUXpFAomtUCKAv6G2XFyICn5s+3A0Vqqg0fHu9AIdm/zZtnn/vii9KCBdGrE0DZ0d8oK8KAh774Qrr1Vns8frx00UUn/5qLLpLuvdce33KL9OWXESsPQAXQ3ygPpgk8EwRSZqb06qtS06bS229LKSmn9rXFxdIvfylt2iR17CgtX85wIhBL6G+UFyMDnnnqKTtQVK5sw4eneqCQ7HPnzrUzkleulGbOjFiZAMqB/kZ5EQY8snOnNHq0Pf7DH6T69cv+HD//uX2tJP32t/acANyjv1ERTBN4orRU6tBBWr3aliFdvdqWJi2PkhKpbVtb1rRdO+n118v/XAAqjv5GRbGLPTF5sh0gqlSR5sypWHMnJ9twYpUq0ltvSVOmhKtKAOVBf6OiCAMe+PRTacwYe/zww9J551X8Oc87T3roIXs8Zox9DwDRR38jHJgmSHAlJTZsuH69DSOuWhW+Ib/SUumKK6Q337Qbn2Rn218VAKKD/ka4MDKQ4B57zA4Up58uzZ4d3rm/pCR7ztNPl955R5o0KXzPDeDk6G+ECyMDCewvf5EuuUQqLJRmzZIGDYrM95k1SxoyxC5J2rJFuuCCyHwfAIfR3wgnwkCCKi62+5b/+c9Sp07SK69EbgGRIJCuvtoWKbn0UvsroizXNwMoG/ob4cY0QYJ64AE7UJxxhi1EEsmVxEIh+x7Vqknvvis9+GDkvhcA+hvhx8hAAsrJsQRfXCwtXCj17h2d77twodSnj93w5N13pQYNovN9AZ/Q34gEwkCCOXDADhTvvSd16yY9/3z01hcPAql7d7sF6sUX2wHjRHdLA1A29DcihWmCBDNhgh0oqleXpk6N7o1GQiFp2jT73u+9J913X/S+N+AD+huRwshAAnn3XTupqKTE/mLo3t1NHc8/L/Xsadckr19vf8kAqBj6G5HEyECC+O47qX9/O1Bcf727A4Uk9egh9epltfTvLxUUuKsFSAT0NyKNMJAgxo6VPvpIOucc6cknXVdjNZxzjtU0dqzraoD4Rn8j0pgmSADr1kmtW9sJPq+8InXu7Loi88or0q9+ZXONa9dKv/yl64qA+EN/IxoYGYhz+fnSgAF2oBg4MHYOFJItVHKwtgEDrFYAp47+RrQQBuLcnXdKf/ubVLt2bK4d/vjj0n/9l/TXv0p33eW6GiC+0N+IFqYJ4tgbb0iXX26PV62yO4zFolWrpKuussdvvCG1b++2HiAe0N+IJkYG4tS33x6+McnIkbF7oJCkK6+URoywx4MGWe0Ajo/+RrQxMhCnhg6VZs6Uzj1X2rZNyshwXdGJ5eXZqmU7dljtM2a4rgiIXfQ3oo0wEIdWrLA7lYVC0urVdqZxPFizRmrb1h6vWCF17Oi0HCAm0d9wgWmCOLNnj91bXJJ+85v4OVBIUps2VrNkP8OePW7rAWIN/Q1XGBmIM337SgsWSD/7mbR1q3Taaa4rKpv9+6VGjaTt2+1nmTfPdUVA7KC/4QphII786U9S165SUpL0zjtSs2auKyqfDRtsgZLSUvuZfv1r1xUB7tHfcIlpgjiRmysNH26Pf/e7+D1QSFLz5tIdd9jjYcPsZwN8Rn/DNUYG4sS110rPPSf94hfSn/8spaW5rqhiCgulxo2lDz+0n+3ZZ11XBLhDf8M1RgbiwLPP2oEiJUXKyor/A4VkP0NWlt0GdfFiDhbwF/2NWEAYiHFffy3deKM9vuce6ZJL3NYTTo0b288k2c/49ddu6wGijf5GrGCaIIYFgZ188/LLdpDYsEFKTXVdVXgVFdn86Nat0jXX2AlHoZDrqoDIo78RSxgZiGHz5tmBolIlG3JLtAOFZD/TwZ/tpZek+fNdVwREB/2NWEIYiFH/+MfhBTzGj7cTixLVRRfZzyhJt94qffGF23qASKO/EWuYJohBQWBLea5aZUNs69bZyUWJrLjYrk3etMlufLJyJcOJSEz0N/0dixgZiEEzZtiBonJlG2JL9AOFdPhM6sqV7Wd/6inXFQGRQX/T37GIMBBjduyQRo+2xxMn2rKkvqhfX/rDH+zx6NHSzp1OywHCjv62x/R37GGaIIaUlkrt29vdv1q3lt56y5Ym9Ulpqd35bO1a+/8bb/j3GiAx0d/0dyxjN8SQJ5+0A0V6ujRnjp9NkpRkP3uVKnb71smTXVcEhAf9TX/HMkYGYsT27VLDhtJ330lTp0ojRriuyK2pU22hktNOk3JypPPPd10RUH7099Ho79hDGIgBJSVSq1bS+vXSFVdIr77KmbZBYGcdv/661KKFDSsmJ7uuCig7+vtY9Hfs8XCgKvY8+qgdKH70I2nWLA4Ukr0Gs2bZa7J+vfTYY64rAsqH/j4W/R17GBlw7MMPbSnSAwek2bOlgQNdVxRbZs+WBg+2G59s2SJdcIHrioBTR3+fGP0dOwgDDhUV2RDZ5s3S1Vfbcp381XC0ILDXZvlyqUkT6Z13EnPZViQe+vvk6O/YwTSBQxMn2oHijDNsIRIOFMcKhWyBkjPOsPu8P/CA64qAU0N/nxz9HTsYGXBk61apaVNbpvPpp6Xrr3ddUWx7+mnphhtsJbN337Uzs4FYRX+XDf3tHmHAgcJC6dJLpfffl7p3l557jr8aTiYIpB49pKVLpYsvtgNGpUquqwKORX+XHf3tHtMEDkyYYAeKGjXselsOFCcXCtlrVb269N579hoCsYj+Ljv62z1GBqJs40bpsstsWc4lS6Ru3VxXFF+WLLG/IJKS7JKkpk1dVwQcRn9XDP3tDiMDUfTdd1L//naguOEGDhTl0b271Lu3vYb9+9trCsQC+rvi6G93CANRdM890iefSDVrSk884bqa+PXEE/YafvyxNHas62oAQ3+HB/3tBtMEUbJ2rdSmjZ0os2yZ1KmT64ri27Jldn1yKCRlZ0stW7quCD6jv8OL/o4+RgaiIC9PGjDADhSDBnGgCIfOnW01tyCw1zY/33VF8BX9HX70d/QRBqJgzBjps8+k2rVZgzucJk2y1/Rvf7PXGHCB/o4M+ju6mCaIsDfekC6/3B6/9trhxwiP11+3O8EdfNyhg9t64Bf6O7Lo7+hhZCCCvv3Whg0lu3c3B4rwu/xyaeRIezxokL3mQDTQ35FHf0cPIwMRNGSI3abz3HOlbdukjAzXFSWmvDypQQMbqh0yxNY6ByKN/o4O+js6CAMRsny5nQQTCklr1kitWrmuKLFlZ0tt29oJR8uXS5mZritCIqO/o4v+jjymCSLgm28swUrSqFEcKKKhdWvpttvs8ZAh0p49TstBAqO/o4/+jjxGBiKgTx9p4UKpfn1pyxbptNNcV+SH776TGjWyhV/69JHmz3ddERIR/e0G/R1ZhIEwe+EFW4Y0KUl65x2pWTPXFfnlyLXhX3hB6tLFdUVIJPS3W/R35DBNEEa7d0vDh9vjMWM4ULjQrJn0u9/Z4+HDpdxct/UgcdDf7tHfkcPIQJgEgdSzp91166KL7H7caWmuq/JTYaHUpIn0wQd2B7TFi7mNLCqG/o4d9HdkMDIQJs88YweKlBQpK4sDhUtpabYPUlKk55+Xnn3WdUWId/R37KC/I4MwEAZffSXddJM9HjvWTnKBW5dcYneRk2zffP2123oQv+jv2EN/hx/TBBUUBNI110ivvCI1biytXy+lprquCpJUVCQ1b25nfP/qV9KLLzKciLKhv2MX/R1ejAxUUFaWHSgqVbLHHChiR2qq7ZNKlaSXX5bmzXNdEeIN/R276O/wIgxUwD/+If3mN/Z4wgTpwgvd1oNj/eIX0vjx9vjWW22fAaeC/o599Hf4ME1QTkEgXXWV3amseXNp3TopOdl1VfghxcW2StyGDdKVV0orVzKciBOjv+MH/R0ejAyU0/TpdqA47TQbquJAEbtSUqS5c6XKlaVVq6QZM1xXhFhHf8cP+js8CAPl8Nln0u232+OJE6Xzz3dbD07uZz+zfSVJo0dLO3a4rQexi/6OP/R3xTFNUEalpVK7dnYXrTZtpDfftKVJEfvYdzgZ3iPxi31XMbxUZfTHP9qbLT1dmjOHN1s8SUqyfZaebredfeIJ1xUh1tDf8Yv+rhhGBsrgk0+khg2lggJp2rTD65QjvkybJo0cafPBOTkMA8PQ34mB/i4fwsApKimRWrbkjNVEcOSZ4i1aSGvXcoKY7+jvxEF/lw+DYKfokUfsQPGjH0kzZ3KgiGehkDRrlu3L9eulRx91XRFco78TB/1dPowMnIIPPrClSA8csDmpAQNcV4RwmDNHGjTIVjDbsoVFZXxFfycm+rtsCAMnUVRk99DeupX1rxPNkevOX3KJ/WXIcrN+ob8TF/1dNkwTnMQf/mAHijPPtMUsOFAkjlDI9ukZZ9hfDgevU4Y/6O/ERX+XDSMDxxEEgd55Z7/atJFKSqpo0aKQevVyXRUiYdEiqXfvQMnJ+5WdLbVoUUUhfiskNPrbH/T3qSEMHMc33+TrrLMyJElduuRp6dJ0/mpIUEEgde2arxdftP3973/n6cwz0x1XhUiiv/1Bf58apgmO4+OPDz+eNInhw0QWCkmPP354+8h9j8REf/uD/j41hIHjaNDg8OMaNdzVgeg4ch8fue+RmOhvv9DfJ0cYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHOEAQAAPEcYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHOEAQAAPEcYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHOEAQAAPEcYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHOEAQAAPEcYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHOEAQAAPBcKgiBwXUQsCoJA+/fvlyRVqVJFoVDIcUWIJPa3X9jffmF/nxxhAAAAzzFNAACA5wgDAAB4jjAAAIDnCAMAAHiOMAAAgOcIAwAAeI4wAACA5wgDAAB4jjAAAIDnCAMAAHiOMAAAgOe8CwP5+fk6++yzFQqFdO6556qoqOgHP6+goEAtW7ZUKBRSWlqaVq9eHd1CERbsb7+wv/3C/g6jwEOPP/54ICmQFMyYMeOYfy8tLQ169OgRSApCoVCwaNEiB1UiXNjffmF/+4X9HR5ehoGCgoLgv//7vwNJQZ06dYLCwsKj/n3UqFGH3lyPPPKIoyoRLuxvv7C//cL+Dg8vw0AQBMHMmTMPvUGmTp166ONHpszbbrvNYYUIJ/a3X9jffmF/V5y3YaC4uDg4//zzA0lB7dq1g8LCwmDJkiVBUlJSICno2bNnUFJS4rpMhAn72y/sb7+wvyvO2zAQBEHwzDPPHEqNgwcPDipXrhxIClq3bh0UFBS4Lg9hxv72C/vbL+zvivE6DJSWlgYNGzY89AaSFFx44YXBnj17Tvh18+fPD4YNGxY0btw4qFSpUiApmDNnTlRqRvmVZ39/8cUXwaRJk4IrrrgiqF27dpCamhqcffbZQbdu3YINGzZEr3iUWXn29549e4JbbrklaN68eXD22WcHlSpVCmrVqhW0a9cueP7554PS0tLo/QAok/Iez7/vwQcfPPT169evj0yxMci7SwuPFAqFNHTo0EPbP/7xj7VixQpVq1bthF93zz33aMaMGfr8889Vs2bNCFeJcCnP/n7iiSc0atQoffbZZ7riiis0evRotWzZUi+++KIuu+wyLV68OAqVozzKs79zc3M1e/Zspaenq0uXLho9erQyMzP14YcfqkePHho+fHgUKkd5lPd4fqSPPvpIv//975Wenh6BCmOc6zTi0vbt24Pq1asfSoHp6enBP//5z5N+3WuvvRbs3LkzCIIgmDhxIiMDcaI8+3vJkiVBdnb2MR/Pzs4OUlNTgzPPPJMhyBhVnv1dXFwcFBUVHfPxb7/9NrjgggsCScEHH3wQqZJRAeU9nh9UXFwcXHrppUHTpk2DPn36MDLgi3/961/q2LGjcnNzddZZZ0myBSz+7//+76Rfe/nll6tOnTqRLhFhVN793a1bN7Vq1eqYj7dq1Urt2rXTN998o/fffz8iNaP8yru/k5OTlZKScszHTz/9dF111VWSpL/+9a/hLxgVUpHj+UEPPvigtm3bptmzZys5OTlSpcYsL8NAfn6+OnfurM8++0wZGRlatWqVunTpIkmaPn26/v73v7stEGEVqf2dmpoqST/4ywPuRGJ/FxQU6M0331QoFNIFF1wQ5opREeHY3x988IHGjx+ve+65RxdeeGGEK45Rrocmoq2oqCjIzMwMJAUpKSnB8uXLgyAIgm3btgWhUCiQFAwaNOiUn49pgtgW7v190Oeffx6kpaUF55xzTlBcXBzuslFO4drfe/bsCcaNGxeMHTs2GD58eFC7du1AUjBu3LgI/wQoi3Ds76KioqBx48ZBgwYNggMHDgRBEAT9+/f3bprAuzAwePDgQ3NKTz311FH/dnDJyuTk5OCTTz45pecjDMS2cO/vIAiCAwcOBK1btw4kBfPmzQt3yaiAcO3vHTt2HHVWempqavDwww9zNUGMCcf+Hj9+fJCSkhJs3rz50McIAwlu3Lhxh944Y8eOPebfP/jgg0OLVFx77bWn9JyEgdgVif1dUlJy6OSioUOHhrtkVEAk9ndxcXGwY8eOYOLEiUGlSpWCrl27/uAJhoi+cOzvnJycIDU1NbjzzjuP+jhhIIEduVxl//79j/t5vXr1OnRDi61bt570eQkDsSkS+7u0tDQYNGhQICno06cPK5rFkEj195EeeuihQFIwZcqUihWLCgvX/m7QoEFQv379Y64IIgwkqGXLlgUpKSmBpODyyy8/NC/0Qz766KMgOTk5kBR06tTppM9NGIg9kdjfJSUlwcCBAwNJwfXXX895AjEkkv19pJycnDKNKiAywrm/j5wKOtF/L7zwQgR/otjgxWnQnTp1Ou59rr+vfv36Ki4ujnBFiKRw7+/S0lINGTJEc+bM0XXXXaf58+d7eelRrIpWf+/atUsSV4+4Fs79PXjw4B/8eHZ2tj799FNdc801qlGjhurWrVueUuMK72rgBEpLSzV48GDNnTtXPXv21IIFCwgCCSwnJ0f16tVT1apVj/r4N998o7vvvluSlJmZ6aI0RMDMmTN/8OMDBgzQp59+qrvuukvNmzePclVuEAbKYebMmVq3bp0kHVpwZubMmVq9erUkqUuXLoeuc0V8mzBhgubOnauMjAydf/75uv/++4/5nC5duqhhw4bRLw5hN3fuXM2cOVPt2rVTnTp1lJ6ers8//1zLli1TXl6eunfvrt69e7suEwg7wkA5rFu3TllZWUd97O2339bbb78tSapbty5hIEHs3LlTkpSXl3fc1czq1q1LGEgQPXr00L59+7RhwwZlZ2dr//79OvPMM9WyZUv169dPvXr1UigUcl0mEHahIAgC10UAAAB3vFyOGAAAHEYYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHOEAQAAPEcYAADAc4QBAAA8RxgAAMBzhAEAADxHGAAAwHP/H9TOIwd5kBhbAAAAAElFTkSuQmCC",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"from kan.hypothesis import plot_tree\n",
|
|
"\n",
|
|
"f = lambda x: (x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2\n",
|
|
"x = torch.rand(100,4) * 2 - 1\n",
|
|
"plot_tree(f, x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "58c2ece4-a8dc-4b4e-83cc-49a3f04c1ec5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda\n",
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 1.58e-03 | test_loss: 4.79e-03 | reg: 2.38e+01 | : 100%|█| 100/100 [00:20<00:00, 4.93"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"print(device)\n",
|
|
"\n",
|
|
"dataset = create_dataset(f, n_var=4, device=device)\n",
|
|
"model = KAN(width=[4,5,5,1], seed=0, device=device)\n",
|
|
"model.fit(dataset, steps=100);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "c02037c9-c903-4fc8-96bf-c78609ce0696",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGaCAYAAACSWkBBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAta0lEQVR4nO3dd3hUZfr/8c+EhCCJCyKo4NcvqCsirALSXTq6GHBdqiKC9GqDRRb1J6ugKzaEa6VLDSCCAl+lig0CShMIllXRBdxVLGQpXgkEUs7vj3tpIiXJzJyZed6v6/JyzpByZ87cJ588zznPCXie5wkAADgrzu8CAACAvwgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOC7e7wIA+MfzPB06dEiSVLJkSQUCAZ8rAuAHRgYAhx06dEjJyclKTk4+HgoAuIcwAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIA4DjZkjyJCUlJ0tPPOFzNQD8QBgAAMBxhAEAABxHGAAAwHGEASCW7dwZHV8TgK8IA0CsWrtWuuYaqW9fKSOj6F/vm2+kdu2kqlWlr78u+tcDEDEIA0AsysuT7rtPys+XXn5ZqlxZmjjRtgvqyBHpqaek666TFi+27QcfDH7NAHxDGABi0d69UpkyJ7b375cGDpTq1JHWrz//r7NsmVStmjR8uHT4sD0XCEhJSSe2AUQ9wgAQiy67THr/fWnFCqlmzRPPb90q/f73Us+e0k8/nfnzd+6Ubr9duu026Z//PPF8ixbSpk3SggXSBReErn4AYUUYAGLZrbdKW7ZI8+ZJv/2tPed50owZ0rXXKn7iRBU7+eOzs23hoWrVpCVLTjxfq5a0apX0zjtS7dph/AEAhEPA8zzP7yIAhEFurp0/MHKk9MMPx5/O10l/FQQCFhaOueYaO1+gY0f7NwAxiTAAuObQIWnsWOm556SDB3/9Y8qXl/76V6l3byk+PqzlAQg/wgDgqq+/Vn7Dhor78cdTn7/uOiktTSpb1p+6AIQd5wwArjl8WHr2WalOndODgCR9/rlUvbo0ZYpNLQCIeYwMAK7Iy5OmT7cTBPfsOf70Wc8ZqFxZ+tvfpA4dwlgogHBjZABwwaJFdoVA374ngkCpUjry3HOae/LHDR1q5wqUKGHbO3bYyYP16tmligBiEmEAiGWrV9sv8vbtpS+/tOcCAal7d2nHDuUOHKi8kz/+ggukESOkzz6zNQaO2bRJat7cLlVMTw9b+QDCgzAAxKIff5RSUqRmzewX+TE1a0offGDrDFxyyZk//6qrbJ2BpUulq68+8fxbb0k33ijdfTcrEAIxhDAAxKKLL5a+/fbE9kUXSePHSx99JDVocP5fp3VrGyUYOfLEioOeJ+3bxwqEQAwhDACxKD7efvnHxdlaATt22L0J4grR8omJdm+Czz+X2rSx7ZdeCnrJAPzDaiJArGrcWPrqKxvyD4aKFe2uhTt3Bu9rAogIjAwAsSwUv7QJAkDMIQwAAOA4wgAAAI4jDAAA4DjCAOAoz5Mee0zqISkg6c+DM+U9/oTPVQHwA/cmAByUkWGLEC5bliUp+b/PZuq225I0YwY3LARcQxgAHJOWJnXuLH33nVS8uKdnnjkkSXr44ZI6ejSgyy+X5s2TGjXyuVAAYUMYAByRlyc9/bTdtDA/X7r2Wmn+fLtbsSRt3y7dcYetTxQXZ7coeOQRqVgxX8sGEAaEAcAB338vdekivfeebXfrJo0bJyUnn/pxmZnSvfdKqam23by5NGeOVL58eOsFEF6EASDGvfWW1LWrtHevlJQkTZgg3XPP2T8nNdVWL87KksqVk2bPllq2DE+9AMKPqwmAGJWTIw0bZncd3rtXuuEGacuWcwcByT7mo4/sc/buta/x8MP2NQHEHkYGgBi0e7d0113Shg22PXCgNHq0VKJEwb7O4cPSkCHSxIm23aCBnVxYsWJQywXgM8IAEGMWL5Z69pQOHJBKlZKmTZPaty/a11y4UOrVSzp4UCpdWpoxw25gCCA2ME0AxIjsbOm++6R27SwI1KsnbdtW9CAg2dfYtk2qW9e+dtu20v332/cEEP0IA0AM2LHDhvDHj7ftoUOltWulK68M3ve48kr7mg89ZNvjxtn33LEjeN8DgD+YJgCi3Jw5Uv/+duZ/2bJ2JUBKSmi/54oVdpJhRoZdnjhxol26CCA6MTIARKmsLKlHD7tsMCtLatrUFg4KdRCQ7Hukp9v3zMy0Gnr2tDoARB/CABCFPv5Yql1bmjnzxGqB77wjVagQvhouv9y+5xNPWA0zZkh16kiffBK+GgAEB9MEQBTxPGnyZGnQIOnIEfvl/8orUpMm/ta1erXd7+D77+3yxbFjpb59pUDA37oAnB/CABAlDh6U+vSRXnvNtlu1spGBcuV8Leu4vXttmeMVK2z7jjukKVPs8kYAkY1pAiAKbNok1axpQSA+XnrhBWnJksgJApLVsnSp9PzzVuOCBdKNN0qbN/tdGYBzYWQAiGD5+dKYMbYUcG6uVKmS9OqrtoZAJNu4UerUyVZCTEiQnnlGGjyYaQMgUhEGgAiVkWHD7suX23aHDtLLL9sKgNHgwAGb1nj9ddtu3dqmNcqW9bMqAL+GaQIgAq1ZI1WvbkEgMdGu41+wIHqCgGS1Llhgd0lMTJSWLZNq1JDS0vyuDMAvEQaACJKXJ40cKTVvLu3ZI1WpYucL9O8fnUPsgYA0YIBNG1x7rfTdd1KzZtKTT9rPCiAyME0ARIg9e2wVv/fft+1u3WzJ3+Rkf+sKlsxMu3fCrFm23by5rZ5Yvry/dQEgDAARYeVKW953714pKcmmBbp29buq0EhNtVsqZ2XZFQizZ0stW/pdFeA2pgkAH+XkSMOG2fK+e/faeQJbtsRuEJAs9GzZIt1wg/3Mt95qr0FOjt+VAe5iZADwye7d0l13SRs22Pa999r6ASVK+FpW2GRnS0OG2AmGklS/vjRvnl0+CSC8CAOADxYtknr1ssvvSpWSpk+X2rXzuyp/LFxor8XBg3YFwvTpUtu2flcFuIVpAiCMsrPtJLr27S0I1Ktnd/9zNQhI9lps22avxYED9lrcd5+9VgDCgzAAhMmOHTYUPn68bf/lL9LatQyLS9KVV9prMXSobY8fLzVoYK8ZgNBjmgAIg9mz7Xr7Y2fQp6baiXM43YoVdpJhRoZdWTFpkl1yCSB0GBkAQigzU+re3X65ZWXZgjvp6QSBs0lJkbZvl5o2tdesa1epRw97DCA0CANAiHz8sVS7ti2yExcnjRghvf22VKGC35VFvgoVpHfesdcsLs7uaVC7tr2mAIKPaQIgyDxPmjxZGjRIOnLEfrG98orUpInflUWnNWukzp1thcbERGnsWKlfv+hcnhmIVIQBIIh+eae+Vq3sr9py5fysKvrt3WvTLcfu4Nixo93BsVQpX8sCYgbTBECQbNok1axpQSA+Xho9WlqyhCAQDOXK2Wv5wgv22r72mr3Wmzb5XRkQGxgZAIooP18aM0Z6+GEpN9cuk3v1ValuXb8ri00bN0qdOtkKjvHx0jPPSIMH27kFAAqHMAAUAcPX/vi16ZhZs6SyZX0tC4haZGmgkNaskWrUsCBQooRdDz9/PkEgHEqXlhYssLs7JibaPqhe3fYJgIIjDAAFlJdnl7w1b25nuFepYnPXnOEeXoGA1L+/vfZVqti+aN5cGjnS9hGA88c0AVAAe/ZId98trV5t2927S+PG2Up58E9mpt3PYNYs227WTJozhzUdgPNFGADO08qVthoey+RGLpZ9BgqHaQLgHHJy7KZCKSkWBGrUkLZuJQhEoq5dpS1b7PyBvXttnw0bZvsQwJkxMgCcxe7ddhnbxo22fd990vPP2wmDiFzZ2dJDD524Q2T9+tK8edwhEjgTwgBwBgsXSr16SQcP2tnr06ZJ7dr5XRUKYtEiqWdP9iFwLkwTAL+QnS3de6/UoYP9EqlfX9q2jV8i0ahdO7tLZL16tjZB+/Y2upOd7XdlQGQhDAAn+fJL++U/YYJtDxsmpaUxvBzNKlWS1q618z4kmzqoX1/ascPXsoCIwjQB8F+pqdLAgZyJHstWrpTuucdOLkxKskWLunb1uyrAf4wMwHmZmVK3bvZfVpZdo56eThCIRbfeavu2WTPb1/fcY2tFZGb6XRngL8IAnLZ9u1S7to0CxMXZ6nVvv81iNbGsQgXbxyNG2D6fNcveAx9/7HdlgH+YJoCTPM8WDRo8WDpyxH5BzJsnNW7sd2UIpzVrpM6dbWXJxERp7FiWlYabCANwzoEDUu/edumgJLVuLc2cyR3vXPXLO0926GB3nixd2s+qgPBimgBO2bhRqlnTgkBCgjR6tLRkCUHAZeXK2Xtg9GgpPt5ui1yzpt0ACXAFIwNwQn6+9OKL0iOPSLm50pVX2u2G69TxuzJEkk2bbMXJXbssGDzzjE0lxfFnE2IcYQAxb+9eu1JgxQrbvuMOacoUqVQpf+tCZDp4UOrTR3rtNdtu1cqmkcqV87UsIKTIu4hpq1fbjYVWrLD7CUyeLL36KkEAZ1aqlI0aTZpk75nly+09tGaN35UBoUMYQEzKy7NLx1q0sDPFq1SxIeC+fTlTHOcWCNhVBZs22Xtnzx6peXN7T+Xl+V0dEHxMEyDm7Nkj3X23jQpIUo8e0ksv2YpzQEFlZdn9DGbOtO2mTaW5c1mLArGFMICYsmKFrSqXkSElJ9tys126+F0VYsGcOVL//hYOypaVZs9mlUrEDqYJEBOOHpWGDrWTvTIybI53yxaCAIKnSxdp61Z7b2VkSCkpdvOjnBy/KwOKjpEBRL1du+xysGPXhd93n/T883byFxBs2dkWPMeNs+169eykVO5siWhGGEBUW7hQ6tXLLgcrXVqaPl1q29bvquCCRYvsvXfggF2BMG2a1L6931UBhcM0AaJSdrbdbrhDBwsC9evb3egIAgiXdu2kbdvsvXfwoL0X773X3ptAtCEMIOp88YUNzU6caNvDhklpaVLFiv7WBfdUqmTvvWHDbHvCBAsHX37pa1lAgTFNgKgya5b99ZWVZSvCzZ4ttWzpd1WAtHKlXcmyd69dxjphgm0D0YAwgKiQmWkhIDXVtps3t0u9ypf3ty7gZHv22FUH779v2/fcI40fb5e5ApGMaQJEvO3bpdq1LQjExUlPPimtWkUQQOSpUEF6+21p5Eh7r6am2nt3+3a/KwPOjpEBRCzPs/MC/vxn6cgR6fLLpVdekRo39rsy4NzS0qS77rLRgsREacwYW7SI5bARiQgDiEgHDki9e9ulg5J0223SjBm28hsQLTIypO7dpWXLbLt9e2nqVLsMFogkTBMg4mzcKNWsaUEgIUF68UXpzTcJAog+ZctKS5ZIo0fbe3nhQntvb9zod2XAqRgZQMTIz7eD5qOPSrm50lVX2cpuder4XRlQdJs3S3feaStmxsdLo0bZFFgcf5IhAhAGEBH27pW6dbMbDUnSHXdIU6bYym5ArDh40G6jvWCBbaek2OWy5cr5WxdAJoXvVq+Wqle3IFCihDR5so0IEAQQa0qVsvf25Mn2Xl+xwm58dOx224BfCAPwTV6e9MQTUosW0vffS9ddZzcb6tuXM64RuwIBe49v2iRVqWJXG7RoIY0YYT0B+IFpAvjiu++ku++W1qyx7Z49pb//3VZuA1yRlSXdf79dKSNJTZtKc+faegVAOBEGEHbLl9v5ARkZtjLbpEkWDABXzZkjDRhgK22WLWuLFaWk+F0VXMI0AcLm6FHpoYek1q0tCNSsKW3dShAAunSRtmyx8wcyMqRWraShQ61ngHBgZABhsWuX1KmTzZNKNjT6/PO2MhsAk51tIWDcONuuW9dOOLzySn/rQuwjDCDkXn/dVhM8eNBWXpsxQ2rTxu+qgMi1eLGdR3PggF2BMG2arV4IhArTBAiZw4dtHrRjRwsCDRpI6ekEAeBc2ra1Xqlf33qnQwdp4EAbOQBCgTCAkPjiCzuQTZpk2w8/bFcOVKzob11AtKhY0W52NGyYbU+cKNWrZ70FBBvTBAi6WbPsr5hDh6RLLpFmz5b+8Ae/qwKi11tvSV272kqdSUnS+PF2RQ4QLIQBBE1mpoWA2bNtu0ULe1y+vL91AbHg++/tqoP33rPte+6xUJCc7G9diA1MEyAo0tOlWrXsl39cnPTUU/bXDEEACI7y5aVVq6Qnn7QeS02VateWtm/3uzLEAkYGUCSeJ02YIA0ZIh05Il1+uTRvntSokd+VAbErLU3q3NlW8kxMtNt8DxjAMt4oPMIACu3AAalXL2nRItu+7TZp5kzp4ov9rApwQ0aG1KOHtHSpbbdvL02dapfvAgXFNAEKZcMGWy1t0SIpIUEaM0Z6802CABAuZctaz734ovXgwoW2qufGjX5XhmhEGECB5OdLzz1n0wDffCNddZX04YfSoEEMUQLhFghIgwdLH3xgvbh7t9Swoa3umZ/vd3WIJkwT4Lz99JNdzrRypW3feafdl71UKX/rAmCLE/XtKy1YYNspKXaZb7ly/taF6MDIAM7L++/btMDKlVKJEtKUKXaiIEEAiAylStl9DCZPth5dsUKqXl1avdrvyhANCAM4q7w86fHHbc2A77+XqlaVNm+W+vRhWgCINIGAjQ5s2iRdd531bIsW0hNPWC8DZ8I0Ac7ou+/s8qW0NNvu2VP6+99tBTQAkS0rS3rgAWn6dNtu0kSaO9cu/wV+iTCAU3iep0OHDmnlymLq1y9R//lPQMnJNvTYubPf1QEoqLlzpf79bYXQsmU9TZ58RC1b5qlkyZIKMLyH/yIM4BT792epTJmJkh6SZJcqzZ8vXXONv3UBKLyvvrITfrdtO/bMC9q3b4AuuohhPhjOGcApPvssTtIgSdKAATlav54gAES7a66R1q+X+vfP+e8zg/7b64BhZACnyMrKUnLyI5L+pczMuUriBAEgZlh/3y3pCmVmPkN/4zjCAE5hBwu7DVpmZiYHCyCG0N84E8aJAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQAAAMcRBgAAcBxhAAAAxxEGAABwHGEAAADHEQYAAHAcYQCnyMv79ccAoh/9jTMhDOAUkyadeDx5sn91AAg++htnEvA8z/O7CESGAwekatU87dlzSJJUoUJJffZZQKVL+1oWgCCgv3E2jAzguAcflPbsCejqq5P0298mac+egAYN8rsqAMFAf+NsGBmAJOmNN6Q2baS4OGndOsnzpEaNpPx8+7fbb/e7QgCFRX/jXBgZgDIypL597fHQoVKDBtJNN0kPPWTP9e0r/ec//tUHoPDob5wPRgagO++UFiyQqlWTtmyREhPt+exsqVYt6R//sI959VV/6wRQcPQ3zgcjA45bsMD+K1ZMmjXrxIFCkkqUkFJT7d/mz5dee82/OgEUHP2N80UYcNgPP0gDB9rj//f/7K+EX6pVS3r0UXs8YID044/hqw9A4dHfKAimCRzleXZC0ZtvSjVrShs2SMWL//rHHj0q1asnpadLf/qTtHixFAiEs1oABUF/o6AYGXDU7Nl2oEhIsOHDMx0oJPu31FT72DfekObMCV+dAAqO/kZBEQYc9O230gMP2OMRI6Trrz/351x/vfTEE/b4/vul774LWXkAioD+RmEwTeAYz5NSUqS33pLq1pU++ECKjz+/z83NlX7/e2nTJunWW6XlyxlOBCIJ/Y3CYmTAMS+/bAeKEiVs+PB8DxSSfezMmXZG8sqV0tSpISsTQCHQ3ygswoBDdu+Whgyxx08/LVWpUvCvcd119rmS9Oc/29cE4D/6G0XBNIEj8vOlFi2k1attGdLVq21p0sLIy5OaNrVlTZs1k955p/BfC0DR0d8oKnaxI8aPtwNEyZLSjBlFa+5ixWw4sWRJ6f33pQkTglUlgMKgv1FUhAEHfPWVNGyYPX7+eenqq4v+Na++WnruOXs8bJh9DwDhR38jGJgmiHF5eTZsuH69DSOuWhW8Ib/8fOmWW6T33rMbn6Sl2V8VAMKD/kawMDIQ41580Q4UF14oTZ8e3Lm/uDj7mhdeKH34oTRmTPC+NoBzo78RLIwMxLB//EO68UbpyBFp2jSpZ8/QfJ9p06Teve2SpK1bpapVQ/N9AJxAfyOYCAMxKjfX7lv+0UdSq1bS0qWhW0DE86TbbrNFSurUsb8iCnJ9M4CCob8RbEwTxKhnnrEDxUUX2UIkoVxJLBCw71G6tLR5s/Tss6H7XgDobwQfIwMxKD3dEnxurjR3rtS5c3i+79y5UpcudsOTzZul6tXD830Bl9DfCAXCQIw5etQOFB9/LLVrJ73+evjWF/c8qX17uwXqDTfYAeNsd0sDUDD0N0KFaYIYM3KkHSjKlpUmTgzvjUYCAWnSJPveH38sPflk+L434AL6G6HCyEAM2bzZTirKy7O/GNq396eO11+XOna0a5LXr7e/ZAAUDf2NUGJkIEYcPix162YHirvu8u9AIUkdOkidOlkt3bpJ2dn+1QLEAvoboUYYiBHDh0uffy5ddpk0bpzf1VgNl11mNQ0f7nc1QHSjvxFqTBPEgHXrpMaN7QSfpUul1q39rsgsXSr98Y8217h2rfT73/tdERB96G+EAyMDUS4rS+re3Q4UPXpEzoFCsoVKjtXWvbvVCuD80d8IF8JAlHv4Yemf/5SuuCIy1w4fO1b6n/+Rvv5aeuQRv6sBogv9jXBhmiCKvfuudPPN9njVKrvDWCRatUpq2dIev/uu1Ly5v/UA0YD+RjgxMhClfv75xI1JBgyI3AOFJP3hD1L//va4Z0+rHcCZ0d8IN0YGolSfPtLUqdJVV0nbt0vJyX5XdHaZmbZq2a5dVvuUKX5XBEQu+hvhRhiIQitW2J3KAgFp9Wo70zgarFkjNW1qj1eskG691ddygIhEf8MPTBNEmf377d7ikvTgg9FzoJCkJk2sZsl+hv37/a0HiDT0N/zCyECU6dpVmjNHuvZaads26YIL/K6oYA4dkmrWlHbssJ8lNdXvioDIQX/DL4SBKPJ//ye1bSvFxUkffijVq+d3RYWzYYMtUJKfbz/Tn/7kd0WA/+hv+IlpgiiRkSH162eP//KX6D1QSFL9+tLQofa4b1/72QCX0d/wGyMDUeKOO6TXXpN+9zvpo4+kxES/KyqaI0ekWrWkzz6zn23+fL8rAvxDf8NvjAxEgfnz7UARHy/NmhX9BwrJfoZZs+w2qAsWcLCAu+hvRALCQIT74Qdp4EB7/Nhj0o03+ltPMNWqZT+TZD/jDz/4Ww8QbvQ3IgXTBBHM8+zkmyVL7CCxYYOUkOB3VcGVk2Pzo9u2SbffbiccBQJ+VwWEHv2NSMLIQARLTbUDRfHiNuQWawcKyX6mYz/bm29Ks2f7XREQHvQ3IglhIEL9+98nFvAYMcJOLIpV119vP6MkPfCA9O23/tYDhBr9jUjDNEEE8jxbynPVKhtiW7fOTi6KZbm5dm3ypk1245OVKxlORGyiv+nvSMTIQASaMsUOFCVK2BBbrB8opBNnUpcoYT/7yy/7XREQGvQ3/R2JCAMRZtcuacgQezxqlC1L6ooqVaSnn7bHQ4ZIu3f7Wg4QdPS3Paa/Iw/TBBEkP19q3tzu/tW4sfT++7Y0qUvy8+3OZ2vX2v/ffde91wCxif6mvyMZuyGCjBtnB4qkJGnGDDebJC7OfvaSJe32rePH+10REBz0N/0dyRgZiBA7dkg1akiHD0sTJ0r9+/tdkb8mTrSFSi64QEpPlypX9rsioPDo71PR35GHMBAB8vKkRo2k9eulW26R3nqLM209z846fucdqUEDG1YsVszvqoCCo79PR39HHgcHqiLP6NF2oPjNb6Rp0zhQSPYaTJtmr8n69dKLL/pdEVA49Pfp6O/Iw8iAzz77zJYiPXpUmj5d6tHD74oiy/TpUq9eduOTrVulqlX9rgg4f/T32dHfkYMw4KOcHBsi27JFuu02W66TvxpO5Xn22ixfLtWuLX34YWwu24rYQ3+fG/0dOZgm8NGoUXaguOgiW4iEA8XpAgFboOSii+w+788843dFwPmhv8+N/o4cjAz4ZNs2qW5dW6bzlVeku+7yu6LI9sor0t1320pmmzfbmdlApKK/C4b+9h9hwAdHjkh16kiffCK1by+99hp/NZyL50kdOkiLFkk33GAHjOLF/a4KOB39XXD0t/+YJvDByJF2oChXzq635UBxboGAvVZly0off2yvIRCJ6O+Co7/9x8hAmG3cKN10ky3LuXCh1K6d3xVFl4UL7S+IuDi7JKluXb8rAk6gv4uG/vYPIwNhdPiw1K2bHSjuvpsDRWG0by917myvYbdu9poCkYD+Ljr62z+EgTB67DHpyy+l8uWll17yu5ro9dJL9hp+8YU0fLjf1QCG/g4O+tsfTBOEydq1UpMmdqLMsmVSq1Z+VxTdli2z65MDASktTWrY0O+K4DL6O7jo7/BjZCAMMjOl7t3tQNGzJweKYGjd2lZz8zx7bbOy/K4IrqK/g4/+Dj/CQBgMGybt3CldcQVrcAfTmDH2mv7zn/YaA36gv0OD/g4vpglC7N13pZtvtsdvv33iMYLjnXfsTnDHHrdo4W89cAv9HVr0d/gwMhBCP/9sw4aS3bubA0Xw3XyzNGCAPe7Z015zIBzo79Cjv8OHkYEQ6t3bbtN51VXS9u1ScrLfFcWmzEypenUbqu3d29Y6B0KN/g4P+js8CAMhsny5nQQTCEhr1kiNGvldUWxLS5OaNrUTjpYvl1JS/K4IsYz+Di/6O/SYJgiBffsswUrS4MEcKMKhcWNp0CB73Lu3tH+/r+UghtHf4Ud/hx4jAyHQpYs0d65UpYq0dat0wQV+V+SGw4elmjVt4ZcuXaTZs/2uCLGI/vYH/R1ahIEgW7zYliGNi5M+/FCqV8/vitxy8trwixdLbdr4XRFiCf3tL/o7dJgmCKK9e6V+/ezxsGEcKPxQr570l7/Y4379pIwMf+tB7KC//Ud/hw4jA0HieVLHjnbXreuvt/txJyb6XZWbjhyRateWPv3U7oC2YAG3kUXR0N+Rg/4ODUYGguTVV+1AER8vzZrFgcJPiYm2D+Ljpddfl+bP97siRDv6O3LQ36FBGAiC77+X7r3XHg8fbie5wF833mh3kZNs3/zwg7/1IHrR35GH/g4+pgmKyPOk22+Xli6VatWS1q+XEhL8rgqSlJMj1a9vZ3z/8Y/SG28wnIiCob8jF/0dXIwMFNGsWXagKF7cHnOgiBwJCbZPiheXliyRUlP9rgjRhv6OXPR3cBEGiuDf/5YefNAejxwpVavmbz043e9+J40YYY8feMD2GXA+6O/IR38HD9MEheR5UsuWdqey+vWldeukYsX8rgq/JjfXVonbsEH6wx+klSsZTsTZ0d/Rg/4ODkYGCmnyZDtQXHCBDVVxoIhc8fHSzJlSiRLSqlXSlCl+V4RIR39HD/o7OAgDhbBzp/TQQ/Z41CipcmV/68G5XXut7StJGjJE2rXL33oQuejv6EN/Fx3TBAWUny81a2Z30WrSRHrvPVuaFJGPfYdz4T0Svdh3RcNLVUB//7u92ZKSpBkzeLNFk7g422dJSXbb2Zde8rsiRBr6O3rR30XDyEABfPmlVKOGlJ0tTZp0Yp1yRJdJk6QBA2w+OD2dYWAY+js20N+FQxg4T3l5UsOGnLEaC04+U7xBA2ntWk4Qcx39HTvo78JhEOw8vfCCHSh+8xtp6lQOFNEsEJCmTbN9uX69NHq03xXBb/R37KC/C4eRgfPw6ae2FOnRozYn1b273xUhGGbMkHr2tBXMtm5lURlX0d+xif4uGMLAOeTk2D20t21j/etYc/K68zfeaH8ZstysW+jv2EV/FwzTBOfw9NN2oChTxhaz4EAROwIB26cXXWR/ORy7ThnuoL9jF/1dMIwMnIHnefrww0Nq0kTKyyupefMC6tTJ76oQCvPmSZ07eypW7JDS0qQGDUoqwG+FmEZ/u4P+Pj+EgTPYty9LF1+cLElq0yZTixYl8VdDjPI8qW3bLL3xhu3v//wnU2XKJPlcFUKJ/nYH/X1+mCY4gy++OPF4zBiGD2NZICCNHXti++R9j9hEf7uD/j4/hIEzqF79xONy5fyrA+Fx8j4+ed8jNtHfbqG/z40wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOC7geZ7ndxGRyPM8HTp0SJJUsmRJBQIBnytCKLG/3cL+dgv7+9wIAwAAOI5pAgAAHEcYAADAcYQBAAAcRxgAAMBxhAEAABxHGAAAwHGEAQAAHEcYAADAcYQBAAAcRxgAAMBxhAEAABznXBjIysrSpZdeqkAgoKuuuko5OTm/+nHZ2dlq2LChAoGAEhMTtXr16vAWiqBgf7uF/e0W9ncQeQ4aO3asJ8mT5E2ZMuW0f8/Pz/c6dOjgSfICgYA3b948H6pEsLC/3cL+dgv7OzicDAPZ2dne//7v/3qSvIoVK3pHjhw55d8HDx58/M31wgsv+FQlgoX97Rb2t1vY38HhZBjwPM+bOnXq8TfIxIkTjz9/csocNGiQjxUimNjfbmF/u4X9XXTOhoHc3FyvcuXKniTviiuu8I4cOeItXLjQi4uL8yR5HTt29PLy8vwuE0HC/nYL+9st7O+iczYMeJ7nvfrqq8dTY69evbwSJUp4krzGjRt72dnZfpeHIGN/u4X97Rb2d9E4HQby8/O9GjVqHH8DSfKqVavm7d+//6yfN3v2bK9v375erVq1vOLFi3uSvBkzZoSlZhReYfb3t99+640ZM8a75ZZbvCuuuMJLSEjwLr30Uq9du3behg0bwlc8Cqww+3v//v3e/fff79WvX9+79NJLveLFi3sVKlTwmjVr5r3++utefn5++H4AFEhhj+e/9Oyzzx7//PXr14em2Ajk3KWFJwsEAurTp8/x7UsuuUQrVqxQ6dKlz/p5jz32mKZMmaJvvvlG5cuXD3GVCJbC7O+XXnpJgwcP1s6dO3XLLbdoyJAhatiwod544w3ddNNNWrBgQRgqR2EUZn9nZGRo+vTpSkpKUps2bTRkyBClpKTos88+U4cOHdSvX78wVI7CKOzx/GSff/65/vrXvyopKSkEFUY4v9OIn3bs2OGVLVv2eApMSkryfvzxx3N+3ttvv+3t3r3b8zzPGzVqFCMDUaIw+3vhwoVeWlraac+npaV5CQkJXpkyZRiCjFCF2d+5ubleTk7Oac///PPPXtWqVT1J3qeffhqqklEEhT2eH5Obm+vVqVPHq1u3rtelSxdGBlzx008/6dZbb1VGRoYuvvhiSbaAxd/+9rdzfu7NN9+sihUrhrpEBFFh93e7du3UqFGj055v1KiRmjVrpn379umTTz4JSc0ovMLu72LFiik+Pv605y+88EK1bNlSkvT1118Hv2AUSVGO58c8++yz2r59u6ZPn65ixYqFqtSI5WQYyMrKUuvWrbVz504lJydr1apVatOmjSRp8uTJ+te//uVvgQiqUO3vhIQESfrVXx7wTyj2d3Z2tt577z0FAgFVrVo1yBWjKIKxvz/99FONGDFCjz32mKpVqxbiiiOU30MT4ZaTk+OlpKR4krz4+Hhv+fLlnud53vbt271AIOBJ8nr27HneX49pgsgW7P19zDfffOMlJiZ6l112mZebmxvsslFIwdrf+/fv9x5//HFv+PDhXr9+/bwrrrjCk+Q9/vjjIf4JUBDB2N85OTlerVq1vOrVq3tHjx71PM/zunXr5tw0gXNhoFevXsfnlF5++eVT/u3YkpXFihXzvvzyy/P6eoSByBbs/e15nnf06FGvcePGniQvNTU12CWjCIK1v3ft2nXKWekJCQne888/z9UEESYY+3vEiBFefHy8t2XLluPPEQZi3OOPP378jTN8+PDT/v3TTz89vkjFHXfccV5fkzAQuUKxv/Py8o6fXNSnT59gl4wiCMX+zs3N9Xbt2uWNGjXKK168uNe2bdtfPcEQ4ReM/Z2enu4lJCR4Dz/88CnPEwZi2MnLVXbr1u2MH9epU6fjN7TYtm3bOb8uYSAyhWJ/5+fnez179vQkeV26dGFFswgSqv4+2XPPPedJ8iZMmFC0YlFkwdrf1atX96pUqXLaFUGEgRi1bNkyLz4+3pPk3XzzzcfnhX7N559/7hUrVsyT5LVq1eqcX5swEHlCsb/z8vK8Hj16eJK8u+66i/MEIkgo+/tk6enpBRpVQGgEc3+fPBV0tv8WL14cwp8oMjhxGnSrVq3OeJ/rX6pSpYpyc3NDXBFCKdj7Oz8/X71799aMGTN05513avbs2U5eehSpwtXfe/bskcTVI34L5v7u1avXrz6flpamr776SrfffrvKlSunSpUqFabUqMK7GjiL/Px89erVSzNnzlTHjh01Z84cgkAMS09P15VXXqlSpUqd8vy+ffv06KOPSpJSUlL8KA0hMHXq1F99vnv37vrqq6/0yCOPqH79+mGuyh+EgUKYOnWq1q1bJ0nHF5yZOnWqVq9eLUlq06bN8etcEd1GjhypmTNnKjk5WZUrV9ZTTz112se0adNGNWrUCH9xCLqZM2dq6tSpatasmSpWrKikpCR98803WrZsmTIzM9W+fXt17tzZ7zKBoCMMFMK6des0a9asU5774IMP9MEHH0iSKlWqRBiIEbt375YkZWZmnnE1s0qVKhEGYkSHDh108OBBbdiwQWlpaTp06JDKlCmjhg0b6p577lGnTp0UCAT8LhMIuoDneZ7fRQAAAP84uRwxAAA4gTAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4wgDAAA4jjAAAIDjCAMAADiOMAAAgOMIAwAAOI4wAACA4/4/AkvbpkMiFu8AAAAASUVORK5CYII=",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.tree(sym_th=1e-2, sep_th=5e-1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2c2f31d6-be08-4bb8-a678-2c0d3f456722",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|