2024-04-27 17:24:14 -04:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import numpy as np
|
|
|
|
from .KANLayer import *
|
|
|
|
from .Symbolic_KANLayer import *
|
|
|
|
from .LBFGS import *
|
|
|
|
import os
|
|
|
|
import glob
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from tqdm import tqdm
|
|
|
|
import random
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
|
|
class KAN(nn.Module):
|
|
|
|
'''
|
|
|
|
KAN class
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
-----------
|
|
|
|
biases: a list of nn.Linear()
|
|
|
|
biases are added on nodes (in principle, biases can be absorbed into activation functions. However, we still have them for better optimization)
|
|
|
|
act_fun: a list of KANLayer
|
|
|
|
KANLayers
|
|
|
|
depth: int
|
|
|
|
depth of KAN
|
|
|
|
width: list
|
|
|
|
number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 5D outputs, with 2 layers of 5 hidden neurons.
|
|
|
|
grid: int
|
|
|
|
the number of grid intervals
|
|
|
|
k: int
|
|
|
|
the order of piecewise polynomial
|
|
|
|
base_fun: fun
|
|
|
|
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
|
|
|
|
symbolic_fun: a list of Symbolic_KANLayer
|
|
|
|
Symbolic_KANLayers
|
|
|
|
symbolic_enabled: bool
|
|
|
|
If False, the symbolic front is not computed (to save time). Default: True.
|
|
|
|
|
|
|
|
Methods:
|
|
|
|
--------
|
|
|
|
__init__():
|
|
|
|
initialize a KAN
|
|
|
|
initialize_from_another_model():
|
|
|
|
initialize a KAN from another KAN (with the same shape, but potentially different grids)
|
|
|
|
update_grid_from_samples():
|
|
|
|
update spline grids based on samples
|
|
|
|
initialize_grid_from_another_model():
|
|
|
|
initalize KAN grids from another KAN
|
|
|
|
forward():
|
|
|
|
forward
|
|
|
|
set_mode():
|
|
|
|
set the mode of an activation function: 'n' for numeric, 's' for symbolic, 'ns' for combined (note they are visualized differently in plot(). 'n' as black, 's' as red, 'ns' as purple).
|
|
|
|
fix_symbolic():
|
|
|
|
fix an activation function to be symbolic
|
|
|
|
suggest_symbolic():
|
|
|
|
suggest the symbolic candicates of a numeric spline-based activation function
|
|
|
|
lock():
|
|
|
|
lock activation functions to share parameters
|
|
|
|
unlock():
|
|
|
|
unlock locked activations
|
|
|
|
get_range():
|
|
|
|
get the input and output ranges of an activation function
|
|
|
|
plot():
|
|
|
|
plot the diagram of KAN
|
|
|
|
train():
|
|
|
|
train KAN
|
|
|
|
prune():
|
|
|
|
prune KAN
|
|
|
|
remove_edge():
|
|
|
|
remove some edge of KAN
|
|
|
|
remove_node():
|
|
|
|
remove some node of KAN
|
|
|
|
auto_symbolic():
|
|
|
|
automatically fit all splines to be symbolic functions
|
|
|
|
symbolic_formula():
|
|
|
|
obtain the symbolic formula of the KAN network
|
|
|
|
'''
|
|
|
|
def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0.1, base_fun=torch.nn.SiLU(), symbolic_enabled=True, bias_trainable=True, grid_eps=1.0, grid_range=[-1,1], sp_trainable=True, sb_trainable=True, device='cpu', seed=0):
|
|
|
|
'''
|
|
|
|
initalize a KAN model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
width : list of int
|
2024-04-29 18:02:35 -04:00
|
|
|
:math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
|
2024-04-27 17:24:14 -04:00
|
|
|
grid : int
|
|
|
|
number of grid intervals. Default: 3.
|
|
|
|
k : int
|
|
|
|
order of piecewise polynomial. Default: 3.
|
|
|
|
noise_scale : float
|
|
|
|
initial injected noise to spline. Default: 0.1.
|
|
|
|
base_fun : fun
|
|
|
|
the residual function b(x). Default: torch.nn.SiLU().
|
|
|
|
symbolic_enabled : bool
|
|
|
|
compute or skip symbolic computations (for efficiency). By default: True.
|
|
|
|
bias_trainable : bool
|
|
|
|
bias parameters are updated or not. By default: True
|
|
|
|
grid_eps : float
|
|
|
|
When grid_eps = 0, the grid is uniform; when grid_eps = 1, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. Default: 0.02.
|
|
|
|
grid_range : list/np.array of shape (2,))
|
|
|
|
setting the range of grids. Default: [-1,1].
|
|
|
|
sp_trainable : bool
|
|
|
|
If true, scale_sp is trainable. Default: True.
|
|
|
|
sb_trainable : bool
|
|
|
|
If true, scale_base is trainable. Default: True.
|
|
|
|
device : str
|
|
|
|
device
|
|
|
|
seed : int
|
|
|
|
random seed
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
self
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3)
|
|
|
|
>>> (model.act_fun[0].in_dim, model.act_fun[0].out_dim), (model.act_fun[1].in_dim, model.act_fun[1].out_dim)
|
|
|
|
((2, 5), (5, 1))
|
|
|
|
'''
|
|
|
|
super(KAN, self).__init__()
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
### initializeing the numerical front ###
|
|
|
|
|
|
|
|
self.biases = []
|
|
|
|
self.act_fun = []
|
|
|
|
self.depth = len(width) - 1
|
|
|
|
self.width = width
|
|
|
|
|
|
|
|
for l in range(self.depth):
|
|
|
|
# splines
|
|
|
|
scale_base = 1/np.sqrt(width[l]) + (torch.randn(width[l]*width[l+1],)*2-1) * noise_scale_base
|
|
|
|
sp_batch = KANLayer(in_dim=width[l],out_dim=width[l+1],num=grid,k=k,noise_scale=noise_scale,scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable)
|
|
|
|
self.act_fun.append(sp_batch)
|
|
|
|
|
|
|
|
# bias
|
|
|
|
bias = nn.Linear(width[l+1],1,bias=False).requires_grad_(bias_trainable)
|
|
|
|
bias.weight.data *= 0.
|
|
|
|
self.biases.append(bias)
|
|
|
|
|
|
|
|
|
|
|
|
self.biases = nn.ModuleList(self.biases)
|
|
|
|
self.act_fun = nn.ModuleList(self.act_fun)
|
|
|
|
|
|
|
|
self.grid = grid
|
|
|
|
self.k = k
|
|
|
|
self.base_fun = base_fun
|
|
|
|
|
|
|
|
### initializing the symbolic front ###
|
|
|
|
self.symbolic_fun = []
|
|
|
|
for l in range(self.depth):
|
|
|
|
sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l+1])
|
|
|
|
self.symbolic_fun.append(sb_batch)
|
|
|
|
|
|
|
|
self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
|
|
|
|
self.symbolic_enabled = symbolic_enabled
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_from_another_model(self, another_model, x):
|
|
|
|
'''
|
|
|
|
initialize from a parent model. The parent has the same width as the current model but may have different grids.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
another_model : KAN
|
|
|
|
the parent model used to initialize the current model
|
|
|
|
x : 2D torch.float
|
|
|
|
inputs, shape (batch, input dimension)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
self : KAN
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model_coarse = KAN(width=[2,5,1], grid=5, k=3)
|
|
|
|
>>> model_fine = KAN(width=[2,5,1], grid=10, k=3)
|
|
|
|
>>> print(model_fine.act_fun[0].coef[0][0].data)
|
|
|
|
>>> x = torch.normal(0,1,size=(100,2))
|
|
|
|
>>> model_fine.initialize_from_another_model(model_coarse, x);
|
|
|
|
>>> print(model_fine.act_fun[0].coef[0][0].data)
|
|
|
|
tensor(-0.0030)
|
|
|
|
tensor(0.0506)
|
|
|
|
'''
|
|
|
|
another_model(x) # get activations
|
|
|
|
batch = x.shape[0]
|
|
|
|
|
|
|
|
self.initialize_grid_from_another_model(another_model, x)
|
|
|
|
|
|
|
|
for l in range(self.depth):
|
|
|
|
spb = self.act_fun[l]
|
|
|
|
spb_parent = another_model.act_fun[l]
|
|
|
|
|
|
|
|
#spb = spb_parent
|
|
|
|
preacts = another_model.spline_preacts[l]
|
|
|
|
postsplines = another_model.spline_postsplines[l]
|
|
|
|
self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch,spb.size).permute(1,0), postsplines.reshape(batch,spb.size).permute(1,0), spb.grid, k=spb.k)
|
|
|
|
spb.scale_base.data = spb_parent.scale_base.data
|
|
|
|
spb.scale_sp.data = spb_parent.scale_sp.data
|
|
|
|
spb.mask.data = spb_parent.mask.data
|
|
|
|
#print(spb.mask.data, self.act_fun[l].mask.data)
|
|
|
|
|
|
|
|
for l in range(self.depth):
|
|
|
|
self.biases[l].weight.data = another_model.biases[l].weight.data
|
|
|
|
|
|
|
|
|
|
|
|
for l in range(self.depth):
|
|
|
|
self.symbolic_fun[l] = another_model.symbolic_fun[l]
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def update_grid_from_samples(self, x):
|
|
|
|
'''
|
|
|
|
update grid from samples
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
x : 2D torch.float
|
|
|
|
inputs, shape (batch, input dimension)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3)
|
|
|
|
>>> print(model.act_fun[0].grid[0].data)
|
|
|
|
>>> x = torch.rand(100,2)*5
|
|
|
|
>>> model.update_grid_from_samples(x)
|
|
|
|
>>> print(model.act_fun[0].grid[0].data)
|
|
|
|
tensor([-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000])
|
|
|
|
tensor([0.0128, 1.0064, 2.0000, 2.9937, 3.9873, 4.9809])
|
|
|
|
'''
|
|
|
|
for l in range(self.depth):
|
|
|
|
self.forward(x)
|
|
|
|
self.act_fun[l].update_grid_from_samples(self.acts[l])
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_grid_from_another_model(self, model, x):
|
|
|
|
'''
|
|
|
|
initialize grid from a parent model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
model : KAN
|
|
|
|
parent model
|
|
|
|
x : 2D torch.float
|
|
|
|
inputs, shape (batch, input dimension)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model_parent = KAN(width=[1,1], grid=5, k=3)
|
|
|
|
>>> model_parent.act_fun[0].grid.data = torch.linspace(-2,2,steps=6)[None,:]
|
|
|
|
>>> x = torch.linspace(-2,2,steps=1001)[:,None]
|
|
|
|
>>> model = KAN(width=[1,1], grid=5, k=3)
|
|
|
|
>>> print(model.act_fun[0].grid.data)
|
|
|
|
>>> model = model.initialize_from_another_model(model_parent, x)
|
|
|
|
>>> print(model.act_fun[0].grid.data)
|
|
|
|
tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
|
|
|
|
tensor([[-2.0000, -1.2000, -0.4000, 0.4000, 1.2000, 2.0000]])
|
|
|
|
'''
|
|
|
|
model(x)
|
|
|
|
for l in range(self.depth):
|
|
|
|
self.act_fun[l].initialize_grid_from_parent(model.act_fun[l], model.acts[l])
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
'''
|
|
|
|
KAN forward
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
x : 2D torch.float
|
|
|
|
inputs, shape (batch, input dimension)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
y : 2D torch.float
|
|
|
|
outputs, shape (batch, output dimension)
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,5,3], grid=5, k=3)
|
|
|
|
>>> x = torch.normal(0,1,size=(100,2))
|
|
|
|
>>> model(x).shape
|
|
|
|
torch.Size([100, 3])
|
|
|
|
'''
|
|
|
|
|
|
|
|
self.acts = [] # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
|
|
|
|
self.spline_preacts = []
|
|
|
|
self.spline_postsplines = []
|
|
|
|
self.spline_postacts = []
|
|
|
|
self.acts_scale = []
|
|
|
|
self.acts_scale_std = []
|
|
|
|
#self.neurons_scale = []
|
|
|
|
|
|
|
|
self.acts.append(x) # acts shape: (batch, width[l])
|
|
|
|
|
|
|
|
|
|
|
|
for l in range(self.depth):
|
|
|
|
|
|
|
|
x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
|
|
|
|
|
|
|
|
if self.symbolic_enabled == True:
|
|
|
|
x_symbolic, postacts_symbolic = self.symbolic_fun[l](x)
|
|
|
|
else:
|
|
|
|
x_symbolic = 0.
|
|
|
|
postacts_symbolic = 0.
|
|
|
|
|
|
|
|
x = x_numerical + x_symbolic
|
|
|
|
postacts = postacts_numerical + postacts_symbolic
|
|
|
|
|
|
|
|
#self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
|
|
|
|
grid_reshape = self.act_fun[l].grid.reshape(self.width[l+1],self.width[l],-1)
|
|
|
|
input_range = grid_reshape[:,:,-1] - grid_reshape[:,:,0] + 1e-4
|
|
|
|
output_range = torch.mean(torch.abs(postacts), dim=0)
|
|
|
|
self.acts_scale.append(output_range/input_range)
|
|
|
|
self.acts_scale_std.append(torch.std(postacts, dim=0))
|
|
|
|
self.spline_preacts.append(preacts.detach())
|
|
|
|
self.spline_postacts.append(postacts.detach())
|
|
|
|
self.spline_postsplines.append(postspline.detach())
|
|
|
|
|
|
|
|
x = x + self.biases[l].weight
|
|
|
|
self.acts.append(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
def set_mode(self, l, i, j, mode, mask_n=None):
|
|
|
|
'''
|
|
|
|
set (l,i,j) activation to have mode
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
i : int
|
|
|
|
input neuron index
|
|
|
|
j : int
|
|
|
|
output neuron index
|
|
|
|
mode : str
|
|
|
|
'n' (numeric) or 's' (symbolic) or 'ns' (combined)
|
|
|
|
mask_n : None or float)
|
|
|
|
magnitude of the numeric front
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
'''
|
|
|
|
if mode == "s":
|
|
|
|
mask_n = 0.; mask_s = 1.
|
|
|
|
elif mode == "n":
|
|
|
|
mask_n = 1.; mask_s = 0.
|
|
|
|
elif mode == "sn" or mode == "ns":
|
|
|
|
if mask_n == None:
|
|
|
|
mask_n = 1.
|
|
|
|
else:
|
|
|
|
mask_n = mask_n
|
|
|
|
mask_s = 1.
|
|
|
|
else:
|
|
|
|
mask_n = 0.; mask_s = 0.
|
|
|
|
|
|
|
|
self.act_fun[l].mask.data[j*self.act_fun[l].in_dim+i] = mask_n
|
|
|
|
self.symbolic_fun[l].mask.data[j,i] = mask_s
|
|
|
|
|
|
|
|
|
|
|
|
def fix_symbolic(self, l, i, j, fun_name, fit_params_bool=True, a_range=(-10,10), b_range=(-10,10), verbose=True, random=False):
|
|
|
|
'''
|
|
|
|
set (l,i,j) activation to be symbolic (specified by fun_name)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
i : int
|
|
|
|
input neuron index
|
|
|
|
j : int
|
|
|
|
output neuron index
|
|
|
|
fun_name : str
|
|
|
|
function name
|
|
|
|
fit_params_bool : bool
|
|
|
|
obtaining affine parameters through fitting (True) or setting default values (False)
|
|
|
|
a_range : tuple
|
|
|
|
sweeping range of a
|
|
|
|
b_range : tuple
|
|
|
|
sweeping range of b
|
|
|
|
verbose : bool
|
|
|
|
If True, more information is printed.
|
|
|
|
random : bool
|
|
|
|
initialize affine parameteres randomly or as [1,0,1,0]
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None or r2 (coefficient of determination)
|
|
|
|
|
|
|
|
Example 1
|
|
|
|
---------
|
|
|
|
>>> # when fit_params_bool = False
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3)
|
|
|
|
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
|
|
|
|
>>> print(model.act_fun[0].mask.reshape(2,5))
|
|
|
|
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
|
|
|
|
tensor([[1., 1., 1., 1., 1.],
|
|
|
|
[1., 1., 0., 1., 1.]])
|
|
|
|
tensor([[0., 0., 0., 0., 0.],
|
|
|
|
[0., 0., 1., 0., 0.]])
|
|
|
|
|
|
|
|
Example 2
|
|
|
|
---------
|
|
|
|
>>> # when fit_params_bool = True
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
|
|
|
|
>>> x = torch.normal(0,1,size=(100,2))
|
|
|
|
>>> model(x) # obtain activations (otherwise model does not have attributes acts)
|
|
|
|
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
|
|
|
|
>>> print(model.act_fun[0].mask.reshape(2,5))
|
|
|
|
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
|
|
|
|
r2 is 0.8131332993507385
|
|
|
|
r2 is not very high, please double check if you are choosing the correct symbolic function.
|
|
|
|
tensor([[1., 1., 1., 1., 1.],
|
|
|
|
[1., 1., 0., 1., 1.]])
|
|
|
|
tensor([[0., 0., 0., 0., 0.],
|
|
|
|
[0., 0., 1., 0., 0.]])
|
|
|
|
'''
|
|
|
|
self.set_mode(l,i,j,mode="s")
|
|
|
|
if not fit_params_bool:
|
|
|
|
self.symbolic_fun[l].fix_symbolic(i,j,fun_name, verbose=verbose, random=random)
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
x = self.acts[l][:,i]
|
|
|
|
y = self.spline_postacts[l][:,j,i]
|
|
|
|
r2 = self.symbolic_fun[l].fix_symbolic(i,j,fun_name,x,y,a_range=a_range,b_range=b_range, verbose=verbose)
|
|
|
|
return r2
|
|
|
|
|
|
|
|
|
|
|
|
def unfix_symbolic(self, l, i, j):
|
|
|
|
'''
|
|
|
|
unfix the (l,i,j) activation function.
|
|
|
|
'''
|
|
|
|
self.set_mode(l,i,j,mode="n")
|
|
|
|
|
|
|
|
def unfix_symbolic_all(self):
|
|
|
|
'''
|
|
|
|
unfix all activation functions.
|
|
|
|
'''
|
|
|
|
for l in range(len(self.width)-1):
|
|
|
|
for i in range(self.width[l]):
|
|
|
|
for j in range(self.width[l+1]):
|
|
|
|
self.unfix_symbolic(l,i,j)
|
|
|
|
|
|
|
|
def lock(self, l, ids):
|
|
|
|
'''
|
|
|
|
lock ids in the l-th layer to be the same function
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
ids : 2D list
|
2024-04-29 18:02:35 -04:00
|
|
|
:math:`[[i_1,j_1],[i_2,j_2],...]` set :math:`(l,i_i,j_1), (l,i_2,j_2), ...` to be the same function
|
2024-04-27 17:24:14 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
|
|
|
|
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
|
|
|
|
>>> model.lock(0,[[1,0],[1,1]])
|
|
|
|
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
|
|
|
|
tensor([[0, 1],
|
|
|
|
[2, 3],
|
|
|
|
[4, 5]])
|
|
|
|
tensor([[0, 1],
|
|
|
|
[2, 1],
|
|
|
|
[4, 5]])
|
|
|
|
'''
|
|
|
|
self.act_fun[l].lock(ids)
|
|
|
|
|
|
|
|
def unlock(self, l, ids):
|
|
|
|
'''
|
|
|
|
unlock ids in the l-th layer to be the same function
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
ids : 2D list)
|
|
|
|
[[i1,j1],[i2,j2],...] set (l,ii,j1), (l,i2,j2), ... to be unlocked
|
|
|
|
|
|
|
|
Example:
|
|
|
|
--------
|
|
|
|
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
|
|
|
|
>>> model.lock(0,[[1,0],[1,1]])
|
|
|
|
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
|
|
|
|
>>> model.unlock(0,[[1,0],[1,1]])
|
|
|
|
>>> print(model.act_fun[0].weight_sharing.reshape(3,2))
|
|
|
|
tensor([[0, 1],
|
|
|
|
[2, 1],
|
|
|
|
[4, 5]])
|
|
|
|
tensor([[0, 1],
|
|
|
|
[2, 3],
|
|
|
|
[4, 5]])
|
|
|
|
'''
|
|
|
|
self.act_fun[l].unlock(ids)
|
|
|
|
|
|
|
|
|
|
|
|
def get_range(self, l, i, j, verbose=True):
|
|
|
|
'''
|
|
|
|
Get the input range and output range of the (l,i,j) activation
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
i : int
|
|
|
|
input neuron index
|
|
|
|
j : int
|
|
|
|
output neuron index
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
x_min : float
|
|
|
|
minimum of input
|
|
|
|
x_max : float
|
|
|
|
maximum of input
|
|
|
|
y_min : float
|
|
|
|
minimum of output
|
|
|
|
y_max : float
|
|
|
|
maximum of output
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
|
|
|
|
>>> x = torch.normal(0,1,size=(100,2))
|
|
|
|
>>> model(x) # do a forward pass to obtain model.acts
|
|
|
|
>>> model.get_range(0,0,0)
|
|
|
|
x range: [-2.13 , 2.75 ]
|
|
|
|
y range: [-0.50 , 1.83 ]
|
|
|
|
(tensor(-2.1288), tensor(2.7498), tensor(-0.5042), tensor(1.8275))
|
|
|
|
'''
|
|
|
|
x = self.spline_preacts[l][:,j,i]
|
|
|
|
y = self.spline_postacts[l][:,j,i]
|
|
|
|
x_min = torch.min(x)
|
|
|
|
x_max = torch.max(x)
|
|
|
|
y_min = torch.min(y)
|
|
|
|
y_max = torch.max(y)
|
|
|
|
if verbose:
|
|
|
|
print('x range: ['+'%.2f'%x_min,',','%.2f'%x_max,']')
|
|
|
|
print('y range: ['+'%.2f'%y_min,',','%.2f'%y_max,']')
|
|
|
|
return x_min, x_max, y_min, y_max
|
|
|
|
|
|
|
|
|
|
|
|
def plot(self, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None):
|
|
|
|
'''
|
|
|
|
plot KAN
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
folder : str
|
|
|
|
the folder to store pngs
|
|
|
|
beta : float
|
|
|
|
positive number. control the transparency of each activation. transparency = tanh(beta*l1).
|
|
|
|
mask : bool
|
|
|
|
If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
|
|
|
|
mode : bool
|
|
|
|
"supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
|
|
|
|
scale : float
|
|
|
|
control the size of the diagram
|
|
|
|
in_vars: None or list of str
|
|
|
|
the name(s) of input variables
|
|
|
|
out_vars: None or list of str
|
|
|
|
the name(s) of output variables
|
|
|
|
title: None or str
|
|
|
|
title
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
Figure
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> # see more interactive examples in demos
|
|
|
|
>>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
|
|
|
|
>>> x = torch.normal(0,1,size=(100,2))
|
|
|
|
>>> model(x) # do a forward pass to obtain model.acts
|
|
|
|
>>> model.plot()
|
|
|
|
'''
|
|
|
|
if not os.path.exists(folder):
|
|
|
|
os.makedirs(folder)
|
|
|
|
#matplotlib.use('Agg')
|
|
|
|
depth = len(self.width) - 1
|
|
|
|
for l in range(depth):
|
|
|
|
w_large = 2.0
|
|
|
|
for i in range(self.width[l]):
|
|
|
|
for j in range(self.width[l+1]):
|
|
|
|
rank = torch.argsort(self.acts[l][:,i])
|
|
|
|
fig, ax = plt.subplots(figsize=(w_large,w_large))
|
|
|
|
|
|
|
|
num = rank.shape[0]
|
|
|
|
|
|
|
|
symbol_mask = self.symbolic_fun[l].mask[j][i]
|
|
|
|
numerical_mask = self.act_fun[l].mask.reshape(self.width[l+1], self.width[l])[j][i]
|
|
|
|
if symbol_mask > 0. and numerical_mask > 0.:
|
|
|
|
color = 'purple'
|
|
|
|
alpha_mask = 1
|
|
|
|
if symbol_mask > 0. and numerical_mask == 0.:
|
|
|
|
color = "red"
|
|
|
|
alpha_mask = 1
|
|
|
|
if symbol_mask == 0. and numerical_mask > 0.:
|
|
|
|
color = "black"
|
|
|
|
alpha_mask = 1
|
|
|
|
if symbol_mask == 0. and numerical_mask == 0.:
|
|
|
|
color = "white"
|
|
|
|
alpha_mask = 0
|
|
|
|
|
|
|
|
if tick == True:
|
|
|
|
ax.tick_params(axis="y",direction="in", pad=-22, labelsize=50)
|
|
|
|
ax.tick_params(axis="x",direction="in", pad=-15, labelsize=50)
|
|
|
|
x_min, x_max, y_min, y_max = self.get_range(l,i,j,verbose=False)
|
|
|
|
plt.xticks([x_min, x_max],['%2.f'%x_min, '%2.f'%x_max])
|
|
|
|
plt.yticks([y_min, y_max],['%2.f'%y_min, '%2.f'%y_max])
|
|
|
|
else:
|
|
|
|
plt.xticks([])
|
|
|
|
plt.yticks([])
|
|
|
|
if alpha_mask == 1:
|
|
|
|
plt.gca().patch.set_edgecolor('black')
|
|
|
|
else:
|
|
|
|
plt.gca().patch.set_edgecolor('white')
|
|
|
|
plt.gca().patch.set_linewidth(1.5)
|
|
|
|
#plt.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
plt.plot(self.acts[l][:,i][rank].cpu().detach().numpy(), self.spline_postacts[l][:,j,i][rank].cpu().detach().numpy(), color=color, lw=5)
|
|
|
|
if sample == True:
|
|
|
|
plt.scatter(self.acts[l][:,i][rank].cpu().detach().numpy(), self.spline_postacts[l][:,j,i][rank].cpu().detach().numpy(), color=color, s=400*scale**2)
|
|
|
|
plt.gca().spines[:].set_color(color)
|
|
|
|
|
|
|
|
lock_id = self.act_fun[l].lock_id[j*self.width[l]+i].long().item()
|
|
|
|
if lock_id > 0:
|
|
|
|
im = plt.imread(f'{folder}/lock.png')
|
|
|
|
newax = fig.add_axes([0.15,0.7,0.15,0.15])
|
|
|
|
plt.text(500,400, lock_id, fontsize=15)
|
|
|
|
newax.imshow(im)
|
|
|
|
newax.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
def score2alpha(score):
|
|
|
|
return np.tanh(beta*score)
|
|
|
|
|
|
|
|
if mode == "supervised":
|
|
|
|
alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale]
|
|
|
|
elif mode == "unsupervised":
|
|
|
|
alpha = [score2alpha(score.cpu().detach().numpy()) for score in self.acts_scale_std]
|
|
|
|
|
|
|
|
# draw skeleton
|
|
|
|
width = np.array(self.width)
|
|
|
|
A = 1
|
|
|
|
y0 = 0.4 # 0.4
|
|
|
|
|
|
|
|
|
|
|
|
#plt.figure(figsize=(5,5*(neuron_depth-1)*y0))
|
|
|
|
neuron_depth = len(width)
|
|
|
|
min_spacing = A/np.maximum(np.max(width),5)
|
|
|
|
|
|
|
|
max_neuron = np.max(width)
|
|
|
|
max_num_weights = np.max(width[:-1] * width[1:])
|
|
|
|
y1 = 0.4/np.maximum(max_num_weights,3)
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10*scale,10*scale*(neuron_depth-1)*y0))
|
|
|
|
#fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))
|
|
|
|
|
|
|
|
|
|
|
|
# plot scatters and lines
|
|
|
|
for l in range(neuron_depth):
|
|
|
|
n = width[l]
|
|
|
|
spacing = A/n
|
|
|
|
for i in range(n):
|
|
|
|
plt.scatter(1/(2*n)+i/n, l*y0, s=min_spacing**2*10000*scale**2, color='black')
|
|
|
|
|
|
|
|
if l < neuron_depth - 1:
|
|
|
|
# plot connections
|
|
|
|
n_next = width[l+1]
|
|
|
|
N = n * n_next
|
|
|
|
for j in range(n_next):
|
|
|
|
id_ = i*n_next + j
|
|
|
|
|
|
|
|
symbol_mask = self.symbolic_fun[l].mask[j][i]
|
|
|
|
numerical_mask = self.act_fun[l].mask.reshape(self.width[l+1], self.width[l])[j][i]
|
|
|
|
if symbol_mask == 1. and numerical_mask == 1.:
|
|
|
|
color = 'purple'
|
|
|
|
alpha_mask = 1.
|
|
|
|
if symbol_mask == 1. and numerical_mask == 0.:
|
|
|
|
color = "red"
|
|
|
|
alpha_mask = 1.
|
|
|
|
if symbol_mask == 0. and numerical_mask == 1.:
|
|
|
|
color = "black"
|
|
|
|
alpha_mask = 1.
|
|
|
|
if symbol_mask == 0. and numerical_mask == 0.:
|
|
|
|
color = "white"
|
|
|
|
alpha_mask = 0.
|
|
|
|
if mask == True:
|
|
|
|
plt.plot([1/(2*n)+i/n, 1/(2*N)+id_/N], [l*y0, (l+1/2)*y0-y1], color=color, lw=2*scale, alpha=alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())
|
|
|
|
plt.plot([1/(2*N)+id_/N, 1/(2*n_next)+j/n_next], [(l+1/2)*y0+y1, (l+1)*y0], color=color, lw=2*scale, alpha=alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())
|
|
|
|
else:
|
|
|
|
plt.plot([1/(2*n)+i/n, 1/(2*N)+id_/N], [l*y0, (l+1/2)*y0-y1], color=color, lw=2*scale, alpha=alpha[l][j][i]*alpha_mask)
|
|
|
|
plt.plot([1/(2*N)+id_/N, 1/(2*n_next)+j/n_next], [(l+1/2)*y0+y1, (l+1)*y0], color=color, lw=2*scale, alpha=alpha[l][j][i]*alpha_mask)
|
|
|
|
|
|
|
|
plt.xlim(0,1)
|
|
|
|
plt.ylim(-0.1*y0, (neuron_depth-1+0.1)*y0)
|
|
|
|
|
|
|
|
|
|
|
|
# -- Transformation functions
|
|
|
|
DC_to_FC = ax.transData.transform
|
|
|
|
FC_to_NFC = fig.transFigure.inverted().transform
|
|
|
|
# -- Take data coordinates and transform them to normalized figure coordinates
|
|
|
|
DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))
|
|
|
|
|
|
|
|
plt.axis('off')
|
|
|
|
|
|
|
|
# plot splines
|
|
|
|
for l in range(neuron_depth-1):
|
|
|
|
n = width[l]
|
|
|
|
for i in range(n):
|
|
|
|
n_next = width[l+1]
|
|
|
|
N = n * n_next
|
|
|
|
for j in range(n_next):
|
|
|
|
id_ = i*n_next + j
|
|
|
|
im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
|
|
|
|
left = DC_to_NFC([1/(2*N)+id_/N-y1,0])[0]
|
|
|
|
right = DC_to_NFC([1/(2*N)+id_/N+y1,0])[0]
|
|
|
|
bottom = DC_to_NFC([0,(l+1/2)*y0-y1])[1]
|
|
|
|
up = DC_to_NFC([0,(l+1/2)*y0+y1])[1]
|
|
|
|
newax = fig.add_axes([left,bottom,right-left,up-bottom])
|
|
|
|
#newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
|
|
|
|
if mask == False:
|
|
|
|
newax.imshow(im, alpha=alpha[l][j][i])
|
|
|
|
else:
|
|
|
|
### make sure to run model.prune() first to compute mask ###
|
|
|
|
newax.imshow(im, alpha=alpha[l][j][i]*self.mask[l][i].item()*self.mask[l+1][j].item())
|
|
|
|
newax.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
if in_vars != None:
|
|
|
|
n = self.width[0]
|
|
|
|
for i in range(n):
|
|
|
|
plt.gcf().get_axes()[0].text(1/(2*(n))+i/(n),-0.1,in_vars[i], fontsize=40*scale,horizontalalignment='center', verticalalignment='center')
|
|
|
|
|
|
|
|
|
|
|
|
if out_vars != None:
|
|
|
|
n = self.width[-1]
|
|
|
|
for i in range(n):
|
|
|
|
plt.gcf().get_axes()[0].text(1/(2*(n))+i/(n),y0*(len(self.width)-1)+0.1,out_vars[i], fontsize=40*scale,horizontalalignment='center',verticalalignment='center')
|
|
|
|
|
|
|
|
|
|
|
|
if title != None:
|
|
|
|
plt.gcf().get_axes()[0].text(0.5,y0*(len(self.width)-1)+0.2,title, fontsize=40*scale,horizontalalignment='center',verticalalignment='center')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2024-04-29 12:35:18 -04:00
|
|
|
def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1 = 1., lamb_entropy = 2., lamb_coef = 0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn = None, lr=1., stop_grid_update_step=50, batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'):
|
2024-04-27 17:24:14 -04:00
|
|
|
'''
|
|
|
|
training
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
dataset : dic
|
|
|
|
contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
|
|
|
|
opt : str
|
|
|
|
"LBFGS" or "Adam"
|
|
|
|
steps : int
|
|
|
|
training steps
|
|
|
|
log : int
|
|
|
|
logging frequency
|
|
|
|
lamb : float
|
|
|
|
overall penalty strength
|
|
|
|
lamb_l1 : float
|
|
|
|
l1 penalty strength
|
|
|
|
lamb_entropy : float
|
|
|
|
entropy penalty strength
|
|
|
|
lamb_coef : float
|
|
|
|
coefficient magnitude penalty strength
|
|
|
|
lamb_coefdiff : float
|
|
|
|
difference of nearby coefficits (smoothness) penalty strength
|
|
|
|
update_grid : bool
|
|
|
|
If True, update grid regularly before stop_grid_update_step
|
|
|
|
grid_update_num : int
|
|
|
|
the number of grid updates before stop_grid_update_step
|
|
|
|
stop_grid_update_step : int
|
|
|
|
no grid updates after this training step
|
|
|
|
batch : int
|
|
|
|
batch size, if -1 then full.
|
|
|
|
small_mag_threshold : float
|
|
|
|
threshold to determine large or small numbers (may want to apply larger penalty to smaller numbers)
|
|
|
|
small_reg_factor : float
|
|
|
|
penalty strength applied to small factors relative to large factos
|
|
|
|
device : str
|
|
|
|
device
|
|
|
|
save_fig_freq : int
|
|
|
|
save figure every (save_fig_freq) step
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
results : dic
|
|
|
|
results['train_loss'], 1D array of training losses (RMSE)
|
|
|
|
results['test_loss'], 1D array of test losses (RMSE)
|
|
|
|
results['reg'], 1D array of regularization
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> # for interactive examples, please see demos
|
|
|
|
>>> from utils import create_dataset
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
|
|
|
|
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
|
|
|
>>> dataset = create_dataset(f, n_var=2)
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
|
|
|
|
>>> model.plot()
|
|
|
|
'''
|
|
|
|
|
|
|
|
def reg(acts_scale):
|
|
|
|
|
|
|
|
def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):
|
|
|
|
return (x < th) * x * factor + (x > th) * (x + (factor-1)*th)
|
|
|
|
|
|
|
|
|
|
|
|
reg_ = 0.
|
|
|
|
for i in range(len(acts_scale)):
|
|
|
|
vec = acts_scale[i].reshape(-1,)
|
|
|
|
|
|
|
|
p = vec/torch.sum(vec)
|
|
|
|
l1 = torch.sum(nonlinear(vec))
|
|
|
|
entropy = - torch.sum(p*torch.log2(p+1e-4))
|
|
|
|
reg_ += lamb_l1 * l1 + lamb_entropy * entropy # both l1 and entropy
|
|
|
|
|
|
|
|
# regularize coefficient to encourage spline to be zero
|
|
|
|
for i in range(len(self.act_fun)):
|
|
|
|
coeff_l1 = torch.sum(torch.mean(torch.abs(self.act_fun[i].coef), dim=1))
|
|
|
|
coeff_diff_l1 = torch.sum(torch.mean(torch.abs(torch.diff(self.act_fun[i].coef)),dim=1))
|
|
|
|
reg_ += lamb_coef * coeff_l1 + lamb_coefdiff * coeff_diff_l1
|
|
|
|
|
|
|
|
return reg_
|
|
|
|
|
|
|
|
|
|
|
|
pbar = tqdm(range(steps), desc='description', ncols=100)
|
|
|
|
|
|
|
|
if loss_fn == None:
|
|
|
|
loss_fn = loss_fn_eval = lambda x,y: torch.mean((x-y)**2)
|
|
|
|
else:
|
|
|
|
loss_fn = loss_fn_eval = loss_fn
|
|
|
|
|
|
|
|
grid_update_freq = int(stop_grid_update_step/grid_update_num)
|
|
|
|
|
|
|
|
if opt == "Adam":
|
|
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
|
|
|
elif opt == "LBFGS":
|
|
|
|
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
|
|
|
|
|
|
|
results = {}
|
|
|
|
results['train_loss'] = []
|
|
|
|
results['test_loss'] = []
|
|
|
|
results['reg'] = []
|
|
|
|
if metrics != None:
|
|
|
|
for i in range(len(metrics)):
|
|
|
|
results[metrics[i].__name__] = []
|
|
|
|
|
|
|
|
if batch == -1 or batch > dataset['train_input'].shape[0]:
|
|
|
|
batch_size = dataset['train_input'].shape[0]
|
|
|
|
batch_size_test = dataset['test_input'].shape[0]
|
|
|
|
else:
|
|
|
|
batch_size = batch
|
|
|
|
batch_size_test = batch
|
|
|
|
|
|
|
|
global train_loss, reg_
|
|
|
|
|
|
|
|
def closure():
|
|
|
|
global train_loss, reg_
|
|
|
|
optimizer.zero_grad()
|
|
|
|
pred = self.forward(dataset['train_input'][train_id].to(device))
|
|
|
|
if sglr_avoid == True:
|
|
|
|
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
|
|
|
|
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
|
|
|
|
else:
|
|
|
|
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
|
|
|
|
reg_ = reg(self.acts_scale)
|
|
|
|
objective = train_loss + lamb*reg_
|
|
|
|
objective.backward()
|
|
|
|
return objective
|
|
|
|
|
2024-04-29 12:35:18 -04:00
|
|
|
if save_fig:
|
|
|
|
if not os.path.exists(img_folder):
|
|
|
|
os.makedirs(img_folder)
|
2024-04-27 17:24:14 -04:00
|
|
|
|
|
|
|
for _ in pbar:
|
|
|
|
|
|
|
|
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
|
|
|
|
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
|
|
|
|
|
|
|
|
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
|
|
|
|
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
|
|
|
|
|
|
|
|
|
|
|
|
if opt == "LBFGS":
|
|
|
|
optimizer.step(closure)
|
|
|
|
|
|
|
|
if opt == "Adam":
|
|
|
|
pred = self.forward(dataset['train_input'][train_id].to(device))
|
|
|
|
if sglr_avoid == True:
|
|
|
|
id_ = torch.where(torch.isnan(torch.sum(pred, dim=1)) == False)[0]
|
|
|
|
train_loss = loss_fn(pred[id_], dataset['train_label'][train_id][id_].to(device))
|
|
|
|
else:
|
|
|
|
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(device))
|
|
|
|
reg_ = reg(self.acts_scale)
|
|
|
|
loss = train_loss + lamb*reg_
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device))
|
|
|
|
|
|
|
|
if _ % log == 0:
|
|
|
|
pbar.set_description("train loss: %.2e | test loss: %.2e | reg: %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
|
|
|
|
|
|
|
|
|
|
|
|
if metrics != None:
|
|
|
|
for i in range(len(metrics)):
|
|
|
|
results[metrics[i].__name__].append(metrics[i]().item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
|
|
|
|
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
|
|
|
|
results['reg'].append(reg_.cpu().detach().numpy())
|
|
|
|
|
2024-04-29 12:35:18 -04:00
|
|
|
if save_fig and _ % save_fig_freq == 0:
|
2024-04-27 17:24:14 -04:00
|
|
|
|
2024-04-29 12:35:18 -04:00
|
|
|
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
|
|
|
|
plt.savefig(img_folder+'/'+str(_)+'.jpg', bbox_inches='tight', dpi=200)
|
2024-04-27 17:24:14 -04:00
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None):
|
|
|
|
'''
|
|
|
|
pruning KAN on the node level. If a node has small incoming or outgoing connection, it will be pruned away.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
threshold : float
|
|
|
|
the threshold used to determine whether a node is small enough
|
|
|
|
mode : str
|
|
|
|
"auto" or "manual". If "auto", the thresold will be used to automatically prune away nodes. If "manual", active_neuron_id is needed to specify which neurons are kept (others are thrown away).
|
|
|
|
active_neuron_id : list of id lists
|
|
|
|
For example, [[0,1],[0,2,3]] means keeping the 0/1 neuron in the 1st hidden layer and the 0/2/3 neuron in the 2nd hidden layer. Pruning input and output neurons is not supported yet.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
model2 : KAN
|
|
|
|
pruned model
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> # for more interactive examples, please see demos
|
|
|
|
>>> from utils import create_dataset
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
|
|
|
|
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
|
|
|
>>> dataset = create_dataset(f, n_var=2)
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
|
|
|
|
>>> model.prune()
|
|
|
|
>>> model.plot(mask=True)
|
|
|
|
'''
|
|
|
|
mask = [torch.ones(self.width[0],)]
|
|
|
|
active_neurons = [list(range(self.width[0]))]
|
|
|
|
for i in range(len(self.acts_scale)-1):
|
|
|
|
if mode == "auto":
|
|
|
|
in_important = torch.max(self.acts_scale[i], dim=1)[0] > threshold
|
|
|
|
out_important = torch.max(self.acts_scale[i+1], dim=0)[0] > threshold
|
|
|
|
overall_important = in_important * out_important
|
|
|
|
elif mode == "manual":
|
|
|
|
overall_important = torch.zeros(self.width[i+1], dtype=torch.bool)
|
|
|
|
overall_important[active_neurons_id[i+1]] = True
|
|
|
|
mask.append(overall_important.float())
|
|
|
|
active_neurons.append(torch.where(overall_important==True)[0])
|
|
|
|
active_neurons.append(list(range(self.width[-1])))
|
|
|
|
mask.append(torch.ones(self.width[-1],))
|
|
|
|
|
|
|
|
self.mask = mask # this is neuron mask for the whole model
|
|
|
|
|
|
|
|
# update act_fun[l].mask
|
|
|
|
for l in range(len(self.acts_scale)-1):
|
|
|
|
for i in range(self.width[l+1]):
|
|
|
|
if i not in active_neurons[l+1]:
|
|
|
|
self.remove_node(l+1, i)
|
|
|
|
|
|
|
|
|
|
|
|
model2 = KAN(copy.deepcopy(self.width), self.grid, self.k, base_fun=self.base_fun)
|
|
|
|
model2.load_state_dict(self.state_dict())
|
|
|
|
for i in range(len(self.acts_scale)):
|
|
|
|
if i < len(self.acts_scale) - 1:
|
|
|
|
model2.biases[i].weight.data = model2.biases[i].weight.data[:,active_neurons[i+1]]
|
|
|
|
|
|
|
|
model2.act_fun[i] = model2.act_fun[i].get_subset(active_neurons[i], active_neurons[i+1])
|
|
|
|
model2.width[i] = len(active_neurons[i])
|
|
|
|
model2.symbolic_fun[i] = self.symbolic_fun[i].get_subset(active_neurons[i], active_neurons[i+1])
|
|
|
|
|
|
|
|
return model2
|
|
|
|
|
|
|
|
def remove_edge(self, l, i, j):
|
|
|
|
'''
|
|
|
|
remove activtion phi(l,i,j) (set its mask to zero)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
i : int
|
|
|
|
input neuron index
|
|
|
|
j : int
|
|
|
|
output neuron index
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
'''
|
|
|
|
self.act_fun[l].mask[j*self.width[l]+i] = 0.
|
|
|
|
|
|
|
|
def remove_node(self, l, i):
|
|
|
|
'''
|
|
|
|
remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
i : int
|
|
|
|
neuron index
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
'''
|
|
|
|
self.act_fun[l-1].mask[i*self.width[l-1]+torch.arange(self.width[l-1])] = 0.
|
|
|
|
self.act_fun[l].mask[torch.arange(self.width[l+1])*self.width[l]+i] = 0.
|
|
|
|
self.symbolic_fun[l-1].mask[i,:] *= 0.
|
|
|
|
self.symbolic_fun[l].mask[:,i] *= 0.
|
|
|
|
|
|
|
|
|
|
|
|
def suggest_symbolic(self, l, i, j, a_range=(-10,10), b_range=(-10,10), lib = None, topk=5, verbose=True):
|
|
|
|
'''suggest the symbolic candidates of phi(l,i,j)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
l : int
|
|
|
|
layer index
|
|
|
|
i : int
|
|
|
|
input neuron index
|
|
|
|
j : int
|
|
|
|
output neuron index
|
|
|
|
lib : dic
|
|
|
|
library of symbolic bases. If lib = None, the global default library will be used.
|
|
|
|
topk : int
|
|
|
|
display the top k symbolic functions (according to r2)
|
|
|
|
verbose : bool
|
|
|
|
If True, more information will be printed.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
|
|
|
|
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
|
|
|
>>> dataset = create_dataset(f, n_var=2)
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
|
|
|
|
>>> model = model.prune()
|
|
|
|
>>> model(dataset['train_input'])
|
|
|
|
>>> model.suggest_symbolic(0,0,0)
|
|
|
|
function , r2
|
|
|
|
sin , 0.9994412064552307
|
|
|
|
gaussian , 0.9196369051933289
|
|
|
|
tanh , 0.8608126044273376
|
|
|
|
sigmoid , 0.8578218817710876
|
|
|
|
arctan , 0.842217743396759
|
|
|
|
'''
|
|
|
|
r2s = []
|
|
|
|
|
|
|
|
if lib == None:
|
|
|
|
symbolic_lib = SYMBOLIC_LIB
|
|
|
|
else:
|
|
|
|
symbolic_lib = {}
|
|
|
|
for item in lib:
|
|
|
|
symbolic_lib[item] = SYMBOLIC_LIB[item]
|
|
|
|
|
|
|
|
for (name, fun) in symbolic_lib.items():
|
|
|
|
r2 = self.fix_symbolic(l,i,j,name,a_range=a_range,b_range=b_range,verbose=False)
|
|
|
|
r2s.append(r2.item())
|
|
|
|
|
|
|
|
self.unfix_symbolic(l,i,j)
|
|
|
|
|
|
|
|
sorted_ids = np.argsort(r2s)[::-1][:topk]
|
|
|
|
r2s = np.array(r2s)[sorted_ids][:topk]
|
|
|
|
topk = np.minimum(topk, len(symbolic_lib))
|
|
|
|
if verbose == True:
|
|
|
|
print('function',',','r2')
|
|
|
|
for i in range(topk):
|
|
|
|
print(list(symbolic_lib.items())[sorted_ids[i]][0],',',r2s[i])
|
|
|
|
|
|
|
|
best_name = list(symbolic_lib.items())[sorted_ids[0]][0]
|
|
|
|
best_fun = list(symbolic_lib.items())[sorted_ids[0]][1]
|
|
|
|
best_r2 = r2s[0]
|
|
|
|
return best_name, best_fun, best_r2
|
|
|
|
|
|
|
|
|
|
|
|
def auto_symbolic(self, a_range=(-10,10), b_range=(-10,10), lib=None, verbose=1):
|
|
|
|
'''
|
|
|
|
automatic symbolic regression: using top 1 suggestion from suggest_symbolic to replace splines with symbolic activations
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
lib : None or a list of function names
|
|
|
|
the symbolic library
|
|
|
|
verbose : int
|
|
|
|
verbosity
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None (print suggested symbolic formulas)
|
|
|
|
|
|
|
|
Example 1
|
|
|
|
---------
|
|
|
|
>>> # default library
|
|
|
|
>>> from utils import create_dataset
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
|
|
|
|
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
|
|
|
>>> dataset = create_dataset(f, n_var=2)
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
|
|
|
|
>>> >>> model = model.prune()
|
|
|
|
>>> model(dataset['train_input'])
|
|
|
|
>>> model.auto_symbolic()
|
|
|
|
fixing (0,0,0) with sin, r2=0.9994837045669556
|
|
|
|
fixing (0,1,0) with cosh, r2=0.9978033900260925
|
|
|
|
fixing (1,0,0) with arctan, r2=0.9997088313102722
|
|
|
|
|
|
|
|
Example 2
|
|
|
|
---------
|
|
|
|
>>> # customized library
|
|
|
|
>>> from utils import create_dataset
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
|
|
|
|
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
|
|
|
>>> dataset = create_dataset(f, n_var=2)
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
|
|
|
|
>>> >>> model = model.prune()
|
|
|
|
>>> model(dataset['train_input'])
|
|
|
|
>>> model.auto_symbolic(lib=['exp','sin','x^2'])
|
|
|
|
fixing (0,0,0) with sin, r2=0.999411404132843
|
|
|
|
fixing (0,1,0) with x^2, r2=0.9962921738624573
|
|
|
|
fixing (1,0,0) with exp, r2=0.9980258941650391
|
|
|
|
'''
|
|
|
|
for l in range(len(self.width)-1):
|
|
|
|
for i in range(self.width[l]):
|
|
|
|
for j in range(self.width[l+1]):
|
|
|
|
if self.symbolic_fun[l].mask[j,i] > 0.:
|
|
|
|
print(f'skipping ({l},{i},{j}) since already symbolic')
|
|
|
|
else:
|
|
|
|
name, fun, r2 = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
|
|
|
|
self.fix_symbolic(l,i,j,name,verbose=verbose>1)
|
|
|
|
if verbose >= 1:
|
|
|
|
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify=False):
|
|
|
|
'''
|
|
|
|
obtain the symbolic formula
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
floating_digit : int
|
|
|
|
the number of digits to display
|
|
|
|
var : list of str
|
|
|
|
the name of variables (if not provided, by default using ['x_1', 'x_2', ...])
|
|
|
|
normalizer : [mean array (floats), varaince array (floats)]
|
|
|
|
the normalization applied to inputs
|
|
|
|
simplify : bool
|
|
|
|
If True, simplify the equation at each step (usually quite slow), so set up False by default.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
symbolic formula : sympy function
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02)
|
|
|
|
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
|
|
|
|
>>> dataset = create_dataset(f, n_var=2)
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
|
|
|
|
>>> model = model.prune()
|
|
|
|
>>> model(dataset['train_input'])
|
|
|
|
>>> model.auto_symbolic(lib=['exp','sin','x^2'])
|
|
|
|
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False);
|
|
|
|
>>> model.symbolic_formula()
|
|
|
|
'''
|
|
|
|
symbolic_acts = []
|
|
|
|
x = []
|
|
|
|
|
|
|
|
def ex_round(ex1, floating_digit=floating_digit):
|
|
|
|
ex2 = ex1
|
|
|
|
for a in sympy.preorder_traversal(ex1):
|
|
|
|
if isinstance(a, sympy.Float):
|
|
|
|
ex2 = ex2.subs(a, round(a, floating_digit))
|
|
|
|
return ex2
|
|
|
|
|
|
|
|
# define variables
|
|
|
|
if var == None:
|
|
|
|
for ii in range(1,self.width[0]+1):
|
|
|
|
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
|
|
|
|
exec(f"x.append(x{ii})")
|
|
|
|
else:
|
|
|
|
x = [sympy.symbols(var_) for var_ in var]
|
|
|
|
|
|
|
|
x0 = x
|
|
|
|
|
|
|
|
if normalizer != None:
|
|
|
|
mean = normalizer[0]
|
|
|
|
std = normalizer[1]
|
|
|
|
x = [(x[i] - mean[i])/std[i] for i in range(len(x))]
|
|
|
|
|
|
|
|
|
|
|
|
symbolic_acts.append(x)
|
|
|
|
|
|
|
|
for l in range(len(self.width)-1):
|
|
|
|
y = []
|
|
|
|
for j in range(self.width[l+1]):
|
|
|
|
yj = 0.
|
|
|
|
for i in range(self.width[l]):
|
|
|
|
a, b, c, d = self.symbolic_fun[l].affine[j,i]
|
|
|
|
sympy_fun = self.symbolic_fun[l].funs_sympy[j][i]
|
|
|
|
try:
|
|
|
|
yj += c*sympy_fun(a*x[i]+b)+d
|
|
|
|
except:
|
|
|
|
print('make sure all activations need to be converted to symbolic formulas first!')
|
|
|
|
return
|
|
|
|
if simplify == True:
|
|
|
|
y.append(sympy.simplify(yj + self.biases[l].weight.data[0,j]))
|
|
|
|
else:
|
|
|
|
y.append(yj + self.biases[l].weight.data[0,j])
|
|
|
|
|
|
|
|
x = y
|
|
|
|
symbolic_acts.append(x)
|
|
|
|
|
|
|
|
self.symbolic_acts = [[ex_round(symbolic_acts[l][i]) for i in range(len(symbolic_acts[l]))] for l in range(len(symbolic_acts))]
|
|
|
|
|
|
|
|
out_dim = len(symbolic_acts[-1])
|
|
|
|
return [ex_round(symbolic_acts[-1][i]) for i in range(len(symbolic_acts[-1]))], x0
|
|
|
|
|
|
|
|
|
|
|
|
def clear_ckpts(self, folder='./model_ckpt'):
|
|
|
|
'''
|
|
|
|
clear all checkpoints
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
folder : str
|
|
|
|
the folder that stores checkpoints
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
'''
|
|
|
|
if os.path.exists(folder):
|
|
|
|
files = glob.glob(folder+'/*')
|
|
|
|
for f in files:
|
|
|
|
os.remove(f)
|
|
|
|
else:
|
|
|
|
os.makedirs(folder)
|
|
|
|
|
|
|
|
def save_ckpt(self, name, folder='./model_ckpt'):
|
|
|
|
'''
|
|
|
|
save the current model as checkpoint
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
name: str
|
|
|
|
the name of the checkpoint to be saved
|
|
|
|
folder : str
|
|
|
|
the folder that stores checkpoints
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
'''
|
|
|
|
|
|
|
|
if not os.path.exists(folder):
|
|
|
|
os.makedirs(folder)
|
|
|
|
|
|
|
|
torch.save(self.state_dict(), folder+'/'+name)
|
|
|
|
print('save this model to', folder+'/'+name)
|
|
|
|
|
|
|
|
def load_ckpt(self, name, folder='./model_ckpt'):
|
|
|
|
'''
|
|
|
|
load a checkpoint to the current model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
-----
|
|
|
|
name: str
|
|
|
|
the name of the checkpoint to be loaded
|
|
|
|
folder : str
|
|
|
|
the folder that stores checkpoints
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
--------
|
|
|
|
None
|
|
|
|
'''
|
|
|
|
self.load_state_dict(torch.load(folder+'/'+name))
|