GitHub_collection_pykan/tutorials/Interp_5_test_symmetry.ipynb
2024-08-11 13:11:59 -04:00

404 lines
70 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "f8ba3161",
"metadata": {},
"source": [
"# Interprebility 5: Test symmetries"
]
},
{
"cell_type": "markdown",
"id": "6535c1f2",
"metadata": {},
"source": [
"Figuring out the symbolic formula represented by a model is ideal but sometimes too challenging. In this case, we might be content with simply figuring out some modular structures or symmetries. These hypothesis testing is partially inspired by AI Feynman."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1416f4c8",
"metadata": {},
"outputs": [],
"source": [
"from kan.hypothesis import *\n",
"import torch"
]
},
{
"cell_type": "markdown",
"id": "6ee16e29",
"metadata": {},
"source": [
"Case 1: detect separability.\n",
"* Additive separability: $f(x_1, x_2, ...) = g_1(x_1,x_2) + g_2(x_3) + g_3(x_4,x_5,x_6) + ...$\n",
"* Multiplicative separability: $f(x_1, x_2, ...) = g_1(x_1,x_2)g_2(x_3)g_3(x_4,x_5,x_6)...$\n",
"* General separability: $f(x_1, x_2, x_3, ...) = h(p(x_1,x_2)+q(x_3,\\cdots))$. (Note that general additive separability = general multiplicative separability)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "87f1e596",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"add separability detected\n"
]
},
{
"data": {
"text/plain": [
"{'hessian': tensor([[0.0000, 0.3609, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.3609, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000, 0.3217, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.3217, 0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3472],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.3472, 0.0000]]),\n",
" 'n_groups': 3,\n",
" 'labels': [2, 2, 1, 1, 0, 0],\n",
" 'groups': [[4, 5], [2, 3], [0, 1]]}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = lambda x: x[:,[0]] * x[:,[1]] + x[:,[2]] * x[:,[3]] + x[:,[4]] * x[:,[5]]\n",
"x = torch.rand(100,6) * 2 - 1\n",
"detect_separability(f, x, 'add')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0b63eed4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mul separability detected\n"
]
}
],
"source": [
"f = lambda x: (x[:,[0]] + x[:,[1]]) * (x[:,[2]] + x[:,[3]]) * (x[:,[4]] + x[:,[5]])\n",
"x = torch.rand(100,6) * 2 - 1\n",
"detect_separability(f, x, 'mul');"
]
},
{
"cell_type": "markdown",
"id": "3933b0dd",
"metadata": {},
"source": [
"We could also test separability by providing a group partition as an argument."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "96110a32",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = lambda x: (x[:,[0]] + x[:,[1]]) * (x[:,[2]] + x[:,[3]]) * (x[:,[4]] + x[:,[5]])\n",
"x = torch.rand(100,6) * 2 - 1\n",
"groups = [[0,1],[2,3],[4,5]]\n",
"test_separability(f, x, groups, 'mul')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e81778e9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(False)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_separability(f, x, [[0,1],[2,4],[3,5]], 'mul')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3c088092",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(False)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = lambda x: torch.sin((x[:,[0]] + x[:,[1]]) * (x[:,[2]] + x[:,[3]]) * (x[:,[4]] + x[:,[5]]))\n",
"x = torch.rand(100,6) * 2 - 1\n",
"test_separability(f, x, [[0,1],[2,3],[4,5]], 'mul')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b42e3b47",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(True)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_general_separability(f, x, [[0,1],[2,3],[4,5]])"
]
},
{
"cell_type": "markdown",
"id": "fbe1d482",
"metadata": {},
"source": [
"Case 2: test symmetry.\n",
"* Symmetry means the output $y$ is only dependent on a scalar function of a few variables, but otherwise does not gain more infomration from knowing the individual values of these variables. \n",
"* For example, we say a function has a symmetry $h(x_1, x_2)$ if $f(x_1,x_2,x_3,\\cdots)= g(h(x_1, x_2), x_3,\\cdots)$.\n",
"* To hypothesis test $h$, use test_symmetry_var"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "29640f8f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0,1]: tensor(True)\n",
"[0,2]: tensor(False)\n",
"[2,3]: tensor(True)\n"
]
}
],
"source": [
"f = lambda x: (x[:,[0]] + x[:,[1]]) * (x[:,[2]] + x[:,[3]]) * (x[:,[4]] + x[:,[5]])\n",
"x = torch.rand(100,6) * 2 - 1\n",
"print('[0,1]:', test_symmetry(f, x, [0,1]))\n",
"print('[0,2]:', test_symmetry(f, x, [0,2]))\n",
"print('[2,3]:', test_symmetry(f, x, [2,3]))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a392089f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100.0% data have more than 0.9 cosine similarity\n",
"suggesting symmetry\n"
]
}
],
"source": [
"from sympy import *\n",
"\n",
"# the function is only dependent on b/c, but not on the individual values of b and c.\n",
"f = lambda x: x[:,[0]] * torch.sqrt(1 + (x[:,[1]]/x[:,[2]])**2)\n",
"input_vars = a, b, c = symbols('a b c')\n",
"symmetry_var = b/c\n",
"x = torch.rand(100,3) * 2 - 1\n",
"test_symmetry_var(f, x, input_vars, symmetry_var);"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b8212789",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"24.0% data have more than 0.9 cosine similarity\n",
"not suggesting symmetry\n"
]
}
],
"source": [
"not_symmetry_var = b * c\n",
"test_symmetry_var(f, x, input_vars, not_symmetry_var);"
]
},
{
"cell_type": "markdown",
"id": "8c782f62",
"metadata": {},
"source": [
"Case 3: Plot tree graph. By applying the hypothesis testing above iteratively, we are able to figure out the tree graph. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "42003070",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f = lambda x: ((x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2) ** 2 + ((x[:,[4]]**2 + x[:,[5]]**2) ** 2 + (x[:,[6]]**2 + x[:,[7]]**2) ** 2) ** 2\n",
"x = torch.rand(100,8) * 2 - 1\n",
"plot_tree(f, x, style='tree') # by default, style = 'tree'"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "8104aede",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_tree(f, x, style='box')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a2136344",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f = lambda x: ((x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2) ** 2 + x[:,[4]]**2\n",
"x = torch.rand(100,5) * 2 - 1\n",
"plot_tree(f, x, style='tree') # by default, style = 'tree'"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "8b0c7563",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_tree(f, x, style='box')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1333bed5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}