diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 769968d09..2972480ee 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -59,8 +59,6 @@ jobs: echo "Running checkpointing tests..." && bash ./tests/checkpointing/test_checkpointing.sh && echo "Running distributed training tests..." && - cd tests && PYTHONPATH=.. python run_dist_tests.py && - cd .. && echo "While running unit tests, the following datasets were downloaded:" && ls ~/.avalanche/data diff --git a/avalanche/training/supervised/ar1.py b/avalanche/training/supervised/ar1.py index a06370706..41040b9b0 100644 --- a/avalanche/training/supervised/ar1.py +++ b/avalanche/training/supervised/ar1.py @@ -17,7 +17,6 @@ CWRStarPlugin, ) from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType from avalanche.training.utils import ( replace_bn_with_brn, get_last_fc_layer, @@ -27,6 +26,7 @@ LayerAndParameter, ) from avalanche.training.plugins.evaluation import default_evaluator +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class AR1(SupervisedTemplate): @@ -44,7 +44,7 @@ class AR1(SupervisedTemplate): def __init__( self, *, - criterion: CriterionType = None, + criterion: CriterionType = CrossEntropyLoss(), lr: float = 0.001, inc_lr: float = 5e-5, momentum=0.9, @@ -149,9 +149,6 @@ def __init__( optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2) - if criterion is None: - criterion = CrossEntropyLoss() - self.ewc_lambda = ewc_lambda self.freeze_below_layer = freeze_below_layer self.rm_sz = rm_sz diff --git a/avalanche/training/supervised/cumulative.py b/avalanche/training/supervised/cumulative.py index 314966be2..0e073456c 100644 --- a/avalanche/training/supervised/cumulative.py +++ b/avalanche/training/supervised/cumulative.py @@ -8,7 +8,7 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class Cumulative(SupervisedTemplate): diff --git a/avalanche/training/supervised/deep_slda.py b/avalanche/training/supervised/deep_slda.py index 7bc336d74..c234c44c9 100644 --- a/avalanche/training/supervised/deep_slda.py +++ b/avalanche/training/supervised/deep_slda.py @@ -13,7 +13,7 @@ ) from avalanche.models.dynamic_modules import MultiTaskModule from avalanche.models import FeatureExtractorBackbone -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class StreamingLDA(SupervisedTemplate): diff --git a/avalanche/training/supervised/der.py b/avalanche/training/supervised/der.py index 7df915748..d9263f2a2 100644 --- a/avalanche/training/supervised/der.py +++ b/avalanche/training/supervised/der.py @@ -19,7 +19,7 @@ from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute from avalanche.benchmarks.utils.flat_data import FlatData -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle from avalanche.core import SupervisedPlugin from avalanche.training.plugins.evaluation import ( diff --git a/avalanche/training/supervised/er_ace.py b/avalanche/training/supervised/er_ace.py index d9a260e15..01146c432 100644 --- a/avalanche/training/supervised/er_ace.py +++ b/avalanche/training/supervised/er_ace.py @@ -13,7 +13,7 @@ ) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle diff --git a/avalanche/training/supervised/er_aml.py b/avalanche/training/supervised/er_aml.py index a79b574e0..9d44f1f1f 100644 --- a/avalanche/training/supervised/er_aml.py +++ b/avalanche/training/supervised/er_aml.py @@ -13,7 +13,7 @@ ) from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index fd04b7ad7..abb249216 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -22,7 +22,7 @@ from avalanche.training.templates import SupervisedTemplate from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin, LwFPlugin from avalanche.training.plugins.evaluation import default_evaluator -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class ExpertGateStrategy(SupervisedTemplate): diff --git a/avalanche/training/supervised/feature_replay.py b/avalanche/training/supervised/feature_replay.py index 04e355f4e..e723b1515 100644 --- a/avalanche/training/supervised/feature_replay.py +++ b/avalanche/training/supervised/feature_replay.py @@ -11,7 +11,7 @@ from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator from avalanche.training.storage_policy import ClassBalancedBuffer from avalanche.training.templates import SupervisedTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType from avalanche.training.utils import cycle from avalanche.training.losses import MaskedCrossEntropy diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 9b6ca3dc4..08294cfc0 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -28,7 +28,7 @@ _experiences_parameter_as_iterable, _group_experiences_by_stream, ) -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class AlreadyTrainedError(Exception): diff --git a/avalanche/training/supervised/lamaml.py b/avalanche/training/supervised/lamaml.py index d727a778f..d880df347 100644 --- a/avalanche/training/supervised/lamaml.py +++ b/avalanche/training/supervised/lamaml.py @@ -8,6 +8,8 @@ from torch.optim import Optimizer import math +from avalanche.training.templates.strategy_mixin_protocol import CriterionType + try: import higher except ImportError: diff --git a/avalanche/training/supervised/lamaml_v2.py b/avalanche/training/supervised/lamaml_v2.py index 42fd8bfbf..36905cb80 100644 --- a/avalanche/training/supervised/lamaml_v2.py +++ b/avalanche/training/supervised/lamaml_v2.py @@ -3,6 +3,8 @@ import warnings import torch +from avalanche.training.templates.strategy_mixin_protocol import CriterionType + if parse(torch.__version__) < parse("2.0.0"): warnings.warn(f"LaMAML requires torch >= 2.0.0.") diff --git a/avalanche/training/supervised/mer.py b/avalanche/training/supervised/mer.py index a834a5930..d4d709d17 100644 --- a/avalanche/training/supervised/mer.py +++ b/avalanche/training/supervised/mer.py @@ -11,7 +11,7 @@ from avalanche.training.plugins.evaluation import default_evaluator from avalanche.training.storage_policy import ReservoirSamplingBuffer from avalanche.training.templates import SupervisedMetaLearningTemplate -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class MERBuffer: diff --git a/avalanche/training/supervised/strategy_wrappers.py b/avalanche/training/supervised/strategy_wrappers.py index c8ddcab44..31fbe8f4d 100644 --- a/avalanche/training/supervised/strategy_wrappers.py +++ b/avalanche/training/supervised/strategy_wrappers.py @@ -45,11 +45,11 @@ ) from avalanche.training.templates.base import BaseTemplate from avalanche.training.templates import SupervisedTemplate -from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics +from avalanche.evaluation.metrics import loss_metrics from avalanche.models.generator import MlpVAE, VAE_loss from avalanche.models.expert_gate import AE_loss from avalanche.logging import InteractiveLogger -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType class Naive(SupervisedTemplate): diff --git a/avalanche/training/supervised/strategy_wrappers_online.py b/avalanche/training/supervised/strategy_wrappers_online.py index 0210eb005..ffe6f1575 100644 --- a/avalanche/training/supervised/strategy_wrappers_online.py +++ b/avalanche/training/supervised/strategy_wrappers_online.py @@ -21,7 +21,7 @@ SupervisedTemplate, ) from avalanche._annotations import deprecated -from avalanche.training.templates.base_sgd import CriterionType +from avalanche.training.templates.strategy_mixin_protocol import CriterionType @deprecated( diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index c1d85957a..f491a951a 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -248,6 +248,7 @@ def __init_subclass__(cls, **kwargs): cls.__init__ = _support_legacy_strategy_positional_args( cls.__init__, cls.__name__ ) + super().__init_subclass__(**kwargs) # we need this only for type checking PLUGIN_CLASS = BasePlugin diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index e0133179c..e54f402a6 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union +from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union from typing_extensions import TypeAlias from packaging.version import parse @@ -22,7 +22,10 @@ collate_from_data_or_kwargs, ) -from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol +from avalanche.training.templates.strategy_mixin_protocol import ( + CriterionType, + SGDStrategyProtocol, +) from avalanche.training.utils import trigger_plugins @@ -30,10 +33,9 @@ TMBInput = TypeVar("TMBInput") TMBOutput = TypeVar("TMBOutput") -CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]] - class BaseSGDTemplate( + Generic[TDatasetExperience, TMBInput, TMBOutput], SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput], BaseTemplate[TDatasetExperience], ): @@ -93,21 +95,6 @@ def __init__( `eval_every` epochs or iterations (Default='epoch'). """ - super().__init__( - model=model, - optimizer=optimizer, - criterion=criterion, - train_mb_size=train_mb_size, - train_epochs=train_epochs, - eval_mb_size=eval_mb_size, - device=device, - plugins=plugins, - evaluator=evaluator, - eval_every=eval_every, - peval_mode=peval_mode, - **kwargs - ) - # Call super with all args if sys.version_info >= (3, 11): super().__init__( @@ -127,7 +114,19 @@ def __init__( else: super().__init__() # type: ignore BaseTemplate.__init__( - self=self, model=model, device=device, plugins=plugins + self, + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs ) self.optimizer: Optimizer = optimizer diff --git a/avalanche/training/templates/common_templates.py b/avalanche/training/templates/common_templates.py index 7dbde7e9a..8405c5a6f 100644 --- a/avalanche/training/templates/common_templates.py +++ b/avalanche/training/templates/common_templates.py @@ -1,8 +1,6 @@ import sys -from typing import Any, Callable, Dict, Sequence, Optional, TypeVar, Union -import warnings +from typing import Callable, Sequence, Optional, TypeVar, Union import torch -import inspect from torch.nn import Module, CrossEntropyLoss from torch.optim import Optimizer @@ -13,8 +11,8 @@ EvaluationPlugin, default_evaluator, ) -from avalanche.training.templates.base import PositionalArgumentDeprecatedWarning from avalanche.training.templates.strategy_mixin_protocol import ( + CriterionType, SupervisedStrategyProtocol, TMBOutput, TMBInput, @@ -23,7 +21,7 @@ from .observation_type import * from .problem_type import * from .update_type import * -from .base_sgd import BaseSGDTemplate, CriterionType +from .base_sgd import BaseSGDTemplate TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience) @@ -202,7 +200,7 @@ def __init__( *, model: Module, optimizer: Optimizer, - criterion=CrossEntropyLoss(), + criterion: CriterionType = CrossEntropyLoss(), train_mb_size: int = 1, train_epochs: int = 1, eval_mb_size: Optional[int] = 1, diff --git a/avalanche/training/templates/strategy_mixin_protocol.py b/avalanche/training/templates/strategy_mixin_protocol.py index c2635f4ad..4596c9d39 100644 --- a/avalanche/training/templates/strategy_mixin_protocol.py +++ b/avalanche/training/templates/strategy_mixin_protocol.py @@ -1,4 +1,5 @@ -from typing import Iterable, List, Optional, TypeVar, Protocol +from typing import Generic, Iterable, List, Optional, TypeVar, Protocol, Callable, Union +from typing_extensions import TypeAlias from torch import Tensor import torch @@ -17,8 +18,10 @@ TMBInput = TypeVar("TMBInput") TMBOutput = TypeVar("TMBOutput") +CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]] -class BaseStrategyProtocol(Protocol[TExperienceType]): + +class BaseStrategyProtocol(Generic[TExperienceType], Protocol[TExperienceType]): model: Module device: torch.device @@ -33,6 +36,7 @@ class BaseStrategyProtocol(Protocol[TExperienceType]): class SGDStrategyProtocol( + Generic[TSGDExperienceType, TMBInput, TMBOutput], BaseStrategyProtocol[TSGDExperienceType], Protocol[TSGDExperienceType, TMBInput, TMBOutput], ): @@ -52,7 +56,7 @@ class SGDStrategyProtocol( loss: Tensor - _criterion: Module + _criterion: CriterionType def forward(self) -> TMBOutput: ... diff --git a/tests/run_dist_tests.py b/tests/run_dist_tests.py index fa8d9f94d..232e98645 100644 --- a/tests/run_dist_tests.py +++ b/tests/run_dist_tests.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import signal import sys import unittest @@ -34,6 +35,9 @@ def get_distributed_test_cases(suite: Union[TestCase, TestSuite]) -> Set[str]: @click.command() @click.argument("test_cases", nargs=-1) def run_distributed_suites(test_cases): + if Path.cwd().name != "tests": + os.chdir(Path.cwd() / "tests") + cases_names = get_distributed_test_cases( unittest.defaultTestLoader.discover(".") ) # Don't change the path!