Skip to content

Commit

Permalink
Fix minor issue. Move declaration of CriterionType.
Browse files Browse the repository at this point in the history
  • Loading branch information
lrzpellegrini committed Feb 2, 2024
1 parent 043effa commit 384d82a
Show file tree
Hide file tree
Showing 20 changed files with 53 additions and 48 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ jobs:
echo "Running checkpointing tests..." &&
bash ./tests/checkpointing/test_checkpointing.sh &&
echo "Running distributed training tests..." &&
cd tests &&
PYTHONPATH=.. python run_dist_tests.py &&
cd .. &&
echo "While running unit tests, the following datasets were downloaded:" &&
ls ~/.avalanche/data
7 changes: 2 additions & 5 deletions avalanche/training/supervised/ar1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
CWRStarPlugin,
)
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.utils import (
replace_bn_with_brn,
get_last_fc_layer,
Expand All @@ -27,6 +26,7 @@
LayerAndParameter,
)
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class AR1(SupervisedTemplate):
Expand All @@ -44,7 +44,7 @@ class AR1(SupervisedTemplate):
def __init__(
self,
*,
criterion: CriterionType = None,
criterion: CriterionType = CrossEntropyLoss(),
lr: float = 0.001,
inc_lr: float = 5e-5,
momentum=0.9,
Expand Down Expand Up @@ -149,9 +149,6 @@ def __init__(

optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=l2)

if criterion is None:
criterion = CrossEntropyLoss()

self.ewc_lambda = ewc_lambda
self.freeze_below_layer = freeze_below_layer
self.rm_sz = rm_sz
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class Cumulative(SupervisedTemplate):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/deep_slda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from avalanche.models.dynamic_modules import MultiTaskModule
from avalanche.models import FeatureExtractorBackbone
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class StreamingLDA(SupervisedTemplate):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/der.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.data_attribute import TensorDataAttribute
from avalanche.benchmarks.utils.flat_data import FlatData
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle
from avalanche.core import SupervisedPlugin
from avalanche.training.plugins.evaluation import (
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/er_ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle


Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/er_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle


Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/expert_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin, LwFPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class ExpertGateStrategy(SupervisedTemplate):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/feature_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from avalanche.training.plugins.evaluation import EvaluationPlugin, default_evaluator
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.utils import cycle
from avalanche.training.losses import MaskedCrossEntropy

Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/joint_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_experiences_parameter_as_iterable,
_group_experiences_by_stream,
)
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class AlreadyTrainedError(Exception):
Expand Down
2 changes: 2 additions & 0 deletions avalanche/training/supervised/lamaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from torch.optim import Optimizer
import math

from avalanche.training.templates.strategy_mixin_protocol import CriterionType

try:
import higher
except ImportError:
Expand Down
2 changes: 2 additions & 0 deletions avalanche/training/supervised/lamaml_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
import torch

from avalanche.training.templates.strategy_mixin_protocol import CriterionType

if parse(torch.__version__) < parse("2.0.0"):
warnings.warn(f"LaMAML requires torch >= 2.0.0.")

Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/mer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.training.storage_policy import ReservoirSamplingBuffer
from avalanche.training.templates import SupervisedMetaLearningTemplate
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class MERBuffer:
Expand Down
4 changes: 2 additions & 2 deletions avalanche/training/supervised/strategy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
)
from avalanche.training.templates.base import BaseTemplate
from avalanche.training.templates import SupervisedTemplate
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.evaluation.metrics import loss_metrics
from avalanche.models.generator import MlpVAE, VAE_loss
from avalanche.models.expert_gate import AE_loss
from avalanche.logging import InteractiveLogger
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


class Naive(SupervisedTemplate):
Expand Down
2 changes: 1 addition & 1 deletion avalanche/training/supervised/strategy_wrappers_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SupervisedTemplate,
)
from avalanche._annotations import deprecated
from avalanche.training.templates.base_sgd import CriterionType
from avalanche.training.templates.strategy_mixin_protocol import CriterionType


@deprecated(
Expand Down
1 change: 1 addition & 0 deletions avalanche/training/templates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init_subclass__(cls, **kwargs):
cls.__init__ = _support_legacy_strategy_positional_args(
cls.__init__, cls.__name__
)
super().__init_subclass__(**kwargs)

# we need this only for type checking
PLUGIN_CLASS = BasePlugin
Expand Down
39 changes: 19 additions & 20 deletions avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Callable, Iterable, Sequence, Optional, TypeVar, Union
from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union
from typing_extensions import TypeAlias
from packaging.version import parse

Expand All @@ -22,18 +22,20 @@
collate_from_data_or_kwargs,
)

