695 lines
23 KiB
Python
695 lines
23 KiB
Python
import numpy as np
|
|
import torch
|
|
from sklearn.linear_model import LinearRegression
|
|
from sympy.utilities.lambdify import lambdify
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
from .utils import batch_jacobian, batch_hessian
|
|
from functools import reduce
|
|
from kan.utils import batch_jacobian, batch_hessian
|
|
import copy
|
|
import matplotlib.pyplot as plt
|
|
import sympy
|
|
from sympy.printing import latex
|
|
|
|
|
|
def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
|
|
'''
|
|
detect function separability
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
mode : str
|
|
mode = 'add' or mode = 'mul'
|
|
score_th : float
|
|
threshold of score
|
|
res_th : float
|
|
threshold of residue
|
|
n_clusters : None or int
|
|
the number of clusters
|
|
bias : float
|
|
bias (for multiplicative separability)
|
|
verbose : bool
|
|
|
|
Returns:
|
|
--------
|
|
results (dictionary)
|
|
|
|
Example1
|
|
--------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
|
|
>>> x = torch.normal(0,1,size=(100,3))
|
|
>>> detect_separability(model, x, mode='add')
|
|
|
|
Example2
|
|
--------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
|
|
>>> x = torch.normal(0,1,size=(100,3))
|
|
>>> detect_separability(model, x, mode='mul')
|
|
'''
|
|
results = {}
|
|
|
|
if mode == 'add':
|
|
hessian = batch_hessian(model, x)
|
|
elif mode == 'mul':
|
|
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
|
|
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
|
|
|
|
std = torch.std(x, dim=0)
|
|
hessian_normalized = hessian * std[None,:] * std[:,None]
|
|
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
|
|
results['hessian'] = score_mat
|
|
|
|
dist_hard = (score_mat < score_th).float()
|
|
|
|
if isinstance(n_clusters, int):
|
|
n_cluster_try = [n_clusters, n_clusters]
|
|
elif isinstance(n_clusters, list):
|
|
n_cluster_try = n_clusters
|
|
else:
|
|
n_cluster_try = [1,x.shape[1]]
|
|
|
|
n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
|
|
|
|
for n_cluster in n_cluster_try:
|
|
|
|
clustering = AgglomerativeClustering(
|
|
metric='precomputed',
|
|
n_clusters=n_cluster,
|
|
linkage='complete',
|
|
).fit(dist_hard)
|
|
|
|
labels = clustering.labels_
|
|
|
|
groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
|
|
blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
|
|
block_sum = torch.sum(torch.stack(blocks))
|
|
total_sum = torch.sum(score_mat)
|
|
residual_sum = total_sum - block_sum
|
|
residual_ratio = residual_sum / total_sum
|
|
|
|
if verbose == True:
|
|
print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
|
|
|
|
if residual_ratio < res_th:
|
|
results['n_groups'] = n_cluster
|
|
results['labels'] = list(labels)
|
|
results['groups'] = groups
|
|
|
|
if results['n_groups'] > 1:
|
|
print(f'{mode} separability detected')
|
|
else:
|
|
print(f'{mode} separability not detected')
|
|
|
|
return results
|
|
|
|
|
|
def batch_grad_normgrad(model, x, group, create_graph=False):
|
|
# x in shape (Batch, Length)
|
|
group_A = group
|
|
group_B = list(set(range(x.shape[1])) - set(group))
|
|
|
|
def jac(x):
|
|
input_grad = batch_jacobian(model, x, create_graph=True)
|
|
input_grad_A = input_grad[:,group_A]
|
|
norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
|
|
input_grad_A_normalized = input_grad_A/norm
|
|
return input_grad_A_normalized
|
|
|
|
def _jac_sum(x):
|
|
return jac(x).sum(dim=0)
|
|
|
|
return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]
|
|
|
|
|
|
def get_dependence(model, x, group):
|
|
group_A = group
|
|
group_B = list(set(range(x.shape[1])) - set(group))
|
|
grad_normgrad = batch_grad_normgrad(model, x, group=group)
|
|
std = torch.std(x, dim=0)
|
|
dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
|
|
dependence = torch.median(torch.abs(dependence), dim=0)[0]
|
|
return dependence
|
|
|
|
def test_symmetry(model, x, group, dependence_th=1e-3):
|
|
'''
|
|
detect function separability
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
group : a list of indices
|
|
dependence_th : float
|
|
threshold of dependence
|
|
|
|
Returns:
|
|
--------
|
|
bool
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
|
|
>>> x = torch.normal(0,1,size=(100,3))
|
|
>>> print(test_symmetry(model, x, [1,2])) # True
|
|
>>> print(test_symmetry(model, x, [0,2])) # False
|
|
'''
|
|
if len(group) == x.shape[1] or len(group) == 0:
|
|
return True
|
|
|
|
dependence = get_dependence(model, x, group)
|
|
max_dependence = torch.max(dependence)
|
|
return max_dependence < dependence_th
|
|
|
|
|
|
def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
|
|
'''
|
|
test function separability
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
mode : str
|
|
mode = 'add' or mode = 'mul'
|
|
score_th : float
|
|
threshold of score
|
|
res_th : float
|
|
threshold of residue
|
|
bias : float
|
|
bias (for multiplicative separability)
|
|
verbose : bool
|
|
|
|
Returns:
|
|
--------
|
|
bool
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
|
|
>>> x = torch.normal(0,1,size=(100,3))
|
|
>>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
|
|
>>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
|
|
'''
|
|
if mode == 'add':
|
|
hessian = batch_hessian(model, x)
|
|
elif mode == 'mul':
|
|
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
|
|
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
|
|
|
|
std = torch.std(x, dim=0)
|
|
hessian_normalized = hessian * std[None,:] * std[:,None]
|
|
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
|
|
|
|
sep_bool = True
|
|
|
|
# internal test
|
|
n_groups = len(groups)
|
|
for i in range(n_groups):
|
|
for j in range(i+1, n_groups):
|
|
sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
|
|
|
|
# external test
|
|
group_id = [x for xs in groups for x in xs]
|
|
nongroup_id = list(set(range(x.shape[1])) - set(group_id))
|
|
if len(nongroup_id) > 0 and len(group_id) > 0:
|
|
sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold
|
|
|
|
return sep_bool
|
|
|
|
def test_general_separability(model, x, groups, threshold=1e-2):
|
|
'''
|
|
test function separability
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
mode : str
|
|
mode = 'add' or mode = 'mul'
|
|
score_th : float
|
|
threshold of score
|
|
res_th : float
|
|
threshold of residue
|
|
bias : float
|
|
bias (for multiplicative separability)
|
|
verbose : bool
|
|
|
|
Returns:
|
|
--------
|
|
bool
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
|
|
>>> x = torch.normal(0,1,size=(100,3))
|
|
>>> print(test_general_separability(model, x, [[1],[0,2]])) # False
|
|
>>> print(test_general_separability(model, x, [[0],[1,2]])) # True
|
|
'''
|
|
grad = batch_jacobian(model, x)
|
|
|
|
gensep_bool = True
|
|
|
|
n_groups = len(groups)
|
|
for i in range(n_groups):
|
|
for j in range(i+1,n_groups):
|
|
group_A = groups[i]
|
|
group_B = groups[j]
|
|
for member_A in group_A:
|
|
for member_B in group_B:
|
|
def func(x):
|
|
grad = batch_jacobian(model, x, create_graph=True)
|
|
return grad[:,[member_B]]/grad[:,[member_A]]
|
|
# test if func is multiplicative separable
|
|
gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
|
|
return gensep_bool
|
|
|
|
|
|
def get_molecule(model, x, sym_th=1e-3, verbose=True):
|
|
'''
|
|
how variables are combined hierarchically
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
sym_th : float
|
|
threshold of symmetry
|
|
verbose : bool
|
|
|
|
Returns:
|
|
--------
|
|
list
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
|
|
>>> x = torch.normal(0,1,size=(100,8))
|
|
>>> get_molecule(model, x, verbose=False)
|
|
[[[0], [1], [2], [3], [4], [5], [6], [7]],
|
|
[[0, 1], [2, 3], [4, 5], [6, 7]],
|
|
[[0, 1, 2, 3], [4, 5, 6, 7]],
|
|
[[0, 1, 2, 3, 4, 5, 6, 7]]]
|
|
'''
|
|
n = x.shape[1]
|
|
atoms = [[i] for i in range(n)]
|
|
molecules = []
|
|
moleculess = [copy.deepcopy(atoms)]
|
|
already_full = False
|
|
n_layer = 0
|
|
last_n_molecule = n
|
|
|
|
while True:
|
|
|
|
|
|
pointer = 0
|
|
current_molecule = []
|
|
remove_atoms = []
|
|
n_atom = 0
|
|
|
|
while len(atoms) > 0:
|
|
|
|
# assemble molecule
|
|
atom = atoms[pointer]
|
|
if verbose:
|
|
print(current_molecule)
|
|
print(atom)
|
|
|
|
if len(current_molecule) == 0:
|
|
full = False
|
|
current_molecule += atom
|
|
remove_atoms.append(atom)
|
|
n_atom += 1
|
|
else:
|
|
# try assemble the atom to the molecule
|
|
if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
|
|
full = True
|
|
already_full = True
|
|
else:
|
|
full = False
|
|
if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
|
|
current_molecule += atom
|
|
remove_atoms.append(atom)
|
|
n_atom += 1
|
|
|
|
pointer += 1
|
|
|
|
if pointer == len(atoms) or full:
|
|
molecules.append(current_molecule)
|
|
if full:
|
|
molecules.append(atom)
|
|
remove_atoms.append(atom)
|
|
# remove molecules from atoms
|
|
for atom in remove_atoms:
|
|
atoms.remove(atom)
|
|
current_molecule = []
|
|
remove_atoms = []
|
|
pointer = 0
|
|
|
|
# if not making progress, terminate
|
|
if len(molecules) == last_n_molecule:
|
|
def flatten(xss):
|
|
return [x for xs in xss for x in xs]
|
|
moleculess.append([flatten(molecules)])
|
|
break
|
|
else:
|
|
moleculess.append(copy.deepcopy(molecules))
|
|
|
|
last_n_molecule = len(molecules)
|
|
|
|
if len(molecules) == 1:
|
|
break
|
|
|
|
atoms = molecules
|
|
molecules = []
|
|
|
|
n_layer += 1
|
|
|
|
#print(n_layer, atoms)
|
|
|
|
|
|
# sort
|
|
depth = len(moleculess) - 1
|
|
|
|
for l in list(range(depth,0,-1)):
|
|
|
|
molecules_sorted = []
|
|
molecules_l = moleculess[l]
|
|
molecules_lm1 = moleculess[l-1]
|
|
|
|
|
|
for molecule_l in molecules_l:
|
|
start = 0
|
|
for i in range(1,len(molecule_l)+1):
|
|
if molecule_l[start:i] in molecules_lm1:
|
|
|
|
molecules_sorted.append(molecule_l[start:i])
|
|
start = i
|
|
|
|
moleculess[l-1] = molecules_sorted
|
|
|
|
return moleculess
|
|
|
|
|
|
def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
|
|
'''
|
|
get tree nodes
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
sep_th : float
|
|
threshold of separability
|
|
skip_test : bool
|
|
if True, don't test the property of each module (to save time)
|
|
|
|
Returns:
|
|
--------
|
|
arities : list of numbers
|
|
properties : list of strings
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
|
|
>>> x = torch.normal(0,1,size=(100,8))
|
|
>>> moleculess = get_molecule(model, x, verbose=False)
|
|
>>> get_tree_node(model, x, moleculess, skip_test=False)
|
|
'''
|
|
arities = []
|
|
properties = []
|
|
|
|
depth = len(moleculess) - 1
|
|
|
|
for l in range(depth):
|
|
molecules_l = copy.deepcopy(moleculess[l])
|
|
molecules_lp1 = copy.deepcopy(moleculess[l+1])
|
|
arity_l = []
|
|
property_l = []
|
|
|
|
for molecule in molecules_lp1:
|
|
start = 0
|
|
arity = 0
|
|
groups = []
|
|
for i in range(1,len(molecule)+1):
|
|
if molecule[start:i] in molecules_l:
|
|
groups.append(molecule[start:i])
|
|
start = i
|
|
arity += 1
|
|
arity_l.append(arity)
|
|
|
|
if arity == 1:
|
|
property = 'Id'
|
|
else:
|
|
property = ''
|
|
# test property
|
|
if skip_test:
|
|
gensep_bool = False
|
|
else:
|
|
gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
|
|
|
|
if gensep_bool:
|
|
property = 'GS'
|
|
if l == depth - 1:
|
|
if skip_test:
|
|
add_bool = False
|
|
mul_bool = False
|
|
else:
|
|
add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
|
|
mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
|
|
if add_bool:
|
|
property = 'Add'
|
|
if mul_bool:
|
|
property = 'Mul'
|
|
|
|
|
|
property_l.append(property)
|
|
|
|
|
|
arities.append(arity_l)
|
|
properties.append(property_l)
|
|
|
|
return arities, properties
|
|
|
|
|
|
def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
|
|
'''
|
|
get tree graph
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
in_var : list of symbols
|
|
input variables
|
|
style : str
|
|
'tree' or 'box'
|
|
sym_th : float
|
|
threshold of symmetry
|
|
sep_th : float
|
|
threshold of separability
|
|
skip_sep_test : bool
|
|
if True, don't test the property of each module (to save time)
|
|
verbose : bool
|
|
|
|
Returns:
|
|
--------
|
|
a tree graph
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
|
|
>>> x = torch.normal(0,1,size=(100,8))
|
|
>>> plot_tree(model, x)
|
|
'''
|
|
moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
|
|
arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)
|
|
|
|
n = x.shape[1]
|
|
var = None
|
|
|
|
in_vars = []
|
|
|
|
if in_var == None:
|
|
for ii in range(1, n + 1):
|
|
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
|
|
exec(f"in_vars.append(x{ii})")
|
|
elif type(var[0]) == Symbol:
|
|
in_vars = var
|
|
else:
|
|
in_vars = [sympy.symbols(var_) for var_ in var]
|
|
|
|
|
|
def flatten(xss):
|
|
return [x for xs in xss for x in xs]
|
|
|
|
def myrectangle(center_x, center_y, width_x, width_y):
|
|
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
|
|
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
|
|
plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
|
|
plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
|
|
|
|
depth = len(moleculess)
|
|
|
|
delta = 1/n
|
|
a = 0.3
|
|
b = 0.15
|
|
y0 = 0.5
|
|
|
|
|
|
# draw rectangles
|
|
for l in range(depth-1):
|
|
molecules = moleculess[l+1]
|
|
n_molecule = len(molecules)
|
|
|
|
centers = []
|
|
|
|
acc_arity = 0
|
|
|
|
for i in range(n_molecule):
|
|
start_id = len(flatten(molecules[:i]))
|
|
end_id = len(flatten(molecules[:i+1]))
|
|
|
|
center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
|
|
center_y = (l+1/2)*y0
|
|
width_x = (end_id - start_id - 1 + 2*a)*delta
|
|
width_y = 2*b
|
|
|
|
# add text (numbers) on rectangles
|
|
if style == 'box':
|
|
myrectangle(center_x, center_y, width_x, width_y)
|
|
plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
|
|
verticalalignment='center')
|
|
elif style == 'tree':
|
|
# if 'GS', no rectangle, n=arity tilted lines
|
|
# if 'Id', no rectangle, n=arity vertical lines
|
|
# if 'Add' or 'Mul'. rectangle, "+" or "x"
|
|
# if '', rectangle
|
|
property = properties[l][i]
|
|
if property == 'GS' or property == 'Add' or property == 'Mul':
|
|
color = 'blue'
|
|
arity = arities[l][i]
|
|
for j in range(arity):
|
|
|
|
if l == 0:
|
|
# x = (start_id + j) * delta + delta/2, center_x
|
|
# y = center_y - b, center_y + b
|
|
plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
|
|
else:
|
|
# x = last_centers[acc_arity:acc_arity+arity], center_x
|
|
# y = center_y - b, center_y + b
|
|
plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)
|
|
|
|
acc_arity += arity
|
|
|
|
if property == 'Add' or property == 'Mul':
|
|
if property == 'Add':
|
|
symbol = '+'
|
|
else:
|
|
symbol = '*'
|
|
|
|
plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
|
|
verticalalignment='center', color='red', fontsize=40)
|
|
if property == 'Id':
|
|
plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
|
|
|
|
if property == '':
|
|
myrectangle(center_x, center_y, width_x, width_y)
|
|
|
|
|
|
|
|
# connections to the next layer
|
|
plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
|
|
centers.append(center_x)
|
|
last_centers = copy.deepcopy(centers)
|
|
|
|
# connections from input variables to the first layer
|
|
for i in range(n):
|
|
x_ = (i + 1/2) * delta
|
|
# connections to the next layer
|
|
plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
|
|
plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
|
|
plt.xlim(0,1)
|
|
#plt.ylim(0,1);
|
|
plt.axis('off');
|
|
plt.show()
|
|
|
|
|
|
def test_symmetry_var(model, x, input_vars, symmetry_var):
|
|
'''
|
|
test symmetry
|
|
|
|
Args:
|
|
-----
|
|
model : MultKAN, MLP or python function
|
|
x : 2D torch.float
|
|
inputs
|
|
input_vars : list of sympy symbols
|
|
symmetry_var : sympy expression
|
|
|
|
Returns:
|
|
--------
|
|
cosine similarity
|
|
|
|
Example
|
|
-------
|
|
>>> from kan.hypothesis import *
|
|
>>> from sympy import *
|
|
>>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
|
|
>>> x = torch.normal(0,1,size=(100,8))
|
|
>>> input_vars = a, b, c = symbols('a b c')
|
|
>>> symmetry_var = b + c
|
|
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
|
|
>>> symmetry_var = b * c
|
|
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
|
|
'''
|
|
orig_vars = input_vars
|
|
sym_var = symmetry_var
|
|
|
|
# gradients wrt to input (model)
|
|
input_grad = batch_jacobian(model, x)
|
|
|
|
# gradients wrt to input (symmetry var)
|
|
func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function
|
|
|
|
func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
|
|
sym_grad = batch_jacobian(func2, x)
|
|
|
|
# get id
|
|
idx = []
|
|
sym_symbols = list(sym_var.free_symbols)
|
|
for sym_symbol in sym_symbols:
|
|
for j in range(len(orig_vars)):
|
|
if sym_symbol == orig_vars[j]:
|
|
idx.append(j)
|
|
|
|
input_grad_part = input_grad[:,idx]
|
|
sym_grad_part = sym_grad[:,idx]
|
|
|
|
cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
|
|
|
|
ratio = torch.sum(cossim > 0.9)/len(cossim)
|
|
|
|
print(f'{100*ratio}% data have more than 0.9 cosine similarity')
|
|
if ratio > 0.9:
|
|
print('suggesting symmetry')
|
|
else:
|
|
print('not suggesting symmetry')
|
|
|
|
return cossim |