-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_benchmark_generators.py
105 lines (93 loc) · 3.12 KB
/
my_benchmark_generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from avalanche.benchmarks.generators import nc_benchmark, ni_benchmark
from avalanche.benchmarks.scenarios import NCScenario, NIScenario
from functools import partial
from itertools import tee
from typing import (
Sequence,
Optional,
Dict,
Union,
Any,
List,
Callable,
Set,
Tuple,
Iterable,
Generator,
)
import torch
from avalanche.benchmarks import (
GenericCLScenario,
ClassificationExperience,
ClassificationStream,
)
from avalanche.benchmarks.scenarios.generic_benchmark_creation import *
from avalanche.benchmarks.scenarios.classification_scenario import (
TStreamsUserDict,
StreamUserDef,
)
from avalanche.benchmarks.scenarios.new_classes.nc_scenario import NCScenario
from avalanche.benchmarks.scenarios.new_instances.ni_scenario import NIScenario
from my_ni_scenario import myNIScenario
from avalanche.benchmarks.utils import concat_datasets_sequentially
from avalanche.benchmarks.utils.avalanche_dataset import (
SupportedDataset,
AvalancheDataset,
AvalancheDatasetType,
AvalancheSubset,
)
def my_ni_benchmark(
train_dataset: Union[Sequence[SupportedDataset], SupportedDataset],
test_dataset: Union[Sequence[SupportedDataset], SupportedDataset],
n_experiences: int,
*,
task_labels: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
fixed_class_order: Sequence[int] = None,
balance_experiences: bool = False,
min_class_patterns_in_exp: int = 0,
fixed_exp_assignment: Optional[Sequence[Sequence[int]]] = None,
train_transform=None,
eval_transform=None,
reproducibility_data: Optional[Dict[str, Any]] = None,
) -> myNIScenario:
seq_train_dataset, seq_test_dataset = train_dataset, test_dataset
if isinstance(train_dataset, list) or isinstance(train_dataset, tuple):
if len(train_dataset) != len(test_dataset):
raise ValueError(
"Train/test dataset lists must contain the "
"exact same number of datasets"
)
seq_train_dataset, seq_test_dataset, _ = concat_datasets_sequentially(
train_dataset, test_dataset
)
transform_groups = dict(
train=(train_transform, None), eval=(eval_transform, None)
)
# Datasets should be instances of AvalancheDataset
seq_train_dataset = AvalancheDataset(
seq_train_dataset,
transform_groups=transform_groups,
initial_transform_group="train",
dataset_type=AvalancheDatasetType.CLASSIFICATION,
)
seq_test_dataset = AvalancheDataset(
seq_test_dataset,
transform_groups=transform_groups,
initial_transform_group="eval",
dataset_type=AvalancheDatasetType.CLASSIFICATION,
)
return myNIScenario(
seq_train_dataset,
seq_test_dataset,
n_experiences,
task_labels,
shuffle=shuffle,
seed=seed,
fixed_class_order=fixed_class_order,
balance_experiences=balance_experiences,
min_class_patterns_in_exp=min_class_patterns_in_exp,
fixed_exp_assignment=fixed_exp_assignment,
reproducibility_data=reproducibility_data,
)