Skip to content

Commit 91fc24e

Browse files
authored
Merge pull request #415 from srigas/master
Updated auto_symbolic to include a configurable threshold & simplicity weight
2 parents 5b2af5e + 3895043 commit 91fc24e

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

kan/MultKAN.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -2160,7 +2160,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No
21602160

21612161
return best_name, best_fun, best_r2, best_c;
21622162

2163-
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
2163+
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0):
21642164
'''
21652165
automatic symbolic regression for all edges
21662166
@@ -2174,7 +2174,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
21742174
library of candidate symbolic functions
21752175
verbose : int
21762176
larger verbosity => more verbosity
2177-
2177+
weight_simple : float
2178+
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
2179+
r2_threshold : float
2180+
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
21782181
Returns:
21792182
--------
21802183
None
@@ -2191,17 +2194,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
21912194
for l in range(len(self.width_in) - 1):
21922195
for i in range(self.width_in[l]):
21932196
for j in range(self.width_out[l + 1]):
2194-
#if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
21952197
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
21962198
print(f'skipping ({l},{i},{j}) since already symbolic')
21972199
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
21982200
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
21992201
print(f'fixing ({l},{i},{j}) with 0')
22002202
else:
2201-
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
2202-
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
2203-
if verbose >= 1:
2204-
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
2203+
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple)
2204+
if r2 >= r2_threshold:
2205+
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
2206+
if verbose >= 1:
2207+
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
2208+
else:
2209+
print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
22052210

22062211
self.log_history('auto_symbolic')
22072212

0 commit comments

Comments
 (0)