2024-07-17 10:45:25 -04:00

55 lines
1.7 KiB
Python

import torch
from .MultKAN import *
def runner1(width, dataset, grids=[5,10,20], steps=20, lamb=0.001, prune_round=3, refine_round=3, edge_th=1e-2, node_th=1e-2, metrics=None, seed=1):
result = {}
result['test_loss'] = []
result['c'] = []
result['G'] = []
result['id'] = []
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__] = []
def collect(evaluation):
result['test_loss'].append(evaluation['test_loss'])
result['c'].append(evaluation['n_edge'])
result['G'].append(evaluation['n_grid'])
result['id'].append(f'{model.round}.{model.state_id}')
if metrics != None:
for i in range(len(metrics)):
result[metrics[i].__name__].append(metrics[i](model, dataset).item())
for i in range(prune_round):
# train and prune
if i == 0:
model = KAN(width=width, grid=grids[0], seed=seed)
else:
model = model.rewind(f'{i-1}.{2*i}')
model.fit(dataset, steps=steps, lamb=lamb)
model = model.prune(edge_th=edge_th, node_th=node_th)
evaluation = model.evaluate(dataset)
collect(evaluation)
for j in range(refine_round):
model = model.refine(grids[j])
model.fit(dataset, steps=steps)
evaluation = model.evaluate(dataset)
collect(evaluation)
for key in list(result.keys()):
result[key] = np.array(result[key])
return result
def pareto_frontier(x,y):
pf_id = np.where(np.sum((x[:,None] <= x[None,:]) * (y[:,None] <= y[None,:]), axis=0) == 1)[0]
x_pf = x[pf_id]
y_pf = y[pf_id]
return x_pf, y_pf, pf_id