from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol
from avalanche.training.templates.strategy_mixin_protocol import (
CriterionType,
SGDStrategyProtocol,
)
from avalanche.training.utils import trigger_plugins


TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience)
TMBInput = TypeVar("TMBInput")
TMBOutput = TypeVar("TMBOutput")

CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]]


class BaseSGDTemplate(
Generic[TDatasetExperience, TMBInput, TMBOutput],
SGDStrategyProtocol[TDatasetExperience, TMBInput, TMBOutput],
BaseTemplate[TDatasetExperience],
):
Expand Down Expand Up @@ -93,21 +95,6 @@ def __init__(
`eval_every` epochs or iterations (Default='epoch').
"""

super().__init__(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
peval_mode=peval_mode,
**kwargs
)

# Call super with all args
if sys.version_info >= (3, 11):
super().__init__(
Expand All @@ -127,7 +114,19 @@ def __init__(
else:
super().__init__() # type: ignore
BaseTemplate.__init__(
self=self, model=model, device=device, plugins=plugins
self,
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
peval_mode=peval_mode,
**kwargs
)

self.optimizer: Optimizer = optimizer
Expand Down
10 changes: 4 additions & 6 deletions avalanche/training/templates/common_templates.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys
from typing import Any, Callable, Dict, Sequence, Optional, TypeVar, Union
import warnings
from typing import Callable, Sequence, Optional, TypeVar, Union
import torch
import inspect

from torch.nn import Module, CrossEntropyLoss
from torch.optim import Optimizer
Expand All @@ -13,8 +11,8 @@
EvaluationPlugin,
default_evaluator,
)
from avalanche.training.templates.base import PositionalArgumentDeprecatedWarning
from avalanche.training.templates.strategy_mixin_protocol import (
CriterionType,
SupervisedStrategyProtocol,
TMBOutput,
TMBInput,
Expand All @@ -23,7 +21,7 @@
from .observation_type import *
from .problem_type import *
from .update_type import *
from .base_sgd import BaseSGDTemplate, CriterionType
from .base_sgd import BaseSGDTemplate


TDatasetExperience = TypeVar("TDatasetExperience", bound=DatasetExperience)
Expand Down Expand Up @@ -202,7 +200,7 @@ def __init__(
*,
model: Module,
optimizer: Optimizer,
criterion=CrossEntropyLoss(),
criterion: CriterionType = CrossEntropyLoss(),
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: Optional[int] = 1,
Expand Down
10 changes: 7 additions & 3 deletions avalanche/training/templates/strategy_mixin_protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable, List, Optional, TypeVar, Protocol
from typing import Generic, Iterable, List, Optional, TypeVar, Protocol, Callable, Union
from typing_extensions import TypeAlias

from torch import Tensor
import torch
Expand All @@ -17,8 +18,10 @@
TMBInput = TypeVar("TMBInput")
TMBOutput = TypeVar("TMBOutput")

CriterionType: TypeAlias = Union[Module, Callable[[Tensor, Tensor], Tensor]]

class BaseStrategyProtocol(Protocol[TExperienceType]):

class BaseStrategyProtocol(Generic[TExperienceType], Protocol[TExperienceType]):
model: Module

device: torch.device
Expand All @@ -33,6 +36,7 @@ class BaseStrategyProtocol(Protocol[TExperienceType]):


class SGDStrategyProtocol(
Generic[TSGDExperienceType, TMBInput, TMBOutput],
BaseStrategyProtocol[TSGDExperienceType],
Protocol[TSGDExperienceType, TMBInput, TMBOutput],
):
Expand All @@ -52,7 +56,7 @@ class SGDStrategyProtocol(

loss: Tensor

_criterion: Module
_criterion: CriterionType

def forward(self) -> TMBOutput: ...

Expand Down
4 changes: 4 additions & 0 deletions tests/run_dist_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import signal
import sys
import unittest
Expand Down Expand Up @@ -34,6 +35,9 @@ def get_distributed_test_cases(suite: Union[TestCase, TestSuite]) -> Set[str]:
@click.command()
@click.argument("test_cases", nargs=-1)
def run_distributed_suites(test_cases):
if Path.cwd().name != "tests":
os.chdir(Path.cwd() / "tests")

cases_names = get_distributed_test_cases(
unittest.defaultTestLoader.discover(".")
) # Don't change the path!
Expand Down

0 comments on commit 384d82a

Please sign in to comment.