361 lines
12 KiB
Python
361 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from .LBFGS import LBFGS
|
|
|
|
seed = 0
|
|
torch.manual_seed(seed)
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
|
|
super(MLP, self).__init__()
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
linears = []
|
|
self.width = width
|
|
self.depth = depth = len(width) - 1
|
|
for i in range(depth):
|
|
linears.append(nn.Linear(width[i], width[i+1]))
|
|
self.linears = nn.ModuleList(linears)
|
|
|
|
#if activation == 'silu':
|
|
self.act_fun = torch.nn.SiLU()
|
|
self.save_act = save_act
|
|
self.acts = None
|
|
|
|
self.cache_data = None
|
|
|
|
self.device = device
|
|
self.to(device)
|
|
|
|
|
|
def to(self, device):
|
|
super(MLP, self).to(device)
|
|
self.device = device
|
|
|
|
return self
|
|
|
|
|
|
def get_act(self, x=None):
|
|
if isinstance(x, dict):
|
|
x = x['train_input']
|
|
if x == None:
|
|
if self.cache_data != None:
|
|
x = self.cache_data
|
|
else:
|
|
raise Exception("missing input data x")
|
|
save_act = self.save_act
|
|
self.save_act = True
|
|
self.forward(x)
|
|
self.save_act = save_act
|
|
|
|
@property
|
|
def w(self):
|
|
return [self.linears[l].weight for l in range(self.depth)]
|
|
|
|
def forward(self, x):
|
|
|
|
# cache data
|
|
self.cache_data = x
|
|
|
|
self.acts = []
|
|
self.acts_scale = []
|
|
self.wa_forward = []
|
|
self.a_forward = []
|
|
|
|
for i in range(self.depth):
|
|
|
|
if self.save_act:
|
|
act = x.clone()
|
|
act_scale = torch.std(x, dim=0)
|
|
wa_forward = act_scale[None, :] * self.linears[i].weight
|
|
self.acts.append(act)
|
|
if i > 0:
|
|
self.acts_scale.append(act_scale)
|
|
self.wa_forward.append(wa_forward)
|
|
|
|
x = self.linears[i](x)
|
|
if i < self.depth - 1:
|
|
x = self.act_fun(x)
|
|
else:
|
|
if self.save_act:
|
|
act_scale = torch.std(x, dim=0)
|
|
self.acts_scale.append(act_scale)
|
|
|
|
return x
|
|
|
|
def attribute(self):
|
|
if self.acts == None:
|
|
self.get_act()
|
|
|
|
node_scores = []
|
|
edge_scores = []
|
|
|
|
# back propagate from the last layer
|
|
node_score = torch.ones(self.width[-1]).requires_grad_(True).to(self.device)
|
|
node_scores.append(node_score)
|
|
|
|
for l in range(self.depth,0,-1):
|
|
|
|
edge_score = torch.einsum('ij,i->ij', torch.abs(self.wa_forward[l-1]), node_score/(self.acts_scale[l-1]+1e-4))
|
|
edge_scores.append(edge_score)
|
|
|
|
# this might be improper for MLPs (although reasonable for KANs)
|
|
node_score = torch.sum(edge_score, dim=0)/torch.sqrt(torch.tensor(self.width[l-1], device=self.device))
|
|
#print(self.width[l])
|
|
node_scores.append(node_score)
|
|
|
|
self.node_scores = list(reversed(node_scores))
|
|
self.edge_scores = list(reversed(edge_scores))
|
|
self.wa_backward = self.edge_scores
|
|
|
|
def plot(self, beta=3, scale=1., metric='w'):
|
|
# metric = 'w', 'act' or 'fa'
|
|
|
|
if metric == 'fa':
|
|
self.attribute()
|
|
|
|
depth = self.depth
|
|
y0 = 0.5
|
|
fig, ax = plt.subplots(figsize=(3*scale,3*y0*depth*scale))
|
|
shp = self.width
|
|
|
|
min_spacing = 1/max(self.width)
|
|
for j in range(len(shp)):
|
|
N = shp[j]
|
|
for i in range(N):
|
|
plt.scatter(1 / (2 * N) + i / N, j * y0, s=min_spacing ** 2 * 5000 * scale ** 2, color='black')
|
|
|
|
plt.ylim(-0.1*y0,y0*depth+0.1*y0)
|
|
plt.xlim(-0.02,1.02)
|
|
|
|
linears = self.linears
|
|
|
|
for ii in range(len(linears)):
|
|
linear = linears[ii]
|
|
p = linear.weight
|
|
p_shp = p.shape
|
|
|
|
if metric == 'w':
|
|
pass
|
|
elif metric == 'act':
|
|
p = self.wa_forward[ii]
|
|
elif metric == 'fa':
|
|
p = self.wa_backward[ii]
|
|
else:
|
|
raise Exception('metric = \'{}\' not recognized. Choices are \'w\', \'act\', \'fa\'.'.format(metric))
|
|
for i in range(p_shp[0]):
|
|
for j in range(p_shp[1]):
|
|
plt.plot([1/(2*p_shp[0])+i/p_shp[0], 1/(2*p_shp[1])+j/p_shp[1]], [y0*(ii+1),y0*ii], lw=0.5*scale, alpha=np.tanh(beta*np.abs(p[i,j].cpu().detach().numpy())), color="blue" if p[i,j]>0 else "red")
|
|
|
|
ax.axis('off')
|
|
|
|
def reg(self, reg_metric, lamb_l1, lamb_entropy):
|
|
|
|
if reg_metric == 'w':
|
|
acts_scale = self.w
|
|
if reg_metric == 'act':
|
|
acts_scale = self.wa_forward
|
|
if reg_metric == 'fa':
|
|
acts_scale = self.wa_backward
|
|
if reg_metric == 'a':
|
|
acts_scale = self.acts_scale
|
|
|
|
if len(acts_scale[0].shape) == 2:
|
|
reg_ = 0.
|
|
|
|
for i in range(len(acts_scale)):
|
|
vec = acts_scale[i]
|
|
vec = torch.abs(vec)
|
|
|
|
l1 = torch.sum(vec)
|
|
p_row = vec / (torch.sum(vec, dim=1, keepdim=True) + 1)
|
|
p_col = vec / (torch.sum(vec, dim=0, keepdim=True) + 1)
|
|
entropy_row = - torch.mean(torch.sum(p_row * torch.log2(p_row + 1e-4), dim=1))
|
|
entropy_col = - torch.mean(torch.sum(p_col * torch.log2(p_col + 1e-4), dim=0))
|
|
reg_ += lamb_l1 * l1 + lamb_entropy * (entropy_row + entropy_col)
|
|
|
|
elif len(acts_scale[0].shape) == 1:
|
|
|
|
reg_ = 0.
|
|
|
|
for i in range(len(acts_scale)):
|
|
vec = acts_scale[i]
|
|
vec = torch.abs(vec)
|
|
|
|
l1 = torch.sum(vec)
|
|
p = vec / (torch.sum(vec) + 1)
|
|
entropy = - torch.sum(p * torch.log2(p + 1e-4))
|
|
reg_ += lamb_l1 * l1 + lamb_entropy * entropy
|
|
|
|
return reg_
|
|
|
|
def get_reg(self, reg_metric, lamb_l1, lamb_entropy):
|
|
return self.reg(reg_metric, lamb_l1, lamb_entropy)
|
|
|
|
def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., loss_fn=None, lr=1., batch=-1,
|
|
metrics=None, in_vars=None, out_vars=None, beta=3, device='cpu', reg_metric='w', display_metrics=None):
|
|
|
|
if lamb > 0. and not self.save_act:
|
|
print('setting lamb=0. If you want to set lamb > 0, set =True')
|
|
|
|
old_save_act = self.save_act
|
|
if lamb == 0.:
|
|
self.save_act = False
|
|
|
|
pbar = tqdm(range(steps), desc='description', ncols=100)
|
|
|
|
if loss_fn == None:
|
|
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
|
|
else:
|
|
loss_fn = loss_fn_eval = loss_fn
|
|
|
|
if opt == "Adam":
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
|
elif opt == "LBFGS":
|
|
optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
|
|
|
results = {}
|
|
results['train_loss'] = []
|
|
results['test_loss'] = []
|
|
results['reg'] = []
|
|
if metrics != None:
|
|
for i in range(len(metrics)):
|
|
results[metrics[i].__name__] = []
|
|
|
|
if batch == -1 or batch > dataset['train_input'].shape[0]:
|
|
batch_size = dataset['train_input'].shape[0]
|
|
batch_size_test = dataset['test_input'].shape[0]
|
|
else:
|
|
batch_size = batch
|
|
batch_size_test = batch
|
|
|
|
global train_loss, reg_
|
|
|
|
def closure():
|
|
global train_loss, reg_
|
|
optimizer.zero_grad()
|
|
pred = self.forward(dataset['train_input'][train_id].to(self.device))
|
|
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
|
|
if self.save_act:
|
|
if reg_metric == 'fa':
|
|
self.attribute()
|
|
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy)
|
|
else:
|
|
reg_ = torch.tensor(0.)
|
|
objective = train_loss + lamb * reg_
|
|
objective.backward()
|
|
return objective
|
|
|
|
for _ in pbar:
|
|
|
|
if _ == steps-1 and old_save_act:
|
|
self.save_act = True
|
|
|
|
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
|
|
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
|
|
|
|
if opt == "LBFGS":
|
|
optimizer.step(closure)
|
|
|
|
if opt == "Adam":
|
|
pred = self.forward(dataset['train_input'][train_id].to(self.device))
|
|
train_loss = loss_fn(pred, dataset['train_label'][train_id].to(self.device))
|
|
if self.save_act:
|
|
reg_ = self.get_reg(reg_metric, lamb_l1, lamb_entropy)
|
|
else:
|
|
reg_ = torch.tensor(0.)
|
|
loss = train_loss + lamb * reg_
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(self.device)), dataset['test_label'][test_id].to(self.device))
|
|
|
|
|
|
if metrics != None:
|
|
for i in range(len(metrics)):
|
|
results[metrics[i].__name__].append(metrics[i]().item())
|
|
|
|
results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
|
|
results['test_loss'].append(torch.sqrt(test_loss).cpu().detach().numpy())
|
|
results['reg'].append(reg_.cpu().detach().numpy())
|
|
|
|
if _ % log == 0:
|
|
if display_metrics == None:
|
|
pbar.set_description("| train_loss: %.2e | test_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(test_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
|
|
else:
|
|
string = ''
|
|
data = ()
|
|
for metric in display_metrics:
|
|
string += f' {metric}: %.2e |'
|
|
try:
|
|
results[metric]
|
|
except:
|
|
raise Exception(f'{metric} not recognized')
|
|
data += (results[metric][-1],)
|
|
pbar.set_description(string % data)
|
|
|
|
return results
|
|
|
|
@property
|
|
def connection_cost(self):
|
|
|
|
with torch.no_grad():
|
|
cc = 0.
|
|
for linear in self.linears:
|
|
t = torch.abs(linear.weight)
|
|
def get_coordinate(n):
|
|
return torch.linspace(0,1,steps=n+1, device=self.device)[:n] + 1/(2*n)
|
|
|
|
in_dim = t.shape[0]
|
|
x_in = get_coordinate(in_dim)
|
|
|
|
out_dim = t.shape[1]
|
|
x_out = get_coordinate(out_dim)
|
|
|
|
dist = torch.abs(x_in[:,None] - x_out[None,:])
|
|
cc += torch.sum(dist * t)
|
|
|
|
return cc
|
|
|
|
def swap(self, l, i1, i2):
|
|
|
|
def swap_row(data, i1, i2):
|
|
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
|
|
|
|
def swap_col(data, i1, i2):
|
|
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
|
|
|
|
swap_row(self.linears[l-1].weight.data, i1, i2)
|
|
swap_row(self.linears[l-1].bias.data, i1, i2)
|
|
swap_col(self.linears[l].weight.data, i1, i2)
|
|
|
|
def auto_swap_l(self, l):
|
|
|
|
num = self.width[l]
|
|
for i in range(num):
|
|
ccs = []
|
|
for j in range(num):
|
|
self.swap(l,i,j)
|
|
self.get_act()
|
|
self.attribute()
|
|
cc = self.connection_cost.detach().clone()
|
|
ccs.append(cc)
|
|
self.swap(l,i,j)
|
|
j = torch.argmin(torch.tensor(ccs))
|
|
self.swap(l,i,j)
|
|
|
|
def auto_swap(self):
|
|
depth = self.depth
|
|
for l in range(1, depth):
|
|
self.auto_swap_l(l)
|
|
|
|
def tree(self, x=None, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
|
|
if x == None:
|
|
x = self.cache_data
|
|
plot_tree(self, x, in_var=in_var, style=style, sym_th=sym_th, sep_th=sep_th, skip_sep_test=skip_sep_test, verbose=verbose) |