Skip to content

Commit 071d813

Browse files
authored
Merge pull request #1595 from MattiaFanan/fix-EarlyStoppingPlugin
Fix early stopping plugin
2 parents cbe1307 + 7dc09dc commit 071d813

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

avalanche/training/plugins/early_stopping.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,20 @@ def _update_best(self, strategy):
115115
f"Metric {self.metric_name} used by the EarlyStopping plugin "
116116
f"is not computed yet. EarlyStopping will not be triggered."
117117
)
118-
if self.best_val is None or self.operator(val_acc, self.best_val):
118+
119+
if self.best_val is None:
120+
self.best_state = deepcopy(strategy.model.state_dict())
121+
self.best_val = val_acc
122+
self.best_step = self._get_strategy_counter(strategy)
123+
return None
124+
125+
delta_val = float(val_acc - self.best_val)
126+
if self.operator(delta_val, 0) and abs(delta_val) >= self.margin:
119127
self.best_state = deepcopy(strategy.model.state_dict())
120-
if self.best_val is None:
121-
self.best_val = val_acc
122-
self.best_step = 0
123-
return None
124-
125-
if self.operator(float(val_acc - self.best_val), self.margin):
126-
self.best_step = self._get_strategy_counter(strategy)
127-
self.best_val = val_acc
128-
if self.verbose:
129-
print("EarlyStopping: new best value:", val_acc)
128+
self.best_val = val_acc
129+
self.best_step = self._get_strategy_counter(strategy)
130+
if self.verbose:
131+
print("EarlyStopping: new best value:", val_acc)
130132

131133
return self.best_val
132134

0 commit comments

Comments
 (0)