Skip to content

Commit cbe1307

Browse files
authored
Merge pull request #1598 from AntonioCarta/master
fix issue #1597
2 parents 524f70c + ec3f3c6 commit cbe1307

File tree

7 files changed

+130
-90
lines changed

7 files changed

+130
-90
lines changed

avalanche/benchmarks/scenarios/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .dataset_scenario import *
99
from .exmodel_scenario import *
1010
from .online import *
11+
from .validation_scenario import *

avalanche/benchmarks/scenarios/dataset_scenario.py

-86
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import random
1515
from avalanche.benchmarks.utils.data import AvalancheDataset
1616
import torch
17-
from itertools import tee
1817
from typing import (
1918
Callable,
2019
Generator,
@@ -253,94 +252,9 @@ def __iter__(
253252
yield self.split_strategy(new_experience.dataset)
254253

255254

256-
def benchmark_with_validation_stream(
257-
benchmark: CLScenario,
258-
validation_size: Union[int, float] = 0.5,
259-
shuffle: bool = False,
260-
seed: Optional[int] = None,
261-
split_strategy: Optional[
262-
Callable[[AvalancheDataset], Tuple[AvalancheDataset, AvalancheDataset]]
263-
] = None,
264-
) -> CLScenario:
265-
"""Helper to obtain a benchmark with a validation stream.
266-
267-
This generator accepts an existing benchmark instance and returns a version
268-
of it in which the train stream has been split into training and validation
269-
streams.
270-
271-
Each train/validation experience will be by splitting the original training
272-
experiences. Patterns selected for the validation experience will be removed
273-
from the training experiences.
274-
275-
The default splitting strategy is a random split as implemented by `split_validation_random`.
276-
If you want to use class balancing you can use `split_validation_class_balanced`, or
277-
use a custom `split_strategy`, as shown in the following example::
278-
279-
validation_size = 0.2
280-
foo = lambda exp: split_dataset_class_balanced(validation_size, exp)
281-
bm = benchmark_with_validation_stream(bm, custom_split_strategy=foo)
282-
283-
:param benchmark: The benchmark to split.
284-
:param validation_size: The size of the validation experience, as an int
285-
or a float between 0 and 1. Ignored if `custom_split_strategy` is used.
286-
:param shuffle: If True, patterns will be allocated to the validation
287-
stream randomly. This will use the default PyTorch random number
288-
generator at its current state. Defaults to False. Ignored if
289-
`custom_split_strategy` is used. If False, the first instances will be
290-
allocated to the training dataset by leaving the last ones to the
291-
validation dataset.
292-
:param split_strategy: A function that implements a custom splitting
293-
strategy. The function must accept an AvalancheDataset and return a tuple
294-
containing the new train and validation dataset. By default, the splitting
295-
strategy will split the data according to `validation_size` and `shuffle`).
296-
A good starting to understand the mechanism is to look at the
297-
implementation of the standard splitting function
298-
:func:`random_validation_split_strategy`.
299-
300-
:return: A benchmark instance in which the validation stream has been added.
301-
"""
302-
303-
if split_strategy is None:
304-
if seed is None:
305-
seed = random.randint(0, 1000000)
306-
307-
# functools.partial is a more compact option
308-
# However, MyPy does not understand what a partial is -_-
309-
def random_validation_split_strategy_wrapper(data):
310-
return split_validation_random(validation_size, shuffle, seed, data)
311-
312-
split_strategy = random_validation_split_strategy_wrapper
313-
else:
314-
split_strategy = split_strategy
315-
316-
stream = benchmark.streams["train"]
317-
if isinstance(stream, EagerCLStream): # eager split
318-
train_exps, valid_exps = [], []
319-
320-
exp: DatasetExperience
321-
for exp in stream:
322-
train_data, valid_data = split_strategy(exp.dataset)
323-
train_exps.append(DatasetExperience(dataset=train_data))
324-
valid_exps.append(DatasetExperience(dataset=valid_data))
325-
else: # Lazy splitting (based on a generator)
326-
split_generator = LazyTrainValSplitter(split_strategy, stream)
327-
train_exps = (DatasetExperience(dataset=a) for a, _ in split_generator)
328-
valid_exps = (DatasetExperience(dataset=b) for _, b in split_generator)
329-
330-
train_stream = make_stream(name="train", exps=train_exps)
331-
valid_stream = make_stream(name="valid", exps=valid_exps)
332-
other_streams = benchmark.streams
333-
334-
del other_streams["train"]
335-
return CLScenario(
336-
streams=[train_stream, valid_stream] + list(other_streams.values())
337-
)
338-
339-
340255
__all__ = [
341256
"_split_dataset_by_attribute",
342257
"benchmark_from_datasets",
343258
"DatasetExperience",
344259
"split_validation_random",
345-
"benchmark_with_validation_stream",
346260
]

avalanche/benchmarks/scenarios/supervised.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
from avalanche.benchmarks.utils.classification_dataset import (
2727
ClassificationDataset,
2828
_as_taskaware_supervised_classification_dataset,
29-
TaskAwareSupervisedClassificationDataset,
3029
)
3130
from avalanche.benchmarks.utils.data import AvalancheDataset
3231
from avalanche.benchmarks.utils.data_attribute import DataAttribute
3332
from .dataset_scenario import _split_dataset_by_attribute, DatasetExperience
34-
from .. import CLScenario, CLStream, EagerCLStream
33+
from .generic_scenario import CLScenario, CLStream, EagerCLStream
3534

3635

3736
def class_incremental_benchmark(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import (
2+
Callable,
3+
Generator,
4+
Generic,
5+
List,
6+
Sequence,
7+
TypeVar,
8+
Union,
9+
Tuple,
10+
Optional,
11+
Iterable,
12+
Dict,
13+
)
14+
15+
import random
16+
from avalanche.benchmarks.utils.data import AvalancheDataset
17+
from .generic_scenario import EagerCLStream, CLScenario, CLExperience, make_stream
18+
from .dataset_scenario import (
19+
LazyTrainValSplitter,
20+
DatasetExperience,
21+
split_validation_random,
22+
)
23+
from .supervised import with_classes_timeline
24+
25+
26+
def benchmark_with_validation_stream(
27+
benchmark: CLScenario,
28+
validation_size: Union[int, float] = 0.5,
29+
shuffle: bool = False,
30+
seed: Optional[int] = None,
31+
split_strategy: Optional[
32+
Callable[[AvalancheDataset], Tuple[AvalancheDataset, AvalancheDataset]]
33+
] = None,
34+
) -> CLScenario:
35+
"""Helper to obtain a benchmark with a validation stream.
36+
37+
This generator accepts an existing benchmark instance and returns a version
38+
of it in which the train stream has been split into training and validation
39+
streams.
40+
41+
Each train/validation experience will be by splitting the original training
42+
experiences. Patterns selected for the validation experience will be removed
43+
from the training experiences.
44+
45+
The default splitting strategy is a random split as implemented by `split_validation_random`.
46+
If you want to use class balancing you can use `split_validation_class_balanced`, or
47+
use a custom `split_strategy`, as shown in the following example::
48+
49+
validation_size = 0.2
50+
foo = lambda exp: split_dataset_class_balanced(validation_size, exp)
51+
bm = benchmark_with_validation_stream(bm, custom_split_strategy=foo)
52+
53+
:param benchmark: The benchmark to split.
54+
:param validation_size: The size of the validation experience, as an int
55+
or a float between 0 and 1. Ignored if `custom_split_strategy` is used.
56+
:param shuffle: If True, patterns will be allocated to the validation
57+
stream randomly. This will use the default PyTorch random number
58+
generator at its current state. Defaults to False. Ignored if
59+
`custom_split_strategy` is used. If False, the first instances will be
60+
allocated to the training dataset by leaving the last ones to the
61+
validation dataset.
62+
:param split_strategy: A function that implements a custom splitting
63+
strategy. The function must accept an AvalancheDataset and return a tuple
64+
containing the new train and validation dataset. By default, the splitting
65+
strategy will split the data according to `validation_size` and `shuffle`).
66+
A good starting to understand the mechanism is to look at the
67+
implementation of the standard splitting function
68+
:func:`random_validation_split_strategy`.
69+
70+
:return: A benchmark instance in which the validation stream has been added.
71+
"""
72+
73+
if split_strategy is None:
74+
if seed is None:
75+
seed = random.randint(0, 1000000)
76+
77+
# functools.partial is a more compact option
78+
# However, MyPy does not understand what a partial is -_-
79+
def random_validation_split_strategy_wrapper(data):
80+
return split_validation_random(validation_size, shuffle, seed, data)
81+
82+
split_strategy = random_validation_split_strategy_wrapper
83+
else:
84+
split_strategy = split_strategy
85+
86+
stream = benchmark.streams["train"]
87+
if isinstance(stream, EagerCLStream): # eager split
88+
train_exps, valid_exps = [], []
89+
90+
exp: DatasetExperience
91+
for exp in stream:
92+
train_data, valid_data = split_strategy(exp.dataset)
93+
train_exps.append(DatasetExperience(dataset=train_data))
94+
valid_exps.append(DatasetExperience(dataset=valid_data))
95+
else: # Lazy splitting (based on a generator)
96+
split_generator = LazyTrainValSplitter(split_strategy, stream)
97+
train_exps = (DatasetExperience(dataset=a) for a, _ in split_generator)
98+
valid_exps = (DatasetExperience(dataset=b) for _, b in split_generator)
99+
100+
train_stream = make_stream(name="train", exps=train_exps)
101+
valid_stream = make_stream(name="valid", exps=valid_exps)
102+
other_streams = benchmark.streams
103+
104+
# don't drop classes-timeline for compatibility with old API
105+
e0 = next(iter(train_stream))
106+
if hasattr(e0, "dataset") and hasattr(e0.dataset, "targets"):
107+
train_stream = with_classes_timeline(train_stream)
108+
valid_stream = with_classes_timeline(valid_stream)
109+
110+
del other_streams["train"]
111+
return CLScenario(
112+
streams=[train_stream, valid_stream] + list(other_streams.values())
113+
)
114+
115+
116+
__all__ = ["benchmark_with_validation_stream"]

avalanche/models/dynamic_modules.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def adaptation(self, experience: CLExperience):
246246
self.active_units[: old_act_units.shape[0]] = old_act_units
247247
# update with new active classes
248248
if self.training:
249-
self.active_units[curr_classes] = 1
249+
self.active_units[list(curr_classes)] = 1
250250

251251
# update classifier weights
252252
if old_nclasses == new_nclasses:

tests/benchmarks/scenarios/test_dataset_scenario.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
from avalanche.benchmarks import (
88
benchmark_from_datasets,
9-
benchmark_with_validation_stream,
109
CLScenario,
1110
CLStream,
1211
split_validation_random,
1312
task_incremental_benchmark,
1413
)
14+
from avalanche.benchmarks.scenarios.validation_scenario import (
15+
benchmark_with_validation_stream,
16+
)
1517
from avalanche.benchmarks.scenarios.dataset_scenario import (
1618
DatasetExperience,
1719
split_validation_class_balanced,
@@ -383,3 +385,9 @@ def test_gen():
383385
mb = get_mbatch(dd, len(dd))
384386
self.assertTrue(torch.equal(test_x, mb[0]))
385387
self.assertTrue(torch.equal(test_y, mb[1]))
388+
389+
def test_regressioni1597(args):
390+
# regression test for issue #1597
391+
bm = get_fast_benchmark()
392+
for exp in bm.train_stream:
393+
assert hasattr(exp, "classes_in_this_experience")

tests/training/test_plugins.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from avalanche.benchmarks import (
1717
nc_benchmark,
1818
GenericCLScenario,
19+
)
20+
from avalanche.benchmarks.scenarios.validation_scenario import (
1921
benchmark_with_validation_stream,
2022
)
2123
from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader

0 commit comments

Comments
 (0)