|
| 1 | +################################################################################ |
| 2 | +# Copyright (c) 2024 ContinualAI. # |
| 3 | +# Copyrights licensed under the MIT License. # |
| 4 | +# See the accompanying LICENSE file for terms. # |
| 5 | +# # |
| 6 | +# Date: 31-05-2024 # |
| 7 | +# Author(s): Antonio Carta # |
| 8 | + |
| 9 | +# Website: avalanche.continualai.org # |
| 10 | +################################################################################ |
| 11 | + |
| 12 | +""" |
| 13 | +An exmaple that shows how to create a class-incremental benchmark from a pytorch dataset. |
| 14 | +""" |
| 15 | + |
| 16 | +import torch |
| 17 | +import argparse |
| 18 | +from torch.nn import CrossEntropyLoss |
| 19 | +from torch.optim import SGD |
| 20 | +from torchvision.transforms import Compose, Normalize, ToTensor |
| 21 | + |
| 22 | +from avalanche.benchmarks.datasets import MNIST, default_dataset_location |
| 23 | +from avalanche.benchmarks.scenarios import class_incremental_benchmark |
| 24 | +from avalanche.benchmarks.utils import ( |
| 25 | + make_avalanche_dataset, |
| 26 | + TransformGroups, |
| 27 | + DataAttribute, |
| 28 | +) |
| 29 | +from avalanche.models import SimpleMLP |
| 30 | +from avalanche.training.supervised import Naive |
| 31 | + |
| 32 | + |
| 33 | +def main(args): |
| 34 | + # Device config |
| 35 | + device = torch.device( |
| 36 | + f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu" |
| 37 | + ) |
| 38 | + |
| 39 | + # create pytorch dataset |
| 40 | + train_data = MNIST(root=default_dataset_location("mnist"), train=True) |
| 41 | + test_data = MNIST(root=default_dataset_location("mnist"), train=False) |
| 42 | + |
| 43 | + # prepare transformations |
| 44 | + train_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) |
| 45 | + eval_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) |
| 46 | + tgroups = TransformGroups({"train": train_transform, "eval": eval_transform}) |
| 47 | + |
| 48 | + # create Avalanche datasets with targets attributes (needed to split by class) |
| 49 | + da = DataAttribute(train_data.targets, "targets") |
| 50 | + train_data = make_avalanche_dataset( |
| 51 | + train_data, data_attributes=[da], transform_groups=tgroups |
| 52 | + ) |
| 53 | + |
| 54 | + da = DataAttribute(test_data.targets, "targets") |
| 55 | + test_data = make_avalanche_dataset( |
| 56 | + test_data, data_attributes=[da], transform_groups=tgroups |
| 57 | + ) |
| 58 | + |
| 59 | + # create benchmark |
| 60 | + bm = class_incremental_benchmark( |
| 61 | + {"train": train_data, "test": test_data}, num_experiences=5 |
| 62 | + ) |
| 63 | + |
| 64 | + # Continual learning strategy |
| 65 | + model = SimpleMLP(num_classes=10) |
| 66 | + optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9) |
| 67 | + criterion = CrossEntropyLoss() |
| 68 | + cl_strategy = Naive( |
| 69 | + model=model, |
| 70 | + optimizer=optimizer, |
| 71 | + criterion=criterion, |
| 72 | + train_mb_size=32, |
| 73 | + train_epochs=100, |
| 74 | + eval_mb_size=32, |
| 75 | + device=device, |
| 76 | + eval_every=1, |
| 77 | + ) |
| 78 | + |
| 79 | + # train and test loop |
| 80 | + results = [] |
| 81 | + for train_task, test_task in zip(bm.train_stream, bm.test_stream): |
| 82 | + print("Current Classes: ", train_task.classes_in_this_experience) |
| 83 | + cl_strategy.train(train_task, eval_streams=[test_task]) |
| 84 | + results.append(cl_strategy.eval(bm.test_stream)) |
| 85 | + |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + parser = argparse.ArgumentParser() |
| 89 | + parser.add_argument( |
| 90 | + "--cuda", |
| 91 | + type=int, |
| 92 | + default=0, |
| 93 | + help="Select zero-indexed cuda device. -1 to use CPU.", |
| 94 | + ) |
| 95 | + args = parser.parse_args() |
| 96 | + |
| 97 | + main(args) |
0 commit comments