Updated auto_symbolic to include configurable threshold & weight
This commit is contained in:
parent
5b2af5eb3d
commit
38950437e3
@ -2160,7 +2160,7 @@ class MultKAN(nn.Module):
|
||||
|
||||
return best_name, best_fun, best_r2, best_c;
|
||||
|
||||
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
|
||||
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0):
|
||||
'''
|
||||
automatic symbolic regression for all edges
|
||||
|
||||
@ -2174,7 +2174,10 @@ class MultKAN(nn.Module):
|
||||
library of candidate symbolic functions
|
||||
verbose : int
|
||||
larger verbosity => more verbosity
|
||||
|
||||
weight_simple : float
|
||||
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
|
||||
r2_threshold : float
|
||||
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
|
||||
Returns:
|
||||
--------
|
||||
None
|
||||
@ -2191,17 +2194,19 @@ class MultKAN(nn.Module):
|
||||
for l in range(len(self.width_in) - 1):
|
||||
for i in range(self.width_in[l]):
|
||||
for j in range(self.width_out[l + 1]):
|
||||
#if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
|
||||
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
|
||||
print(f'skipping ({l},{i},{j}) since already symbolic')
|
||||
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
|
||||
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
|
||||
print(f'fixing ({l},{i},{j}) with 0')
|
||||
else:
|
||||
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
|
||||
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
|
||||
if verbose >= 1:
|
||||
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
|
||||
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple)
|
||||
if r2 >= r2_threshold:
|
||||
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
|
||||
if verbose >= 1:
|
||||
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
|
||||
else:
|
||||
print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
|
||||
|
||||
self.log_history('auto_symbolic')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user