From 38950437e33df9b041b1338caecdd5035a4a1f60 Mon Sep 17 00:00:00 2001 From: Spyros Date: Wed, 21 Aug 2024 21:14:13 +0300 Subject: [PATCH] Updated auto_symbolic to include configurable threshold & weight --- kan/MultKAN.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/kan/MultKAN.py b/kan/MultKAN.py index 9955e9b4..c615cc27 100644 --- a/kan/MultKAN.py +++ b/kan/MultKAN.py @@ -2160,7 +2160,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No return best_name, best_fun, best_r2, best_c; - def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1): + def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0): ''' automatic symbolic regression for all edges @@ -2174,7 +2174,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose= library of candidate symbolic functions verbose : int larger verbosity => more verbosity - + weight_simple : float + a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity + r2_threshold : float + If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold Returns: -------- None @@ -2191,17 +2194,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose= for l in range(len(self.width_in) - 1): for i in range(self.width_in[l]): for j in range(self.width_out[l + 1]): - #if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.: if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.: print(f'skipping ({l},{i},{j}) since already symbolic') elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.: self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False) print(f'fixing ({l},{i},{j}) with 0') else: - name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False) - self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False) - if verbose >= 1: - print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}') + name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple) + if r2 >= r2_threshold: + self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False) + if verbose >= 1: + print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}') + else: + print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.') self.log_history('auto_symbolic')