Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support benchmark without task labels in avalanche.strategies #1652

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
from typing import Any, Callable, Generic, Iterable, Sequence, Optional, TypeVar, Union

from torch.utils.data import DataLoader
from typing_extensions import TypeAlias
from packaging.version import parse

Expand Down Expand Up @@ -453,9 +455,15 @@ def make_train_dataloader(
if "ffcv_args" in kwargs:
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]

self.dataloader = TaskBalancedDataLoader(
self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args
)
# use task-balanced dataloader for task-aware benchmarks
if hasattr(self.experience, "task_labels"):
self.dataloader = TaskBalancedDataLoader(
self.adapted_dataset,
oversample_small_groups=True,
**other_dataloader_args
)
else:
self.dataloader = DataLoader(self.adapted_dataset, **other_dataloader_args)

def make_eval_dataloader(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def mb_task_id(self):
"""Current mini-batch task labels."""
mbatch = self.mbatch
assert mbatch is not None
assert len(mbatch) >= 3
assert len(mbatch) >= 3, "Task label not found."
return mbatch[-1]

def criterion(self):
Expand All @@ -44,13 +44,16 @@ def criterion(self):

def forward(self):
"""Compute the model's output given the current mini-batch."""
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
# use task-aware forward only for task-aware benchmarks
if hasattr(self.experience, "task_labels"):
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
else:
return self.model(self.mb_x)

def _unpack_minibatch(self):
"""Check if the current mini-batch has 3 components."""
mbatch = self.mbatch
assert mbatch is not None
assert len(mbatch) >= 3

if isinstance(mbatch, tuple):
mbatch = list(mbatch)
Expand Down
97 changes: 97 additions & 0 deletions examples/custom_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
################################################################################
# Copyright (c) 2024 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 31-05-2024 #
# Author(s): Antonio Carta #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################

"""
An exmaple that shows how to create a class-incremental benchmark from a pytorch dataset.
"""

import torch
import argparse
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torchvision.transforms import Compose, Normalize, ToTensor

from avalanche.benchmarks.datasets import MNIST, default_dataset_location
from avalanche.benchmarks.scenarios import class_incremental_benchmark
from avalanche.benchmarks.utils import (
make_avalanche_dataset,
TransformGroups,
DataAttribute,
)
from avalanche.models import SimpleMLP
from avalanche.training.supervised import Naive


def main(args):
# Device config
device = torch.device(
f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
)

# create pytorch dataset
train_data = MNIST(root=default_dataset_location("mnist"), train=True)
test_data = MNIST(root=default_dataset_location("mnist"), train=False)

# prepare transformations
train_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
eval_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
tgroups = TransformGroups({"train": train_transform, "eval": eval_transform})

# create Avalanche datasets with targets attributes (needed to split by class)
da = DataAttribute(train_data.targets, "targets")
train_data = make_avalanche_dataset(
train_data, data_attributes=[da], transform_groups=tgroups
)

da = DataAttribute(test_data.targets, "targets")
test_data = make_avalanche_dataset(
test_data, data_attributes=[da], transform_groups=tgroups
)

# create benchmark
bm = class_incremental_benchmark(
{"train": train_data, "test": test_data}, num_experiences=5
)

# Continual learning strategy
model = SimpleMLP(num_classes=10)
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
cl_strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=32,
train_epochs=100,
eval_mb_size=32,
device=device,
eval_every=1,
)

# train and test loop
results = []
for train_task, test_task in zip(bm.train_stream, bm.test_stream):
print("Current Classes: ", train_task.classes_in_this_experience)
cl_strategy.train(train_task, eval_streams=[test_task])
results.append(cl_strategy.eval(bm.test_stream))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuda",
type=int,
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
args = parser.parse_args()

main(args)
Loading