2024-08-11 13:02:16 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "c982abca",
"metadata": {},
"source": [
2024-08-11 13:11:03 -04:00
"# Interpretability 3: KAN Compiler"
2024-08-11 13:02:16 -04:00
]
},
{
"cell_type": "markdown",
"id": "6b9ec6c4",
"metadata": {},
"source": [
"We have shown in many examples how to extract symbolic formulas from KANs. Now we want to consider the reverse task: compiling a symbolic formula into KANs. This might be needed for many reasons. One use case is that we have prior knowledge which is the approximate ground truth (empirical/constitutive laws etc.) and we want to build this knowledge into neural networks and only fine tune the network to real data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d8f94f0f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-11 18:21:12 -04:00
"cuda\n"
2024-08-11 13:02:16 -04:00
]
},
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAFICAYAAACcDrP3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoeklEQVR4nO3deXSV9Z3H8c9zSSAJCSZgQC2iJEZkVUHQsogWJWqqY6VTqaftcMY5UheY46zqdKqjqKNHR2CsnhnstGJt40xjqwgFK0VRUYPIoggIpKhhX5KQ5Wa9v/njy81NlCWQ5+Yueb/OySHem3vzI+bh83x/q+eccwIAwEeBWDcAAJB8CBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA71Ji3QAgETjndPDgQdXU1CgzM1P9+vWT53mxbhYQt6hcgOOorKzUvHnzVFBQoNzcXA0ePFi5ubkqKCjQvHnzVFlZGesmAnHJ4yRK4OiWLVumadOmqa6uTpJVL2HhqiUjI0MlJSUqLCyMSRuBeEW4AEexbNkyFRUVyTmnUCh0zK8LBALyPE+LFy8mYIA2CBfgKyorKzVw4EAFg8HjBktYIBBQenq6ysvLlZ2dHf0GAgmAMRfgK55//nnV1dV1KFgkKRQKqa6uTgsXLoxyy4DEQeUCtOGcU0FBgcrKynQyl4bnecrLy9PWrVuZRQaIcAHaOXDggHJzczv1+n79+vnYIiAx0S0GtFFTU9Op11dXV/vUEiCxES5AG5mZmZ16fVZWlk8tARIb4QK00a9fP+Xn55/0uInnecrPz1ffvn2j1DIgsRAuQBue52nWrFmn9NrZs2czmA8cwYA+8BWscwE6j8oF+Irs7GyVlJTI8zwFAse/RMIr9F9++WWCBWiDcAGOorCwUIsXL1Z6ero8z/tad1f4sfT0dC1ZskRTp06NUUuB+ES4AMdQWFio8vJyzZ07V3l5ee2ey8vL09y5c7Vz506CBTgKxlyADnDOacWKFZoyZYqWL1+uK6+8ksF74DioXIAO8DyvdUwlOzubYAFOgHABAPiOcAEA+I5wAQD4jnABAPiOcAEA+I5wAQD4jnABAPiOcAEA+I5wAQD4jnABAPiOcAEA+I5wAQD4jnABAPiOcAEA+I5wAQD4jnABAPiOcAFOoKmpSTt37tSmTZskSdu3b9ehQ4cUCoVi3DIgfnHMMXAMlZWVKikp0YsvvqiNGzequrpajY2NSktLU25uriZNmqRbb71VEyZMUEpKSqybC8QVwgU4ivfee0933323NmzYoLFjx6qoqEijRo1SZmamKisrtWbNGi1atEjbtm3TzTffrDlz5ig3NzfWzQbiBuECfMXrr7+uGTNmKDMzU48++qiuu+46NTY2qri4WA0NDerTp4+mT5+upqYmFRcX64EHHtDw4cP1wgsvaMCAAbFuPhAXCBegjc8++0zXXHONevfureLiYg0bNkye56msrEyjR49WVVWVBg8erDVr1ignJ0fOOb3zzju65ZZbdMUVV+i5555Tr169Yv3XAGKOAX3giJaWFj3yyCOqqKjQ008/3Rosx+N5niZOnKjHH39cr7zyipYuXdpFrQXiG+ECHLFt2zYtWrRIN910kyZOnHjCYAnzPE833nijLrvsMi1YsEDNzc1RbikQ/5jiAhyxatUq1dTUaNq0adqxY4dqa2tbnysvL1dLS4skqbGxURs3blSfPn1anz/rrLN000036YEHHtCePXs0cODALm8/EE8IF+CIzZs3KyMjQ3l5eZo5c6befffd1uecc2poaJAk7dq1S1dffXXrc57n6cknn9TIkSNVV1enXbt2ES7o9ggX4IhgMKiUlBT16tVLDQ0Nqq+vP+rXOee+9lxzc7PS09PbhRDQnREuwBH9+/dXMBhUZWWlLr30UvXu3bv1uWAwqFWrVrWGyPjx41sXTnqep0GDBmnfvn0KBALKycmJ1V8BiBuEC3DEmDFj1NTUpNLSUj322GPtnisrK9PYsWNVVVWlAQMG6KWXXlJ2dnbr857n6b777tMZZ5xBlxggZosBrcaNG6e8vDw9//zzqq2tVY8ePdp9hHmep0Ag0Pp4IBDQ7t279dvf/lZFRUU67bTTYvi3AOID4QIc0a9fP91111366KOPNH/+/A5PKW5oaNBDDz2kYDComTNndngKM5DM6BYD2pgxY4ZWrlypxx57TBkZGbr99tuVlpYmSUpJSVFKSkprFeOcU3V1tR5++GEVFxfrqaee0pAhQ2LZfCBusP0L8BX79+/XnXfeqddee02FhYW6++67NXToUG3ZskWhUEg9e/bUeeedp9LSUj3xxBNat26dHnzwQd1+++3tus+A7oxwAY6itrZWCxYs0Pz587V3717l5eWpoKBAWVlZqqio0JYtW7Rr1y6NGTNG999/vyZPnqxAgF5mIIxwAY5jz549Wr58ud566y2VrV+v+tJS5UyapBETJmjq1Km69NJLlZGREetmAnGHcAE6qGX1arlx4xRYvVqBSy6JdXOAuMaAPtBBPXr0kDxPovsLOCGuEgCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvCBcAgO8IFwCA7wgXAIDvOM8F6CjnpFDIttz3vFi3BohrVC7AyeAsF6BDOCwMSaG5qUkVX3whFwrFuimd5nmesgcNUmrPnrFuCnDKCBckhcrycpXecYeyx4yJdVNOTTgUAwFVrVmjsc88o9z8/Ni2CegEwgVJwTmnPqNGafzDD8e6KafmhRekxYul6dP1TlOTje8ACYxwQdLxEm2wvalJ+s1vpNdfl0tPl3JzY90ioNMYnQRi7fPPpdWrbbLADTcwaQBJgd9iIJack954Q6qokM46S5owIdYtAnxBuACx1NQk/e539vkVV9AlhqRBuACxtG2bVFoq9eghTZvG4kwkDcIFiBXnpNdekw4fls45hy4xJBXCBYiVujqppMQ+v/ZaqW/f2LYH8BHhAsSCc9KaNdLHH0u9eknf+16sWwT4inABYsE5W9vS0CBddJE0ejTjLUgqhAsQC7t323iLJN18s5SeHtv2AD4jXICu5pz06qsWMAMGSDfeSNWCpEO4AF2trk5auNBC5rrrpLPPjnWLAN8RLkBXck5auVJau1ZKS5NmzKBqQVIiXICu1NwsLVhgK/PHj5cuuYRwQVIiXICu4pxVLG+8YSvyb7vNpiEDSYhwAbpKKCQ984xUWyuNHCkVFlK1IGkRLkBXcM4WTL76qm2p/+MfS1lZsW4VEDWEC9AVWlqk+fOlqipp6FA2qUTSI1yAaHNO+ugj6eWXrWq56y4pJyfWrQKiinABoq2pSXriCam6Who1SvrLv6RqQdIjXIBock5avlxavFhKSZH+8R+l7OxYtwqIOsIFiKbDh6U5c6T6emnyZOmGG6ha0C0QLkC0OCf9/Od20mRmpvSTn7BBJboNwgWIBuekTz+VnnzS1rf88Ie2Ip+qBd0E4QJEQzAo/cu/SHv2SOedJ91zj63KB7oJwgXwWyhk3WFLltj2Lg89JH3jG1Qt6FYIF8BPztkYy0MP2cLJW27hvBZ0S4QL4BfnrBts1izp4EHbP+zBB6XU1Fi3DOhyhAvgl7o66e/+zlbj5+RI8+ZJZ55J1YJuiXAB/NDYKD38sPTb30o9e1q32KRJBAu6LcIF6KzmZunpp6WnnrL/vv126dZbbR8xoJvitx/ojJYW6bnnpJ/+1PYQ++53pX/7N6tegG6McAFOVXOz9Oyztl9YMChde61VMJmZsW4ZEHMpsW4AkHCckxoabKfjRx+1fcOuucYqmL59GWcBRLgAJ8c56dAh6d57peeft26x73zHKpjTTydYgCPoFgM6KnxU8bRp0v/8jz32N39jq/EJFqAdKhfgRJyzrq8XX5Tuv98WSmZl2d5hs2bZFi8EC9AO4QIcTygkbdggPfCA9Ic/2CD++edL//Ef0tSpbEYJHAPhAhyNc9KXX0o/+5l1gR06ZBXK9Om
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 6 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from kan.compiler import kanpiler\n",
"from sympy import *\n",
"from kan.utils import create_dataset\n",
"import torch\n",
"\n",
2024-08-11 18:21:12 -04:00
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)\n",
"\n",
2024-08-11 13:02:16 -04:00
"input_variables = x,y = symbols('x y')\n",
"expr = exp(sin(pi*x)+y**2)\n",
"\n",
2024-08-11 18:21:12 -04:00
"model = kanpiler(input_variables, expr).to(device)\n",
2024-08-11 13:02:16 -04:00
"\n",
"f = lambda x: torch.exp(torch.sin(torch.pi*x[:,0]) + x[:,1]**2)\n",
2024-08-11 18:21:12 -04:00
"dataset = create_dataset(f, n_var=2, device=device)\n",
2024-08-11 13:02:16 -04:00
"model.get_act(dataset)\n",
"\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "535c253f",
"metadata": {},
"source": [
"if you want more complicated formulas, you can load in an equation in the Feynman dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e9cf1b61",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAGfCAYAAACTNZ9gAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABfaklEQVR4nO3deXhMZ/sH8O/JvspGEHuCqn2LXQlCVdsXtS+1tFot2ur7dkEtRVu/Vluham1RWqWoltqXJNpYklBrJILYkoissmeW+/fHY0aGIJiZc2bm/lxXLpGZydwzOWe+51nOcyQiIjDGGGNGZCd3AYwxxqwPhwtjjDGj43BhjDFmdBwujDHGjI7DhTHGmNFxuDDGGDM6DhfGGGNGx+HCGGPM6DhcGGOMGR2HC2OMMaPjcGGMMWZ0HC6MMcaMjsOFMcaY0XG4MMYYMzoOF8YYY0bH4cLYPVQqFRo1agRJktC4cWOo1epyPa5FixaQJAnVqlUzcYWMKR+HC2P3mD9/Ps6dOwcACAsLg4ODQ7ke17p1awBAcnIyLl26ZLL6GLMEHC6MlZKamoq5c+cCAHr37o3u3buX+7GNGjXSfx8bG2v02hizJBwujJXy+eefo6CgAAAwY8aMx3ps9erV9d8nJiYatS7GLA2HC2N3ZGRkYPny5QCATp06oV27do/1+IoVK+q/T0lJMWptjFkaDhfG7li9ejWKi4sBAGPHjn3sx9vZ3d2dSkpKjFYXY5aIw4WxOzZu3AhAhES/fv3uuz06OhpDhgzB8OHDodVq77s9Pz9f/72bm5vpCmXMApRvGgxjVu727dv6QfiWLVvC29v7vvvs2rULGzZsQFBQkEErRefatWv672vWrGmyWhmzBNxyYQzA2bNnodFoAIjzVcoSFRUFAKhfv36Zt584cUL/fatWrYxcIWOWhcOFMRjO7goKCrrv9qKiIkRGRgIwnBVWWkREBADAxcUFwcHBJqiSMcvB4cIYgJycHP33ZXWJ7dixQz9F2d3d/b7b4+LiEBcXB0CcH+Pi4mKaQhmzEBwujN0jLy/vvp8tWrQIkiQ98PaFCxfqvx8zZozpimPMQnC4MAagRo0a+u913V86mzZtQnh4ODp37gxAzBorLSoqCitXrgQgxmtefPFFE1fLmPJJRERyF8GY3HJyclC5cmUUFxdDkiTMnDkTL7zwAo4cOYKPPvoIGo0Ghw4dQrt27UBEmDFjBl588UVERUVh+vTpyM3NhbOzM44ePYpmzZrJ/XIYkx2HC2N3zJ49GzNnzizztuXLl2PcuHHo3bs3du3add/tHh4e+O233/D888+bukzGLAKHC2OlrFy5EkuWLMH58+chSRKCg4MxdepUhIaGAgDS09PxzjvvYOfOnSgsLET16tXxwgsv4MMPP3zgLDLGbBGHC2OMMaPjAX3GGGNGx+HCGGPM6DhcGGOMGR2HC2OMMaPjcGGMMWZ0HC6MMcaMjsOFMcaY0XG4MMYYMzoOF8YYY0bH4cIYY8zoOFwYY4wZHYcLY4wxo+NwYYwxZnQcLowxxoyOw4UxxpjRcbgwxhgzOg4XxhhjRsfhwhhjzOg4XBhjjBkdhwtjjDGj43BhjDFmdBwujDHGjI7DhTHGmNFxuDDGGDM6DhfGGGNGx+HCGGPM6DhcGGOMGR2HC2OMMaPjcGGMMWZ0HC6MMcaMjsOFMcaY0XG4MMYYMzoOF8YYY0bnIHcBjCmdVqtFdnY2MjIyUFJSAg8PD1SsWBHu7u5yl8aYYnG4MPYARUVFOHDgAH766SdER0cjIyMDarUaLi4uqFGjBnr27IlRo0bh2WefhSRJcpfLmKJIRERyF8GY0ly6dAkffvgh/vrrLwQEBCAkJAQtWrRAhQoVkJGRgZiYGBw8eBAqlQrvv/8+3nnnHbi5ucldNmOKwS0Xxu5x7tw5DB06FNeuXcPUqVMxbtw4+Pv7Q5IkSJIEIgIR4erVq/j222/x+eefIykpCd988w0HDGN3cLgwVkpGRgbeeOMNJCcnY9WqVQgODsb169dRuXJlg66vkpISREdHY/r06WjYsCE+/PBDBAQE4JNPPoGdHc+TYYz3AsZKWbJkCWJiYvDZZ5+hdevWGD16NP7zn/9g7969+hZLcXExZs+ejZEjR+L999/HgAEDMGnSJCxcuBD//vuv3C+BMUXgMRfG7khLS0P79u1Rs2ZNbN26Fe+99x5Wr14NAKhatSpWr16Nzp07Y86cOZg/fz5UKhXs7e0RFhaGgQMHomvXrujSpQu+//57HuBnNo9bLozdER0djWvXrmHEiBGoUKECxo8fj6CgIABASkoKRo8ejbFjx+qDRZIkDBgwAEOHDkWlSpXQv39/7Nu3D9nZ2fK+EMYUgMOFsTtOnDgBJycnBAcHQ5IktGnTBj///DMCAwMBiID59ddf9cEycOBAfP/99/Dx8YEkSejYsSPS0tJw48YNmV8JY/LjcGHsjrS0NLi4uMDX1xcA9AGzatUqeHt7G9z3ueeew5IlS/TBAgD+/v76Ey4Zs3UcLozd4ezsDK1WC5VKBQD6wftdu3YhLy/P4L7x8fGIiYkx+FlJSQmICI6OjmarmTGl4nBhNk+tVuPIkSOIj49HXl4ekpKSDGaFzZ8/H2q1GpIk6c9jSU1NxejRow1mkSUkJECSJNjb28v8ihiTH4cLs0mXLl3CsmXL8Morr6BSpUpo3749Dhw4ACLCzp07oVKpMHfuXIPB+4EDB+L33383GIMZPXo09u/fD7Vaje3btyMvLw/BwcFo0aIFPvzwQ+zbtw9FRUUyv1rGzI/DhdmE7OxsbNmyBW+99RaCgoIQFBSECRMm4ObNm3jvvffwzz//4ObNmwgJCcH69euRmJiIwsJCEJE+WJYsWYLQ0FD88ssv+llkWVlZiI+Px7Fjx7Bv3z7MmzcPa9euRZMmTbB27VqEhobCx8cHPXv2xPz583Hy5Enw7H9mC/g8F2aVVCoVjh49ij179mDv3r04duwYtFot6tWrh549eyI0NBRdu3aFl5eXweP279+Pfv36oUePHli8eDG++uor3Lx5E4sWLdIP3hMRoqOjMXbsWIwaNQqDBg3CkCFDoFKpsHfvXvj4+AAQYzZnzpzB3r17sXfvXkRERKCwsBD+/v4IDQ3VfwUEBMjxFjFmUhwuzCroxjz27t2LPXv2IDw8HLm5ufD19UX37t31gVKrVq2H/h6NRoPPP/8cc+fORf/+/TF37lxUqVIFbm5uBidGEhFSUlKQmZmJd999FydPnsSff/6JDh06PPB3FxUVISoqSl/jiRMnQERo1KgRQkND0bNnTzz33HO8lD+zChwuzGKlp6dj//79+g/ra9euwdHRER07dtR/WLdo0eKxB9iLi4sxb948zJs3D7Vq1cKECRPQu3dvVKlSBY6OjigpKcHVq1exdetWLF26FBqNBmvXrkX37t2fqH5d6+ratWtwcnJChw4d9GHYsmVLXquMWSQOF2YxiouLERUVpf8wPn78OIgIDRs21H8YP/fcc/Dw8Hjq59JoNNi1axfmzp2L48ePw93dHQEBAXB3d0dOTg5SUlKgVqvRv39/zJw5E3Xr1n2q59O1vHSv7eDBg8jLy4Ofnx+6d++u70J7VMuLMaXgcGGKRUQ4e/asvmUSGRmJgoIC+Pv7o0ePHujZsyd69OiBatWqmayGgoICxMTEIDIyEocOHcKePXvw8ssvo3fv3ggJCUHdunVNMvVYpVLhyJEj+tceHR0NrVaL+vXr61tlXbt2RYUKFYz+3IwZA4cLU5TU1FTs27dPPwiekpICZ2dnPPfcc/qj96ZNm8rSVXT8+HG0atUKsbGxaNmypVmfOysrCwcPHtS3bC5dugR7e3u0a9dO32oLDg6GgwNfRYMpA4cLk1VhYaG+RbB3716cOnUKANCsWTP9h2anTp3g6uoqc6Xyhsu9Ll68qA/g/fv3IycnB15eXujWrZu+ZaObLs2YHDhcmFlptVqcPHlS/8F46NAhFBcXIyAgQP+h2L17d1SuXFnuUu+jpHApTa1WIyYmRt+FduTIEajVatSpU0f/nnbr1k0/RZoxc+BwYSZ3/fp
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x400 with 36 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from kan.feynman import get_feynman_dataset\n",
"import matplotlib.pyplot as plt\n",
"\n",
"problem_id = 36 # problem_id in 1-120\n",
"input_variables, expr, f, ranges = get_feynman_dataset(problem_id)\n",
"n_var = len(input_variables)\n",
"model = kanpiler(input_variables, expr)\n",
"\n",
"dataset = create_dataset(f, n_var=n_var, ranges=ranges)\n",
"model.get_act(dataset)\n",
"#model.plot(in_vars=input_variables, out_vars=[expr], beta=10000, title='P{}'.format(problem_id))\n",
"model.plot(in_vars=input_variables, out_vars=[symbols('omega')], beta=10000)\n",
"#plt.savefig('./fig1.pdf', bbox_inches='tight', dpi=200)"
]
},
{
"cell_type": "markdown",
"id": "d1db913e",
"metadata": {},
"source": [
"We can check that the model indeed achieves zero loss (near machine precision) on the data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "910c99a9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.5383e-15, grad_fn=<MeanBackward0>)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.mean((model(dataset['train_input'])-dataset['train_label'])**2)"
]
},
{
"cell_type": "markdown",
"id": "35c347d2",
"metadata": {},
"source": [
"Assume we have a dataset for which the symbolic formula is only an approximate ground truth, we want to train on the real data to fine tune the model. The current model has the symbolic front turned on and the spline front turned off. So only the affine parameters in the symbolic equations are trainable. Depending on how much expressive power you would like, you may need:\n",
"\n",
"* If you want to keep the symbolic functions, but just train the affine parameters, no need to do anything.\n",
"* If you want to the functions to be trainable, call model.perturb(). If you want only the currently active functions to be trainable while the currently dead functions to remain dead, use mode='minimal'. Otherwise if you want to allow the currently dead functions to be active, use mode = 'all' (by default).\n",
"* If you think the ground truth should be more complicated than the current network, you can expand it first using expand_width and/or expand_depth, and then use model.perturb().\n",
"\n",
"In the following, we present the most complicated case where you want to expand the network first."
]
},
{
"cell_type": "markdown",
"id": "63af424e",
"metadata": {},
"source": [
"step 1: expand depth, add an extra linear function in the end"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "381b8a03",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAHiCAYAAAAkiYF/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABZNElEQVR4nO3dd3hUZdoG8HvSG0kgIRAEXBIQARUp0kEiEFAQAUEBC4gF8FujoMKCooioWGmuFEUEC1hQYCkJhA4CEpqhd4EQiGlAejLzfH+8ZEwoqWfmTLl/18W1bibJPHPyzrnnnLcZRERARESkIRe9CyAiIsfDcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHNuehdAZA9EBCkpKcjIyICfnx+CgoJgMBj0LovIZvHKhagE6enpmD59Oho0aIDq1aujXr16qF69Oho0aIDp06cjPT1d7xKJbJKBO1ES3VxMTAweffRRZGVlAVBXL4UKr1p8fHywZMkSdO/eXZcaiWwVw4XoJmJiYtCzZ0+ICEwm0y2/z8XFBQaDAStXrmTAEBXBcCG6Tnp6OmrXro3s7OwSg6WQi4sLvL29cf78eQQGBlq+QCI7wD4XoussWLAAWVlZZQoWADCZTMjKysLChQstXBmR/eCVC1ERIoIGDRrg1KlTKM9bw2AwICwsDMePH+coMiIwXIiKSU5ORvXq1Sv180FBQRpWRGSfeFuMqIiMjIxK/fzVq1c1qoTIvjFciIrw8/Or1M9XqVJFo0qI7BvDhaiIoKAghIeHl7vfxGAwIDw8HNWqVbNQZUT2heFCVITBYMBLL71UoZ+NiopiZz7RNezQJ7oO57kQVR6vXIiuExgYiCVLlsBgMMDFpeS3SOEM/V9//ZXBQlQEw4XoJrp3746VK1fC29sbBoPhhttdhV/z9vbGqlWrEBkZqVOlRLaJ4UJ0C927d8f58+cxbdo0hIWFFXssLCwM06ZNQ0JCAoOF6CbY50JUBiKCDRs2oEuXLli3bh0iIiLYeU9UAl65EJWBwWAw96kEBgYyWIhKwXAhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEqhclkQmpqKs6ePQsASExMRGZmps5VEdk2bnNMdAs5OTlYv349Fi5ciF27diEpKQkZGRkICAhAvXr1EBkZiSFDhqBRo0bcmZLoOgwXops4deoUxowZg5UrV6JWrVqIiIhAs2bN4O/vj5SUFMTFxWHDhg3Iz8/H6NGjERUVBR8fH73LJrIZDBei6xw6dAiDBg3CuXPnMGrUKDz//PPw9/fH3r17UVBQAC8vL9x7771ITEzE1KlTMX/+fAwePBifffYZA4boGoYLUREpKSl45JFHcPToUXz11Vfo1asXXF1dcerUKbRu3Rrp6emoV68edu7cicDAQBQUFODrr7/GmDFj8Oqrr+LNN9+Eiwu7Monc9C6AyJbMmjULcXFxmDFjBh5++OFiQZGfn4+CggIUFBQAAAwGA9zd3fHss8/i3LlzmDFjBnr16oXmzZvrVT6RzeBHLKJrkpKSMH/+fLRt2xZPPPFEma9A3NzcEBUVhZCQEHz55ZfgzQAihguR2a5du3Du3Dk8+eST8PLygtFoLPavkIjc8FhwcDD69euH2NhYpKen6/ciiGwEb4sRXbN37154eHigefPmGDt2LA4cOGB+LDs72zy35dKlSxg4cCDc3P55+4wcORLt27fHzJkzkZCQgKpVq1q9fiJbwnAhuiYpKQleXl4ICAjAzp07sXXr1pt+X3Z2NtatW1fsaz179kS7du1gMpl45UIEhguRmaenJ0wmEwoKCuDi4nJDn4vJZDL/9/WPGQwG5OXlAQDc3d0tXyyRjWO4EF0THh6OzMxMnD9/Hh9++CHS0tLMjyUmJiIqKgqZmZmoUaMGZs6cCT8/P/PjjRo1wqZNm+Dl5YUaNWroUT6RTWG4EF3TunVreHh4IDo6GlOmTCl2dXLq1ClzH4uPjw+6du1arF+loKAAq1atQqNGjRAaGmr12olsDUeLEV3TuHFjtG3bFosXL8bJkyfLPKRYRLBz506sXbsWgwYNgqenp4UrJbJ9DBeiazw9PTF27Fikp6dj7NixuHLlSqkBIyJITEzEmDFj0KBBAwwcONBK1RLZNoYLURGdO3fG66+/jtWrV2PEiBFISEiAiMDV1RU1a9ZEaGgoQkJC4OLiAhHBsWPHMGTIEJw5cwZTp07lEGSia7i2GNF1cnNzMWXKFHz88ceoW7cuRo4cicjISHh6esLV1RVGoxEZGRlYtmwZ5s6dC3d3d8yZMwddunTRu3Qim8FwIboJo9Fo7tjfvXs3vL29ERoaCl9fX1y9ehUXLlyAq6sr+vTpg3HjxqF+/fp6l0xkUxguRCXIyspCXFwctmzZguM7diB7xQoE9e+Ppl27onPnzqhfvz5cXV31LpPI5jBciMpqzx5IixYw7N4NcOVjohKxQ5+oHLidMVHZMFyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSnEFERO8iiGxBypEjJX+DCGAyAS4ugMFw4+MHDwLu7gjq3dsyBRLZEYYLUVkVfatcHy4HDgCPPQZs2ADUqGHduohskJveBRDZjZtdrQD/BEt0NIOF6BqGC1FlFAZLTAxQp47e1RDZDHboE1VU0SsWBgtRMQwXooqIj//niqVuXb2rIbI5DBei8jpwAHj8cV6xEJWA4UJUHrxiISoThgtRWcXHAwMHsvOeqAw4WoyoLOLj1a0wBgtRmfDKhag0vGIhKjeGC1FJCoOFnfdE5cJwIboVBgtRhTFciG7mzz95K4yoEhguRNeLjwcGDVJXLLVr610NkV1iuBAVxSsWIk1wyX2iQmfPAg8+yD4WIg0wXIiuSVqxQi2rX8lbYSFNm2pUEZH9YrgQXZN7+XLp3yRy631drvEMCNCoIiL7xXAhKqs9e4CWLYG4OKB5c72rIbJp7NAnIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIcwwXIiLSHMOFiIg0x3AhIiLNMVyIiEhzDBciItIcw4WIiDTHcCEiIs0xXIiISHMMFyIi0hzDhYiINMdwISIizTFciIhIc256F0Bk60wmE9LT05Fy9izyROCXmIjgzEz4+vrqXRqRzTKIiOhdBJEtysnJwfr167Fw4ULs2rULKUlJKMjIgFdAAOrUq4fIyEgMGTIEjRo1gsFg0LtcIpvCcCG6iVOnTmHMmDFYuXIlatWqhYiICDRr1gz+/v5ISUlBXFwcNmzYgPz8fIwePRpRUVHw8fHRu2wim8HbYkTXOXToEAYNGoRz585h/PjxeP755xESEgKDwQCDwQARgYjg7NmzmDp1Kt5//32cOXMGn332GQOG6BqGC1ERKSkpeOGFF3DhwgXMnz8f9913H86fP48aNWoUu/WVl5eHXbt2YcKECWjcuDHGjBmDWrVq4c0334SLC8fJEPFdQFTErFmzEBcXh/feew8tW7bE0KFD8cgjj2Dt2rXmK5bc3FxMmjQJTz31FEaPHo3+/fvjpZdewowZM7Bv3z69XwKRTWCfC9E1SUlJaNu
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x600 with 38 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.expand_depth()\n",
"model.get_act(dataset)\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "27a934fe",
"metadata": {},
"source": [
"step 2: add two addition nodes in layer 1."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5c5f92c9",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAHiCAYAAAAkiYF/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABcAklEQVR4nO3dd3gUVRcG8HfTCC30rtSAiXREqqLUEBBBRUEEFGNDVJSqIgqKNCkKgoIinQ+lKwKhiNKkK10k1CA9hBASlpQ93x/HRAgJpMzubHl/z8OjhM3umdm7e2ZuOdciIgIiIiIDeZkdABERuR8mFyIiMhyTCxERGY7JhYiIDMfkQkREhmNyISIiwzG5EBGR4ZhciIjIcEwuRERkOCYXIiIyHJMLEREZjsmFiIgMx+RCRESGY3IhIiLDMbkQEZHhmFyIiMhwTC5ERGQ4H7MDIHIVUVFRiI6ORqFChVCkSBGzwyFyarxzIbqLlStXokWLFihatCgqV66MokWLokWLFli1apXZoRE5LYuIiNlBEDmrESNG4P3334e3tzeSk5NTf57y9xEjRuDdd981MUIi58TkQpSBlStXok2bNpl6XOvWrR0QEZHrYLcYUQbGjBkDb2/vOz7G29sbY8eOdVBERK6Ddy5E6YiKikLRokUz/fhLly5xkJ/oJrxzIUpHdHS0XR9P5O6YXIjSUahQIbs+nsjdMbkQpaNIkSJo3rx5psZcWrRowS4xojSYXIgy0K9fv1umH6cnOTkZffv2dVBERK6DyYUoA61bt8bw4cMB4LY7mJS/jxgxgtOQidLB2WJEd7Fq1SqMHTsWa9euTf1ZixYt0LdvXyYWogwwuRBl0i+//ILmzZtj3bp1aNasmdnhEDk1dosRZVLBggVv+S8RZYzJhYiIDMfkQkREhmNyISIiwzG5EBGR4ZhciIjIcEwuRERkOCYXIiIyHJMLEREZjsmFiIgMx+RCRESGY3IhIiLDMbkQEZHhmFyIiMhwTC5ERGQ4JhciIjIckwsRERmOyYWIiAzH5EJERIZjciEiIsMxuRARkeGYXIiIyHA+ZgdA5MxsNhsiIiIQHh6OdevWAQC+/PJLdOnSBY0aNUKePHlMjpDIOVlERMwOgsgZxcXFYfz48ZgxYwaCgoLQoEED5M+fHydOnMDGjRtRokQJjBo1CtWqVTM7VCKnw24xonRYrVYMHDgQixYtwpQpU/DDDz+gYMGCiIyMROfOnbFq1SrUr18f3bt3x759+8wOl8jpsFuMKB0LFizA6tWrsWDBAtSoUQMJCQmYN28efv/9dwQFBaF+/fp477334OPjg4EDB2LBggXImzev2WETOQ3euRClERcXh6+++go9e/ZEjRo1YLFY0n2cr68vXn/9dVy9ehWbNm1ycJREzo13LkRpHD9+HKdPn0bDhg2xcuVK2Gw2JCUlITo6GgCwd+9eLF++HABQpkwZtGrVCuvWrUNISIiZYRM5FSYXojTOnTuHfPny4dy5c+jcuTMSExMB6MwxQGeLTZo0CQDQpUsXhIaG4scff4SIZHiXQ+RpmFyI0vDz80NSUhJ8fX0RGBiIxMREiAgiIyNhtVpRvHhxFChQAABQokQJWK1W+Pn5mRw1kXNhciFKo1y5ckhISEBAQAA2b94MAEhISMBjjz2G7du3Y9CgQejevTsAwMfHB/369bvj2AyRJ+KAPlEaZcqUQaNGjTBz5kzkzp0bBQoUQEBAALy9vQHglp9FRkZi06ZNCA0NNTlqIufC5EKUho+PD/r374/169fju+++Sx1zsVgs8PLygsVigYjg/Pnz6NevH9q3b4/g4GCToyZyLlyhT5QOEcHSpUvRr18/tG3bFmFhYThx4gRiYmIQHByMs2fPYsyYMahSpQrGjRuHgIAAs0MmcipMLkQZEBHs3r0bo0ePxq5du1CwQAH4+/ggJjYWeQMCEBYWhueee471xYjSweRCdBdJSUmIjIzE8fBw3OjZEyXnzkVgu3bInz+/2aEROS3OFiO6Cx8fH1SoUAEV6tUDLBYgKAhgYiG6Iw7oExGR4ZhciIjIcEwuRERkOCYXIiIyHJMLEREZjsmFiIgMx+RCRESGY3IhIiLDMbkQEZHhmFyIiMhwTC5ERGQ4JhciIjIckwsRERmOyYWIiAzH5EJERIZjciEiIsMxuRARkeGYXIiIyHBMLkREZDgmFyIiMhyTCxERGY7JhYiIDMfkQkREhmNyISIiwzG5EBGR4ZhciIjIcEwuRERkOCYXIiIyHJMLEREZjsmFiIgMx+RCRESGY3IhIiLDMbkQEZHhmFyIiMhwTC5ERGQ4JhciIjIckwsRERmOyYWIiAzH5EJERIZjciEiIsNZRETMDoLI2UQfPXr7D0UAmw3w8gIslv9+fuUKCj3wgOOCI3IBTC5E6Ui6fv32H978UUlJLleuAC+/DJ/lyx0SF5Gr8DE7ACJn5JM7990fFB0NvPUWMH++/QMicjEccyHKjuho4JVXgOnTgXz5zI6GyOkwuRBlFRML0V2xW4woK1ISy4wZQN68ZkdD5LR450KUWTffsTCxEN0R71yIMoN3LERZwjsXorvhHQtRlvHOhehOeMdClC28cyHKCBMLUbYxuRCl5/JlJhaiHGD5F6J0JISGZimx+HG9C9EtmFyI0hFz4ED6iUXk1qKV/ypQvrz9gyJyIUwuRJm1ezdQty6wcydQp47Z0RA5NY65EBGR4ZhciIjIcEwuRERkOCYXIiIyHJMLEREZjsmFiIgMx+RCRESGY3IhIiLDMbkQEZHhmFyIiMhwTC5ERGQ4JhciIjIckwsRERmOyYWIiAzH5EJERIZjciEiIsMxuRARkeGYXIiIyHBMLkREZDgmFyIiMhyTCxERGY7JhYiIDMfkQkREhmNyISIiwzG5EBGR4ZhciIjIcEwuRERkOCYXIiIyHJMLEREZjsmFiIgMx+RCRESGY3IhIiLDMbkQEZHhmFyIiMhwTC5ERGQ4JhciIjIckwsRERmOyYWIiAzH5EKUBSJidghELsHH7ACInJnNZkNERATCw8OxZ906WAGU/fJLNOvSBY0aNUKePHnMDpHIKVmEl2JE6YqLi8P48eMxY8YMBAUFoUGDBsifPz9OnDiBjRs3okSJEhg1ahSqVatmdqhETofJhSgdVqsV/fr1w+bNmzFixAg88sgj8Pf3T/33U6dOYcaMGVi2bBlmzpyJ6tWrmxgtkfPhmAtROhYsWIDVq1dj0qRJWLt2LSZNmoTExEQAwJ9//okePXqgbt26eOqppzBw4EDExcWZHDGRc+GYC1EacXFx+Oqrr9CzZ0+cPHkSX375JSwWC2w2G5o2bYoXXngBBw8eRGJiIubNm4eff/4ZmzZtQkhIiNmhEzkN3rkQpXH8+HGcPn0abdu2Rdu2bdGtWzckJibio48+Qrt27XDw4EEEBgZizJgxuOeeexASEoJ169aZHTaRU2FyIUrj3LlzyJs3L4oVK4b8+fNj7NixaN++PaxWK86fP4+iRYti+vTpqFevHiwWC6pUqYIjR45wmjLRTZhciG4iIjhz5gxu3LiBpKQkAMDRo0dx6NCh1MfExsZiy5YtSExMhIggPj4+dTyGiBTHXMjjXblyBWvXrkV4eDhWrVqF06dPw9/fH0eOHMGZM2fQtWtXHDp0CIGBgahatSqWL1+Ojz76CADQu3dvbN++HStXrkSzZs3QunVrhISEoGbNmrBYLCYfGZF5OBWZPE5ycjJ27dqVmky2bduG5ORkBAcHIyQkBC1btsTMmTORN29ePProo+jVqxdKliyJOXPmIDg4GH379sX06dNRv359jB49Gp07d8bzzz+PPXv2YP369YiLi0PJkiXRqlUrtG7dGi1btkTRokXNPmwih2JyIY9w9uxZrF69GuHh4Vi9ejWioqIQEBCAFi1apN5tlC1bNvXxf/zxB9q1a4cBAwagQIECuP/++1G3bl0A2i02YcIENG7cGKNGjUKhQoUwd+5ceHl54caNG9i8eXNq4tq7dy8sFgvq1q2b+jr169eHjw87Dci9MbmQW0pISLjlS37Pnj2wWCx44IEHbvmS9/X1Tff
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x600 with 52 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.expand_width(1, 2)\n",
"model.get_act(dataset)\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "3459fc85",
"metadata": {},
"source": [
"step 3: add two multiplication nodes in layer 2, with arity 2 and 3."
]
},
{
"cell_type": "code",
2024-08-11 18:21:12 -04:00
"execution_count": 6,
2024-08-11 13:02:16 -04:00
"id": "ec1bfb11",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAHiCAYAAAAkiYF/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABn0ElEQVR4nO3dd3hT9fcH8HeSTlrKppRV9hBkj7K3DEGmlL2KigKioLLEiYJfmUUQlCGjhbJV9t5lTxGlBWXTQqHQ0pnk/P44v1DAAh03ubnJeT1PH4WG5Nybm5z7WeejIyKCEEIIoSC92gEIIYRwPJJchBBCKE6SixBCCMVJchFCCKE4SS5CCCEUJ8lFCCGE4iS5CCGEUJwkFyGEEIqT5CKEEEJxklyEEEIoTpKLEEIIxUlyEUIIoThJLkIIIRQnyUUIIYTiJLkIIYRQnCQXIYQQipPkIoQQQnGSXIQQQijORe0AhNCCqKgohIWFISoqCr6+vggMDISvr6/aYQlht3RERGoHIYS9MpvNGDduHKZOnQqz2QyDwQCTyQS9Xo9Ro0bh22+/hV4vHQBCPEuSixAvMGbMGHz33XfP/f3o0aMxefJkG0YkhDZIchHiOaKiolC0aFEYjcbnPsbFxQXXr1+XLjIhniHteSGeIywsDGaz+YWPMZvNCAsLs1FEQmiHJBchniMqKgoGg+GFjzEYDIiKirJRREJohyQXIZ7D19cXJpPphY8xmUzSJSZEOmTMRYjnkDEXIbJOWi5CPIevry9GjRr1wseMGjVKEosQ6ZBFlEK8wLfffgsAmDp1KkwmE4gIOp0OBoPh8ToXIcR/SbeYEBkQFRWFKVOmYMqUKfjoo4/w0UcfSYtFiBeQ5CJEBp08eRI1a9bEiRMnUKNGDbXDEcKuyZiLEEIIxUlyEUIIoThJLkIIIRQnyUUIIYTiJLkIIYRQnCQXIYQQipPkIoQQQnGSXIQQQihOkosQQgjFSXIRQgihOEkuQgghFCfJRQghhOIkuQghhFCcJBchhBCKk+QihBBCcZJchBBCKE6SixBCCMVJchFCCKE4SS5CCCEUJ8lFCCGE4iS5CCGEUJwkFyGEEIqT5CKEEEJxklyEeAEiwtmzZzFmzBgMGzYMrq6u+Pzzz7Fu3TokJiaqHZ4QdktHRKR2EELYI6PRiNmzZ2PBggXo2rUr6tatCzc3N0RHR2PlypUwGAwIDg6Gn5+f2qEKYXek5SLEcyxbtgwhISEIDQ3FhAkTsHr1agwePBilSpVCaGgoatWqheHDhyMhIUHtUIWwO5JchEjHnTt3MH36dEyZMgWVKlUCANy9exfXrl1DcnIyPDw88OGHH8LT0xO//fabytEKYX8kuQiRjvDwcBQoUAD+/v6IjIxEREQE4uPjQUS4fv06IiIicO3aNfTp0wfr1q2D2WxWO2Qh7IqL2gEIYY8iIiJQsWJFjBkzBhs2bAAAJCYmwmQyISgoCAaDAT4+PggNDUVUVBRSU1Ph7u6uctRC2A9JLkKkw83NDcnJyShSpAgqVKgAAIiMjERsbCz8/f3h7e0Nb29v6HQ66PV66HQ6lSMWwr5It5gQ6aiaKxfOnDiB8ePH48CBA9i/fz8aNWoEFxcXzJ49GwcOHMCWLVtw8+ZNlCtXDm5ubmqHLIRdkeQihMWjR8CiRUDDhqjZvz88//wT69auhaurK9zc3KDX88fF8uf4+HjMnz8fvXv3VjlwIeyPdIsJ50YEHD8OLFgALF8OxMcDr70Gr1WrMKlgQQx+9124e3igS5cuKFOmDGrVqgVvb29cu3YNn3zyCQICAlC/fn21j0IIuyOLKIVzun8fCAnhpHLmDFC0KDBoEDBwIODv//hhBw8exNixY+Hh4YF69erBy8sLly9fxokTJxAYGIjhw4fLQL4Q6ZDkIpwHEbB3LyeU1asBkwl44w0gKAh47TXAYEj3n8XFxSE8PBwnTpxAYmIiypQpgyZNmqB48eIykC/Ec0hyEY7v1i1g8WJg4UIgMhIoV44TSr9+gK+v2tEJ4ZBkzEU4JqMR2LoVmD8f2LABcHUF3nyT/9yoESAtDiGsSlouwrH88w/P+Fq0CLhxA6hWDRg8GOjVC8idW+3ohHAa0nIR2pecDKxfz2MpO3YAPj5A797c9VWjhtrRCeGUJLkI7Tp/nsdRliwBYmKAhg25xdKtG+DlpXZ0Qjg16RYT2hIfD6xcya2U8HCgQAGgf3+eRvz/ZVqEEOqTlouwf5aFjvPn80LHR4946vDKlTyVWEqvCGF3pOUi7Ne9e2kLHc+eBYoV4xbKgAFPLXQUQtgfSS7CvpjNaQsd16zhhY4dO/LgfKtWz13oKISwL5JchH2wLHRcsAC4dEkWOgqhcTLmItRjNAJbtnBCsSx07N6dZ4A1bCgLHYXQMGm5CNu7fDltoePNm0D16rzQsWdPWegohIOQlouwDVnoKIRTkeQirOv8eU4oS5bw7K9GjYBffuGFjjlyqB2dEMJKpFtMKM+y0HH+fODwYVnoKIQTkpaLUAYRcOxY2o6Ojx4BrVsDq1YBHTrIQkchnIy0XET23LsHLFvGSeXcubSFjgMHAsWLqx2dEEIlklxE5slCRyHES0hyERl36xYPxi9cyAsdy5dPW+hYsKDa0Qkh7IiMuYgXMxqBzZu5lbJxI4+dvPkmr1Fp0EAWOgoh0iUtF5G+Zxc61qjBrRRZ6CiEyABpuYg0SUlpCx137gRy5eKFjoMGyUJHIUSmSHIRwB9/cEJZulQWOgohFCHdYs4qPh4IC+OkYlnoOGAAt1LKl1c7OiGExknLxZkQAUePckJZsYIXOrZpA6xeDbRvLwsdhRCKkZaLM7h3j7u8FizgLrDixXmRoyx0FEJYiSQXR2U2A3v2cEJZu5b/bFno2LKlLHQUQliVJBdHc/Nm2kLHy5d5/GTwYKBvX1noKISwGRlzcQSWhY7z5wObNvHYSffunGRkoaMQQgXSctGyS5fSFjreugXUrJm20DFXLrWjE0I4MWm5aE1SErBuHY+l7NqVttAxKIi3CxZCCDsgyUUrnl3o2LgxsHgx0LWrLHQUQtgd6RazZ3FxaQsdjxzhAfkBA3gKsSx0FELYMWm52BvLQsf583mhY0KCLHQUQmiOtFzsRUxM2o6OloWOgwZxS0UWOgohNEaSi5rMZmD37rSFjkS80HHwYKBFC1noKITQLEkuanh2oWOFCjzbSxY6CiEchIy52IrRyAscLTs6ursDgYE846t+fVnoKIRwKNJysbZLl7iF8ssvstBRCOE0pOViDektdOzThwfoZaGjEMIJSHJR0rlzaQsd79/nhY5LlvBCR09PtaMTQgibkW6x7IqL4/UoCxbw+hTLQsdBg4By5dSOTgghVCEtl6wg4hXzlh0dExN5oeOaNbzQ0dVV7QiFEEJV0nLJjJiYtB0dz59PW+g4cCBQrJja0QkhhN2Q5JIRN24AH3+cttCxUyee8SULHYUQIl16tQOwZ4N69AARgfz8QEuWgB49AiUkgOrXx4D58yWxCCHEc0hyeQmdTgedXg+diwv//6xZ0A0bpnZYQghh1yS5ZJTZDAQHA8OHS4tFCCFeQpJLRkhiEUKITJHk8jKSWIQQItMkubyMJBYhhMg0WUT5Ai3q10d8ly5AdDT/BdHj6sUtGzRQMTIhhLBvss4lo06eBGrVAo4fB2rUUDsaIYSwa9ItJoQQQnGSXIQQQihOkosQQgjFSXIRQgihOEkuQgghFCfJRQghhOIkuQghhFCcJBchhBCKk+QihBBCcZJchBBCKE6SixBCCMVJchFCCKE4SS5CCCEUJ8lFCCGE4iS5CCGEUJwkFyGEEIqT5CKEEEJxklyEEEIoTpKLEEIIxUlyEUIIoThJLkIIIRQnyUUIIYTiJLkIIYRQnCQXIYQQipPkIoQQQnGSXIQQQihOkosQQgjFSXIRQgihOEkuQgghFCfJRQghhOIkuQghhFCcJBchhBCKk+QihBBCcZJchBBCKE6SixBCCMVJchFCCKE4SS5CCCEUJ8lFCCGE4iS5CCGEUJwkFyGEEIqT5CKEEEJxklyEEEIoTpKLEEIIxUlyySC
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x600 with 86 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.expand_width(2, 2, sum_bool=False, mult_arity=[2,3])\n",
"model.get_act(dataset)\n",
"model.plot()"
]
},
{
"cell_type": "markdown",
"id": "038ea175",
"metadata": {},
"source": [
"step 4: now we perturb all edges (mode='minimal' only perturb the currently active edges, mode='all' perturbs all neurons)."
]
},
{
"cell_type": "code",
2024-08-11 18:21:12 -04:00
"execution_count": 7,
2024-08-11 13:02:16 -04:00
"id": "45c8e738",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAHiCAYAAAAkiYF/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADcgElEQVR4nOydd5wU5f3HP7P1GkcvBxxH7xwcSEcsWDB2URFBjaCJiRoTNWKJpmjU5Ges0djQKIJiwwK2iIJSpJejKBwg3FGOelzbNjPf3x+zMzc7OzM7uzdbDp63L17g3e7Md2aeeb7Ptz4cEREYDAaDwbARR7oFYDAYDMbJB1MuDAaDwbAdplwYDAaDYTtMuTAYDAbDdphyYTAYDIbtMOXCYDAYDNthyoXBYDAYtsOUC4PBYDBshykXBoPBYNgOUy4MBoPBsB2mXBgMBoNhO0y5MBgMBsN2mHJhMBgMhu0w5cJgMBgM22HKhcFgMBi2w5QLg8FgMGyHKRcGg8Fg2A5TLgwGg8GwHVe6BWAwmgKVlZWYN28eKisr0b59e0yePBnt27dPt1gMRsbCERGlWwgGI1MRRRH3338//vWvf0EURTidTgiCAIfDgbvuuguPPvooHA7mAGAwtDDlwmCYcO+99+If//iH4e9nzpyJxx9/PIUSMRhNA6ZcGAwDKisr0blzZ/A8b/gZl8uFiooK5iJjMDQwe57BMGDevHkQRdH0M6IoYt68eSmSiMFoOjDlwmAYUFlZCafTafoZp9OJysrKFEnEYDQdmHJhMAxo3749BEEw/YwgCMwlxmDowGIuDIYBLObCYCQOs1wYDAPat2+Pu+66y/Qzd911F1MsDIYOrIiSwTDh0UcfBQD861//giAIICJwHAen06nUuTAYjGiYW4zBsEBlZSWeeOIJPPHEE7j77rtx9913M4uFwTCBKRcGwyLr1q3DsGHDsHbtWgwdOjTd4jAYGQ2LuTAYDAbDdphyYTAYDIbtMOXCYDAYDNthyoXBYDAYtsOUC4PBYDBshykXBoPBYNgOUy4MBoPBsB2mXBgMBoNhO0y5MBgMBsN2mHJhMBgMhu0w5cJgMBgM22HKhcFgMBi2w5QLg8FgMGyHKRcGg8Fg2A5TLgwGg8GwHaZcGAwGg2E7TLkwGAwGw3aYcmEwGAyG7TDlwmAwGAzbYcqFwWAwGLbDlAuDwWAwbIcpFwaDwWDYDlMuDAaDwbAdplwYDBOICJs2bcK9996L2267DW63G3/+858xf/58+Hy+dIvHYGQsHBFRuoVgMDIRnufx/PPPY9asWZg0aRJGjhwJj8eDQ4cO4d1334XT6cSzzz6LgoKCdIvKYGQczHJhMAx46623MGfOHMydOxcPPvgg3n//fdx0003o3r075s6di9NOOw2333476uvr0y0qg5FxMOXCYOhw+PBhPPXUU3jiiScwYMAAAMCRI0dQXl6OQCCArKws/OEPf0B2djY++eSTNEvLYGQeTLkwGDqsWLECbdu2RVFREcrKyrBjxw7U1taCiFBRUYEdO3agvLwc06ZNw/z58yGKYrpFZjAyCle6BWAwMpEdO3agX79+uPfee7FgwQIAgM/ngyAImDFjBpxOJ/Lz8zF37lxUVlYiFArB6/WmWWoGI3NgyoXB0MHj8SAQCKBTp07o27cvAKCsrAxVVVUoKipCXl4e8vLywHEcHA4HOI5Ls8QMRmbB3GIMhg5FLYqwfu16PPDAA1i6dCm+//57nH766XC5XHj++eexdOlSfPHFF9i/fz969+4Nj8eTbpEZjIyCKRcGI0ywLogNb2zA6+Nfx+obV6NqaxXmfzgfbrcbHo8HDof0usj/X1tbi1dffRVTp05Ns+QMRubB3GKMUxoiwoG1B7Bu1jpsmbcFwdogepzbA1PmTcH5bc/Hr3/za3izvLjiiivQs2dPnHbaacjLy0N5eTnuuecejBo1CmPGjEn3ZTAYGQcromSckviO+7D57c1YP2s9KjdVIr9zPgb/cjCG/HIIWhS1UD63bNky3HfffcjKysLo0aORm5uLXbt2Ye3atZg8eTJuv/12FshnMHRgyoVxykBE2PPdHmx4bQO2frAVJBB6X9wbJdNL0P3c7nA49b3ENTU1WLFiBdauXQufz4eePXvijDPOQJcuXVggn8EwgCkXxklPzYEabJq9CetfW4/jO4+jVa9WKJlRguJpxchrn5du8RiMkxKmXBgnJSIvYudXO7F+1npsX7gdTrcT/a/sjyHTh6DLOGZxMBjJhikXxklF1c9VWP/6emx8YyNq9tWg/eD2GHrTUAy8ZiCyWmSlWzwG45SBZYsxmjx8gMdPH/+E9a+tx+5Fu+HN92LglIEomVGCghLWsZjBSAfMcmE0WQ5vPYz1r63Hprc2wXfUh8KxhSiZXoJ+k/rBk8uKGhmMdMKUC6NJEawNYuv7W7F+1npU/FCBnLY5KL6uGCXTS9CmT5t0i8dgMMIw5cLIeJRCx1fDhY51UqHjkOlD0OfiPnB6nOkWkcFgaGDKhZGx+I75UPp2KdbPWo9DpYeQX5iPIb8cgsE3DI4odGQwGJkHUy6MjIJEqdBx/Wvrse3DbSCB0OeSPhgyfQi6n2Nc6MhgMDILplwYGYFS6DhrPY7vYoWODEZThykXRtoQeRFlX5Zhw2sbGgodr+qPkuklKBxbyAodGYwmDFMujJRzfPdxbHh9g1TouL8GHYZ0QMlNJRg4mRU6MhgnC6yIkpESWKEjg3FqwSwXRlI5vPUw1s1ah9K3SuE75kOXcV0wZPoQ9J/UH+4cd7rFYzAYSYIpF4btyIWO615dh30r97FCRwbjFIQpF4YtEBH2r9mP9bPWNxQ6ntcDJdNL0Pui3qzQkcE4xWDKhdEofMd8KJ0bLnTcHC50vHEIhtwwBM27NE+3eAwGI00w5cKIG1boyGAwYsGUC8MyNQdqsPHNjdjw2gYc33UcrXu3Vgodc9vlpls8BoORQTDlwjBF5EWUfVGG9a+tx47PdsDpkXZ0LJlRgsIxrNCRwWDow5QLQ5eoQseSDiiZwQodGQyGNVgRJUOB9/P48eMfseG1Ddj9zW54m4cLHaezQkcGgxEfzHJh4NCWQ1g/az1K57BCRwaDYQ9MuZyiBGuD2PLeFqyftV4pdBx8/WCUTC9B696t0y0eg8Fo4jDlcgpBRNi/ej/WzVqHre9uRbAuiJ7n98SQ6UPQ+0JW6MhgMOyDKZdTAN8xHzbNkfZKObzlMJp3aY7BvxzMCh0ZDEbSYMrlJIVEws9LfsaG1zZg2/xtIFEqdCyZXoJuE7qxQkcGg5FUmHI5yajZLxU6rn9tPap2V0mFjjeVoHgqK3RkMBipgymXkwCl0HHWeuz4PFzoKO/oyAodGQxGGmDKpQlzfNdxrH99PTa+sRG1B2pRMLQAJTNKMGDyAGQ1Z4WODAYjfbAiyiYG7+fx40c/Yv1r6/Hztz8rhY5DZwxFhyEd0i0eg8FgAGCWS5MhqtDx9C4omV6Cflf0Y4WODAYj42DKJYMJ1ASw9b2tUqHjqn3IbZeL4uuLUXIjK3RkMBiZDVMuGYZS6PjqOmx5dwtC9SFW6MhgMJocTLlkCPVH65UdHeVCxyE3DsHg6wezQkcGg9HkYMoljZBI+Hnxz1j/2nr8OP9HEIULHWeUoNvZrNCRwWA0XZhySQNRhY59wjs6skJHBoNxksCUS4oQeRE7Pt+BDa9tkHZ09EqFjkNnDEXn0Z1ZoSODwTipYMolyRzfdRzrX1uPjW+yQkcGg3HqwIook4BeoeOgawehZHoJK3RkMBinBMxysZFDmw9h3ax1KJ1TCv9xv1ToOCNc6JjNCh0ZDMapA1MujSRQE8CWd6UdHfev3t9Q6Di9BK17sUJHBoNxasKUSwIQEfat2of1s9Zjy7tbwPt49Di/B0qml6DXhb3gdLNCRwaDcWrDlEsc1B+tR+mccKHjVlWh4w2D0byQFToyGAyGDFMuFqjeV42vZ36tFDr2vbQvhkwfwgodGQwGwwCmXEy44dob8N85/wURQeRFcA4OHMdh9fOr8e/l/8bsd2anW0QGg8HISNiyOwYcx8HhcMDlccH
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x600 with 86 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.perturb(mag=0.1, mode='all')\n",
"model.get_act(dataset)\n",
"model.plot(metric='forward_n')\n",
"# purple means both symbolic front (red) and spline front (black) are active"
]
},
{
"cell_type": "code",
2024-08-11 18:21:12 -04:00
"execution_count": 8,
2024-08-11 13:02:16 -04:00
"id": "6feae91b",
"metadata": {},
"outputs": [
{
"data": {
2024-08-11 18:21:12 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAHiCAYAAAAkiYF/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3hU1dbG3ymZSe+9dzqpkEAIHelNVERUql69Fq5d9LMX1GsDxWtBBZEAgiJIU6lSQigJoZNKeu99ylnfH8MMmWSSTJJJZibZv+fh8d7MmXPW3mfPec/eq2weEREYDAaDwdAhfH0bwGAwGIy+BxMXBoPBYOgcJi4MBoPB0DlMXBgMBoOhc5i4MBgMBkPnMHFhMBgMhs5h4sJgMBgMncPEhcFgMBg6h4kLg8FgMHQOExcGg8Fg6BwmLgwGg8HQOUxcGAwGg6FzmLgwGAwGQ+cwcWEwGAyGzmHiwmAwGAydw8SFwWAwGDqHiQuDwWAwdA4TFwaDwWDoHKG+DWAwjIGioiJs374dRUVFcHFxwcKFC+Hi4qJvsxgMg4VHRKRvIxgMQ4XjOLzyyiv45JNPwHEcBAIB5HI5+Hw+nnvuObz//vvg89kCAIPREiYuDEY7vPzyy/jwww/b/Pyll17CBx980IsWMRjGARMXBqMNioqK4OnpCZlM1uYxQqEQubm5bImMwWgBm88zGG2wfft2cBzX7jEcx2H79u29ZBGDYTwwcWEw2qCoqAgCgaDdYwQCAYqKinrJIgbDeGDiwmC0gYuLC+RyebvHyOVytiTGYGiA+VwYjDZgPhcGo+uwmQuD0QYuLi547rnn2j3mueeeY8LCYGiAJVEyGO3w/vvvAwA++eQTyOVyEBF4PB4EAoEqz4XBYLSGLYsxGFpQVFSEjz/+GB9//DGef/55PP/882zGwmC0AxMXBkNLEhMTERERgQsXLiA8PFzf5jAYBg3zuTAYDAZD5zBxYTAYDIbOYeLCYDAYDJ3DxIXBYDAYOoeJC4PBYDB0DhMXBoPBYOgcJi4MBoPB0DlMXBgMBoOhc5i4MBgMBkPnMHFhMBgMhs5h4sJgMBgMncPEhcFgMBg6h4kLg8FgMHQOExcGg8Fg6BwmLgwGg8HQOUxcGAwGg6FzmLgwGAwGQ+cwcWEwGAyGzmHiwmAwGAydw8SFwWAwGDqHiQuDwWAwdA4TFwaDwWDoHCYuDAaDwdA5TFwYjHYgIly6dAkvv/wynnzySZiYmOCNN97Arl270NDQoG/zGAyDhUdEpG8jGAxDRCaTYf369fj++++xYMECREVFQSQSobi4GL/88gsEAgHWrVsHNzc3fZvKYBgcbObCYLTBzz//jC1btiAuLg6vvfYadu7ciZUrV8Lf3x9xcXGIjIzEU089hfr6en2bymAYHExcGAwNlJSU4LPPPsPHH3+MIUOGAABKS0uRk5ODpqYmmJqa4plnnoGZmRn27NmjZ2sZDMODiQuDoYH4+Hg4OTnBx8cHaWlpSE1NRW1tLYgIubm5SE1NRU5ODh588EHs2rULHMfp22QGw6AQ6tsABsMQSU1NxaBBg/Dyyy9j7969AICGhgbI5XKsWLECAoEA1tbWiIuLQ1FREaRSKcRisZ6tZjAMByYuDIYGRCIRmpqa4OHhgYEDBwIA0tLSUFlZCR8fH1haWsLS0hI8Hg98Ph88Hk/PFjMYhgVbFmMwNOBj44OkC0l49dVXcfLkSZw4cQKxsbEQCoVYv349Tp48iYMHDyI/Px/BwcEQiUT6NpnBMCiYuDAYt5HUSZD0YxJ+iPkBCUsSUHG1Art+2wUTExOIRCLw+Yqfi/L/19bWYsOGDVi8eLGeLWcwDA+2LMbo1xAR8s/nI3FDIq5svQJJrQQBdwVg8Y7FmO48Hf96/F8Qm4px9913IzAwEJGRkbC0tEROTg5efPFFREdHY/To0fpuBoNhcLAkSka/pKGiAZe3XEbihkQUJRfB2tMaoctDEbY8DLY+tqrjTp06hdWrV8PU1BSjRo2ChYUFMjIycOHCBSxcuBBPPfUUc+QzGBpg4sLoNxARso5nIXFDIq7tvAaSEwbMGYCwlWEIuCsAfIHmVeKamhrEx8fj/PnzaKhvQGBQIMaPHw9vb2/myGcw2oCJC6PPU1NQg+RNyUj6PgnlaeVwCHZA2MowhDwcAksXS63PQ0QgOYEn4DFRYTA6gIkLo0/CyTik/ZmGxO8SkbI3BQITAQbfOxjhK8PhHdu1GQcTFwZDe5i4MPoUFZkVSPohCRd/vIiavBq4hroi/JFwDHtgGExtTbt1biYuDIb2sGgxhtEja5Lhxu83kLQhCRmHMiC2FmPY4mEIXxkOt3BWsZjB0Ads5sIwWoqvFiPp+yQk/5SMhrIGeI/xRtjKMAy+ZzBEFrpPamQzFwZDe9jMhWFUSGoluPrLVSRuSERufC7MncwRuiwU4SvC4TjQUd/mMRiM27CZC8PgUSU6fnc70bFOkegYvjIcA+YMgEAk6DU72MyFwdAOJi4Mg6WhvAGXtlxC0oYkFF0qgrWXNcKWhyF0WahaomNvwcSFwdAeJi4Mg4I4wq3jt5C0IQnXfr2d6Dh3AMJXhsN/in+biY69YhsTFwZDa5i4MAwCZaJj4oZEVKRXdDnRsSdh4sJgaA8TF4be4GQc0g6mIXHDnUTHIfcNQdjKMHiPMbzSKkxcGAztYeLC6HUqMpolOubXwDXsdqLjou4nOvYkTFwYDO1hociMXoElOjIY/Qs2c2H0KMVXi5G4IRGXfrqEhvIGeMd6I3xlOAbfMxgm5ib6Nq9TsJkLg6E9bObC0DmqRMfvEpF75nai43KW6Mhg9CfYzIWhE4gI+eea7ehYJ0Hg1ECErQzDgNm9l+jYk7CZC4OhPUxcGN2iobwBl36+hMQNiSi+XKxIdFwRhrBlYbDxttG3eTqFiQuDoT1MXBidxpATHXsSJi4MhvYwcWFoTU1BDS5uvIik75MUiY4DHBC+MhwhD4fAwtlC3+b1OExcGAztYeLCaBdOxiH1QCqSNiQhZV8KBCIBhtw7BOGPhMMrxqtfPWSZuDAY2sPEhaGRlomObuFuCFsZZvCJjj0JExcGQ3tYKDJDhaxRkeiYuCERmYczIba5nei4giU6MhiMzsFmLgwUX7md6LjZ+BMdexI2c2EwtIfNXPopkloJrmy/gqQNSapEx7AVYQhbEQbHASzRkcFgdA82c+lHEBHyzuYhcUMirm67qkh0nBaI8JXhCJ4V3CcSHXsSNnNhMLSHiUs/oKG8Acmbk5G0IQnFV4ph422D0OWhfTLRsSdh4sJgaA8Tlz4KcYRbx24hcUMirv92HcQRBs4diLCVYfCf3HcTHXsSJi4MhvYwcelj1OQ3S3TMuJ3o+Eg4Qh7qH4mOPQkTFwZDe5i49AGUiY6J3yUidX+qItHxviEIX9n/Eh17EiYuDIb2MHExYsrTy1WJjrUFtXCLcEP4ynAMXTQUpjb9M9GxJ2HiwmBoDwtFNjJkjTJc33UdSRuSkHmkWaLjynC4hbFERwaDYRiwmYuR0DLR0WesD8JWhmHwApbo2FuwmQuDoT1s5mLANNU04er2q0jckIi8hDxYOFsgbGUYwpazREcGg2HYsJmLgaFKdPwuEVe2XYG0XsoSHQ0ENnNhMLSHiYuBUF9Wj0s/X1JLdAxbEYbQpaEs0dFAYOLCYGgPExc9Qhwh82gmkjYkKRIdSZHoGP5IOPwm+bFERwODiQuDoT1MXPRAy0RHx4GOCFsZxhIdDRwmLgyG9jBx6SU4GYfU/alI3JCI1H2pEIgFGLpwKMJWhsFrNEt0NAaYuDAY2sPEpYcpTy9H0vdJuLiRJToaO0xcGAztYaHIPYCmRMfhDw5H2IowlujIYDD6BWzmokOKLhepEh0bKxrhM9YH4Y+EY9CCQTAxY4mOxg6buTAY2sNmLt2kqaYJV7YpdnTMO6tIdAx/JBzhK8LhEOygb/MYDAZDL7CZSxcgIuQlKHZ0vLLtCmQNMgROC0TYyjBFoqMJS3Tsi7CZC4OhPUxcOkF9WT0ubb6ExA2JKLlacifRcVkobLxYomNfh4kLg6E9TFy0oDq
2024-08-11 13:02:16 -04:00
"text/plain": [
"<Figure size 500x600 with 86 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.plot(beta=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eabf7aa3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-08-11 18:21:12 -04:00
"version": "3.9.16"
2024-08-11 13:02:16 -04:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}