GitHub_collection_pykan/tutorials/Interp_4_feature_attribution.ipynb
2024-08-11 13:11:23 -04:00

434 lines
44 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "f8ba3161",
"metadata": {},
"source": [
"# Interpretability 4: Feature attribution"
]
},
{
"cell_type": "markdown",
"id": "6535c1f2",
"metadata": {},
"source": [
"How to determine the importance of features? This is known as feature attribution. This notebook shows how to get feature scores in KANs."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1d88fa9d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 6.30e-03 | test_loss: 6.26e-03 | reg: 4.54e+00 | : 100%|█| 40/40 [00:21<00:00, 1.82it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"from sympy import *\n",
"\n",
"# let's construct a dataset\n",
"f = lambda x: x[:,0]**2 + 0.3*x[:,1] + 0.1*x[:,2]**3 + 0.0*x[:,3]\n",
"dataset = create_dataset(f, n_var=4)\n",
"\n",
"input_vars = [r'$x_'+str(i)+'$' for i in range(4)]\n",
"\n",
"model = KAN(width=[4,5,1])\n",
"model.fit(dataset, steps=40, lamb=0.001);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "36296de7",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAreUlEQVR4nO3de1xUdd4H8M8ZQOQmIOJdEwjdXG3NCxYXSxN1xdXR7PLaLrvp82xbZk/trV33qdwuu6/ap1W09qnU10a6+5hlooZiq4VCaBpZGhkIJC0iIjCjwAwwl9/zh80s4HCZ4QznzJnP+/XilTEMfPn+ZubD7/c754wkhBAgIiKSkU7pAoiISHsYLkREJDuGCxERyY7hQkREsmO4EBGR7BguREQkO4YLERHJjuFCRESyY7gQEZHsGC5ERCQ7hgsREcmO4UJERLJjuBARkewYLkREJDuGCxERyS5Q6QKIfIEQAvX19WhqakJ4eDhiYmIgSZLSZRGpFmcuRN0wGo3IzMxEYmIiYmNjERcXh9jYWCQmJiIzMxNGo1HpEolUSeI7URK5duDAAdxxxx0wmUwArs5eHByzltDQUOzcuRPz589XpEYitWK4ELlw4MABZGRkQAgBu93e5dfpdDpIkoScnBwGDFE7DBeiToxGI0aPHg2z2dxtsDjodDqEhISgqqoKUVFR3i+QyAdwz4Wok6ysLJhMpl4FCwDY7XaYTCa89dZbXq6MyHdw5kLUjhACiYmJqKiogDtPDUmSEB8fj7Nnz/IoMiIwXIg6qKurQ2xsbJ/uHxMTI2NFRL6Jy2JE7TQ1NfXp/o2NjTJVQuTbGC5E7YSHh/fp/hERETJVQuTbGC5E7cTExCAhIcHtfRNJkpCQkIDBgwd7qTIi38JwIWpHkiSsXr3ao/s+9thj3Mwn+g439Ik64XkuRH3HmQtRJ1FRUdi5cyckSYJO1/1TxHGG/nvvvcdgIWqH4ULkwvz585GTk4OQkBBIknTNcpfjcyEhIdi3bx/mzZunUKVE6sRwIerC/PnzUVVVhfXr1yM+Pr7DbfHx8Vi/fj3Onz/PYCFygXsuRL0ghMBHH32E22+/HYcOHcLs2bO5eU/UDc5ciHpBkiTnnkpUVBSDhagHDBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMF6IeWCwWnD9/HmfOnAEAlJeXo6GhAXa7XeHKiNSLb3NM1AWj0YidO3fi73//O4qLi9HY2Ii2tjYMHDgQsbGxSEtLw8qVK5GSkoLAwEClyyVSFYYLkQtHjx7FE088gVOnTmHGjBnIyMjAjTfeiPDwcBiNRhQVFWHv3r0oKyvD3Xffjeeffx6xsbFKl02kGgwXok4++OAD/PSnP0V4eDj+9Kc/YeHChWhra8P27dvR2tqKQYMG4Z577oHFYsH27duxdu1afP/738fWrVsxbNgwpcsnUgWGC1E7paWlWLBgAcLCwrB9+3ZMnDgRkiShoqICU6dOxeXLlxEXF4eioiJER0dDCIGCggL8+Mc/xm233YbNmzcjODhY6V+DSHHc0Cf6js1mwx//+EcYDAa88sorzmDpjiRJSE1NxUsvvYTdu3cjNze3n6olUjeGC9F3ysrKsHfvXixbtgypqak9BouDJEnQ6/W4+eabsWnTJlitVi9XSqR+PMSF6DuFhYVoamrCHXfcgXPnzqG5udl5W1VVFWw2GwCgra0NxcXFGDRokPP2kSNHYtmyZVi7di1qamowevTofq+fSE0YLkTf+frrrxEaGor4+Hg89NBD+Pjjj523CSHQ2toKAKiurkZ6errzNkmS8PLLL2Py5MkwmUyorq5muJDfY7gQfcdsNiMwMBDBwcFobW1FS0uLy68TQlxzm9VqRUhISIcQIvJnDBei7wwdOhRmsxlGoxEzZ85EWFiY8zaz2YzCwkJniCQnJztPnJQkCWPHjkVtbS10Oh2io6OV+hWIVIPhQvSdadOmwWKx4Pjx43jxxRc73FZRUYEZM2bg8uXLGDZsGN5++21ERUU5b5ckCWvWrMHw4cO5JEYEHi1G5JSUlIT4+HhkZWWhubkZAQEBHT4cJEmCTqdzfl6n0+HChQt49913kZGRgcjISAV/CyJ1YLgQ4eo+SnR0NFatWoXPPvsMGzZs6PUhxa2trXjuuedgNpvx0EMP9foQZiIt47IY+T3H1Y0lScKDDz6I/Px8vPjiiwgNDcXDDz+MgQMHAgACAwMRGBjonMUIIdDY2IgXXngB27dvx7p16zBhwgTFfg8iNeHlX8hvCSHgePjrdP+exF+6dAmrVq3C+++/j/nz5+OJJ57ADTfcgJKSEtjtdgwYMADXX389jh8/jv/5n//B559/jmeffRYPP/xwh+UzIn/GcCG/0z5UJElyuYzV3NyMTZs2YcOGDbh48SLi4+ORmJiIiIgIGAwGlJSUoLq6GtOmTcMzzzyDW2+9tUNAEfk7hgv5jd6ESmc1NTU4dOgQDh8+jC+++ALHjx9HWloaUlJSMG/ePMycOROhoaHeLp3I5zBcyC90tQTmjhMnTiApKQknTpzA9OnT5SyPSHO4oU+a5slspSsBAQHOw5CJqHsMF9IkOUOFiNzHcCHNcQQLQ4VIOQwX0gzOVojUg+FCPo+hQqQ+DBfyae3PrmeoEKkHw4V8khyHFhOR9zBcyKdwCYzINzBcyCcwVIh8C8OFVI+HFhP5HoYLqRZnK0S+i+FCqsNQIfJ9DBdSFR5aTKQNDBdSBc5WiLSF4UKKYqgQaRPDhRTDJTAi7WK4UL/j2fVE2sdwoX7DJTAi/8FwIa9jqBD5H4YLeRXPrifyTwwX8grOVoj8G8OFZMVQISKA4UIy4qHFROTAcKE+46HFRNQZw4U8xiUwIuoKw4XcxlAhop4wXMgtXAIjot5guFCvcLZCRO5guFC3GCpE5AmGC3WJZ9cTkacYLnQNzlaIqK8YLuTEUCEiuTBcCADPricieTFc/BwPLSYib2C4+CkugRGRNzFc/BCPAiMib5OE489X0qT2w+sIFFe3t/88AwcwmUzXfE4IAbvdDp1O16FHly5dwsMPP4x9+/b1Z4lEqsZw8RN2ux2SJHUImM7/5p5L99o/VRx9MxgMWLlyJbKyshAREaFUaUSqw2UxP+FqCazz51zNbOjfOvfmypUrWLlyJbZs2cJgIeqEf6r6AVd/cbv6Nyexvdfc3IwVK1bg9ddfR3R0tNLlEKkOw8UPdJ6RcCO/b5qbm/Hggw9i48aNiI2NVbocIlViuJATA6dnjmD5y1/+ghEjRihdDpFqMVyoAy6Nda19sIwePVrpcohUjeFCTtx36RqDhcg9DBc/wOWuvmGwELmP4ULUjaamJgYLkQcYLn7AnZkLZzn/5jhBct26dQwWIjfxJEqNc/fESIbLVQ0NDVi1ahU2bNiAYcOGKV0Okc/h5V80zpPhZcAAeXl5CA0NRWRkZK/vM2HCBC9WRORbGC4ax3DxTFtb2zWf6+pKBw4DBgzwak1EvoR7LhrnOBu/80f7S+53/qCrQdH5o7i4GCEhISguLnZ5OxH9G8OFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSXaDSBfSFEAL19fVoampCeHg4YmJiIEmS0mWpmqNnV65cwaBBg9izXhJCwGAwAAAMBgOEEOxbD/j89IxW+uaTMxej0YjMzEwkJiYiNjYWcXFxiI2NRWJiIjIzM2E0GpUuUXU69ywhIYE964X2fZs7dy6EEJg7dy771g0+Pz2jub4JH5ObmyvCwsKEJElCkiQBwPnh+FxYWJjIzc1VulTVYM88w765jz3zjBb75lPhkpubKwICAoROp+vQ/M4fOp1OBAQE+NRAeAt75hn2zX3smWe02jdJCCHkng15g9FoxOjRo2E2m2G323v8ep1Oh5CQEFRVVSEqKsr7BaoQe+YZ9s197JlntNw3n9lzycrKgslk6tUAAIDdbofJZMJbb73l5crUiz3zDPvmPvbMM1rum0/MXIQQSExMREVFBdwpV5IkxMfH4+zZsz55tEVfsGeeYd/cx555Rut984lwqaurQ2xsbJ/uHxMTI2NF6seeeYZ9cx975hmt980nlsWampr6dP/GxkaZKvEd7Jln2Df3sWee0XrffCJcwsPD+3T/iIgImSrxHeyZZ9g397FnntF633wiXGJiYpCQkOD2+qIkSUhISMDgwYO9VJl6sWeeYd/cx555Rut984lwkSQJq1ev9ui+jz32mKo3vbyFPfMM++Y+9swzWu+bT2zoA9o+Htxb2DPPsG/uY888o+W++cTMBQCioqKwc+dOSJIEna77snU6HSRJwnvvvaf6AfAm9swz7Jv72DPPaLpv/X1JgL7q7TV4Dhw4oHSpqsGeeYZ9cx975hkt9s3nwkUIIQwGg8jMzBQJCQkdBiEhIUFkZmYKo9GodImqw555hn1zH3vmGa31zSfDxcFut4tDhw4JAOLQoUPCbrcrXZLqsWeeYd/cx555Rit985k9F1ckSXKuPUZFRan+6Ak1YM88w765jz3zjFb65tPhQkRE6sRwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2TFciIhIdgwXIiKSnc+GS1NTE0pLS3H69GkAQE1NDdra2hSuSv2amppQWVkJADhz5gz+9a9/sW89sFgsOH/+PM6cOQMAKC8vR0NDA+x2u8KVqRsfa+7T0uuaJIQQShfhjoqKCmzevBl79uzBv/71L1gsFrS2tmLQoEG46aab8JOf/ATLli1DRESE0qWqSvu+VVZWwmw2Y8CAAQgLC8PkyZPZNxeMRiN27tyJv//97yguLkZjYyPa2towcOBAxMbGIi0tDStXrkRKSgoCAwOVLlc1+FhznxZf13wmXGw2G/7v//4Pa9asgdlsxg9/+EOkp6dj7NixsNvtKCsrw/79+/HRRx9h6tSp2LhxIyZOnKh02Ypj3zxz9OhRPPHEEzh16hRmzJiBjIwM3HjjjQgPD4fRaERRURH27t2LsrIy3H333Xj++ecRGxurdNmK4mPNfZrumfABNptNvPrqqyIsLEz88Ic/FF988YWwWq2isLBQZGZmiszMTHHmzBnR1tYmDh8+LKZPny4mTJggTp8+rXTpimLfPHPgwAExYsQIkZiYKN59911hMpmE0WgUr732msjMzBR/+9vfhNlsFleuXBFvvPGGGDlypEhPTxc1NTVKl64YPtbcp/We+US4fPTRRyIqKkosX75cNDQ0CLvdLoQQ4r//+78FAAFAbN26VQghhN1uF5WVlSI5OVmkpqYKg8GgYOXKYt/cV1JSIuLi4sSkSZPEl19+6exZeXm5iIyMFABEXFycaGhoEEJc7duRI0fE6NGjxX333SdaWlqULF8xfKy5T+s9U/2GvtlsxrPPPothw4Zh3bp1iIqKgiRJXX69JEkYM2YMNm7ciNLSUmzbtq0fq1UP9s19NpsNf/zjH2EwGPDKK69g4sSJ3fYMuNq31NRUvPTSS9i9ezdyc3P7qVr14GPNff7QM9WHS1FREY4dO4ZHHnkEo0aN6vHJDlwdiClTpuCuu+7Cm2++CZPJ1A+Vqgv75r6ysjLs3bsXy5YtQ2pqaq96Blztm16vx80334xNmzbBarV6uVJ14WPNff7QM9Uf4pKXl4fg4GDMnTsXZ86c6fDEvXjxovPf3377LU6dOuX8/6ioKOj1emzbtg3nzp3znU0wmbBv7issLERTUxPuuOMOnDt3Ds3Nzc7bqqqqYLPZAABtbW0oLi7GoEGDnLePHDkSy5Ytw9q1a1FTU4PRo0f3e/1K4WPNfX7RM6XX5Xpy3333ifHjx4vS0lIxduxYMXDgQOdHYGCgc20yKCiow20PPvig+Oabb8SQIUPE/v37lf41+h375r7f/OY3IioqSpw5c0bcfvvtHfoSHBzs7JkkSR1uCwkJEX/9619Ffn6+iIiIEJ988onSv0q/4mPNff7QM1XPXIQQaGlpQXBwMAICAtDS0oKWlhaXX2uxWGCxWJz/39bWhgEDBjjv50/YN8+YzWYEBgYiODgYra2tXf7+jv62Z7VaERISAiEEWltb+6NcVeBjzX3+0jNVh4skSRgyZAiOHz8Om82G2bNnw2g0Om8/e/YsKioqAACTJ0/GyJEjnbfdeOONMBqNaG1txeDBg/u7dEWxb57R6XQwmUwwGo2YOXMmwsLCnLeZzWYUFhY6QyQ5Odl54qQkSRg7dixqa2uh0+kQHR2t1K/Q7/hYc5/f9EzJaVNvbNq0SYSEhIgjR44Iq9Xa4WPNmjXO6WNWVlaH22w2m3jzzTfF8OHDRVVVldK/Rr9j33pms9nE8ePHxZo1a8T3v/99IUmSkCRJvPbaa9f0rLS01Hko8rhx40RdXd01fXvyySfF+PHjRX19vfOwUn/Ax5r7/KFnqj9abM6cOYiIiEBWVhaEEAgICHB+6HT/Ll+n03W4raWlBW+99RZSU1MxfPhwBX8DZbBvrrW1teHgwYN49NFHMW7cONx888144403MGPGDGRlZWH8+PHIyspCc3Nzh74EBAQ4v4ckSR36ptPpcOHCBbz77rtYuHAhIiMjYbfbYbPZYLfbIXzjIhge42PNff7QM9WHy7hx43Dvvfdix44dOHDgQK+eqHa7HW+++SZOnjyJ1atXd3hh8Bfs2781NjZi586duO+++zB8+HDMnz8f+/btwx133IEPP/wQFy5cwN/+9jfcf//9WL16NT777DNs2LCh14cUt7a24rnnnoPZbMbPf/7za14kOgeN1sKGjzX3+UXPFJoxueXChQtixowZYsyYMeLgwYPCZrMJIYR4+umnRWBgoAgKChLbtm0TdrtdWCwWsXXrVjFkyBCxZs0aYbVaFa5eOf7ct5qaGrF582axaNEiERISInQ6nZgyZYpYu3atOHnyZJfLVoWFhSIiIkKEhoaKl19+WZhMJmG320V5ebmIiYkRgYGB4vrrr3eeUX358mXx61//WkRERIgtW7Z0WY/dbhc2m+2aJQ6tLJ/582PNU1rvmU+EixBCFBcXi6lTp4rBgweL3//+96KsrEyUlpaKvLw8kZeXJyorK8WpU6fEz3/+cxEZGSlWrVolmpublS5bcf7Ut7Nnz4qXX35ZpKWliYCAABEYGChuu+02sW7dOlFeXt7j/QsLC0V8fLyYM2eO0Ov1IiQkROj1enH48GFRW1sr8vPzxeHDh8XRo0fFpUuXRE5Ojpg9e7aIjIwUK1euFAUFBc4XiO5oNWj86bEmFy33zGeuigwA58+fx3PPPYe3334bgYGBmDhxIsaMGQObzYZz586hpKQEMTExePLJJ3H//fcjODhY6ZJVQat9E0Lgs88+w+7du7F79258+eWXCA4ORnp6OvR6PRYtWtTrKxUfPHgQK1asQFJSErKysgAAmzZtwoYNG3Dx4kXEx8cjMTERERERMBgMKCkpQXV1NaZNm4ZnnnkGiYmJOH78OIYOHYpbbrml10sW4rtlsvZPQ0mSnB++RquPNW/Sas98KlyAq9d/OnPmDHJycnD8+HHU1tYiKCgIcXFxmD17NubNm4ehQ4cqXabqdNW3cePGYc6cOT7TN4vFgoKCAuzatcv53hdRUVFYtGgR9Ho90tPTER4e7tb33L17Nx555BHMnTsXb7zxRocnb01NDQ4dOoTDhw+joqICLS0tiI6OxqRJkzBv3jzMnDkToaGhAIDa2loUFhYiKioKKSkpCAoKcvv36xw2vhg0fI66T4s987lwaU8IAZvNBkmS1L+5pSKOvgFXj0Zpf3SKGjU3N+ODDz5AdnY2cnJyYDAYMHr0aCxZsgR6vR5paWkevZADwLZt2/CrX/0Ky5cvx/r167t90y+bzQYhRLc9q6+vR0FBAcLCwpCWltanvzK1EDR8jrpPKz3z6XChvrPb7aoMl0uXLiEnJwe7du3CwYMH0dLSgkmTJmHJkiVYsmQJpk6d2ucX2f/93//F2rVrsWLFCrzwwguy9eHy5cs4cuQIgoODkZaWhpCQkD5/T60tn5H2MVz8nN1uV80L1DfffOPcPykoKIAQAsnJyc5Auf7662X5OUIIvPjii1i3bh3+67/+C7/73e9k//2bmppw5MgRSJKEtLQ0t5fqusOgIV/AcPFzSoaLEAJffPEFsrOzsXv3bpw6dQoDBgzA3LlznRvyw4YNk/Vn2u12PPXUU9i8eTOeeuopPProo7J+//ZMJhPy8/NhsVgwa9asDldRlguDhtSK4eLnHC9M/bU0ZrVa8fHHHzsDpbKyEpGRkVi4cCH0ej3mz5+PiIgIr/3sX/ziF9ixYwdeeuklPPDAA175Oe21trYiPz8fJpMJqampXr0elOOp7Ov7NKQNDBc/1x/hYjKZcPDgQezatQvvv/8+GhoaMHLkSOeG/KxZszBgwACv/Xzg6mVfHnroIXzwwQd45ZVXsHTpUq/+vPYcR7hdvnwZKSkpvT48uq+0cEAA+S6GC3llaay+vh45OTnIzs7GBx98ALPZjBtuuMEZKNOmTeu32ZLJZMJPf/pTHDt2DFu2bEF6enq//Nz2rFYrjh49ikuXLuGWW27BiBEj+vXnc/mM+hvDhWQLl8rKSuzevRvZ2dkoKCiAzWbDLbfcgiVLlmDx4sWYMGGCTBX33uXLl3Hvvffiq6++wrZt25CcnNzvNTjY7XZ88sknqK6uRlJSEsaMGaNIHQwa6g8MF+qwbOLu/U6fPu08wuvkyZMICgrCnDlzoNfr8aMf/ajf/0Jvr66uDnfddRfOnz+Pt99+G1OmTFGsFgchBIqKinDu3DlMnToV8fHxitfDoCFvYLiQW/suNpsNhYWFzg35b775BhEREc4N+QULFnjlqCh3nT9/HnfeeSeampqwY8cOfO9731O6JCfHUXJlZWWYPHmyIjM6V3hAAMlJ1e9ESepgNptx6NAhZGdnY+/evairq8Pw4cOxePFi6PV63Hbbbaq63lF5eTmWL1+OwMBA7NmzB+PGjVO6pA4kScKUKVMwYMAAnD59GhaLBZMmTVK6LGeAOP7rCBm73e78PIOGeovhQpAkyflC4njhMBgMzg35AwcOwGQyYfz48VixYgWWLFmCpKQkVZ7ZX1xcjLvuugvR0dF45513FF2W68nEiRMRGBiIU6dOwWKxYMqUKap64W4fJK6CpvPXELXHZTECcHWzuaqqCnv37kV2djYOHz4Mm82GpKQk5xFealpacuXTTz/Fj3/8Y4wbNw7bt29X/3uMf+ebb75BUVERxo4dixkzZqj+xZr7NNQbDBc/JoTAV199hezsbGRnZ6OoqAhBQUGYPXu2c0N+1KhRSpfZK0eOHMEDDzyAKVOmYOvWrV47EdNbqqqq8Mknn2DEiBGYOXOmz1ywkEFDXWG4+BmbzYZjx445j/AqKytDeHg4FixYgCVLlmDhwoWIiopSuky37Nu3Dz/72c8wa9YsbNmyRZYLRSqhpqYGR48eRUxMDJKTk7u9QrMatQ8ZHhBADBc/0NLSgo8++gi7du3C3r17UVtbi6FDhzo35GfPno2BAweq9grJ3dmxYwcef/xxLFq0CK+++qrHl95Xi7q6OhQUFGDQoEFITU31+pULvIlXCPBvDBeNMhqN2L9/P7Kzs5Gbm4umpiYkJCRAr9dDr9e7XHpR0xWSe2PLli1Ys2YN7r33Xvz5z3/2maWknhgMBuTn5yMkJARpaWkYOHCg0iX1maugcfzXVx5v5B6Gi4acP38ee/bsQXZ2NvLy8mC1WjF9+nTnJesnTpzY7RPZV8JFCIHMzEz86U9/wsMPP4xnnnlG9TW768qVK8jPz0dAQABmzZrlfLdLLeA+jX9guPgwIQS+/vpr5/7J8ePHERAQgFtvvRV6vR6LFy926xIj/X2FZE8IIfDcc8/h1VdfxW9/+1s8/vjjmn1Bam5uxpEjR2C32zFr1iyfO0ihNxg02sVw8TGO61M5AqW0tBShoaFYsGAB9Ho9Fi5ciOjoaI++t9rDxWaz4cknn8TWrVvx/PPP4z//8z+VLsnrzGYz8vPz0drairS0NJ872MIdPCBAWxguPqC1tRV5eXnIzs7Gnj17UFNTgyFDhuBHP/oRli5dijlz5sh2hJRaN/UtFgseffRR7NmzB+vXr8fdd9+tdEn9pq2tDfn5+WhqakJqaipiYmKULqlf8IAA38ZwUakrV644N+T379+PxsZGxMXFQa/XY8mSJUhOTvbKBrYa911aWlrwH//xH8jLy8Prr7+OjIwMpUvqd443WWtoaEBycrLs79Cpdl0FjePfpD4MFxW5cOGC8wz5Dz/8EBaLBTfddJPzDPlJkyZ5/YmktnBpbGzEAw88gJMnTyIrKwu33nqr0iUpxnGO0sWLFzFz5kyfOcFVbtyn8Q0MF4WVlpY6rzB87NgxBAQEIC0tzbkhf9111/VrPWradzEYDLjnnntQUVGBf/zjH5gxY4bSJSnObrfjxIkTqKqqwrRp01R3Uc7+xqBRL986BVgD7HY7Pv30U+eG/JkzZxASEoJ58+bhzTffxMKFC/1mTb07NTU1uPPOO1FfX49du3ap4qrBaqDT6ZCUlISgoCB8+umnsFgsSExMVLosxbi6uCb3adSBM5d+0NbWhsOHDzs35KurqxETE4NFixZBr9dj7ty5qjqPQemlscrKSixfvhxWqxXvvPMOrr/+ekXqULvTp0+jpKQEEydOxMSJE5UuR3UYNMrizMVLGhsbceDAAWRnZ2Pfvn24fPkyrrvuOixfvhx6vR4pKSk+d+2o/lBaWoo777wToaGhyM7O9tt9hd6YPHkygoKC8OWXX8JqteLGG29UuiRV6ektA3hAgHdx5iKjixcvOjfkDx06hLa2NvzgBz9wniH/gx/8wCceyJ6+7XFfff7557jnnnswYsQI7NixA7Gxsf36831VeXk5Tp48ibi4OEydOtUnHmNK4j5N/2C49FFZWZlz/6SwsBCSJCE1NdW5IR8XF6d0iW5TYlO/sLAQ999/PyZMmIB//OMfmj5Z0Bu+/fZbnDhxAqNGjVLtG7mpEYPGexgubhJCoKioyBkoxcXFGDhwINLT06HX65GRkeHzf3H3d7j885//xMqVK5GUlISsrCyEhYX1y8/Vmurqahw7dgxDhw7FLbfcopkLefYXBo28GC69YLFYkJ+f7zxkuKqqCtHR0cjIyMDSpUuRnp6uuRfE/trU3717Nx555BGkp6fj9ddfR3BwsFd/ntbV1taisLAQUVFRSElJ8fm3IFASDwjoG4ZLF5qbm50b8jk5OTAajRgzZozzhMbU1FRNP3H7I1y2bduGX/3qV1i+fDnWr1/PAxxkUl9fj4KCAoSFhSEtLY2BLQMGjfsYLu1cunQJ77//PrKzs3Hw4EG0tLRg8uTJzg35m266yW8eTN5eGvvrX/+KP/zhD1ixYgVeeOEF7hHI7PLly8jPz0dQUBBmzZrls+/OqUZcPusdvw+XiooK5/7Jxx9/DCEEUlJSsGTJEixevNhvz7HwVrgIIfDiiy9i3bp1ePzxx/Hb3/6WT0gvaWpqwpEjRyBJEtLS0hAeHq50SZrDoOma34WLEAKff/65M1BOnTqF4OBg3H777Vi6dCkyMjL87qKAXZH7Csl2ux1PPfUUNm/ejKeffhqrVq2S7XuTayaTCfn5+bBYLJg1axYGDRqkdEmaxaDpyC/CxWq1oqCgwLkh/+233yIyMhIZGRnQ6/WYN2+eJt+Iqa/k3HexWq144okn8M477+DPf/4z7r//fhkqpN5obW1Ffn4+TCYTUlNTMXjwYKVL0rz270vjr/s0fhEu+/fvx6JFizBq1CgsXrwYS5cuRVpaGgYMGKB0aaomZ7hs2bIFTz/9NF599VXo9fq+F0dusVgsKCgogMlkwoIFC3iYcj9rP6uRJMkv9hg1Fy5ff/21y9BwDGpnLS0tvC4TAE8eBl2FTmVlpcvvb7VaXR5h19DQgJtuusntn09da2pquuZzdrsdLS0tLq9j19raygumyqyr51RXr0WAti5Fo7ljP0NCQtzaM5k+fTq+/PJLL1bkf9zpf1FREd5++22Gi8y6Ojqsq/OxPvzwQyxevNibJfmtzmGipQDpjubC5brrrsO2bdtw33339err/f39MBwkSeowbQfg8t+9eWIMHDiwVz9z//79+OKLL7Bu3TrPCyeX3Fn2EkL47VGR3uR4Tjn+29NSmMYWkbS3LAYAQUFBsFgsPX5dZWUlRo0axZP32nH1cGgfPHKtFWdnZ6OqqgqrVq3ym7/k1Oqrr77CDTfcwHHwAscRl50vBuvq4rByH52pNO38Ju2sWLGiV1+XlJTEYOmk/REtri5LLsffIu+88w4uXrzIYFEBIQQuXLjAcfAyxx9oAK45XFmrNBkur732GvLy8nr8uqlTp3q/GA1p/wTxhBACWVlZaGxsxM9+9jO+oCnMcRHW2267TelS/ELnpWetB4wml8WAqxuaZrO5y9uPHDmC5ORkzlzc5OnUXQiBjRs3YsSIEbjzzju9UBl1x2w2w2g0QpIkBAcHQwiB6upqDBkyBMOHD1e6PM1q/3zpfHJl56UxrS2LaTZctmzZgp/85CddhkdYWBiam5v7uSrf58kbiQkh8OyzzyIlJQVz5871VmnUhaqqKhiNRgwdOhRCCLS1tUGSJERHR2vuat5q091hx8C14aOl2bxmw0UIgYULF2L//v3X3HblyhXk5eXx0EsPufMXlt1ux5NPPom7774b06dP93Jl1FlDQwNqa2sxYcIETb1waYWWw0Wza0KSJGHSpElobm6+5q+z5ORkntvSD2w2Gx5//HGsWrUK3/ve95Qux+8IIXDq1CnceuutmnrR8hW9CQtXh/1rhWZnLsDVAUtNTUVBQYFz4I4ePYqLFy/yEiR90JulMavVitWrV+PXv/414uPj+6s0aufEiROYOnUqL/WikN4ERvsjyLS03wJo9GgxB0mS8MYbb+CXv/wlqqur8fHHH2Pjxo0Mlj7q6UgXi8WCRx55BL/73e8YLAoRQiAoKIjBonJaPmpMczMXV9e1KikpQU5ODgYPHowHHnjgmr8Qrrvuuv4qT7XcfRh095dWQUEBgoOD3ToKacyYMW79fOpebW2t25v13NyXl5zX6/NFmguXro4A626KyieV6ydCT1dF7urzJpPJ5fe3Wq0IDAx0eT9XF1Mkz1mt1ms+Z7fbYTabERIS4vIPAx6WLy9/v3Cl5pbFwsLCrvkoKSlBZGQkSkpKXN5Ors/Md7wAubqtuydBaGjoNR/l5eUYP348ysvLXd5O8goMDLzmw2Qy4Z///CdMJpPL20leXT1nXL2/ixbf50Vz4ULykfOSL0TkXxgu1C0tbzgSkfcwXKhbnL0QkScYLtQjzl6IyF0MFyIikh3DhXrEpTEichfDhXqFS2NE5A6GC/UKZy9E5A6GC/UaZy9E1FsMF+o1zl6IqLcYLuQWzl6IqDcYLkREJDuGC7mFS2NE1BsMF3Ibl8aIqCcMF3IbZy9E1BOGC3mEsxci6g7DhTzC2QsRdYfhQh7j7IWIusJwISIi2TFcyGNcGiOirjBcqE+4NEZErjBcqE84eyEiVxgu1GecvRBRZwwX6jPOXoioM4YLyYbhQkQODBeShWP2QkQEMFxIJlwaI6L2GC4kG27sE5EDw4Vkw9kLETkwXEhWnL0QEcBwIZlx9kJEAMOFiIi8gOFCsuPSGBExXEh2XBojIoYLeQVnL0T+jeFCXsHZC5F/03y4CCFgMBgAAAaDgS92/cgxe3GMgc1m4xgoRAiB+vp61NbWor6+nmOgACEE6urqcO7cOdTV1Wl/DIRGGQwGsX79epGQkCAAOD8SEhLE+vXrhcFgULpEzTMYDOIvf/mLiI+P5xgohM8D5fnrGGgyXHJzc0VYWJiQJElIktRhQB2fCwsLE7m5uUqXqlmOMWjfe45B/+LzQHn+PAaaC5fc3FwREBAgdDqdyxc2x4dOpxMBAQGaHFSlcQyUxzFQnr+PgSSEdhb+jEYjRo8eDbPZDLvd3uPX63Q6hISEoKqqClFRUd4v0A9wDJTHMVAex0BjG/pZWVkwmUy9GkwAsNvtMJlMeOutt7xcmf/gGCiPY6A8jgGgmZmLEAKJiYmoqKhw6ygMSZIQHx+Ps2fP8g2v+ohjoDyOgfI4BldpJlzq6uoQGxvbp/vHxMTIWJH/4Rgoj2OgPI7BVZpZFmtqaurT/RsbG2WqxH9xDJTHMVAex+AqzYRLeHh4n+4fEREhUyX+i2OgPI6B8jgGV2kmXGJiYpCQkOD2WqUkSUhISMDgwYO9VJn/4Bgoj2OgPI7BVZoJF0mSsHr1ao/u+9hjj2liA01pHAPlcQyUxzG4SjMb+gCPLVcDjoHyOAbK4xhoaOYCAFFRUdi5cyckSYJO1/2vptPpIEkS3nvvPc0MphpwDJTHMVAexwDavHBlb6/nc+DAAaVL1SyOgfI4Bsrz5zHQZLgIcfVKpJmZmS6vRJqZmSmMRqPSJWoex0B5HAPl+esYaGrPxRUhBBoaGtDY2IiIiAgMHjxYMxtmvoJjoDyOgfL8bQw0Hy5ERNT/NLWhT0RE6sBwISIi2TFciIhIdgwXIiKSHcOFiIhkx3AhIiLZMVyIiEh2DBciIpIdw4WIiGTHcCEiItkxXIiISHYMFyIikh3DhYiIZMdwISIi2f0/421YPKKgu+MAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 500x400 with 32 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "8c782f62",
"metadata": {},
"source": [
"get feature score (for input variables)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2693a8c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.8906, 0.5176, 0.1139, 0.0041], grad_fn=<MeanBackward1>)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.feature_score"
]
},
{
"cell_type": "markdown",
"id": "9fb3a0a8",
"metadata": {},
"source": [
"Inspect how hidden nodes depend on features"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2f80a6e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.8900, 0.5142, 0.1136, 0.0038], grad_fn=<SelectBackward0>)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWkAAAESCAYAAAA/niRMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAATnklEQVR4nO3db0yV9/3/8dcRxsGtnmOQeioVka1rw0baxcPqwNJkdj3fUGNGskQWE9EWkxK1BFmbSU3qaprgus3ZrINqLG2a2JZ0xa1JmXqSTcSSJoVg1p+a/VN3aIUyaHYOtRtEuH43jCc5A5TryHrewPORXDe4vD6c97nSPnN5cbzwOI7jCABg0oJUDwAAmBqRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYempHmA6xsfHdfnyZS1atEgejyfV4wDALXMcR8PDw8rJydGCBVNfL8+KSF++fFm5ubmpHgMAZlxvb6+WL18+5Z/PikgvWrRI0rU34/P5UjwNANy6WCym3NzceN+mMisiff0Wh8/nI9IA5pSb3cLlB4cAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAybFf+Y5Vas3PVuqkf4wlzaty7VIwCYYVxJA4BhRBoADCPSAGAYkQYAw4g0ABhGpAHAMCINAIYRaQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAxLKtKNjY3Kz89XZmamgsGgOjo6bnj8kSNHdN999+nLX/6yli1bpkcffVRDQ0NJDQwA84nrSLe0tKi2tla7d+9WT0+PSktLVVZWpkgkMunxp0+fVmVlpaqqqnT27Fm99dZb+uCDD7R169ZbHh4A5jrXkd6/f7+qqqq0detWFRQU6MCBA8rNzVVTU9Okx7///vtauXKlampqlJ+frwceeECPP/64urq6bnl4AJjrXEV6dHRU3d3dCoVCCftDoZA6OzsnXVNSUqKPPvpIbW1tchxHn3zyiX7zm99o3bqpf7P1yMiIYrFYwgYA85GrSA8ODmpsbEyBQCBhfyAQUH9//6RrSkpKdOTIEVVUVCgjI0N33HGHFi9erF/96ldTvk5DQ4P8fn98y83NdTMmAMwZSf3g0OPxJHztOM6EfdedO3dONTU1euaZZ9Td3a1jx47p4sWLqq6unvL719fXKxqNxrfe3t5kxgSAWS/dzcHZ2dlKS0ubcNU8MDAw4er6uoaGBq1Zs0ZPPfWUJOnee+/VV77yFZWWluq5557TsmXLJqzxer3yer1uRgOAOcnVlXRGRoaCwaDC4XDC/nA4rJKSkknXfP7551qwIPFl0tLSJF27AgcATM317Y66ujodPnxYzc3NOn/+vHbu3KlIJBK/fVFfX6/Kysr48evXr1dra6uampp04cIFvffee6qpqdH999+vnJycmXsnADAHubrdIUkVFRUaGhrS3r171dfXp8LCQrW1tSkvL0+S1NfXl/CZ6S1btmh4eFgvvviifvSjH2nx4sVau3atfvrTn87cuwCAOcrjzIJ7DrFYTH6/X9FoVD6fz9Xalbve/R9NZc+lfVN/rBGALdPtGs/uAADDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAwj0gBgGJEGAMOINAAYRqQBwDAiDQCGEWkAMIxIA4BhRBoADCPSAGAYkQYAw4g0ABhGpAHAMCINAIYRaQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAwj0gBgGJEGAMOINAAYRqQBwDAiDQCGEWkAMIxIA4BhSUW6sbFR+fn5yszMVDAYVEdHxw2PHxkZ0e7du5WXlyev16uvfe1ram5uTmpgAJhP0t0uaGlpUW1trRobG7VmzRodPHhQZWVlOnfunFasWDHpmg0bNuiTTz7Ryy+/rLvuuksDAwO6evXqLQ8PAHOdx3Ecx82C1atXa9WqVWpqaorvKygoUHl5uRoaGiYcf+zYMf3whz/UhQsXlJWVldSQsVhMfr9f0WhUPp/P1dqVu95N6jVno0v71qV6BADTNN2uubrdMTo6qu7uboVCoYT9oVBInZ2dk6555513VFRUpOeff1533nmn7r77bj355JP697//PeXrjIyMKBaLJWwAMB+5ut0xODiosbExBQKBhP2BQED9/f2Trrlw4YJOnz6tzMxMHT16VIODg9q2bZs+/fTTKe9LNzQ06Nlnn3UzGgDMSUn94NDj8SR87TjOhH3XjY+Py+Px6MiRI7r//vv1yCOPaP/+/Xr11VenvJqur69XNBqNb729vcmMCQCznqsr6ezsbKWlpU24ah4YGJhwdX3dsmXLdOedd8rv98f3FRQUyHEcffTRR/r6178+YY3X65XX63UzGgDMSa6upDMyMhQMBhUOhxP2h8NhlZSUTLpmzZo1unz5sj777LP4vr/85S9asGCBli9fnsTIADB/uL7dUVdXp8OHD6u5uVnnz5/Xzp07FYlEVF1dLenarYrKysr48Rs3btSSJUv06KOP6ty5czp16pSeeuopPfbYY1q4cOHMvRMAmINcf066oqJCQ0ND2rt3r/r6+lRYWKi2tjbl5eVJkvr6+hSJROLH33bbbQqHw3riiSdUVFSkJUuWaMOGDXruuedm7l0AwBzl+nPSqcDnpKeHz0kDs8f/5HPSAIAvFpEGAMOINAAYRqQBwDAiDQCGEWkAMIxIA4BhRBoADCPSAGAYkQYAw4g0ABhGpAHAMCINAIYRaQAwzPXzpDE38UhXwCaupAHAMCINAIYRaQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAwj0gBgGJEGAMOINAAYRqQBwDAiDQCGEWkAMIxIA4BhRBoADCPSAGAYkQYAw4g0ABiWVKQbGxuVn5+vzMxMBYNBdXR0TGvde++9p/T0dH3rW99K5mUBYN5xHemWlhbV1tZq9+7d6unpUWlpqcrKyhSJRG64LhqNqrKyUg899FDSwwLAfOM60vv371dVVZW2bt2qgoICHThwQLm5uWpqarrhuscff1wbN25UcXHxTV9jZGREsVgsYQOA+chVpEdHR9Xd3a1QKJSwPxQKqbOzc8p1r7zyiv7+979rz54903qdhoYG+f3++Jabm+tmTACYM1xFenBwUGNjYwoEAgn7A4GA+vv7J13z17/+Vbt27dKRI0eUnp4+rdepr69XNBqNb729vW7GBIA5Y3rV/C8ejyfha8dxJuyTpLGxMW3cuFHPPvus7r777ml/f6/XK6/Xm8xoADCnuIp0dna20tLSJlw1DwwMTLi6lqTh4WF1dXWpp6dHO3bskCSNj4/LcRylp6frxIkTWrt27S2MDwBzm6vbHRkZGQoGgwqHwwn7w+GwSkpKJhzv8/n04Ycf6syZM/Gturpa99xzj86cOaPVq1ff2vQAMMe5vt1RV1enTZs2qaioSMXFxTp06JAikYiqq6slXbuf/PHHH+u1117TggULVFhYmLB+6dKlyszMnLAfADCR60hXVFRoaGhIe/fuVV9fnwoLC9XW1qa8vDxJUl9f300/Mw0AmB6P4zhOqoe4mVgsJr/fr2g0Kp/P52rtyl3v/o+msufSvnVJr+U8AV+s6XaNZ3cAgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAwj0gBgGJEGAMOINAAYRqQBwDAiDQCGEWkAMIxIA4BhRBoADCPSAGAYkQYAw4g0ABhGpAHAMCINAIYRaQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAwj0gBgGJEGAMOINAAYRqQBwLCkIt3Y2Kj8/HxlZmYqGAyqo6NjymNbW1v18MMP6/bbb5fP51NxcbGOHz+e9MAAMJ+4jnRLS4tqa2u1e/du9fT0qLS0VGVlZYpEIpMef+rUKT388MNqa2tTd3e3vvvd72r9+vXq6em55eEBYK7zOI7juFmwevVqrVq1Sk1NTfF9BQUFKi8vV0NDw7S+xze/+U1VVFTomWeemdbxsVhMfr9f0WhUPp/PzbhauetdV8fPZpf2rUt6LecJ+GJNt2uurqRHR0fV3d2tUCiUsD8UCqmzs3Na32N8fFzDw8PKysqa8piRkRHFYrGEDQDmI1eRHhwc1NjYmAKBQML+QCCg/v7+aX2PX/ziF7py5Yo2bNgw5TENDQ3y+/3xLTc3182YADBnJPWDQ4/Hk/C14zgT9k3mjTfe0E9+8hO1tLRo6dKlUx5XX1+vaDQa33p7e5MZEwBmvXQ3B2dnZystLW3CVfPAwMCEq+v/1tLSoqqqKr311lv63ve+d8NjvV6vvF6vm9EAYE5ydSWdkZGhYDCocDicsD8cDqukpGTKdW+88Ya2bNmi119/XevW8UMbAJguV1fSklRXV6dNmzapqKhIxcXFOnTokCKRiKqrqyVdu1Xx8ccf67XXXpN0LdCVlZV64YUX9J3vfCd+Fb5w4UL5/f4ZfCsAMPe4jnRFRYWGhoa0d+9e9fX1qbCwUG1tbcrLy5Mk9fX1JXxm+uDBg7p69aq2b9+u7du3x/dv3rxZr7766q2/AwCYw1xHWpK2bdumbdu2Tfpn/x3ekydPJvMSAADx7A4AMI1IA4BhRBoADCPSAGAYkQYAw4g0ABhGpAHAMCINAIYRaQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGJbUo0qB+WrlrndTPcIX6tI+fpNSqnElDQCGEWkAMIxIA4BhRBoADCPSAGAYkQYAw4g0ABhGpAHAMCINAIYRaQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAwj0gBgGJEGAMOINAAYRqQBwLCkIt3Y2Kj8/HxlZmYqGAyqo6Pjhse3t7crGAwqMzNTX/3qV/XSSy8lNSwAzDeuI93S0qLa2lrt3r1bPT09Ki0tVVlZmSKRyKTHX7x4UY888ohKS0vV09Ojp59+WjU1NXr77bdveXgAmOvS3S7Yv3+/qqqqtHXrVknSgQMHdPz4cTU1NamhoWHC8S+99JJWrFihAwcOSJIKCgrU1dWln//85/rBD34w6WuMjIxoZGQk/nU0GpUkxWIxt+NqfORz12tmq2TOz3Wcp+mZT+dJSv5cFe45PsOT2Pb/nv0/12uun1vHcW58oOPCyMiIk5aW5rS2tibsr6mpcR588MFJ15SWljo1NTUJ+1pbW5309HRndHR00jV79uxxJLGxsbHN+a23t/eG3XV1JT04OKixsTEFAoGE/YFAQP39/ZOu6e/vn/T4q1evanBwUMuWLZuwpr6+XnV1dfGvx8fH9emnn2rJkiXyeDxuRk6JWCym3Nxc9fb2yufzpXocszhP08N5mp7Zdp4cx9Hw8LBycnJueJzr2x2SJoTScZwbxnOy4yfbf53X65XX603Yt3jx4iQmTS2fzzcr/mNJNc7T9HCepmc2nSe/33/TY1z94DA7O1tpaWkTrpoHBgYmXC1fd8cdd0x6fHp6upYsWeLm5QFg3nEV6YyMDAWDQYXD4YT94XBYJSUlk64pLi6ecPyJEydUVFSkL33pSy7HBYD5xfVH8Orq6nT48GE1Nzfr/Pnz2rlzpyKRiKqrqyVdu59cWVkZP766ulr/+Mc/VFdXp/Pnz6u5uVkvv/yynnzyyZl7F8Z4vV7t2bNnwi0bJOI8TQ/naXrm6nnyOM7NPv8xUWNjo55//nn19fWpsLBQv/zlL/Xggw9KkrZs2aJLly7p5MmT8ePb29u1c+dOnT17Vjk5Ofrxj38cjzoAYGpJRRoA8MXg2R0AYBiRBgDDiDQAGEakAcAwIj3D3D7GdT46deqU1q9fr5ycHHk8Hv32t79N9UjmNDQ06Nvf/rYWLVqkpUuXqry8XH/+859TPZZJTU1Nuvfee+P/0rC4uFi///3vUz3WjCHSM8jtY1znqytXrui+++7Tiy++mOpRzGpvb9f27dv1/vvvKxwO6+rVqwqFQrpy5UqqRzNn+fLl2rdvn7q6utTV1aW1a9fq+9//vs6ePZvq0WYEH8GbQatXr9aqVavU1NQU31dQUKDy8vJJH+OKa89vOXr0qMrLy1M9imn//Oc/tXTpUrW3t8f/TQKmlpWVpZ/97GeqqqpK9Si3jCvpGTI6Oqru7m6FQqGE/aFQSJ2dnSmaCnPF9WeqZ2VlpXgS28bGxvTmm2/qypUrKi4uTvU4MyKpp+BhomQe4wpMh+M4qqur0wMPPKDCwsJUj2PShx9+qOLiYv3nP//RbbfdpqNHj+ob3/hGqseaEUR6hrl9jCtwMzt27NCf/vQnnT59OtWjmHXPPffozJkz+te//qW3335bmzdvVnt7+5wINZGeIck8xhW4mSeeeELvvPOOTp06peXLl6d6HLMyMjJ01113SZKKior0wQcf6IUXXtDBgwdTPNmt4570DEnmMa7AVBzH0Y4dO9Ta2qo//OEPys/PT/VIs4rjOAm/J3U240p6BtXV1WnTpk0qKipScXGxDh06lPAYV1zz2Wef6W9/+1v864sXL+rMmTPKysrSihUrUjiZHdu3b9frr7+u3/3ud1q0aFH8b2h+v18LFy5M8XS2PP300yorK1Nubq6Gh4f15ptv6uTJkzp27FiqR5sZN/7Vs3Dr17/+tZOXl+dkZGQ4q1atctrb21M9kjl//OMfJ/2FnJs3b071aGZMdn4kOa+88kqqRzPnsccei/8/d/vttzsPPfSQc+LEiVSPNWP4nDQAGMY9aQAwjEgDgGFEGgAMI9IAYBiRBgDDiDQAGEakAcAwIg0AhhFpADCMSAOAYUQaAAz7/2/a0k6oDGraAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 400x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# the 2nd neuron (index start from 0) in the 1st layer\n",
"model.attribute(1,2)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2a297860",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([1.9413e-05, 1.3491e-04, 5.7833e-05, 3.2742e-05],\n",
" grad_fn=<SelectBackward0>)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEVCAYAAAAckrn/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAphElEQVR4nO3df1DTd54/8GeWhKBWIoJNzJRqRMfCemdLuOFgF3XaM/you3DnnNl2NtPpOQ5Mt2Lg7gR/ja6d44fbsZ4D4trh2nV3TpgWqcwVHeJ4ZXVJPeUitcp5dy0urJDhQDehasMP3/eHX/K9mADvD+uKmudj5jMOb17vH3lPm2c+yedDVEIIASIiIgnfmekFEBHRk4OhQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETS1NPpdOjQIfzsZz9DX18fvvvd7+LAgQPIyMiYsL61tRXFxcW4cuUKjEYjtm7dioKCgoCahoYG7Nq1C1999RUSEhLwD//wD/jLv/xLRfMeP34cP//5z9He3o7BwUG4XC68+OKLIdckhEBOTg5OnTqFxsZG5OXlST32e/fuobe3F3PnzoVKpZLqQ0T0OBNCYGhoCEajEd/5zhTnEkKhuro6odFoxPvvvy+uXr0qtmzZIubMmSN++9vfhqz/+uuvxezZs8WWLVvE1atXxfvvvy80Go34+OOP/TVtbW0iIiJClJWVic7OTlFWVibUarX4/PPPFc179OhR8dOf/lS8//77AoBwuVwTPo79+/eL7OxsAUA0NjZKP/6enh4BgAcPHjyeuqOnp2fK50CVEMr+YGFqaiqSk5NRU1Pjb0tMTEReXh7Ky8uD6ktKStDU1ITOzk5/W0FBATo6OuB0OgEAVqsVXq8XJ0+e9NdkZWUhJiYGx44dUzzv9evXYTKZJjzT6OjowLp163DhwgUsXLhQ0ZmGx+PBvHnz0NPTg+joaKk+RESPM6/Xi/j4ePz+97+HTqebtFbR21PDw8Nob29HaWlpQLvFYkFbW1vIPk6nExaLJaAtMzMTtbW1GBkZgUajgdPpRFFRUVDNgQMHpj3vRO7cuYPXXnsNVVVVMBgMU9b7fD74fD7/z0NDQwCA6OhohgYRPVVk3nJX9EH4wMAAxsbGoNfrA9r1ej3cbnfIPm63O2T96OgoBgYGJq0ZH3M6806kqKgI6enpyM3NlaovLy+HTqfzH/Hx8YrmIyJ6mkzr6qkH00gIMWlChap/sF1mTKXzPqipqQlnzpzxn8HI2LZtGzwej//o6emR7ktE9LRRFBpxcXGIiIgIenXf398fdBYwzmAwhKxXq9WIjY2dtGZ8zOnMG8qZM2fw1VdfYd68eVCr1VCr7787t379eqxZsyZkH61W638rim9JEVG4UxQakZGRMJvNcDgcAe0OhwPp6ekh+6SlpQXVt7S0ICUlBRqNZtKa8TGnM28opaWl+OKLL3Dp0iX/AQDvvfcePvjgA+lxiIjClvS1pv/P+KWvtbW14urVq8Jut4s5c+aI69evCyGEKC0tFTabzV8/fsltUVGRuHr1qqitrQ265PY3v/mNiIiIEBUVFaKzs1NUVFRMeMntRPMKIcTg4KBwuVzi008/FQBEXV2dcLlcoq+vb8LHAyi75Nbj8QgAwuPxSPchInqcKXleUxwaQghRXV0tFi1aJCIjI0VycrJobW31/+6NN94Qq1evDqj/7LPPxEsvvSQiIyPF4sWLRU1NTdCYH330kVi+fLnQaDTihRdeEA0NDYrmFUKIDz74IOS1x7t3757wsTA0iCjcKXleU3yfRrjzer3Q6XTweDz8fIOIngpKntf4t6eIiEjatP72FNEf2+LST2d6CY/M9YpXZ3oJRNJ4pkFERNIYGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJG1aoXHo0CGYTCZERUXBbDbj7Nmzk9a3trbCbDYjKioKS5YsweHDh4NqGhoakJSUBK1Wi6SkJDQ2Niqe9/jx48jMzERcXBxUKhUuXboU8PubN29i8+bNWL58OWbPno3nn38ehYWF8Hg8yjeBiCgMKQ6N+vp62O127NixAy6XCxkZGcjOzkZ3d3fI+q6uLuTk5CAjIwMulwvbt29HYWEhGhoa/DVOpxNWqxU2mw0dHR2w2WzYsGEDzp8/r2je27dv43vf+x4qKipCrqW3txe9vb149913cfnyZXz44Yc4deoUNm7cqHQbiIjCkkoIIZR0SE1NRXJyMmpqavxtiYmJyMvLQ3l5eVB9SUkJmpqa0NnZ6W8rKChAR0cHnE4nAMBqtcLr9eLkyZP+mqysLMTExODYsWOK571+/TpMJhNcLhdefPHFSR/PRx99hB//+Me4ffs21Gr1lI/f6/VCp9PB4/EgOjp6ynqansWln870Eh6Z6xWvzvQSKMwpeV5TdKYxPDyM9vZ2WCyWgHaLxYK2traQfZxOZ1B9ZmYmLl68iJGRkUlrxseczryyxjdposDw+Xzwer0BBxFRuFIUGgMDAxgbG4Nerw9o1+v1cLvdIfu43e6Q9aOjoxgYGJi0ZnzM6cwrY3BwEO+88w7y8/MnrCkvL4dOp/Mf8fHx056PiOhJN60PwlUqVcDPQoigtqnqH2yXGVPpvJPxer149dVXkZSUhN27d09Yt23bNng8Hv/R09MzrfmIiJ4GU7+J/3/ExcUhIiIi6NV9f39/0FnAOIPBELJerVYjNjZ20prxMacz72SGhoaQlZWFZ555Bo2NjdBoNBPWarVaaLVaxXMQET2NFJ1pREZGwmw2w+FwBLQ7HA6kp6eH7JOWlhZU39LSgpSUFP+T9UQ142NOZ96JeL1eWCwWREZGoqmpCVFRUYr6ExGFM0VnGgBQXFwMm82GlJQUpKWl4ciRI+ju7kZBQQGA+2/n3LhxA0ePHgVw/0qpqqoqFBcXY9OmTXA6naitrfVfFQUAW7ZswapVq1BZWYnc3FycOHECp0+fxrlz56TnBe7fh9Hd3Y3e3l4AwLVr1wDcP5MxGAwYGhqCxWLBnTt38Ktf/Srgg+0FCxYgIiJC6XYQEYUVxaFhtVoxODiIvXv3oq+vDytWrEBzczMWLVoEAOjr6wu4d8JkMqG5uRlFRUWorq6G0WjEwYMHsX79en9Neno66urqsHPnTuzatQsJCQmor69Hamqq9LwA0NTUhDfffNP/849+9CMAwO7du7Fnzx60t7f77/1YunRpwOPq6urC4sWLlW4HEVFYUXyfRrjjfRqPBu/TIHp0/mj3aRARUXhjaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSZtWaBw6dAgmkwlRUVEwm804e/bspPWtra0wm82IiorCkiVLcPjw4aCahoYGJCUlQavVIikpCY2NjYrnPX78ODIzMxEXFweVSoVLly4FjeHz+bB582bExcVhzpw5+OEPf4jf/e53yjaAiChMKQ6N+vp62O127NixAy6XCxkZGcjOzkZ3d3fI+q6uLuTk5CAjIwMulwvbt29HYWEhGhoa/DVOpxNWqxU2mw0dHR2w2WzYsGEDzp8/r2je27dv43vf+x4qKiomXL/dbkdjYyPq6upw7tw5fPPNN1i3bh3GxsaUbgURUdhRCSGEkg6pqalITk5GTU2Nvy0xMRF5eXkoLy8Pqi8pKUFTUxM6Ozv9bQUFBejo6IDT6QQAWK1WeL1enDx50l+TlZWFmJgYHDt2TPG8169fh8lkgsvlwosvvuhv93g8WLBgAX75y1/CarUCAHp7exEfH4/m5mZkZmZO+fi9Xi90Oh08Hg+io6OnrKfpWVz66Uwv4ZG5XvHqTC+BwpyS5zVFZxrDw8Nob2+HxWIJaLdYLGhrawvZx+l0BtVnZmbi4sWLGBkZmbRmfMzpzBtKe3s7RkZGAsYxGo1YsWLFhOP4fD54vd6Ag4goXCkKjYGBAYyNjUGv1we06/V6uN3ukH3cbnfI+tHRUQwMDExaMz7mdOadaC2RkZGIiYmRHqe8vBw6nc5/xMfHS89HRPS0mdYH4SqVKuBnIURQ21T1D7bLjKl0XlmTjbNt2zZ4PB7/0dPT8wfPR0T0pFIUGnFxcYiIiAh6Vd7f3x90FjDOYDCErFer1YiNjZ20ZnzM6cw70VqGh4dx69Yt6XG0Wi2io6MDDiKicKUoNCIjI2E2m+FwOALaHQ4H0tPTQ/ZJS0sLqm9paUFKSgo0Gs2kNeNjTmfeUMxmMzQaTcA4fX19+PLLLxWNQ0QUrtRKOxQXF8NmsyElJQVpaWk4cuQIuru7UVBQAOD+2zk3btzA0aNHAdy/UqqqqgrFxcXYtGkTnE4namtr/VdFAcCWLVuwatUqVFZWIjc3FydOnMDp06dx7tw56XkB4ObNm+ju7kZvby8A4Nq1awDun2EYDAbodDps3LgRf/u3f4vY2FjMnz8ff/d3f4c/+ZM/wV/8xV9MY/uIiMKL4tCwWq0YHBzE3r170dfXhxUrVqC5uRmLFi0CcP+V+/+9d8JkMqG5uRlFRUWorq6G0WjEwYMHsX79en9Neno66urqsHPnTuzatQsJCQmor69Hamqq9LwA0NTUhDfffNP/849+9CMAwO7du7Fnzx4AwHvvvQe1Wo0NGzbg7t27eOWVV/Dhhx8iIiJC6VYQEYUdxfdphDvep/Fo8D4Nokfnj3afBhERhTeGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJm1ZoHDp0CCaTCVFRUTCbzTh79uyk9a2trTCbzYiKisKSJUtw+PDhoJqGhgYkJSVBq9UiKSkJjY2NiucVQmDPnj0wGo2YNWsW1qxZgytXrgTUuN1u2Gw2GAwGzJkzB8nJyfj444+nsQtEROFHcWjU19fDbrdjx44dcLlcyMjIQHZ2Nrq7u0PWd3V1IScnBxkZGXC5XNi+fTsKCwvR0NDgr3E6nbBarbDZbOjo6IDNZsOGDRtw/vx5RfPu27cP+/fvR1VVFS5cuACDwYC1a9diaGjIX2Oz2XDt2jU0NTXh8uXL+Ku/+itYrVa4XC6lW0FEFHZUQgihpENqaiqSk5NRU1Pjb0tMTEReXh7Ky8uD6ktKStDU1ITOzk5/W0FBATo6OuB0OgEAVqsVXq8XJ0+e9NdkZWUhJiYGx44dk5pXCAGj0Qi73Y6SkhIAgM/ng16vR2VlJfLz8wEAzzzzDGpqamCz2fzjxMbGYt++fdi4ceOUj9/r9UKn08Hj8SA6Olpqz0i5xaWfzvQSHpnrFa/O9BIozCl5XlN0pjE8PIz29nZYLJaAdovFgra2tpB9nE5nUH1mZiYuXryIkZGRSWvGx5SZt6urC263O6BGq9Vi9erVAWv7/ve/j/r6ety8eRP37t1DXV0dfD4f1qxZE3L9Pp8PXq834CAiCleKQmNgYABjY2PQ6/UB7Xq9Hm63O2Qft9sdsn50dBQDAwOT1oyPKTPv+L9Tra2+vh6jo6OIjY2FVqtFfn4+GhsbkZCQEHL95eXl0Ol0/iM+Pj5kHRFROJjWB+EqlSrgZyFEUNtU9Q+2y4z5MGp27tyJW7du4fTp07h48SKKi4vx13/917h8+XLItW/btg0ej8d/9PT0TPg4iYiedmolxXFxcYiIiAg6q+jv7w96hT/OYDCErFer1YiNjZ20ZnxMmXkNBgOA+2ccCxcuDFnz1VdfoaqqCl9++SW++93vAgBWrlyJs2fPorq6OuRVXVqtFlqtdpJdISIKH4rONCIjI2E2m+FwOALaHQ4H0tPTQ/ZJS0sLqm9paUFKSgo0Gs2kNeNjysxrMplgMBgCaoaHh9Ha2uqvuXPnzv0H/Z3Ahx0REYF79+5NvQFERGFO0ZkGABQXF8NmsyElJQVpaWk4cuQIuru7UVBQAOD+2zk3btzA0aNHAdy/UqqqqgrFxcXYtGkTnE4namtr/VdFAcCWLVuwatUqVFZWIjc3FydOnMDp06dx7tw56XlVKhXsdjvKysqwbNkyLFu2DGVlZZg9ezZef/11AMALL7yApUuXIj8/H++++y5iY2PxySefwOFw4F/+5V+mv4tERGFCcWhYrVYMDg5i79696Ovrw4oVK9Dc3IxFixYBAPr6+gLunTCZTGhubkZRURGqq6thNBpx8OBBrF+/3l+Tnp6Ouro67Ny5E7t27UJCQgLq6+uRmpoqPS8AbN26FXfv3sVbb72FW7duITU1FS0tLZg7dy4AQKPRoLm5GaWlpfjBD36Ab775BkuXLsUvfvEL5OTkKN89IqIwo/g+jXDH+zQeDd6nQfTo/NHu0yAiovDG0CAiImkMDSIiksbQICIiaQwNIiKSxtAgIiJpDA0iIpLG0CAiImkMDSIiksbQICIiaQwNIiKSxtAgIiJpDA0iIpLG0CAiImkMDSIiksbQICIiaQwNIiKSxtAgIiJpDA0iIpLG0CAiImkMDSIikjat0Dh06BBMJhOioqJgNptx9uzZSetbW1thNpsRFRWFJUuW4PDhw0E1DQ0NSEpKglarRVJSEhobGxXPK4TAnj17YDQaMWvWLKxZswZXrlwJGsfpdOLll1/GnDlzMG/ePKxZswZ3795VuAtEROFHcWjU19fDbrdjx44dcLlcyMjIQHZ2Nrq7u0PWd3V1IScnBxkZGXC5XNi+fTsKCwvR0NDgr3E6nbBarbDZbOjo6IDNZsOGDRtw/vx5RfPu27cP+/fvR1VVFS5cuACDwYC1a9diaGgoYK6srCxYLBb827/9Gy5cuIC3334b3/kOT7qIiKaiEkIIJR1SU1ORnJyMmpoaf1tiYiLy8vJQXl4eVF9SUoKmpiZ0dnb62woKCtDR0QGn0wkAsFqt8Hq9OHnypL8mKysLMTExOHbsmNS8QggYjUbY7XaUlJQAAHw+H/R6PSorK5Gfnw8A+PM//3OsXbsW77zzjpKH7ef1eqHT6eDxeBAdHT2tMWhqi0s/neklPDLXK16d6SVQmFPyvKbo5fXw8DDa29thsVgC2i0WC9ra2kL2cTqdQfWZmZm4ePEiRkZGJq0ZH1Nm3q6uLrjd7oAarVaL1atX+2v6+/tx/vx5PPvss0hPT4der8fq1atx7ty5CR+zz+eD1+sNOIiIwpWi0BgYGMDY2Bj0en1Au16vh9vtDtnH7XaHrB8dHcXAwMCkNeNjysw7/u9kNV9//TUAYM+ePdi0aRNOnTqF5ORkvPLKK/iv//qvkOsvLy+HTqfzH/Hx8SHriIjCwbTeyFepVAE/CyGC2qaqf7BdZsw/tObevXsAgPz8fLz55pt46aWX8N5772H58uX4p3/6p5Br37ZtGzwej//o6emZ8HESET3t1EqK4+LiEBEREXRW0d/fH/QKf5zBYAhZr1arERsbO2nN+Jgy8xoMBgD3zzgWLlwYsma8PSkpKWCcxMTECT/I12q10Gq1IX9HRBRuFJ1pREZGwmw2w+FwBLQ7HA6kp6eH7JOWlhZU39LSgpSUFGg0mklrxseUmddkMsFgMATUDA8Po7W11V+zePFiGI1GXLt2LWCc//zP/8SiRYuk9oCIKJwpOtMAgOLiYthsNqSkpCAtLQ1HjhxBd3c3CgoKANx/O+fGjRs4evQogPtXSlVVVaG4uBibNm2C0+lEbW2t/6ooANiyZQtWrVqFyspK5Obm4sSJEzh9+nTAB9RTzatSqWC321FWVoZly5Zh2bJlKCsrw+zZs/H666/7a/7+7/8eu3fvxsqVK/Hiiy/iF7/4Bf7jP/4DH3/88fR3kWiG8CozetQUh4bVasXg4CD27t2Lvr4+rFixAs3Nzf5X6n19fQFv9ZhMJjQ3N6OoqAjV1dUwGo04ePAg1q9f769JT09HXV0ddu7ciV27diEhIQH19fVITU2VnhcAtm7dirt37+Ktt97CrVu3kJqaipaWFsydO9dfY7fb8e2336KoqAg3b97EypUr4XA4kJCQoHQriIjCjuL7NMId79N4NPgKWg73iR6GP9p9GkREFN4YGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkbVqhcejQIZhMJkRFRcFsNuPs2bOT1re2tsJsNiMqKgpLlizB4cOHg2oaGhqQlJQErVaLpKQkNDY2Kp5XCIE9e/bAaDRi1qxZWLNmDa5cuRJyTUIIZGdnQ6VS4ZNPPpF/8EREYUxxaNTX18Nut2PHjh1wuVzIyMhAdnY2uru7Q9Z3dXUhJycHGRkZcLlc2L59OwoLC9HQ0OCvcTqdsFqtsNls6OjogM1mw4YNG3D+/HlF8+7btw/79+9HVVUVLly4AIPBgLVr12JoaChoXQcOHIBKpVL68ImIwppKCCGUdEhNTUVycjJqamr8bYmJicjLy0N5eXlQfUlJCZqamtDZ2elvKygoQEdHB5xOJwDAarXC6/Xi5MmT/pqsrCzExMTg2LFjUvMKIWA0GmG321FSUgIA8Pl80Ov1qKysRH5+vr9fR0cH1q1bhwsXLmDhwoVobGxEXl6e1OP3er3Q6XTweDyIjo6W6kPKLS79dKaX8Mhcr3h12n25T/QwKHleU3SmMTw8jPb2dlgsloB2i8WCtra2kH2cTmdQfWZmJi5evIiRkZFJa8bHlJm3q6sLbrc7oEar1WL16tUBa7tz5w5ee+01VFVVwWAwTPmYfT4fvF5vwEFEFK4UhcbAwADGxsag1+sD2vV6Pdxud8g+brc7ZP3o6CgGBgYmrRkfU2be8X+nWltRURHS09ORm5sr9ZjLy8uh0+n8R3x8vFQ/IqKn0bQ+CH/wswAhxKSfD4Sqf7BdZsw/tKapqQlnzpzBgQMHJlzrg7Zt2waPx+M/enp6pPsSET1tFIVGXFwcIiIigs4q+vv7g17hjzMYDCHr1Wo1YmNjJ60ZH1Nm3vG3miarOXPmDL766ivMmzcParUaarUaALB+/XqsWbMm5Pq1Wi2io6MDDiKicKUoNCIjI2E2m+FwOALaHQ4H0tPTQ/ZJS0sLqm9paUFKSgo0Gs2kNeNjysxrMplgMBgCaoaHh9Ha2uqvKS0txRdffIFLly75DwB477338MEHHyjZCiKisKRW2qG4uBg2mw0pKSlIS0vDkSNH0N3djYKCAgD33865ceMGjh49CuD+lVJVVVUoLi7Gpk2b4HQ6UVtb678qCgC2bNmCVatWobKyErm5uThx4gROnz6Nc+fOSc+rUqlgt9tRVlaGZcuWYdmyZSgrK8Ps2bPx+uuvA7h/NhLqw+/nn38eJpNJ6VYQEYUdxaFhtVoxODiIvXv3oq+vDytWrEBzczMWLVoEAOjr6wu4d8JkMqG5uRlFRUWorq6G0WjEwYMHsX79en9Neno66urqsHPnTuzatQsJCQmor69Hamqq9LwAsHXrVty9exdvvfUWbt26hdTUVLS0tGDu3LnT2hwiIgqk+D6NcMf7NB4N3n8gh/tED8Mf7T4NIiIKbwwNIiKSxtAgIiJpDA0iIpLG0CAiImmKL7klInrS8Cqzh4dnGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETSGBpERCSNoUFERNIYGkREJI2hQURE0hgaREQkjaFBRETSphUahw4dgslkQlRUFMxmM86ePTtpfWtrK8xmM6KiorBkyRIcPnw4qKahoQFJSUnQarVISkpCY2Oj4nmFENizZw+MRiNmzZqFNWvW4MqVK/7f37x5E5s3b8by5csxe/ZsPP/88ygsLITH45nONhARhR3FoVFfXw+73Y4dO3bA5XIhIyMD2dnZ6O7uDlnf1dWFnJwcZGRkwOVyYfv27SgsLERDQ4O/xul0wmq1wmazoaOjAzabDRs2bMD58+cVzbtv3z7s378fVVVVuHDhAgwGA9auXYuhoSEAQG9vL3p7e/Huu+/i8uXL+PDDD3Hq1Cls3LhR6TYQEYUllRBCKOmQmpqK5ORk1NTU+NsSExORl5eH8vLyoPqSkhI0NTWhs7PT31ZQUICOjg44nU4AgNVqhdfrxcmTJ/01WVlZiImJwbFjx6TmFULAaDTCbrejpKQEAODz+aDX61FZWYn8/PyQj+ejjz7Cj3/8Y9y+fRtq9dTffuv1eqHT6eDxeBAdHT1lPU0Pv55TDvdJDvdpckqe1xSdaQwPD6O9vR0WiyWg3WKxoK2tLWQfp9MZVJ+ZmYmLFy9iZGRk0prxMWXm7erqgtvtDqjRarVYvXr1hGsD4N+kiQLD5/PB6/UGHERE4UpRaAwMDGBsbAx6vT6gXa/Xw+12h+zjdrtD1o+OjmJgYGDSmvExZeYd/1fJ2gYHB/HOO+9MeBYCAOXl5dDpdP4jPj5+wloioqfd1O/HhKBSqQJ+FkIEtU1V/2C7zJgPqwa4fzr26quvIikpCbt3755w7du2bUNxcXFAv+kGRzidIgN/2NsJRPR4UhQacXFxiIiICHrl3t/fH/QKf5zBYAhZr1arERsbO2nN+Jgy8xoMBgD3zzgWLlw46dqGhoaQlZWFZ555Bo2NjdBoNBM+Zq1WC61WO+HviYjCiaK3pyIjI2E2m+FwOALaHQ4H0tPTQ/ZJS0sLqm9paUFKSor/yXqimvExZeY1mUwwGAwBNcPDw2htbQ1Ym9frhcViQWRkJJqamhAVFaVkC4iIwprit6eKi4ths9mQkpKCtLQ0HDlyBN3d3SgoKABw/+2cGzdu4OjRowDuXylVVVWF4uJibNq0CU6nE7W1tf6rogBgy5YtWLVqFSorK5Gbm4sTJ07g9OnTOHfunPS8KpUKdrsdZWVlWLZsGZYtW4aysjLMnj0br7/+OoD7ZxgWiwV37tzBr371q4APthcsWICIiIhpbiMRUXhQHBpWqxWDg4PYu3cv+vr6sGLFCjQ3N2PRokUAgL6+voB7J0wmE5qbm1FUVITq6moYjUYcPHgQ69ev99ekp6ejrq4OO3fuxK5du5CQkID6+nqkpqZKzwsAW7duxd27d/HWW2/h1q1bSE1NRUtLC+bOnQsAaG9v99/7sXTp0oDH1dXVhcWLFyvdDiKisKL4Po1w94fcp8EPwuWF015xn+Rwn+Q8VvdpEBFReGNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJm1ZoHDp0CCaTCVFRUTCbzTh79uyk9a2trTCbzYiKisKSJUtw+PDhoJqGhgYkJSVBq9UiKSkJjY2NiucVQmDPnj0wGo2YNWsW1qxZgytXrgTU+Hw+bN68GXFxcZgzZw5++MMf4ne/+900doGIKPwoDo36+nrY7Xbs2LEDLpcLGRkZyM7ORnd3d8j6rq4u5OTkICMjAy6XC9u3b0dhYSEaGhr8NU6nE1arFTabDR0dHbDZbNiwYQPOnz+vaN59+/Zh//79qKqqwoULF2AwGLB27VoMDQ35a+x2OxobG1FXV4dz587hm2++wbp16zA2NqZ0K4iIwo5KCCGUdEhNTUVycjJqamr8bYmJicjLy0N5eXlQfUlJCZqamtDZ2elvKygoQEdHB5xOJwDAarXC6/Xi5MmT/pqsrCzExMTg2LFjUvMKIWA0GmG321FSUgLg/lmFXq9HZWUl8vPz4fF4sGDBAvzyl7+E1WoFAPT29iI+Ph7Nzc3IzMyc8vF7vV7odDp4PB5ER0cr2TosLv1UUf2T7nrFq9PuG057xX2Sw32SM519UvK8plYy8PDwMNrb21FaWhrQbrFY0NbWFrKP0+mExWIJaMvMzERtbS1GRkag0WjgdDpRVFQUVHPgwAHpebu6uuB2uwPm0mq1WL16Ndra2pCfn4/29naMjIwE1BiNRqxYsQJtbW0hQ8Pn88Hn8/l/9ng8AO5vslL3fHcU93mSTWePxoXTXnGf5HCf5Exnn8b7yJxDKAqNgYEBjI2NQa/XB7Tr9Xq43e6Qfdxud8j60dFRDAwMYOHChRPWjI8pM+/4v6Fqfvvb3/prIiMjERMTI73+8vJy/PSnPw1qj4+PD1lP/5/uwEyv4MnAfZLDfZLzh+zT0NAQdDrdpDWKQmOcSqUK+FkIEdQ2Vf2D7TJjPqyaB01Ws23bNhQXF/t/vnfvHm7evInY2Ngpx30ceL1exMfHo6enR/HbaeGE+ySH+yTvSdorIQSGhoZgNBqnrFUUGnFxcYiIiAh6Vd7f3x/0Cn+cwWAIWa9WqxEbGztpzfiYMvMaDAYA988mFi5cOGHN8PAwbt26FXC20d/fj/T09JDr12q10Gq1AW3z5s0LWfs4i46Ofuz/w30ccJ/kcJ/kPSl7NdUZxjhFV09FRkbCbDbD4XAEtDscjgmfdNPS0oLqW1pakJKSAo1GM2nN+Jgy85pMJhgMhoCa4eFhtLa2+mvMZjM0Gk1ATV9fH7788ssJ109ERP+HUKiurk5oNBpRW1srrl69Kux2u5gzZ464fv26EEKI0tJSYbPZ/PVff/21mD17tigqKhJXr14VtbW1QqPRiI8//thf85vf/EZERESIiooK0dnZKSoqKoRarRaff/659LxCCFFRUSF0Op04fvy4uHz5snjttdfEwoULhdfr9dcUFBSI5557Tpw+fVr8+7//u3j55ZfFypUrxejoqNKteCJ4PB4BQHg8npleymON+ySH+yTvad0rxaEhhBDV1dVi0aJFIjIyUiQnJ4vW1lb/79544w2xevXqgPrPPvtMvPTSSyIyMlIsXrxY1NTUBI350UcfieXLlwuNRiNeeOEF0dDQoGheIYS4d++e2L17tzAYDEKr1YpVq1aJy5cvB9TcvXtXvP3222L+/Pli1qxZYt26daK7u3s62/BE+Pbbb8Xu3bvFt99+O9NLeaxxn+Rwn+Q9rXul+D4NIiIKX/zbU0REJI2hQURE0hgaREQkjaFBRETSGBpERCSNofGUU/rdJ+Hm17/+NX7wgx/AaDRCpVLhk08+meklPZbKy8vxZ3/2Z5g7dy6effZZ5OXl4dq1azO9rMdOTU0N/vRP/9R/F3haWlrAX+9+GjA0nmJKv/skHN2+fRsrV65EVVXVTC/lsdba2oqf/OQn+Pzzz+FwODA6OgqLxYLbt2/P9NIeK8899xwqKipw8eJFXLx4ES+//DJyc3ODvgzuScb7NJ5iSr/7JNypVCo0NjYiLy9vppfy2Puf//kfPPvss2htbcWqVatmejmPtfnz5+NnP/sZNm7cONNLeSh4pvGUGv8Okge/y2Sy7z4hkjX+vTLz58+f4ZU8vsbGxlBXV4fbt28jLS1tppfz0EzrT6PT4286331CJEMIgeLiYnz/+9/HihUrZno5j53Lly8jLS0N3377LZ555hk0NjYiKSlpppf10DA0nnLT+X4Rosm8/fbb+OKLL3Du3LmZXspjafny5bh06RJ+//vfo6GhAW+88QZaW1ufmuBgaDylpvPdJ0RT2bx5M5qamvDrX/8azz333Ewv57EUGRmJpUuXAgBSUlJw4cIF/OM//iN+/vOfz/DKHg5+pvGUms53nxBNRAiBt99+G8ePH8eZM2dgMplmeklPDCEEfD7fTC/joeGZxlOsuLgYNpsNKSkpSEtLw5EjR9Dd3Y2CgoKZXtpj45tvvsF///d/+3/u6urCpUuXMH/+fDz//PMzuLLHy09+8hP88z//M06cOIG5c+f6z2B1Oh1mzZo1w6t7fGzfvh3Z2dmIj4/H0NAQ6urq8Nlnn+HUqVMzvbSHZ+b+Kjs9ClN9B0m4+9d//VcBIOh44403Znppj5VQewRAfPDBBzO9tMfK3/zN3/j/f1uwYIF45ZVXREtLy0wv66HifRpERCSNn2kQEZE0hgYREUljaBARkTSGBhERSWNoEBGRNIYGERFJY2gQEZE0hgYREUljaBARkTSGBhERSWNoEBGRtP8FKQOerj58EZwAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 400x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# the 3nd neuron (index start from 0) in the 1st layer\n",
"# note the y axis scale is really small\n",
"model.attribute(1,3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89d836df",
"metadata": {},
"outputs": [],
"source": [
"model.plot(in_vars=input_vars)"
]
},
{
"cell_type": "markdown",
"id": "6182005a",
"metadata": {},
"source": [
"prune inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cac3ea5f",
"metadata": {},
"outputs": [],
"source": [
"model = model.prune_input()\n",
"model.plot(in_vars=input_vars)"
]
},
{
"cell_type": "markdown",
"id": "9e7eaa42",
"metadata": {},
"source": [
"Let's consider a high-dimensional case. In the case of many inputs but only few are important, the users may want to prune input otherwise too many inputs make interpretable hard."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a5b6ccf",
"metadata": {},
"outputs": [],
"source": [
"from kan import *\n",
"\n",
"# let's construct a dataset\n",
"n_var = 100\n",
"\n",
"def f(x):\n",
" y = 0\n",
" for i in range(n_var):\n",
" # exponential decay\n",
" y += x[:,[i]]**2*0.5**i\n",
" return y\n",
" \n",
"dataset = create_dataset(f, n_var=n_var)\n",
"\n",
"input_vars = [r'$x_{'+str(i)+'}$' for i in range(n_var)]\n",
"\n",
"model = KAN(width=[n_var,10,10,1], seed=2)\n",
"model.fit(dataset, steps=50, lamb=1e-3, reg_metric='edge_forward_n');"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd91e538",
"metadata": {},
"outputs": [],
"source": [
"model.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eefc4650",
"metadata": {},
"outputs": [],
"source": [
"model = model.rewind('0.1')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e42f8d6",
"metadata": {},
"outputs": [],
"source": [
"plt.scatter(np.arange(n_var)+1, model.feature_score.detach().numpy())\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.xlabel('rank of input features', fontsize=15)\n",
"plt.ylabel('feature attribution score', fontsize=15)"
]
},
{
"cell_type": "markdown",
"id": "7bf0deb1",
"metadata": {},
"source": [
"Since there are 100D inputs, it's very time consuming to plot the whole diagram and hard to read anything meaningful out of the diagram. So we want to prune the network first (including pruning hidden nodes and pruning inputs) and then plot it."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9e0b3dad",
"metadata": {},
"outputs": [],
"source": [
"model = model.prune()\n",
"model = model.prune_input(threshold=3e-2)\n",
"model.plot(in_vars=input_vars)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd20b031",
"metadata": {},
"outputs": [],
"source": [
"model.fit(dataset, steps=50);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96bf1149",
"metadata": {},
"outputs": [],
"source": [
"model.plot(in_vars=input_vars)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "293b2a06",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "7000447f",
"metadata": {},
"outputs": [],
"source": [
"model.input_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d3648a5",
"metadata": {},
"outputs": [],
"source": [
"input_vars"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a81b0147",
"metadata": {},
"outputs": [],
"source": [
"model.input_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "77a3ae3b",
"metadata": {},
"outputs": [],
"source": [
"model.cache_data.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d883067",
"metadata": {},
"outputs": [],
"source": [
"# manual prune inputs\n",
"model = model.prune_input(active_inputs=[0,1,2,3,4])\n",
"model.plot(in_vars=input_vars)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3452ca73",
"metadata": {},
"outputs": [],
"source": [
"# prune nodes\n",
"model = model.prune_node()\n",
"model.plot(in_vars=input_vars)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42003070",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}