Skip to content

Commit

Permalink
Merge pull request #1633 from AntonioCarta/updatable_objects
Browse files Browse the repository at this point in the history
Updatable objects
  • Loading branch information
AntonioCarta authored Apr 30, 2024
2 parents 6e5e3b2 + 876140c commit 5bc33e1
Show file tree
Hide file tree
Showing 24 changed files with 1,101 additions and 63 deletions.
3 changes: 1 addition & 2 deletions avalanche/_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,12 @@ def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.simplefilter("always", DeprecationWarning)
warnings.simplefilter("once", DeprecationWarning)
warnings.warn(
msg.format(name=func.__name__, version=version, reason=reason),
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func(*args, **kwargs)

return wrapper
Expand Down
14 changes: 11 additions & 3 deletions avalanche/benchmarks/scenarios/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
shuffle: bool = True,
drop_last: bool = False,
access_task_boundaries: bool = False,
seed: int = None,
) -> None:
"""Returns a lazy stream generated by splitting an experience into
smaller ones.
Expand All @@ -181,7 +182,8 @@ def __init__(
:param experience_size: The experience size (number of instances).
:param shuffle: If True, instances will be shuffled before splitting.
:param drop_last: If True, the last mini-experience will be dropped if
not of size `experience_size`
not of size `experience_size`.
:param seed: random seed for shuffling the data if `shuffle == True`.
:return: The list of datasets that will be used to create the
mini-experiences.
"""
Expand All @@ -190,10 +192,12 @@ def __init__(
self.shuffle = shuffle
self.drop_last = drop_last
self.access_task_boundaries = access_task_boundaries
self.seed = seed

# we need to fix the seed because repeated calls to the generator
# must return the same order every time.
self.seed = random.randint(0, 2**32 - 1)
if seed is None:
self.seed = random.randint(0, 2**32 - 1)

def __iter__(self) -> Generator[OnlineCLExperience, None, None]:
exp_dataset = self.experience.dataset
Expand Down Expand Up @@ -250,13 +254,15 @@ def _default_online_split(
access_task_boundaries: bool,
exp: DatasetExperience,
size: int,
seed: int,
):
return FixedSizeExperienceSplitter(
experience=exp,
experience_size=size,
shuffle=shuffle,
drop_last=drop_last,
access_task_boundaries=access_task_boundaries,
seed=seed,
)


Expand All @@ -272,6 +278,7 @@ def split_online_stream(
]
] = None,
access_task_boundaries: bool = False,
seed: int = None,
) -> CLStream[DatasetExperience[TCLDataset]]:
"""Split a stream of large batches to create an online stream of small
mini-batches.
Expand Down Expand Up @@ -300,6 +307,7 @@ def split_online_stream(
A good starting to understand the mechanism is to look at the
implementation of the standard splitting function
:func:`fixed_size_experience_split`.
:param seed: random seed used for shuffling by the default splitter.
:return: A lazy online stream with experiences of size `experience_size`.
"""

