Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements of the checkpoint functionality and memory occupation of DER #1567

Merged
merged 13 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions avalanche/benchmarks/classic/cimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def SplitImageNet(
class_ids_from_zero_in_each_exp: bool = False,
class_ids_from_zero_from_first_exp: bool = False,
train_transform: Optional[Any] = _default_train_transform,
eval_transform: Optional[Any] = _default_eval_transform
eval_transform: Optional[Any] = _default_eval_transform,
meta_root: Optional[Union[str, Path]] = None,
):
"""
Creates a CL benchmark using the ImageNet dataset.
Expand Down Expand Up @@ -130,11 +131,19 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
comprehensive list of possible transformations).
If no transformation is passed, the default test transformation
will be used.
:param meta_root: Directory where the `ILSVRC2012_devkit_t12.tar.gz`
file can be found. The first time you use this dataset, the meta file will be
extracted from the archive and a `meta.bin` file will be created in the `meta_root`
directory. Defaults to None, which means that the meta file is expected to be
in the path provied in the `root` argument.
This is an additional argument not found in the original ImageNet class
from the torchvision package. For more info, see the `meta_root` argument
in the :class:`AvalancheImageNet` class.

:returns: A properly initialized :class:`NCScenario` instance.
"""

train_set, test_set = _get_imagenet_dataset(dataset_root)
train_set, test_set = _get_imagenet_dataset(dataset_root, meta_root=meta_root)

return nc_benchmark(
train_dataset=train_set,
Expand All @@ -152,10 +161,10 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
)


def _get_imagenet_dataset(root):
train_set = ImageNet(root, split="train")
def _get_imagenet_dataset(root, meta_root=None):
train_set = ImageNet(root, split="train", meta_root=meta_root)

test_set = ImageNet(root, split="val")
test_set = ImageNet(root, split="val", meta_root=meta_root)

return train_set, test_set

Expand Down
21 changes: 21 additions & 0 deletions avalanche/benchmarks/datasets/core50/core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import glob
import os
import pickle as pkl
import dill
from pathlib import Path
from typing import List, Optional, Tuple, Union
from warnings import warn
Expand All @@ -26,6 +27,7 @@
from avalanche.benchmarks.datasets.downloadable_dataset import (
DownloadableDataset,
)
from avalanche.checkpointing import constructor_based_serialization


class CORe50Dataset(DownloadableDataset):
Expand Down Expand Up @@ -247,6 +249,25 @@ def CORe50(*args, **kwargs):
return CORe50Dataset(*args, **kwargs)


@dill.register(CORe50Dataset)
def checkpoint_CORe50Dataset(pickler, obj: CORe50Dataset):
constructor_based_serialization(
pickler,
obj,
CORe50Dataset,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
loader=obj.loader,
mini=obj.mini,
object_level=obj.object_level,
),
)


if __name__ == "__main__":
# this litte example script can be used to visualize the first image
# leaded from the dataset.
Expand Down
19 changes: 19 additions & 0 deletions avalanche/benchmarks/datasets/cub200/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import gdown
import os
import dill
from collections import OrderedDict
from torchvision.datasets.folder import default_loader

Expand All @@ -31,6 +32,7 @@
DownloadableDataset,
)
from avalanche.benchmarks.utils import PathsDataset
from avalanche.checkpointing import constructor_based_serialization


class CUB200(PathsDataset, DownloadableDataset):
Expand Down Expand Up @@ -178,6 +180,23 @@ def _load_metadata(self):
return True


@dill.register(CUB200)
def checkpoint_CUB200(pickler, obj: CUB200):
constructor_based_serialization(
pickler,
obj,
CUB200,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
loader=obj.loader,
),
)


if __name__ == "__main__":
"""Simple test that will start if you run this script directly"""

Expand Down
41 changes: 25 additions & 16 deletions avalanche/benchmarks/datasets/external_datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torchvision.datasets import CIFAR100, CIFAR10

from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.checkpointing import constructor_based_serialization


def get_cifar10_dataset(dataset_root):
Expand Down Expand Up @@ -31,26 +32,34 @@ def load_CIFAR100(root, train, transform, target_transform):


@dill.register(CIFAR100)
def save_CIFAR100(pickler, obj: CIFAR100):
pickler.save_reduce(
load_CIFAR100,
(obj.root, obj.train, obj.transform, obj.target_transform),
obj=obj,
)


def load_CIFAR10(root, train, transform, target_transform):
return CIFAR10(
root=root, train=train, transform=transform, target_transform=target_transform
def checkpoint_CIFAR100(pickler, obj: CIFAR100):
constructor_based_serialization(
pickler,
obj,
CIFAR100,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


@dill.register(CIFAR10)
def save_CIFAR10(pickler, obj: CIFAR10):
pickler.save_reduce(
load_CIFAR10,
(obj.root, obj.train, obj.transform, obj.target_transform),
obj=obj,
def checkpoint_CIFAR10(pickler, obj: CIFAR10):
constructor_based_serialization(
pickler,
obj,
CIFAR10,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


Expand Down
24 changes: 13 additions & 11 deletions avalanche/benchmarks/datasets/external_datasets/fmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torchvision.datasets import FashionMNIST

from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.checkpointing import constructor_based_serialization


def get_fmnist_dataset(dataset_root):
Expand All @@ -13,18 +14,19 @@ def get_fmnist_dataset(dataset_root):
return train_set, test_set


def load_FashionMNIST(root, train, transform, target_transform):
return FashionMNIST(
root=root, train=train, transform=transform, target_transform=target_transform
)


@dill.register(FashionMNIST)
def save_FashionMNIST(pickler, obj: FashionMNIST):
pickler.save_reduce(
load_FashionMNIST,
(obj.root, obj.train, obj.transform, obj.target_transform),
obj=obj,
def checkpoint_FashionMNIST(pickler, obj: FashionMNIST):
constructor_based_serialization(
pickler,
obj,
FashionMNIST,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


Expand Down
22 changes: 13 additions & 9 deletions avalanche/benchmarks/datasets/external_datasets/mnist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dill
from torchvision.datasets import MNIST
from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.checkpointing import constructor_based_serialization


class TensorMNIST(MNIST):
Expand Down Expand Up @@ -35,16 +36,19 @@ def get_mnist_dataset(dataset_root):
return train_set, test_set


def load_MNIST(root, train, transform, target_transform):
return TensorMNIST(
root=root, train=train, transform=transform, target_transform=target_transform
)


@dill.register(TensorMNIST)
def save_MNIST(pickler, obj: TensorMNIST):
pickler.save_reduce(
load_MNIST, (obj.root, obj.train, obj.transform, obj.target_transform), obj=obj
def checkpoint_TensorMNIST(pickler, obj: TensorMNIST):
constructor_based_serialization(
pickler,
obj,
TensorMNIST,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


Expand Down
1 change: 1 addition & 0 deletions avalanche/benchmarks/datasets/imagenet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .imagenet import *
Loading
Loading