Skip to content

Commit

Permalink
[benchmarks] typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jneuendorf committed Jun 7, 2024
1 parent 1486573 commit 7cc252f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 4 additions & 5 deletions avalanche/benchmarks/utils/classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
labels automatically. Concatenation and subsampling operations are optimized
to be used frequently, as is common in replay strategies.
"""

from functools import partial
from typing import (
List,
Expand All @@ -29,7 +28,7 @@
Dict,
Tuple,
Mapping,
overload,
overload, Self,
)

import torch
Expand Down Expand Up @@ -64,11 +63,11 @@
)

T_co = TypeVar("T_co", covariant=True)
TAvalancheDataset = TypeVar("TAvalancheDataset", bound="AvalancheDataset")
TAvalancheDataset = TypeVar("TAvalancheDataset", bound=AvalancheDataset)
TTargetType = int

TClassificationDataset = TypeVar(
"TClassificationDataset", bound="ClassificationDataset"
"TClassificationDataset", bound=IDatasetWithTargets
)


Expand Down Expand Up @@ -114,7 +113,7 @@ def task_pattern_indices(self) -> Dict[int, Sequence[int]]:
return self.targets_task_labels.val_to_idx # type: ignore

@property
def task_set(self: TClassificationDataset) -> TaskSet[TClassificationDataset]:
def task_set(self) -> TaskSet[Self]:
"""Returns the dataset's ``TaskSet``, which is a mapping <task-id,
task-dataset>."""
return TaskSet(self)
Expand Down
4 changes: 3 additions & 1 deletion avalanche/benchmarks/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,13 +653,15 @@ class TaskSet(Mapping[int, TAvalancheDataset], Generic[TAvalancheDataset]):
"""

data: TAvalancheDataset

def __init__(self, data: TAvalancheDataset):
"""Constructor.
:param data: original data
"""
super().__init__()
self.data: TAvalancheDataset = data
self.data = data

def __iter__(self) -> Iterator[int]:
t_labels = self._get_task_labels_field()
Expand Down

0 comments on commit 7cc252f

Please sign in to comment.