diff --git a/avalanche/training/templates/observation_type/batch_observation.py b/avalanche/training/templates/observation_type/batch_observation.py index fe076de75..e71bcf61e 100644 --- a/avalanche/training/templates/observation_type/batch_observation.py +++ b/avalanche/training/templates/observation_type/batch_observation.py @@ -8,7 +8,7 @@ from avalanche.models.utils import avalanche_model_adaptation from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol from avalanche.models.dynamic_optimizers import reset_optimizer, update_optimizer -from avalanche.training.utils import at_task_boundary +from avalanche.training.utils import _at_task_boundary class BatchObservation(SGDStrategyProtocol): @@ -73,7 +73,7 @@ def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs): if self.optimized_param_id is None: self.make_optimizer(reset_optimizer_state=True, **kwargs) - if at_task_boundary(self.experience): + if _at_task_boundary(self.experience, before=True): self.model = self.model_adaptation() self.make_optimizer(reset_optimizer_state=reset_optimizer_state, **kwargs) diff --git a/avalanche/training/utils.py b/avalanche/training/utils.py index 50c121ab3..4b4a6d748 100644 --- a/avalanche/training/utils.py +++ b/avalanche/training/utils.py @@ -26,7 +26,7 @@ from avalanche.models.batch_renorm import BatchRenorm2D -def at_task_boundary(training_experience) -> bool: +def _at_task_boundary(training_experience, before=True) -> bool: """ Given a training experience, returns true if the experience is at the task boundary @@ -41,14 +41,17 @@ def at_task_boundary(training_experience) -> bool: - If the experience is not an online experience, returns True + :param before: If used in before_training_exp, + set to True, otherwise set + to False + """ if isinstance(training_experience, OnlineCLExperience): if training_experience.access_task_boundaries: - if ( - training_experience.is_first_subexp - or training_experience.is_last_subexp - ): + if before and training_experience.is_first_subexp: + return True + elif (not before) and training_experience.is_last_subexp: return True else: return True @@ -484,5 +487,4 @@ def __str__(self): "examples_per_class", "ParamData", "cycle", - "at_task_boundary", ]