Skip to content

Commit

Permalink
Merge pull request #1652 from AntonioCarta/remove_task_label_strategy
Browse files Browse the repository at this point in the history
support benchmark without task labels in avalanche.strategies
  • Loading branch information
AntonioCarta authored Jun 3, 2024
2 parents bbd0778 + 4dedf6e commit d752103
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 6 deletions.
14 changes: 11 additions & 3 deletions avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union

from torch.utils.data import DataLoader
from typing_extensions import TypeAlias
from packaging.version import parse

Expand Down Expand Up @@ -453,9 +455,15 @@ def make_train_dataloader(
if "ffcv_args" in kwargs:
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]

self.dataloader = TaskBalancedDataLoader(
self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args
)
# use task-balanced dataloader for task-aware benchmarks
if hasattr(self.experience, "task_labels"):
self.dataloader = TaskBalancedDataLoader(
self.adapted_dataset,
oversample_small_groups=True,
**other_dataloader_args
)
else:
self.dataloader = DataLoader(self.adapted_dataset, **other_dataloader_args)

def make_eval_dataloader(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def mb_task_id(self):
"""Current mini-batch task labels."""
mbatch = self.mbatch
assert mbatch is not None
assert len(mbatch) >= 3
assert len(mbatch) >= 3, "Task label not found."
return mbatch[-1]

def criterion(self):
Expand All @@ -44,13 +44,16 @@ def criterion(self):

def forward(self):
"""Compute the model's output given the current mini-batch."""
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
# use task-aware forward only for task-aware benchmarks
if hasattr(self.experience, "task_labels"):
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
else:
return self.model(self.mb_x)

def _unpack_minibatch(self):
"""Check if the current mini-batch has 3 components."""
mbatch = self.mbatch
assert mbatch is not None
assert len(mbatch) >= 3

if isinstance(mbatch, tuple):
mbatch = list(mbatch)
Expand Down
97 changes: 97 additions & 0 deletions examples/custom_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
################################################################################
# Copyright (c) 2024 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 31-05-2024 #
# Author(s): Antonio Carta #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################

"""
An exmaple that shows how to create a class-incremental benchmark from a pytorch dataset.
"""

import torch
import argparse
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision.transforms import Compose, Normalize, ToTensor

from avalanche.benchmarks.datasets import MNIST, default_dataset_location
from avalanche.benchmarks.scenarios import class_incremental_benchmark
from avalanche.benchmarks.utils import (
make_avalanche_dataset,
TransformGroups,
DataAttribute,
)
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive


def main(args):
# Device config
device = torch.device(
f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
)

# create pytorch dataset
train_data = MNIST(root=default_dataset_location("mnist"), train=True)
test_data = MNIST(root=default_dataset_location("mnist"), train=False)

# prepare transformations
train_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
eval_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
tgroups = TransformGroups({"train": train_transform, "eval": eval_transform})

# create Avalanche datasets with targets attributes (needed to split by class)
da = DataAttribute(train_data.targets, "targets")
train_data = make_avalanche_dataset(
train_data, data_attributes=[da], transform_groups=tgroups
)

da = DataAttribute(test_data.targets, "targets")
test_data = make_avalanche_dataset(
test_data, data_attributes=[da], transform_groups=tgroups
)

# create benchmark
bm = class_incremental_benchmark(
{"train": train_data, "test": test_data}, num_experiences=5
)

# Continual learning strategy
model = SimpleMLP(num_classes=10)
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
cl_strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=32,
train_epochs=100,
eval_mb_size=32,
device=device,
eval_every=1,
)

# train and test loop
results = []
for train_task, test_task in zip(bm.train_stream, bm.test_stream):
print("Current Classes: ", train_task.classes_in_this_experience)
cl_strategy.train(train_task, eval_streams=[test_task])
results.append(cl_strategy.eval(bm.test_stream))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuda",
type=int,
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
args = parser.parse_args()

main(args)

0 comments on commit d752103

Please sign in to comment.