Skip to content

Commit ea6e595

Browse files
committed
added parameter for early stopping
1 parent e84cf87 commit ea6e595

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

DAG_search/eliminations.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ class EliminationRegressor(sklearn.base.BaseEstimator, sklearn.base.RegressorMix
516516
Sklearn interface.
517517
'''
518518

519-
def __init__(self, symb_regr, only_input:bool = False, positives:list = None, expr = None, exec_func = None, score_fkt = codec_coefficient, **kwargs):
519+
def __init__(self, symb_regr, only_input:bool = False, positives:list = None, expr = None, exec_func = None, score_fkt = codec_coefficient, early_stop_thresh = 0.99999, **kwargs):
520520
'''
521521
@Params:
522522
symb_regr... symbolic regressor (has .fit(X, y), .predict(X), .model() function)
@@ -529,6 +529,7 @@ def __init__(self, symb_regr, only_input:bool = False, positives:list = None, ex
529529
self.exec_func = exec_func
530530
self.score_fkt = score_fkt
531531
self.only_input = only_input
532+
self.early_stop_thresh = early_stop_thresh
532533

533534
def fit(self, X:np.ndarray, y:np.ndarray, verbose:int = 1):
534535
'''
@@ -540,7 +541,7 @@ def fit(self, X:np.ndarray, y:np.ndarray, verbose:int = 1):
540541
assert len(y.shape) == 1, f'y must be 1-dimensional (current shape: {y.shape})'
541542

542543

543-
r2_thresh = 1-1e-5 # if solution is found with higher r2 score than this: early stop
544+
r2_thresh = self.early_stop_thresh # if solution is found with higher r2 score than this: early stop
544545
x_symbs = [f'x_{i}' for i in range(X.shape[1])]
545546

546547
self.positives = np.all(X > 0, axis = 0)

0 commit comments

Comments
 (0)