From 7cc252f56a3ab67a4d5e8b255927e8b0fece165c Mon Sep 17 00:00:00 2001 From: Jim Neuendorf Date: Fri, 7 Jun 2024 10:39:47 +0200 Subject: [PATCH] [benchmarks] typing --- avalanche/benchmarks/utils/classification_dataset.py | 9 ++++----- avalanche/benchmarks/utils/utils.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/avalanche/benchmarks/utils/classification_dataset.py b/avalanche/benchmarks/utils/classification_dataset.py index 0da356f1c..c781da8ca 100644 --- a/avalanche/benchmarks/utils/classification_dataset.py +++ b/avalanche/benchmarks/utils/classification_dataset.py @@ -16,7 +16,6 @@ labels automatically. Concatenation and subsampling operations are optimized to be used frequently, as is common in replay strategies. """ - from functools import partial from typing import ( List, @@ -29,7 +28,7 @@ Dict, Tuple, Mapping, - overload, + overload, Self, ) import torch @@ -64,11 +63,11 @@ ) T_co = TypeVar("T_co", covariant=True) -TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset") +TAvalancheDataset = TypeVar("TAvalancheDataset", bound=AvalancheDataset) TTargetType = int TClassificationDataset = TypeVar( - "TClassificationDataset", bound="ClassificationDataset" + "TClassificationDataset", bound=IDatasetWithTargets ) @@ -114,7 +113,7 @@ def task_pattern_indices(self) -> Dict[int, Sequence[int]]: return self.targets_task_labels.val_to_idx # type: ignore @property - def task_set(self: TClassificationDataset) -> TaskSet[TClassificationDataset]: + def task_set(self) -> TaskSet[Self]: """Returns the dataset's ``TaskSet``, which is a mapping .""" return TaskSet(self) diff --git a/avalanche/benchmarks/utils/utils.py b/avalanche/benchmarks/utils/utils.py index 1f206f028..26e3797ef 100644 --- a/avalanche/benchmarks/utils/utils.py +++ b/avalanche/benchmarks/utils/utils.py @@ -653,13 +653,15 @@ class TaskSet(Mapping[int, TAvalancheDataset], Generic[TAvalancheDataset]): """ + data: TAvalancheDataset + def __init__(self, data: TAvalancheDataset): """Constructor. :param data: original data """ super().__init__() - self.data: TAvalancheDataset = data + self.data = data def __iter__(self) -> Iterator[int]: t_labels = self._get_task_labels_field()