Expand All @@ -308,7 +316,7 @@ def split_online_stream(
# However, MyPy does not understand what a partial is -_-
def default_online_split_wrapper(e, e_sz):
return _default_online_split(
shuffle, drop_last, access_task_boundaries, e, e_sz
shuffle, drop_last, access_task_boundaries, e, e_sz, seed=seed
)

split_strategy = default_online_split_wrapper
Expand Down
1 change: 1 addition & 0 deletions avalanche/benchmarks/scenarios/validation_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def random_validation_split_strategy_wrapper(data):

# don't drop classes-timeline for compatibility with old API
e0 = next(iter(train_stream))

if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"):
train_stream = with_classes_timeline(train_stream)
valid_stream = with_classes_timeline(valid_stream)
Expand Down
149 changes: 148 additions & 1 deletion avalanche/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,160 @@
"""
This module contains Protocols for some of the main components of Avalanche,
such as strategy plugins and the agent state.
Most of these protocols are checked dynamically at runtime, so it is often not
necessary to inherit explicit from them or implement all the methods.
"""

from abc import ABC
from typing import Any, TypeVar, Generic
from typing import Any, TypeVar, Generic, Protocol, runtime_checkable
from typing import TYPE_CHECKING

from avalanche.benchmarks import CLExperience

if TYPE_CHECKING:
from avalanche.training.templates.base import BaseTemplate

Template = TypeVar("Template", bound="BaseTemplate")


class Agent:
"""Avalanche Continual Learning Agent.
The agent stores the state needed by continual learning training methods,
such as optimizers, models, regularization losses.
You can add any objects as attributes dynamically:
.. code-block::
agent = Agent()
agent.replay = ReservoirSamplingBuffer(max_size=200)
agent.loss = MaskedCrossEntropy()
agent.reg_loss = LearningWithoutForgetting(alpha=1, temperature=2)
agent.model = my_model
agent.opt = SGD(agent.model.parameters(), lr=0.001)
agent.scheduler = ExponentialLR(agent.opt, gamma=0.999)
Many CL objects will need to perform some operation before or
after training on each experience. This is supported via the `Adaptable`
Protocol, which requires the `pre_adapt` and `post_adapt` methods.
To call the pre/post adaptation you can implement your training loop
like in the following example:
.. code-block::
def train(agent, exp):
agent.pre_adapt(exp)
# do training here
agent.post_adapt(exp)
Objects that implement the `Adaptable` Protocol will be called by the Agent.
You can also add additional functionality to the adaptation phases with
hooks. For example:
.. code-block::
agent.add_pre_hooks(lambda a, e: update_optimizer(a.opt, new_params={}, optimized_params=dict(a.model.named_parameters())))
# we update the lr scheduler after each experience (not every epoch!)
agent.add_post_hooks(lambda a, e: a.scheduler.step())
"""

def __init__(self, verbose=False):
"""Init.
:param verbose: If True, print every time an adaptable object or hook
is called during the adaptation. Useful for debugging.
"""
self._updatable_objects = []
self.verbose = verbose
self._pre_hooks = []
self._post_hooks = []

def __setattr__(self, name, value):
super().__setattr__(name, value)
if hasattr(value, "pre_adapt") or hasattr(value, "post_adapt"):
self._updatable_objects.append(value)
if self.verbose:
print("Added updatable object ", value)

def pre_adapt(self, exp):
"""Pre-adaptation.
Remember to call this before training on a new experience.
:param exp: current experience
"""
for uo in self._updatable_objects:
if hasattr(uo, "pre_adapt"):
uo.pre_adapt(self, exp)
if self.verbose:
print("pre_adapt ", uo)
for foo in self._pre_hooks:
if self.verbose:
print("pre_adapt hook ", foo)
foo(self, exp)

def post_adapt(self, exp):
"""Post-adaptation.
Remember to call this after training on a new experience.
:param exp: current experience
"""
for uo in self._updatable_objects:
if hasattr(uo, "post_adapt"):
uo.post_adapt(self, exp)
if self.verbose:
print("post_adapt ", uo)
for foo in self._post_hooks:
if self.verbose:
print("post_adapt hook ", foo)
foo(self, exp)

def add_pre_hooks(self, foo):
"""Add a pre-adaptation hooks
Hooks take two arguments: `<agent, experience>`.
:param foo: the hook function
"""
self._pre_hooks.append(foo)

def add_post_hooks(self, foo):
"""Add a post-adaptation hooks
Hooks take two arguments: `<agent, experience>`.
:param foo: the hook function
"""
self._post_hooks.append(foo)


class Adaptable(Protocol):
"""Adaptable objects Protocol.
These class documents the Adaptable objects API but it is not necessary
for an object to inherit from it since the `Agent` will search for the methods
dynamically.
Adaptable objects are objects that require to run their `pre_adapt` and
`post_adapt` methods before (and after, respectively) training on each
experience.
Adaptable objects can implement only the method that they need since the
`Agent` will look for the methods dynamically and call it only if it is
implemented.
"""

def pre_adapt(self, agent: Agent, exp: CLExperience):
pass

def post_adapt(self, agent: Agent, exp: CLExperience):
pass


class BasePlugin(Generic[Template], ABC):
"""ABC for BaseTemplate plugins.
Expand Down
Loading

0 comments on commit 5bc33e1

Please sign in to comment.