@@ -2160,7 +2160,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No
2160
2160
2161
2161
return best_name , best_fun , best_r2 , best_c ;
2162
2162
2163
- def auto_symbolic (self , a_range = (- 10 , 10 ), b_range = (- 10 , 10 ), lib = None , verbose = 1 ):
2163
+ def auto_symbolic (self , a_range = (- 10 , 10 ), b_range = (- 10 , 10 ), lib = None , verbose = 1 , weight_simple = 0.8 , r2_threshold = 0.0 ):
2164
2164
'''
2165
2165
automatic symbolic regression for all edges
2166
2166
@@ -2174,7 +2174,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
2174
2174
library of candidate symbolic functions
2175
2175
verbose : int
2176
2176
larger verbosity => more verbosity
2177
-
2177
+ weight_simple : float
2178
+ a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
2179
+ r2_threshold : float
2180
+ If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
2178
2181
Returns:
2179
2182
--------
2180
2183
None
@@ -2191,17 +2194,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
2191
2194
for l in range (len (self .width_in ) - 1 ):
2192
2195
for i in range (self .width_in [l ]):
2193
2196
for j in range (self .width_out [l + 1 ]):
2194
- #if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
2195
2197
if self .symbolic_fun [l ].mask [j , i ] > 0. and self .act_fun [l ].mask [i ][j ] == 0. :
2196
2198
print (f'skipping ({ l } ,{ i } ,{ j } ) since already symbolic' )
2197
2199
elif self .symbolic_fun [l ].mask [j , i ] == 0. and self .act_fun [l ].mask [i ][j ] == 0. :
2198
2200
self .fix_symbolic (l , i , j , '0' , verbose = verbose > 1 , log_history = False )
2199
2201
print (f'fixing ({ l } ,{ i } ,{ j } ) with 0' )
2200
2202
else :
2201
- name , fun , r2 , c = self .suggest_symbolic (l , i , j , a_range = a_range , b_range = b_range , lib = lib , verbose = False )
2202
- self .fix_symbolic (l , i , j , name , verbose = verbose > 1 , log_history = False )
2203
- if verbose >= 1 :
2204
- print (f'fixing ({ l } ,{ i } ,{ j } ) with { name } , r2={ r2 } , c={ c } ' )
2203
+ name , fun , r2 , c = self .suggest_symbolic (l , i , j , a_range = a_range , b_range = b_range , lib = lib , verbose = False , weight_simple = weight_simple )
2204
+ if r2 >= r2_threshold :
2205
+ self .fix_symbolic (l , i , j , name , verbose = verbose > 1 , log_history = False )
2206
+ if verbose >= 1 :
2207
+ print (f'fixing ({ l } ,{ i } ,{ j } ) with { name } , r2={ r2 } , c={ c } ' )
2208
+ else :
2209
+ print (f'For ({ l } ,{ i } ,{ j } ) the best fit was { name } , but r^2 = { r2 } and this is lower than { r2_threshold } . This edge was omitted, keep training or try a different threshold.' )
2205
2210
2206
2211
self .log_history ('auto_symbolic' )
2207
2212
0 commit comments