2024-08-17 18:19:10 -04:00

498 lines
18 KiB
Python

from sympy import *
import sympy
import numpy as np
from kan.MultKAN import MultKAN
import torch
def next_nontrivial_operation(expr, scale=1, bias=0):
'''
remove the affine part of an expression
Args:
-----
expr : sympy expression
scale : float
bias : float
Returns:
--------
expr : sympy expression
scale : float
bias : float
Example
-------
>>> from kan.compiler import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
>>> next_nontrivial_operation(expression)
'''
if expr.func == Add or expr.func == Mul:
n_arg = len(expr.args)
n_num = 0
n_var_id = []
n_num_id = []
var_args = []
for i in range(n_arg):
is_number = expr.args[i].is_number
n_num += is_number
if not is_number:
n_var_id.append(i)
var_args.append(expr.args[i])
else:
n_num_id.append(i)
if n_num > 0:
# trivial
if expr.func == Add:
for i in range(n_num):
if i == 0:
bias = expr.args[n_num_id[i]]
else:
bias += expr.args[n_num_id[i]]
if expr.func == Mul:
for i in range(n_num):
if i == 0:
scale = expr.args[n_num_id[i]]
else:
scale *= expr.args[n_num_id[i]]
return next_nontrivial_operation(expr.func(*var_args), scale, bias)
else:
return expr, scale, bias
else:
return expr, scale, bias
def expr2kan(input_variables, expr, grid=5, k=3, auto_save=False):
'''
compile a symbolic formula to a MultKAN
Args:
-----
input_variables : a list of sympy symbols
expr : sympy expression
grid : int
the number of grid intervals
k : int
spline order
auto_save : bool
if auto_save = True, models are automatically saved
Returns:
--------
MultKAN
Example
-------
>>> from kan.compiler import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = exp(sin(pi*a) + b**2)
>>> model = kanpiler(input_vars, expression)
>>> x = torch.rand(100,2) * 2 - 1
>>> model(x)
>>> model.plot()
'''
class Node:
def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None):
self.expr = expr
self.mult_bool = mult_bool
if self.mult_bool:
self.mult_arity = mult_arity
self.depth = depth
if len(Nodes) <= depth:
Nodes.append([])
index = 0
else:
index = len(Nodes[depth])
Nodes[depth].append(self)
self.index = index
if parent == None:
self.parent_index = None
else:
self.parent_index = parent.index
self.child_index = []
# update parent's child_index
if parent != None:
parent.child_index.append(self.index)
self.scale = scale
self.bias = bias
class SubNode:
def __init__(self, expr, depth, scale, bias, parent=None):
self.expr = expr
self.depth = depth
if len(SubNodes) <= depth:
SubNodes.append([])
index = 0
else:
index = len(SubNodes[depth])
SubNodes[depth].append(self)
self.index = index
self.parent_index = None # shape: (2,)
self.child_index = [] # shape: (n, 2)
# update parent's child_index
parent.child_index.append(self.index)
self.scale = scale
self.bias = bias
class Connection:
def __init__(self, affine, fun, fun_name, parent=None, child=None, power_exponent=None):
# connection = activation function that connects a subnode to a node in the next layer node
self.affine = affine #[1,0,1,0] # (a,b,c,d)
self.fun = fun # y = c*fun(a*x+b)+d
self.fun_name = fun_name
self.parent_index = parent.index
self.depth = parent.depth
self.child_index = child.index
self.power_exponent = power_exponent # if fun == Pow
Connections[(self.depth,self.parent_index,self.child_index)] = self
def create_node(expr, parent=None, n_layer=None):
#print('before', expr)
expr, scale, bias = next_nontrivial_operation(expr)
#print('after', expr)
if parent == None:
depth = 0
else:
depth = parent.depth
if expr.func == Mul:
mult_arity = len(expr.args)
node = Node(expr, True, depth, scale, bias, parent=parent, mult_arity=mult_arity)
# create mult_arity SubNodes, + 1
for i in range(mult_arity):
# create SubNode
expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
subnode = SubNode(expr_i, node.depth+1, scale, bias, parent=node)
if expr_i.func == Add:
for j in range(len(expr_i.args)):
expr_ij, scale, bias = next_nontrivial_operation(expr_i.args[j])
# expr_ij is impossible to be Add, should be Mul or 1D
if expr_ij.func == Mul:
#print(expr_ij)
# create a node with expr_ij
new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
# create a connection which is a linear function
c = Connection([1,0,float(scale),float(bias)], lambda x: x, 'x', parent=subnode, child=new_node)
elif expr_ij.func == Symbol:
#print(expr_ij)
new_node = create_node(expr_ij, parent=subnode, n_layer=n_layer)
c = Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
else:
# 1D function case
# create a node with expr_ij.args[0]
new_node = create_node(expr_ij.args[0], parent=subnode, n_layer=n_layer)
# create 1D function expr_ij.func
if expr_ij.func == Pow:
power_exponent = expr_ij.args[1]
else:
power_exponent = None
Connection([1,0,float(scale),float(bias)], expr_ij.func, fun_name = expr_ij.func, parent=subnode, child=new_node, power_exponent=power_exponent)
elif expr_i.func == Mul:
# create a node with expr_i
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
# create 1D function, linear
Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
elif expr_i.func == Symbol:
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
else:
# 1D functions
# create a node with expr_i.args[0]
new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
# create 1D function expr_i.func
if expr_i.func == Pow:
power_exponent = expr_i.args[1]
else:
power_exponent = None
Connection([1,0,1,0], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)
elif expr.func == Add:
node = Node(expr, False, depth, scale, bias, parent=parent)
subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
for i in range(len(expr.args)):
expr_i, scale, bias = next_nontrivial_operation(expr.args[i])
if expr_i.func == Mul:
# create a node with expr_i
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
# create a connection which is a linear function
Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
elif expr_i.func == Symbol:
new_node = create_node(expr_i, parent=subnode, n_layer=n_layer)
Connection([1,0,float(scale),float(bias)], lambda x: x, fun_name = 'x', parent=subnode, child=new_node)
else:
# 1D function case
# create a node with expr_ij.args[0]
new_node = create_node(expr_i.args[0], parent=subnode, n_layer=n_layer)
# create 1D function expr_i.func
if expr_i.func == Pow:
power_exponent = expr_i.args[1]
else:
power_exponent = None
Connection([1,0,float(scale),float(bias)], expr_i.func, fun_name = expr_i.func, parent=subnode, child=new_node, power_exponent=power_exponent)
elif expr.func == Symbol:
# expr.func is a symbol (one of input variables)
if n_layer == None:
node = Node(expr, False, depth, scale, bias, parent=parent)
else:
node = Node(expr, False, depth, scale, bias, parent=parent)
return_node = node
for i in range(n_layer - depth):
subnode = SubNode(expr, node.depth+1, 1, 0, parent=node)
node = Node(expr, False, subnode.depth, 1, 0, parent=subnode)
Connection([1,0,1,0], lambda x: x, fun_name = 'x', parent=subnode, child=node)
node = return_node
Start_Nodes.append(node)
else:
# expr.func is 1D function
#print(expr, scale, bias)
node = Node(expr, False, depth, scale, bias, parent=parent)
expr_i, scale, bias = next_nontrivial_operation(expr.args[0])
subnode = SubNode(expr_i, node.depth+1, 1, 0, parent=node)
# create a node with expr_i.args[0]
new_node = create_node(expr.args[0], parent=subnode, n_layer=n_layer)
# create 1D function expr_i.func
if expr.func == Pow:
power_exponent = expr.args[1]
else:
power_exponent = None
Connection([1,0,1,0], expr.func, fun_name = expr.func, parent=subnode, child=new_node, power_exponent=power_exponent)
return node
Nodes = [[]]
SubNodes = [[]]
Connections = {}
Start_Nodes = []
create_node(expr, n_layer=None)
n_layer = len(Nodes) - 1
Nodes = [[]]
SubNodes = [[]]
Connections = {}
Start_Nodes = []
create_node(expr, n_layer=n_layer)
# move affine parameters in leaf nodes to connections
for node in Start_Nodes:
c = Connections[(node.depth,node.parent_index,node.index)]
c.affine[0] = float(node.scale)
c.affine[1] = float(node.bias)
node.scale = 1.
node.bias = 0.
#input_variables = symbol
node2var = []
for node in Start_Nodes:
for i in range(len(input_variables)):
if node.expr == input_variables[i]:
node2var.append(i)
# Nodes
n_mult = []
n_sum = []
for layer in Nodes:
n_mult.append(0)
n_sum.append(0)
for node in layer:
if node.mult_bool == True:
n_mult[-1] += 1
else:
n_sum[-1] += 1
# depth
n_layer = len(Nodes) - 1
# converter
# input tree node id, output kan node id (distinguish sum and mult node)
# input tree subnode id, output tree subnode id
# node id
subnode_index_convert = {}
node_index_convert = {}
connection_index_convert = {}
mult_arities = []
for layer_id in range(n_layer+1):
mult_arity = []
i_sum = 0
i_mult = 0
for i in range(len(Nodes[layer_id])):
node = Nodes[layer_id][i]
if node.mult_bool == True:
kan_node_id = n_sum[layer_id] + i_mult
arity = len(node.child_index)
for i in range(arity):
subnode = SubNodes[node.depth+1][node.child_index[i]]
kan_subnode_id = n_sum[layer_id] + np.sum(mult_arity) + i
subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
i_mult += 1
mult_arity.append(arity)
else:
kan_node_id = i_sum
if len(node.child_index) > 0:
subnode = SubNodes[node.depth+1][node.child_index[0]]
kan_subnode_id = i_sum
subnode_index_convert[(subnode.depth,subnode.index)] = (int(n_layer-subnode.depth),int(kan_subnode_id))
i_sum += 1
if layer_id == n_layer:
# input layer
node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(node2var[kan_node_id]))
else:
node_index_convert[(node.depth,node.index)] = (int(n_layer-node.depth),int(kan_node_id))
# node: depth (node.depth -> n_layer - node.depth)
# width (node.index -> kan_node_id)
# subnode: depth (subnode.depth -> n_layer - subnode.depth)
# width (subnote.index -> kan_subnode_id)
mult_arities.append(mult_arity)
for index in list(Connections.keys()):
depth, subnode_id, node_id = index
# to int(n_layer-depth),
_, kan_subnode_id = subnode_index_convert[(depth, subnode_id)]
_, kan_node_id = node_index_convert[(depth, node_id)]
connection_index_convert[(depth, subnode_id, node_id)] = (n_layer-depth, kan_subnode_id, kan_node_id)
n_sum.reverse()
n_mult.reverse()
mult_arities.reverse()
width = [[n_sum[i], n_mult[i]] for i in range(len(n_sum))]
width[0][0] = len(input_variables)
# allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc.
model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save)
# clean the graph
for l in range(model.depth):
for i in range(model.width_in[l]):
for j in range(model.width_out[l+1]):
model.fix_symbolic(l,i,j,'0',fit_params_bool=False)
# Nodes
Nodes_flat = [x for xs in Nodes for x in xs]
self = model
for node in Nodes_flat:
node_depth = node.depth
node_index = node.index
kan_node_depth, kan_node_index = node_index_convert[(node_depth,node_index)]
#print(kan_node_depth, kan_node_index)
if kan_node_depth > 0:
self.node_scale[kan_node_depth-1].data[kan_node_index] = float(node.scale)
self.node_bias[kan_node_depth-1].data[kan_node_index] = float(node.bias)
# SubNodes
SubNodes_flat = [x for xs in SubNodes for x in xs]
for subnode in SubNodes_flat:
subnode_depth = subnode.depth
subnode_index = subnode.index
kan_subnode_depth, kan_subnode_index = subnode_index_convert[(subnode_depth,subnode_index)]
#print(kan_subnode_depth, kan_subnode_index)
self.subnode_scale[kan_subnode_depth].data[kan_subnode_index] = float(subnode.scale)
self.subnode_bias[kan_subnode_depth].data[kan_subnode_index] = float(subnode.bias)
# Connections
Connections_flat = list(Connections.values())
for connection in Connections_flat:
c_depth = connection.depth
c_j = connection.parent_index
c_i = connection.child_index
kc_depth, kc_j, kc_i = connection_index_convert[(c_depth, c_j, c_i)]
# get symbolic fun_name
fun_name = connection.fun_name
#if fun_name == Pow:
# print(connection.power_exponent)
if fun_name == 'x':
kfun_name = 'x'
elif fun_name == exp:
kfun_name = 'exp'
elif fun_name == sin:
kfun_name = 'sin'
elif fun_name == cos:
kfun_name = 'cos'
elif fun_name == tan:
kfun_name = 'tan'
elif fun_name == sqrt:
kfun_name = 'sqrt'
elif fun_name == log:
kfun_name = 'log'
elif fun_name == tanh:
kfun_name = 'tanh'
elif fun_name == asin:
kfun_name = 'arcsin'
elif fun_name == acos:
kfun_name = 'arccos'
elif fun_name == atan:
kfun_name = 'arctan'
elif fun_name == atanh:
kfun_name = 'arctanh'
elif fun_name == sign:
kfun_name = 'sgn'
elif fun_name == Pow:
alpha = connection.power_exponent
if alpha == Rational(1,2):
kfun_name = 'x^0.5'
elif alpha == - Rational(1,2):
kfun_name = '1/x^0.5'
elif alpha == Rational(3,2):
kfun_name = 'x^1.5'
else:
alpha = int(connection.power_exponent)
if alpha > 0:
if alpha == 1:
kfun_name = 'x'
else:
kfun_name = f'x^{alpha}'
else:
if alpha == -1:
kfun_name = '1/x'
else:
kfun_name = f'1/x^{-alpha}'
model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False)
model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine)
return model
sf2kan = kanpiler = expr2kan