Updated auto_symbolic to include configurable threshold & weight

This commit is contained in:
Spyros 2024-08-21 21:14:13 +03:00
parent 5b2af5eb3d
commit 38950437e3

View File

@ -2160,7 +2160,7 @@ class MultKAN(nn.Module):
return best_name, best_fun, best_r2, best_c;
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0):
'''
automatic symbolic regression for all edges
@ -2174,7 +2174,10 @@ class MultKAN(nn.Module):
library of candidate symbolic functions
verbose : int
larger verbosity => more verbosity
weight_simple : float
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
r2_threshold : float
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
Returns:
--------
None
@ -2191,17 +2194,19 @@ class MultKAN(nn.Module):
for l in range(len(self.width_in) - 1):
for i in range(self.width_in[l]):
for j in range(self.width_out[l + 1]):
#if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
print(f'skipping ({l},{i},{j}) since already symbolic')
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
print(f'fixing ({l},{i},{j}) with 0')
else:
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple)
if r2 >= r2_threshold:
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
if verbose >= 1:
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
else:
print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
self.log_history('auto_symbolic')