Skip to content

Commit

Permalink
switch at_task_boundary to private and add before arg to detect one s…
Browse files Browse the repository at this point in the history
…ide of the boundary or another
  • Loading branch information
AlbinSou committed Oct 10, 2023
1 parent 7a92ffc commit a91704b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from avalanche.models.utils import avalanche_model_adaptation
from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol
from avalanche.models.dynamic_optimizers import reset_optimizer, update_optimizer
from avalanche.training.utils import at_task_boundary
from avalanche.training.utils import _at_task_boundary


class BatchObservation(SGDStrategyProtocol):
Expand Down Expand Up @@ -73,7 +73,7 @@ def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs):
if self.optimized_param_id is None:
self.make_optimizer(reset_optimizer_state=True, **kwargs)

if at_task_boundary(self.experience):
if _at_task_boundary(self.experience, before=True):
self.model = self.model_adaptation()
self.make_optimizer(reset_optimizer_state=reset_optimizer_state, **kwargs)

Expand Down
14 changes: 8 additions & 6 deletions avalanche/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from avalanche.models.batch_renorm import BatchRenorm2D


def at_task_boundary(training_experience) -> bool:
def _at_task_boundary(training_experience, before=True) -> bool:
"""
Given a training experience,
returns true if the experience is at the task boundary
Expand All @@ -41,14 +41,17 @@ def at_task_boundary(training_experience) -> bool:
- If the experience is not an online experience, returns True
:param before: If used in before_training_exp,
set to True, otherwise set
to False
"""

if isinstance(training_experience, OnlineCLExperience):
if training_experience.access_task_boundaries:
if (
training_experience.is_first_subexp
or training_experience.is_last_subexp
):
if before and training_experience.is_first_subexp:
return True
elif (not before) and training_experience.is_last_subexp:
return True
else:
return True
Expand Down Expand Up @@ -484,5 +487,4 @@ def __str__(self):
"examples_per_class",
"ParamData",
"cycle",
"at_task_boundary",
]

0 comments on commit a91704b

Please sign in to comment.