Skip to content

Commit 91841ac

Browse files
committed
support benchmark without task labels in avalanche.strategies
1 parent bbd0778 commit 91841ac

File tree

3 files changed

+118
-6
lines changed

3 files changed

+118
-6
lines changed

avalanche/training/templates/base_sgd.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import sys
22
from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union
3+
4+
from torch.utils.data import DataLoader
35
from typing_extensions import TypeAlias
46
from packaging.version import parse
57

@@ -453,9 +455,17 @@ def make_train_dataloader(
453455
if "ffcv_args" in kwargs:
454456
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]
455457

456-
self.dataloader = TaskBalancedDataLoader(
457-
self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args
458-
)
458+
# use task-balanced dataloader for task-aware benchmarks
459+
if hasattr(self.experience, "task_label") or hasattr(
460+
self.experience, "task_labels"
461+
):
462+
self.dataloader = TaskBalancedDataLoader(
463+
self.adapted_dataset,
464+
oversample_small_groups=True,
465+
**other_dataloader_args
466+
)
467+
else:
468+
self.dataloader = DataLoader(self.adapted_dataset, **other_dataloader_args)
459469

460470
def make_eval_dataloader(
461471
self,

avalanche/training/templates/problem_type/supervised_problem.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def mb_task_id(self):
3535
"""Current mini-batch task labels."""
3636
mbatch = self.mbatch
3737
assert mbatch is not None
38-
assert len(mbatch) >= 3
38+
assert len(mbatch) >= 3, "Task label not found."
3939
return mbatch[-1]
4040

4141
def criterion(self):
@@ -44,13 +44,18 @@ def criterion(self):
4444

4545
def forward(self):
4646
"""Compute the model's output given the current mini-batch."""
47-
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
47+
# use task-aware forward only for task-aware benchmarks
48+
if hasattr(self.experience, "task_labels") or hasattr(
49+
self.experience, "task_label"
50+
):
51+
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
52+
else:
53+
return self.model(self.mb_x)
4854

4955
def _unpack_minibatch(self):
5056
"""Check if the current mini-batch has 3 components."""
5157
mbatch = self.mbatch
5258
assert mbatch is not None
53-
assert len(mbatch) >= 3
5459

5560
if isinstance(mbatch, tuple):
5661
mbatch = list(mbatch)

examples/custom_datasets.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
################################################################################
2+
# Copyright (c) 2024 ContinualAI. #
3+
# Copyrights licensed under the MIT License. #
4+
# See the accompanying LICENSE file for terms. #
5+
# #
6+
# Date: 31-05-2024 #
7+
# Author(s): Antonio Carta #
8+
# E-mail: [email protected] #
9+
# Website: avalanche.continualai.org #
10+
################################################################################
11+
12+
"""
13+
An exmaple that shows how to create a class-incremental benchmark from a pytorch dataset.
14+
"""
15+
16+
import torch
17+
import argparse
18+
from torch.nn import CrossEntropyLoss
19+
from torch.optim import SGD
20+
from torchvision.transforms import Compose, Normalize, ToTensor
21+
22+
from avalanche.benchmarks.datasets import MNIST, default_dataset_location
23+
from avalanche.benchmarks.scenarios import class_incremental_benchmark
24+
from avalanche.benchmarks.utils import (
25+
make_avalanche_dataset,
26+
TransformGroups,
27+
DataAttribute,
28+
)
29+
from avalanche.models import SimpleMLP
30+
from avalanche.training.supervised import Naive
31+
32+
33+
def main(args):
34+
# Device config
35+
device = torch.device(
36+
f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
37+
)
38+
39+
# create pytorch dataset
40+
train_data = MNIST(root=default_dataset_location("mnist"), train=True)
41+
test_data = MNIST(root=default_dataset_location("mnist"), train=False)
42+
43+
# prepare transformations
44+
train_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
45+
eval_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
46+
tgroups = TransformGroups({"train": train_transform, "eval": eval_transform})
47+
48+
# create Avalanche datasets with targets attributes (needed to split by class)
49+
da = DataAttribute(train_data.targets, "targets")
50+
train_data = make_avalanche_dataset(
51+
train_data, data_attributes=[da], transform_groups=tgroups
52+
)
53+
54+
da = DataAttribute(test_data.targets, "targets")
55+
test_data = make_avalanche_dataset(
56+
test_data, data_attributes=[da], transform_groups=tgroups
57+
)
58+
59+
# create benchmark
60+
bm = class_incremental_benchmark(
61+
{"train": train_data, "test": test_data}, num_experiences=5
62+
)
63+
64+
# Continual learning strategy
65+
model = SimpleMLP(num_classes=10)
66+
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
67+
criterion = CrossEntropyLoss()
68+
cl_strategy = Naive(
69+
model=model,
70+
optimizer=optimizer,
71+
criterion=criterion,
72+
train_mb_size=32,
73+
train_epochs=100,
74+
eval_mb_size=32,
75+
device=device,
76+
eval_every=1,
77+
)
78+
79+
# train and test loop
80+
results = []
81+
for train_task, test_task in zip(bm.train_stream, bm.test_stream):
82+
print("Current Classes: ", train_task.classes_in_this_experience)
83+
cl_strategy.train(train_task, eval_streams=[test_task])
84+
results.append(cl_strategy.eval(bm.test_stream))
85+
86+
87+
if __name__ == "__main__":
88+
parser = argparse.ArgumentParser()
89+
parser.add_argument(
90+
"--cuda",
91+
type=int,
92+
default=0,
93+
help="Select zero-indexed cuda device. -1 to use CPU.",
94+
)
95+
args = parser.parse_args()
96+
97+
main(args)

0 commit comments

Comments
 (0)