1332 lines
50 KiB
Python
Raw Normal View History

2024-04-27 17:24:14 -04:00
import torch
import torch.nn as nn
import numpy as np
from .KANLayer import *
from .Symbolic_KANLayer import *
from .LBFGS import *
import os
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import copy
class KAN(nn.Module):
'''
KAN class
Attributes:
-----------
biases: a list of nn.Linear()
biases are added on nodes (in principle, biases can be absorbed into activation functions. However, we still have them for better optimization)
act_fun: a list of KANLayer
KANLayers
depth: int
depth of KAN
width: list
number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 5D outputs, with 2 layers of 5 hidden neurons.
grid: int
the number of grid intervals
k: int
the order of piecewise polynomial
base_fun: fun
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
symbolic_fun: a list of Symbolic_KANLayer
Symbolic_KANLayers
symbolic_enabled: bool
If False, the symbolic front is not computed (to save time). Default: True.
Methods:
--------
__init__():
initialize a KAN
initialize_from_another_model():
initialize a KAN from another KAN (with the same shape, but potentially different grids)
update_grid_from_samples():
update spline grids based on samples
initialize_grid_from_another_model():
initalize KAN grids from another KAN
forward():
forward
set_mode():
set the mode of an activation function: 'n' for numeric, 's' for symbolic, 'ns' for combined (note they are visualized differently in plot(). 'n' as black, 's' as red, 'ns' as purple).
fix_symbolic():
fix an activation function to be symbolic
suggest_symbolic():
suggest the symbolic candicates of a numeric spline-based activation function
lock():
lock activation functions to share parameters
unlock():
unlock locked activations
get_range():
get the input and output ranges of an activation function
plot():
plot the diagram of KAN
train():
train KAN
prune():
prune KAN
remove_edge():
remove some edge of KAN
remove_node():
remove some node of KAN
auto_symbolic():
automatically fit all splines to be symbolic functions
symbolic_formula():
obtain the symbolic formula of the KAN network
'''
def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0.1, base_fun=torch.nn.SiLU(), symbolic_enabled=True, bias_trainable=True, grid_eps=1.0, grid_range=[-1,1], sp_trainable=True, sb_trainable=True, device='cpu', seed=0):
'''
initalize a KAN model
Args:
-----
width : list of int
2024-04-29 18:02:35 -04:00
:math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
2024-04-27 17:24:14 -04:00
grid : int
number of grid intervals. Default: 3.
k : int
order of piecewise polynomial. Default: 3.
noise_scale : float
initial injected noise to spline. Default: 0.1.
base_fun : fun
the residual function b(x). Default: torch.nn.SiLU().
symbolic_enabled : bool
compute or skip symbolic computations (for efficiency). By default: True.
bias_trainable : bool
bias parameters are updated or not. By default: True
grid_eps : float
When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02.
grid_range : list/np.array of shape (2,))
setting the range of grids. Default: [-1,1].
sp_trainable : bool
If true, scale_sp is trainable. Default: True.
sb_trainable : bool
If true, scale_base is trainable. Default: True.
device : str
device
seed : int
random seed
Returns:
--------
self
Example
-------
>>> model = KAN(width=[2,5,1], grid=5, k=3)
>>> (model.act_fun[0].in_dim, model.act_fun[0].out_dim), (model.act_fun[1].in_dim, model.act_fun[1].out_dim)
((2, 5), (5, 1))
'''
super(KAN, self).__init__()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
### initializeing the numerical front ###
self.biases = []
self.act_fun = []
self.depth = len(width) - 1
self.width = width
for l in range(self.depth):
# splines
scale_base = 1/np.sqrt(width[l]) + (torch.randn(width[l]*width[l+1],)*2-1) * noise_scale_base
sp_batch = KANLayer(in_dim=width[l],out_dim=width[l+1],num=grid,k=k,noise_scale=noise_scale,scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable)
self.act_fun.append(sp_batch)
# bias
bias = nn.Linear(width[l+1],1,bias=False).requires_grad_(bias_trainable)
bias.weight.data *= 0.
self.biases.append(bias)
self.biases = nn.ModuleList(self.biases)
self.act_fun = nn.ModuleList(self.act_fun)
self.grid = grid
self.k = k
self.base_fun = base_fun
### initializing the symbolic front ###
self.symbolic_fun = []
for l in range(self.depth):
sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l+1])
self.symbolic_fun.append(sb_batch)
self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
self.symbolic_enabled = symbolic_enabled
def initialize_from_another_model(self, another_model, x):
'''
initialize from a parent model. The parent has the same width as the current model but may have different grids.
Args:
-----
another_model : KAN
the parent model used to initialize the current model
x : 2D torch.float
inputs, shape (batch, input dimension)
Returns:
--------
self : KAN
Example
-------
>>> model_coarse = KAN(width=[2,5,1], grid=5, k=3)
>>> model_fine = KAN(width=[2,5,1], grid=10, k=3)
>>> print(model_fine.act_fun[0].coef[0][0].data)
>>> x = torch.normal(0,1,size=(100,2))
>>> model_fine.initialize_from_another_model(model_coarse, x);
>>> print(model_fine.act_fun[0].coef[0][0].data)
tensor(-0.0030)
tensor(0.0506)
'''
another_model(x) # get activations
batch = x.shape[0]
self.initialize_grid_from_another_model(another_model, x)
for l in range(self.depth):
spb = self.act_fun[l]
spb_parent = another_model.act_fun[l]
#spb = spb_parent
preacts = another_model.spline_preacts[l]
postsplines = another_model.spline_postsplines[l]
self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch,spb.size).permute(1,0), postsplines.reshape(batch,spb.size).permute(1,0), spb.grid, k=spb.k)
spb.scale_base.data = spb_parent.scale_base.data
spb.scale_sp.data = spb_parent.scale_sp.data
spb.mask.data = spb_parent.mask.data
#print(spb.mask.data, self.act_fun[l].mask.data)
for l in range(self.depth):
self.biases[l].weight.data = another_model.biases[l].weight.data
for l in range(self.depth):
self.symbolic_fun[l] = another_model.symbolic_fun[l]
return self
def update_grid_from_samples(self, x):
'''
update grid from samples
Args:
-----
x : 2D torch.float
inputs, shape (batch, input dimension)
Returns:
--------
None
Example
-------
>>> model = KAN(width=[2,5,1], grid=5, k=3)
>>> print(model.act_fun[0].grid[0].data)
>>> x = torch.rand(100,2)*5
>>> model.update_grid_from_samples(x)
>>> print(model.act_fun[0].grid[0].data)
tensor([-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000])
tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
'''
for l in range(self.depth):
self.forward(x)
self.act_fun[l].update_grid_from_samples(self.acts[l])
def initialize_grid_from_another_model(self, model, x):
'''
initialize grid from a parent model
Args:
-----
model : KAN
parent model
x : 2D torch.float
inputs, shape (batch, input dimension)
Returns:
--------
None
Example
-------
>>> model_parent = KAN(width=[1,1], grid=5, k=3)
>>> model_parent.act_fun[0].grid.data = torch.linspace(-2,2,steps=6)[None,:]
>>> x = torch.linspace(-2,2,steps=1001)[:,None]
>>> model = KAN(width=[1,1], grid=5, k=3)
>>> print(model.act_fun[0].grid.data)
>>> model = model.initialize_from_another_model(model_parent, x)
>>> print(model.act_fun[0].grid.data)
tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
tensor([[-2.0000, -1.2000, -0.4000, 0.4000, 1.2000, 2.0000]])
'''
model(x)
for l in range(self.depth):
self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
def forward(self, x):
'''
KAN forward
Args:
-----
x : 2D torch.float
inputs, shape (batch, input dimension)
Returns:
--------
y : 2D torch.float
outputs, shape (batch, output dimension)
Example
-------
>>> model = KAN(width=[2,5,3], grid=5, k=3)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x).shape
torch.Size([100, 3])
'''
self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
self.spline_preacts = []
self.spline_postsplines = []
self.spline_postacts = []
self.acts_scale = []
self.acts_scale_std = []
#self.neurons_scale = []
self.acts.append(x) # acts shape: (batch, width[l])
for l in range(self.depth):
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
if self.symbolic_enabled == True:
x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
else:
x_symbolic = 0.
postacts_symbolic = 0.
x = x_numerical + x_symbolic
postacts = postacts_numerical + postacts_symbolic
#self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
grid_reshape = self.act_fun[l].grid.reshape(self.width[l+1],self.width[l],-1)
input_range = grid_reshape[:,:,-1] - grid_reshape[:,:,0] + 1e-4
output_range = torch.mean(torch.abs(postacts), dim=0)
self.acts_scale.append(output_range/input_range)
self.acts_scale_std.append(torch.std(postacts, dim=0))
self.spline_preacts.append(preacts.detach())
self.spline_postacts.append(postacts.detach())
self.spline_postsplines.append(postspline.detach())
x = x + self.biases[l].weight
self.acts.append(x)
return x
def set_mode(self, l, i, j, mode, mask_n=None):
'''
set (l,i,j) activation to have mode
Args:
-----
l : int
layer index
i : int
input neuron index
j : int
output neuron index
mode : str
'n' (numeric) or 's' (symbolic) or 'ns' (combined)
mask_n : None or float)
magnitude of the numeric front
Returns:
--------
None
'''
if mode == "s":
mask_n = 0.; mask_s = 1.
elif mode == "n":
mask_n = 1.; mask_s = 0.
elif mode == "sn" or mode == "ns":
if mask_n == None:
mask_n = 1.
else:
mask_n = mask_n
mask_s = 1.
else:
mask_n = 0.; mask_s = 0.
self.act_fun[l].mask.data[j*self.act_fun[l].in_dim+i] = mask_n
self.symbolic_fun[l].mask.data[j,i] = mask_s
def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10,10), b_range=(-10,10), verbose=True, random=False):
'''
set (l,i,j) activation to be symbolic (specified by fun_name)
Args:
-----
l : int
layer index
i : int
input neuron index
j : int
output neuron index
fun_name : str
function name
fit_params_bool : bool
obtaining affine parameters through fitting (True) or setting default values (False)
a_range : tuple
sweeping range of a
b_range : tuple
sweeping range of b
verbose : bool
If True, more information is printed.
random : bool
initialize affine parameteres randomly or as [1,0,1,0]
Returns:
--------
None or r2 (coefficient of determination)
Example 1
---------
>>> # when fit_params_bool = False
>>> model = KAN(width=[2,5,1], grid=5, k=3)
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
>>> print(model.act_fun[0].mask.reshape(2,5))
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
tensor([[1., 1., 1., 1., 1.],
[1., 1., 0., 1., 1.]])
tensor([[0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
Example 2
---------
>>> # when fit_params_bool = True
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x) # obtain activations (otherwise model does not have attributes acts)
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
>>> print(model.act_fun[0].mask.reshape(2,5))
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
r2 is 0.8131332993507385
r2 is not very high, please double check if you are choosing the correct symbolic function.
tensor([[1., 1., 1., 1., 1.],
[1., 1., 0., 1., 1.]])
tensor([[0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
'''
self.set_mode(l,i,j,mode="s")
if not fit_params_bool:
self.symbolic_fun[l].fix_symbolic(i,j,fun_name, verbose=verbose, random=random)
return None
else:
x = self.acts[l][:,i]
y = self.spline_postacts[l][:,j,i]
r2 = self.symbolic_fun[l].fix_symbolic(i,j,fun_name,x,y,a_range=a_range,b_range=b_range, verbose=verbose)
return r2
def unfix_symbolic(self, l, i, j):
'''
unfix the (l,i,j) activation function.
'''
self.set_mode(l,i,j,mode="n")
def unfix_symbolic_all(self):
'''
unfix all activation functions.
'''
for l in range(len(self.width)-1):
for i in range(self.width[l]):
for j in range(self.width[l+1]):
self.unfix_symbolic(l,i,j)
def lock(self, l, ids):
'''
lock ids in the l-th layer to be the same function
Args:
-----
l : int
layer index
ids : 2D list
2024-04-29 18:02:35 -04:00
:math:`[[i_1,j_1],[i_2,j_2],...]` set :math:`(l,i_i,j_1), (l,i_2,j_2), ...` to be the same function
2024-04-27 17:24:14 -04:00
Returns:
--------
None
Example
-------
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
>>> model.lock(0,[[1,0],[1,1]])
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
tensor([[0, 1],
[2, 3],
[4, 5]])
tensor([[0, 1],
[2, 1],
[4, 5]])
'''
self.act_fun[l].lock(ids)
def unlock(self, l, ids):
'''
unlock ids in the l-th layer to be the same function
Args:
-----
l : int
layer index
ids : 2D list)
[[i1,j1],[i2,j2],...] set (l,ii,j1), (l,i2,j2), ... to be unlocked
Example:
--------
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
>>> model.lock(0,[[1,0],[1,1]])
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
>>> model.unlock(0,[[1,0],[1,1]])
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
tensor([[0, 1],
[2, 1],
[4, 5]])
tensor([[0, 1],
[2, 3],
[4, 5]])
'''
self.act_fun[l].unlock(ids)
def get_range(self, l, i, j, verbose=True):
'''
Get the input range and output range of the (l,i,j) activation
Args:
-----
l : int
layer index
i : int
input neuron index
j : int
output neuron index
Returns:
--------
x_min : float
minimum of input
x_max : float
maximum of input
y_min : float
minimum of output
y_max : float
maximum of output
Example
-------
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x) # do a forward pass to obtain model.acts
>>> model.get_range(0,0,0)
x range: [-2.13 , 2.75 ]
y range: [-0.50 , 1.83 ]
(tensor(-2.1288), tensor(2.7498), tensor(-0.5042), tensor(1.8275))
'''
x = self.spline_preacts[l][:,j,i]
y = self.spline_postacts[l][:,j,i]
x_min = torch.min(x)
x_max = torch.max(x)
y_min = torch.min(y)
y_max = torch.max(y)
if verbose:
print('x range: ['+'%.2f'%x_min,',','%.2f'%x_max,']')
print('y range: ['+'%.2f'%y_min,',','%.2f'%y_max,']')
return x_min, x_max, y_min, y_max
def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None):
'''
plot KAN
Args:
-----
folder : str
the folder to store pngs
beta : float
positive number. control the transparency of each activation. transparency = tanh(beta*l1).
mask : bool
If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
mode : bool
"supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
scale : float
control the size of the diagram
in_vars: None or list of str
the name(s) of input variables
out_vars: None or list of str
the name(s) of output variables
title: None or str
title
Returns:
--------
Figure
Example
-------
>>> # see more interactive examples in demos
>>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x) # do a forward pass to obtain model.acts
>>> model.plot()
'''
if not os.path.exists(folder):
os.makedirs(folder)
#matplotlib.use('Agg')
depth = len(self.width) - 1
for l in range(depth):
w_large = 2.0
for i in range(self.width[l]):
for j in range(self.width[l+1]):
rank = torch.argsort(self.acts[l][:,i])
fig, ax = plt.subplots(figsize=(w_large,w_large))
num = rank.shape[0]
symbol_mask = self.symbolic_fun[l].mask[j][i]
numerical_mask = self.act_fun[l].mask.reshape(self.width[l+1], self.width[l])[j][i]
if symbol_mask > 0. and numerical_mask > 0.:
color = 'purple'
alpha_mask = 1
if symbol_mask > 0. and numerical_mask == 0.:
color = "red"
alpha_mask = 1
if symbol_mask == 0. and numerical_mask > 0.:
color = "black"
alpha_mask = 1
if symbol_mask == 0. and numerical_mask == 0.:
color = "white"
alpha_mask = 0
if tick == True:
ax.tick_params(axis="y",direction="in", pad=-22, labelsize=50)
ax.tick_params(axis="x",direction="in", pad=-15, labelsize=50)
x_min, x_max, y_min, y_max = self.get_range(l,i,j,verbose=False)
plt.xticks([x_min, x_max],['%2.f'%x_min, '%2.f'%x_max])
plt.yticks([y_min, y_max],['%2.f'%y_min, '%2.f'%y_max])
else:
plt.xticks([])
plt.yticks([])
if alpha_mask == 1:
plt.gca().patch.set_edgecolor('black')
else:
plt.gca().patch.set_edgecolor('white')
plt.gca().patch.set_linewidth(1.5)
#plt.axis('off')
plt.plot(self.acts[l][:,i][rank].cpu().detach().numpy(), self.spline_postacts[l][:,j,i][rank].cpu().detach().numpy(), color=color, lw=5)
if sample == True:
plt.scatter(self.acts[l][:,i][rank].cpu().detach().numpy(), self.spline_postacts[l][:,j,i][rank].cpu().detach().numpy(), color=color, s=400*scale**2)
plt.gca().spines[:].set_color(color)
lock_id = self.act_fun[l].lock_id[j*self.width[l]+i].long().item()
if lock_id > 0:
im = plt.imread(f'{folder}/lock.png')
newax = fig.add_axes([0.15,0.7,0.15,0.15])
plt.text(500,400, lock_id, fontsize=15)
newax.imshow(im)
newax.axis('off')
plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
plt.close()
def score2alpha(score):
return np.tanh(beta*score)
if mode == "supervised":
alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale]
elif mode == "unsupervised":
alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale_std]
# draw skeleton
width = np.array(self.width)
A = 1
y0 = 0.4 # 0.4
#plt.figure(figsize=(5,5*(neuron_depth-1)*y0))
neuron_depth = len(width)
min_spacing = A/np.maximum(np.max(width),5)
max_neuron = np.max(width)
max_num_weights = np.max(width[:-1] * width[1:])
y1 = 0.4/np.maximum(max_num_weights,3)
fig, ax = plt.subplots(figsize=(10*scale,10*scale*(neuron_depth-1)*y0))
#fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))
# plot scatters and lines
for l in range(neuron_depth):
n = width[l]
spacing = A/n
for i in range(n):
plt.scatter(1/(2*n)+i/n, l*y0, s=min_spacing**2*10000*scale**2, color='black')
if l < neuron_depth - 1:
# plot connections
n_next = width[l+1]
N = n * n_next
for j in range(n_next):
id_ = i*n_next + j
symbol_mask = self.symbolic_fun[l].mask[j][i]
numerical_mask = self.act_fun[l].mask.reshape(self.width[l+1], self.width[l])[j][i]
if symbol_mask == 1. and numerical_mask == 1.:
color = 'purple'
alpha_mask = 1.
if symbol_mask == 1. and numerical_mask == 0.:
color = "red"
alpha_mask = 1.
if symbol_mask == 0. and numerical_mask == 1.:
color = "black"
alpha_mask = 1.
if symbol_mask == 0. and numerical_mask == 0.:
color = "white"
alpha_mask = 0.
if mask == True:
plt.plot([1/(2*n)+i/n, 1/(2*N)+id_/N], [l*y0, (l+1/2)*y0-y1], color=color, lw=2*scale, alpha=alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())
plt.plot([1/(2*N)+id_/N, 1/(2*n_next)+j/n_next], [(l+1/2)*y0+y1, (l+1)*y0], color=color, lw=2*scale, alpha=alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())
else:
plt.plot([1/(2*n)+i/n, 1/(2*N)+id_/N], [l*y0, (l+1/2)*y0-y1], color=color, lw=2*scale, alpha=alpha[l][j][i]*alpha_mask)
plt.plot([1/(2*N)+id_/N, 1/(2*n_next)+j/n_next], [(l+1/2)*y0+y1, (l+1)*y0], color=color, lw=2*scale, alpha=alpha[l][j][i]*alpha_mask)
plt.xlim(0,1)
plt.ylim(-0.1*y0, (neuron_depth-1+0.1)*y0)
# -- Transformation functions
DC_to_FC = ax.transData.transform
FC_to_NFC = fig.transFigure.inverted().transform
# -- Take data coordinates and transform them to normalized figure coordinates
DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
plt.axis('off')
# plot splines
for l in range(neuron_depth-1):
n = width[l]
for i in range(n):
n_next = width[l+1]
N = n * n_next
for j in range(n_next):
id_ = i*n_next + j
im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
left = DC_to_NFC([1/(2*N)+id_/N-y1,0])[0]
right = DC_to_NFC([1/(2*N)+id_/N+y1,0])[0]
bottom = DC_to_NFC([0,(l+1/2)*y0-y1])[1]
up = DC_to_NFC([0,(l+1/2)*y0+y1])[1]
newax = fig.add_axes([left,bottom,right-left,up-bottom])
#newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
if mask == False:
newax.imshow(im, alpha=alpha[l][j][i])
else:
### make sure to run model.prune() first to compute mask ###
newax.imshow(im, alpha=alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())
newax.axis('off')
if in_vars != None:
n = self.width[0]
for i in range(n):
plt.gcf().get_axes()[0].text(1/(2*(n))+i/(n),-0.1,in_vars[i], fontsize=40*scale,horizontalalignment='center', verticalalignment='center')
if out_vars != None:
n = self.width[-1]
for i in range(n):
plt.gcf().get_axes()[0].text(1/(2*(n))+i/(n),y0*(len(self.width)-1)+0.1,out_vars[i], fontsize=40*scale,horizontalalignment='center',verticalalignment='center')
if title != None:
plt.gcf().get_axes()[0].text(0.5,y0*(len(self.width)-1)+0.2,title, fontsize=40*scale,horizontalalignment='center',verticalalignment='center')
2024-04-29 12:35:18 -04:00
def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1 = 1., lamb_entropy = 2., lamb_coef = 0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn = None, lr=1., stop_grid_update_step=50, batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'):
2024-04-27 17:24:14 -04:00
'''
training
Args:
-----
dataset : dic
contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
opt : str
"LBFGS" or "Adam"
steps : int
training steps
log : int
logging frequency
lamb : float
overall penalty strength
lamb_l1 : float
l1 penalty strength
lamb_entropy : float
entropy penalty strength
lamb_coef : float
coefficient magnitude penalty strength
lamb_coefdiff : float
difference of nearby coefficits (smoothness) penalty strength
update_grid : bool
If True, update grid regularly before stop_grid_update_step
grid_update_num : int
the number of grid updates before stop_grid_update_step
stop_grid_update_step : int
no grid updates after this training step
batch : int
batch size, if -1 then full.
small_mag_threshold : float
threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers)
small_reg_factor : float
penalty strength applied to small factors relative to large factos
device : str
device
save_fig_freq : int
save figure every (save_fig_freq) step
Returns:
--------
results : dic
results['train_loss'], 1D array of training losses (RMSE)
results['test_loss'], 1D array of test losses (RMSE)
results['reg'], 1D array of regularization
Example
-------
>>> # for interactive examples, please see demos
>>> from utils import create_dataset
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model.plot()
'''
def reg(acts_scale):
def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
return (x < th) * x * factor + (x > th) * (x + (factor-1)*th)
reg_ = 0.
for i in range(len(acts_scale)):
vec = acts_scale[i].reshape(-1,)
p = vec/torch.sum(vec)
l1 = torch.sum(nonlinear(vec))
entropy = - torch.sum(p*torch.log2(p+1e-4))
reg_ += lamb_l1 * l1 + lamb_entropy * entropy # both l1 and entropy
# regularize coefficient to encourage spline to be zero
for i in range(len(self.act_fun)):
coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)),dim=1))
reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
return reg_
pbar = tqdm(range(steps), desc='description', ncols=100)
if loss_fn == None:
loss_fn = loss_fn_eval = lambda x,y: torch.mean((x-y)**2)
else:
loss_fn = loss_fn_eval = loss_fn
grid_update_freq = int(stop_grid_update_step/grid_update_num)
if opt == "Adam":
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
elif opt == "LBFGS":
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
results = {}
results['train_loss'] = []
results['test_loss'] = []
results['reg'] = []
if metrics != None:
for i in range(len(metrics)):
results[metrics[i].__name__] = []
if batch == -1 or batch > dataset['train_input'].shape[0]:
batch_size = dataset['train_input'].shape[0]
batch_size_test = dataset['test_input'].shape[0]
else:
batch_size = batch
batch_size_test = batch
global train_loss, reg_
def closure():
global train_loss, reg_
optimizer.zero_grad()
pred = self.forward(dataset['train_input'][train_id].to(device))
if sglr_avoid == True:
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
else:
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
reg_ = reg(self.acts_scale)
objective = train_loss + lamb*reg_
objective.backward()
return objective
2024-04-29 12:35:18 -04:00
if save_fig:
if not os.path.exists(img_folder):
os.makedirs(img_folder)
2024-04-27 17:24:14 -04:00
for _ in pbar:
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
if opt == "LBFGS":
optimizer.step(closure)
if opt == "Adam":
pred = self.forward(dataset['train_input'][train_id].to(device))
if sglr_avoid == True:
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
else:
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
reg_ = reg(self.acts_scale)
loss = train_loss + lamb*reg_
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device))
if _ % log == 0:
pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
if metrics != None:
for i in range(len(metrics)):
results[metrics[i].__name__].append(metrics[i]().item())
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
results['reg'].append(reg_.cpu().detach().numpy())
2024-04-29 12:35:18 -04:00
if save_fig and _ % save_fig_freq == 0:
2024-04-27 17:24:14 -04:00
2024-04-29 12:35:18 -04:00
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
plt.savefig(img_folder+'/'+str(_)+'.jpg', bbox_inches='tight', dpi=200)
2024-04-27 17:24:14 -04:00
plt.close()
return results
def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None):
'''
pruning KAN on the node level. If a node has small incoming or outgoing connection, it will be pruned away.
Args:
-----
threshold : float
the threshold used to determine whether a node is small enough
mode : str
"auto" or "manual". If "auto", the thresold will be used to automatically prune away nodes. If "manual", active_neuron_id is needed to specify which neurons are kept (others are thrown away).
active_neuron_id : list of id lists
For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet.
Returns:
--------
model2 : KAN
pruned model
Example
-------
>>> # for more interactive examples, please see demos
>>> from utils import create_dataset
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model.prune()
>>> model.plot(mask=True)
'''
mask = [torch.ones(self.width[0],)]
active_neurons = [list(range(self.width[0]))]
for i in range(len(self.acts_scale)-1):
if mode == "auto":
in_important = torch.max(self.acts_scale[i], dim=1)[0] > threshold
out_important = torch.max(self.acts_scale[i+1], dim=0)[0] > threshold
overall_important = in_important * out_important
elif mode == "manual":
overall_important = torch.zeros(self.width[i+1], dtype=torch.bool)
overall_important[active_neurons_id[i+1]] = True
mask.append(overall_important.float())
active_neurons.append(torch.where(overall_important==True)[0])
active_neurons.append(list(range(self.width[-1])))
mask.append(torch.ones(self.width[-1],))
self.mask = mask # this is neuron mask for the whole model
# update act_fun[l].mask
for l in range(len(self.acts_scale)-1):
for i in range(self.width[l+1]):
if i not in active_neurons[l+1]:
self.remove_node(l+1, i)
model2 = KAN(copy.deepcopy(self.width), self.grid, self.k, base_fun=self.base_fun)
model2.load_state_dict(self.state_dict())
for i in range(len(self.acts_scale)):
if i < len(self.acts_scale) - 1:
model2.biases[i].weight.data = model2.biases[i].weight.data[:,active_neurons[i+1]]
model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons[i], active_neurons[i+1])
model2.width[i] = len(active_neurons[i])
model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons[i], active_neurons[i+1])
return model2
def remove_edge(self, l, i, j):
'''
remove activtion phi(l,i,j) (set its mask to zero)
Args:
-----
l : int
layer index
i : int
input neuron index
j : int
output neuron index
Returns:
--------
None
'''
self.act_fun[l].mask[j*self.width[l]+i] = 0.
def remove_node(self, l, i):
'''
remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
Args:
-----
l : int
layer index
i : int
neuron index
Returns:
--------
None
'''
self.act_fun[l-1].mask[i*self.width[l-1]+torch.arange(self.width[l-1])] = 0.
self.act_fun[l].mask[torch.arange(self.width[l+1])*self.width[l]+i] = 0.
self.symbolic_fun[l-1].mask[i,:] *= 0.
self.symbolic_fun[l].mask[:,i] *= 0.
def suggest_symbolic(self, l, i, j, a_range=(-10,10), b_range=(-10,10), lib = None, topk=5, verbose=True):
'''suggest the symbolic candidates of phi(l,i,j)
Args:
-----
l : int
layer index
i : int
input neuron index
j : int
output neuron index
lib : dic
library of symbolic bases. If lib = None, the global default library will be used.
topk : int
display the top k symbolic functions (according to r2)
verbose : bool
If True, more information will be printed.
Returns:
--------
None
Example
-------
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.suggest_symbolic(0,0,0)
function , r2
sin , 0.9994412064552307
gaussian , 0.9196369051933289
tanh , 0.8608126044273376
sigmoid , 0.8578218817710876
arctan , 0.842217743396759
'''
r2s = []
if lib == None:
symbolic_lib = SYMBOLIC_LIB
else:
symbolic_lib = {}
for item in lib:
symbolic_lib[item] = SYMBOLIC_LIB[item]
for (name, fun) in symbolic_lib.items():
r2 = self.fix_symbolic(l,i,j,name,a_range=a_range,b_range=b_range,verbose=False)
r2s.append(r2.item())
self.unfix_symbolic(l,i,j)
sorted_ids = np.argsort(r2s)[::-1][:topk]
r2s = np.array(r2s)[sorted_ids][:topk]
topk = np.minimum(topk, len(symbolic_lib))
if verbose == True:
print('function',',','r2')
for i in range(topk):
print(list(symbolic_lib.items())[sorted_ids[i]][0],',',r2s[i])
best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
best_r2 = r2s[0]
return best_name, best_fun, best_r2
def auto_symbolic(self, a_range=(-10,10), b_range=(-10,10), lib=None, verbose=1):
'''
automatic symbolic regression: using top 1 suggestion from suggest_symbolic to replace splines with symbolic activations
Args:
-----
lib : None or a list of function names
the symbolic library
verbose : int
verbosity
Returns:
--------
None (print suggested symbolic formulas)
Example 1
---------
>>> # default library
>>> from utils import create_dataset
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> >>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.auto_symbolic()
fixing (0,0,0) with sin, r2=0.9994837045669556
fixing (0,1,0) with cosh, r2=0.9978033900260925
fixing (1,0,0) with arctan, r2=0.9997088313102722
Example 2
---------
>>> # customized library
>>> from utils import create_dataset
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> >>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.auto_symbolic(lib=['exp','sin','x^2'])
fixing (0,0,0) with sin, r2=0.999411404132843
fixing (0,1,0) with x^2, r2=0.9962921738624573
fixing (1,0,0) with exp, r2=0.9980258941650391
'''
for l in range(len(self.width)-1):
for i in range(self.width[l]):
for j in range(self.width[l+1]):
if self.symbolic_fun[l].mask[j,i] > 0.:
print(f'skipping ({l},{i},{j}) since already symbolic')
else:
name, fun, r2 = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
self.fix_symbolic(l,i,j,name,verbose=verbose>1)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}')
def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify=False):
'''
obtain the symbolic formula
Args:
-----
floating_digit : int
the number of digits to display
var : list of str
the name of variables (if not provided, by default using ['x_1', 'x_2', ...])
normalizer : [mean array (floats), varaince array (floats)]
the normalization applied to inputs
simplify : bool
If True, simplify the equation at each step (usually quite slow), so set up False by default.
Returns:
--------
symbolic formula : sympy function
Example
-------
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.auto_symbolic(lib=['exp','sin','x^2'])
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False);
>>> model.symbolic_formula()
'''
symbolic_acts = []
x = []
def ex_round(ex1, floating_digit=floating_digit):
ex2 = ex1
for a in sympy.preorder_traversal(ex1):
if isinstance(a, sympy.Float):
ex2 = ex2.subs(a, round(a, floating_digit))
return ex2
# define variables
if var == None:
for ii in range(1,self.width[0]+1):
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
exec(f"x.append(x{ii})")
else:
x = [sympy.symbols(var_) for var_ in var]
x0 = x
if normalizer != None:
mean = normalizer[0]
std = normalizer[1]
x = [(x[i] - mean[i])/std[i] for i in range(len(x))]
symbolic_acts.append(x)
for l in range(len(self.width)-1):
y = []
for j in range(self.width[l+1]):
yj = 0.
for i in range(self.width[l]):
a, b, c, d = self.symbolic_fun[l].affine[j,i]
sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
try:
yj += c*sympy_fun(a*x[i]+b)+d
except:
print('make sure all activations need to be converted to symbolic formulas first!')
return
if simplify == True:
y.append(sympy.simplify(yj + self.biases[l].weight.data[0,j]))
else:
y.append(yj + self.biases[l].weight.data[0,j])
x = y
symbolic_acts.append(x)
self.symbolic_acts = [[ex_round(symbolic_acts[l][i]) for i in range(len(symbolic_acts[l]))] for l in range(len(symbolic_acts))]
out_dim = len(symbolic_acts[-1])
return [ex_round(symbolic_acts[-1][i]) for i in range(len(symbolic_acts[-1]))], x0
def clear_ckpts(self, folder='./model_ckpt'):
'''
clear all checkpoints
Args:
-----
folder : str
the folder that stores checkpoints
Returns:
--------
None
'''
if os.path.exists(folder):
files = glob.glob(folder+'/*')
for f in files:
os.remove(f)
else:
os.makedirs(folder)
def save_ckpt(self, name, folder='./model_ckpt'):
'''
save the current model as checkpoint
Args:
-----
name: str
the name of the checkpoint to be saved
folder : str
the folder that stores checkpoints
Returns:
--------
None
'''
if not os.path.exists(folder):
os.makedirs(folder)
torch.save(self.state_dict(), folder+'/'+name)
print('save this model to', folder+'/'+name)
def load_ckpt(self, name, folder='./model_ckpt'):
'''
load a checkpoint to the current model
Args:
-----
name: str
the name of the checkpoint to be loaded
folder : str
the folder that stores checkpoints
Returns:
--------
None
'''
self.load_state_dict(torch.load(folder+'/'+name))