diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 80b47f325..d58bae600 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -148,9 +148,9 @@ def train( self.model.to(self.device) # Normalize training and eval data. - experiences_list: Iterable[TExperienceType] = ( - _experiences_parameter_as_iterable(experiences) - ) + experiences_list: Iterable[ + TExperienceType + ] = _experiences_parameter_as_iterable(experiences) if eval_streams is None: eval_streams = [experiences_list] @@ -202,9 +202,9 @@ def eval( self.is_training = False self.model.eval() - experiences_list: Iterable[TExperienceType] = ( - _experiences_parameter_as_iterable(experiences) - ) + experiences_list: Iterable[ + TExperienceType + ] = _experiences_parameter_as_iterable(experiences) self.current_eval_stream = experiences_list self._before_eval(**kwargs)