diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index e54f402a6..b51046edf 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -1,5 +1,7 @@ import sys from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union + +from torch.utils.data import DataLoader from typing_extensions import TypeAlias from packaging.version import parse @@ -453,9 +455,15 @@ def make_train_dataloader( if "ffcv_args" in kwargs: other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"] - self.dataloader = TaskBalancedDataLoader( - self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args - ) + # use task-balanced dataloader for task-aware benchmarks + if hasattr(self.experience, "task_labels"): + self.dataloader = TaskBalancedDataLoader( + self.adapted_dataset, + oversample_small_groups=True, + **other_dataloader_args + ) + else: + self.dataloader = DataLoader(self.adapted_dataset, **other_dataloader_args) def make_eval_dataloader( self, diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py index 1de5639ab..d7c8b7d54 100644 --- a/avalanche/training/templates/problem_type/supervised_problem.py +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -35,7 +35,7 @@ def mb_task_id(self): """Current mini-batch task labels.""" mbatch = self.mbatch assert mbatch is not None - assert len(mbatch) >= 3 + assert len(mbatch) >= 3, "Task label not found." return mbatch[-1] def criterion(self): @@ -44,13 +44,16 @@ def criterion(self): def forward(self): """Compute the model's output given the current mini-batch.""" - return avalanche_forward(self.model, self.mb_x, self.mb_task_id) + # use task-aware forward only for task-aware benchmarks + if hasattr(self.experience, "task_labels"): + return avalanche_forward(self.model, self.mb_x, self.mb_task_id) + else: + return self.model(self.mb_x) def _unpack_minibatch(self): """Check if the current mini-batch has 3 components.""" mbatch = self.mbatch assert mbatch is not None - assert len(mbatch) >= 3 if isinstance(mbatch, tuple): mbatch = list(mbatch) diff --git a/examples/custom_datasets.py b/examples/custom_datasets.py new file mode 100644 index 000000000..fd531b835 --- /dev/null +++ b/examples/custom_datasets.py @@ -0,0 +1,97 @@ +################################################################################ +# Copyright (c) 2024 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 31-05-2024 # +# Author(s): Antonio Carta # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +""" +An exmaple that shows how to create a class-incremental benchmark from a pytorch dataset. +""" + +import torch +import argparse +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torchvision.transforms import Compose, Normalize, ToTensor + +from avalanche.benchmarks.datasets import MNIST, default_dataset_location +from avalanche.benchmarks.scenarios import class_incremental_benchmark +from avalanche.benchmarks.utils import ( + make_avalanche_dataset, + TransformGroups, + DataAttribute, +) +from avalanche.models import SimpleMLP +from avalanche.training.supervised import Naive + + +def main(args): + # Device config + device = torch.device( + f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu" + ) + + # create pytorch dataset + train_data = MNIST(root=default_dataset_location("mnist"), train=True) + test_data = MNIST(root=default_dataset_location("mnist"), train=False) + + # prepare transformations + train_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + eval_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) + tgroups = TransformGroups({"train": train_transform, "eval": eval_transform}) + + # create Avalanche datasets with targets attributes (needed to split by class) + da = DataAttribute(train_data.targets, "targets") + train_data = make_avalanche_dataset( + train_data, data_attributes=[da], transform_groups=tgroups + ) + + da = DataAttribute(test_data.targets, "targets") + test_data = make_avalanche_dataset( + test_data, data_attributes=[da], transform_groups=tgroups + ) + + # create benchmark + bm = class_incremental_benchmark( + {"train": train_data, "test": test_data}, num_experiences=5 + ) + + # Continual learning strategy + model = SimpleMLP(num_classes=10) + optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) + criterion = CrossEntropyLoss() + cl_strategy = Naive( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=32, + train_epochs=100, + eval_mb_size=32, + device=device, + eval_every=1, + ) + + # train and test loop + results = [] + for train_task, test_task in zip(bm.train_stream, bm.test_stream): + print("Current Classes: ", train_task.classes_in_this_experience) + cl_strategy.train(train_task, eval_streams=[test_task]) + results.append(cl_strategy.eval(bm.test_stream)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--cuda", + type=int, + default=0, + help="Select zero-indexed cuda device. -1 to use CPU.", + ) + args = parser.parse_args() + + main(args)