Skip to content

Commit 75f9aff

Browse files
authored
Merge pull request #1016 from activeloopai/fix/2.0/pytorch_old
Fixes issues that prevented pytorch_old to run with workers>0
2 parents fa8e734 + 0b79dd4 commit 75f9aff

File tree

4 files changed

+38
-85
lines changed

4 files changed

+38
-85
lines changed

hub/api/dataset.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,13 @@
44
import numpy as np
55

66
from hub.api.tensor import Tensor
7-
from hub.constants import (
8-
DEFAULT_MEMORY_CACHE_SIZE,
9-
DEFAULT_LOCAL_CACHE_SIZE,
10-
MB,
11-
)
7+
from hub.constants import DEFAULT_MEMORY_CACHE_SIZE, DEFAULT_LOCAL_CACHE_SIZE, MB
128

139
from hub.core.meta.dataset_meta import DatasetMeta
1410

1511
from hub.core.typing import StorageProvider
1612
from hub.core.index import Index
17-
from hub.integrations import dataset_to_pytorch, dataset_to_tensorflow
13+
from hub.integrations import dataset_to_tensorflow
1814
from hub.util.keys import dataset_exists, get_dataset_meta_key, tensor_exists
1915
from hub.util.bugout_reporter import hub_reporter
2016
from hub.util.cache_chain import generate_chain
@@ -263,7 +259,6 @@ def pytorch(
263259
self,
264260
transform: Optional[Callable] = None,
265261
num_workers: int = 1,
266-
tensors: Optional[List[str]] = None,
267262
batch_size: Optional[int] = 1,
268263
drop_last: Optional[bool] = False,
269264
collate_fn: Optional[Callable] = None,
@@ -278,8 +273,6 @@ def pytorch(
278273
Args:
279274
transform (Callable, optional) : Transformation function to be applied to each sample.
280275
num_workers (int): The number of workers to use for fetching data in parallel.
281-
tensors (List, optional): Optionally provide a list of tensor names in the ordering that your training script expects.
282-
For example, if the dataset that has "image" and "label" tensors and `tensors=["image", "label"]`, your training script should expect each batch will be provided as a tuple of (image, label).
283276
batch_size (int, optional): Number of samples per batch to load. Default value is 1.
284277
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
285278
If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. Default value is False.
@@ -292,11 +285,12 @@ def pytorch(
292285
Returns:
293286
A torch.utils.data.DataLoader object.
294287
"""
288+
from hub.integrations import dataset_to_pytorch
289+
295290
return dataset_to_pytorch(
296291
self,
297292
transform,
298293
num_workers=num_workers,
299-
tensors=tensors,
300294
batch_size=batch_size,
301295
drop_last=drop_last,
302296
collate_fn=collate_fn,

hub/integrations/pytorch.py

+8-21
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from hub.core.storage import StorageProvider, S3Provider, MemoryProvider
77
from hub.core.meta.tensor_meta import TensorMeta
88
from hub.util.remove_cache import get_base_storage
9-
from hub.util.subscript_namedtuple import subscript_namedtuple as namedtuple
109
from itertools import repeat
1110
from collections import defaultdict
1211
from typing import Any, Callable, List, Optional, Set, Dict, Union, Tuple, Sequence
@@ -75,15 +74,16 @@ def dataset_to_pytorch(
7574
dataset,
7675
transform: Optional[Callable] = None,
7776
num_workers: int = 1,
78-
tensors: Optional[Sequence[str]] = None,
7977
batch_size: Optional[int] = 1,
8078
drop_last: Optional[bool] = False,
8179
collate_fn: Optional[Callable] = None,
8280
pin_memory: Optional[bool] = False,
8381
):
8482
dataset.flush()
8583
_import_torch()
86-
pytorch_ds = TorchDataset(dataset, transform, num_workers, tensors)
84+
# TODO new pytorch approach doesn't support 0 workers currently
85+
num_workers = max(num_workers, 1)
86+
pytorch_ds = TorchDataset(dataset, transform, num_workers)
8787
return torch.utils.data.DataLoader( # type: ignore
8888
pytorch_ds,
8989
batch_size=batch_size,
@@ -99,25 +99,12 @@ def __init__(
9999
dataset,
100100
transform: Optional[Callable] = None,
101101
num_workers: int = 1,
102-
tensors: Optional[Sequence[str]] = None,
103102
):
104103
self.transform = transform
105104
self.num_workers: int = num_workers
106105
self.map = ProcessPool(nodes=num_workers).map
107106
self.length = len(dataset)
108-
self.keys = list(dataset.tensors)
109-
110-
self.tensor_keys: List[str]
111-
if tensors is not None:
112-
for t in tensors:
113-
if t not in dataset.tensors:
114-
raise TensorDoesNotExistError(t)
115-
self.tensor_keys = list(tensors)
116-
else:
117-
self.tensor_keys = list(dataset.tensors)
118-
119-
self._return_type = namedtuple("Tensors", self.tensor_keys)
120-
107+
self.tensor_keys = list(dataset.tensors)
121108
self.storage = get_base_storage(dataset.storage)
122109
if isinstance(self.storage, MemoryProvider):
123110
raise DatasetUnsupportedPytorch(
@@ -199,7 +186,7 @@ def _load_all_chunk_engines(self):
199186
# creating a cache around base storage to pass to ChunkEngine
200187
return {
201188
key: ChunkEngine(key, LRUCache(MemoryProvider(), self.storage, 16 * MB))
202-
for key in self.keys
189+
for key in self.tensor_keys
203190
}
204191

205192
def _load_all_meta(self):
@@ -313,9 +300,9 @@ def _process_samples(self):
313300
last_index = min(self.last_index_meta[key] for key in self.tensor_keys)
314301
samples = []
315302
for i in range(first_index, last_index + 1):
316-
sample = self._return_type(
317-
**{key: self.all_index_value_maps[key][i] for key in self.tensor_keys}
318-
)
303+
sample = {
304+
key: self.all_index_value_maps[key][i] for key in self.tensor_keys
305+
}
319306
samples.append(sample)
320307
self.processed_samples = samples
321308
self.processed_range = slice(first_index, last_index)

hub/integrations/pytorch_old.py

+18-23
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
ModuleNotInstalledException,
88
TensorDoesNotExistError,
99
)
10-
from hub.util.subscript_namedtuple import subscript_namedtuple as namedtuple
10+
import hub
1111

1212

1313
def dataset_to_pytorch(
1414
dataset,
1515
transform: Optional[Callable] = None,
1616
num_workers: int = 1,
17-
tensors: Optional[Sequence[str]] = None,
1817
batch_size: Optional[int] = 1,
1918
drop_last: Optional[bool] = False,
2019
collate_fn: Optional[Callable] = None,
@@ -33,14 +32,9 @@ def dataset_to_pytorch(
3332
pytorch_ds = TorchDataset(
3433
dataset,
3534
transform,
36-
tensors,
3735
python_version_warning=python_version_warning,
3836
)
39-
# TODO add pytorch for num_workers > 1
40-
if num_workers > 0:
41-
raise NotImplementedError(
42-
"Multiproccessed pytorch training is not support for Python version < 3.8. Please set num_workers equal to 0 or upgrade to python 3.8."
43-
)
37+
4438
return torch.utils.data.DataLoader( # type: ignore
4539
pytorch_ds,
4640
num_workers=num_workers,
@@ -56,7 +50,6 @@ def __init__(
5650
self,
5751
dataset,
5852
transform: Optional[Callable] = None,
59-
tensors: Optional[Sequence[str]] = None,
6053
python_version_warning: bool = True,
6154
):
6255

@@ -65,36 +58,38 @@ def __init__(
6558
"Python version<3.8 detected. Pytorch iteration speeds will be slow. Use newer Python versions for faster data streaming to Pytorch."
6659
)
6760

68-
self.dataset = dataset
61+
self.dataset = None
6962

70-
base_storage = get_base_storage(dataset.storage)
71-
if isinstance(base_storage, MemoryProvider):
63+
self.storage = get_base_storage(dataset.storage)
64+
self.index = dataset.index
65+
if isinstance(self.storage, MemoryProvider):
7266
raise DatasetUnsupportedPytorch(
7367
"Datasets whose underlying storage is MemoryProvider are not supported for Pytorch iteration."
7468
)
7569

7670
self.transform = transform
77-
self.tensor_keys: List[str]
78-
if tensors is not None:
79-
for t in tensors:
80-
if t not in dataset.tensors:
81-
raise TensorDoesNotExistError(t)
82-
self.tensor_keys = list(tensors)
83-
else:
84-
self.tensor_keys = list(dataset.tensors)
85-
self._return_type = namedtuple("Tensors", self.tensor_keys)
71+
self.tensor_keys = list(dataset.tensors)
8672

8773
def _apply_transform(self, sample: Union[Dict, Tuple]):
8874
return self.transform(sample) if self.transform else sample
8975

76+
def _init_ds(self):
77+
"""
78+
For each process, dataset should be independently loaded
79+
"""
80+
if self.dataset is None:
81+
self.dataset = hub.Dataset(storage=self.storage, index=self.index)
82+
9083
def __len__(self):
84+
self._init_ds()
9185
return len(self.dataset)
9286

9387
def __getitem__(self, index: int):
94-
sample = self._return_type()
88+
self._init_ds()
89+
sample = {}
9590
# pytorch doesn't support certain dtypes, which are type casted to another dtype below
9691
for key in self.tensor_keys:
97-
item = self.dataset[key][index].numpy()
92+
item = self.dataset[key][index].numpy() # type: ignore
9893
if item.dtype == "uint16":
9994
item = item.astype("int32")
10095
elif item.dtype in ["uint32", "uint64"]:

hub/integrations/tests/test_pytorch.py

+8-31
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from hub.core.tests.common import parametrize_all_dataset_storages
1313

1414

15+
def to_tuple(sample):
16+
return sample["image"], sample["image2"]
17+
18+
1519
@requires_torch
1620
@parametrize_all_dataset_storages
1721
def test_pytorch_small(ds):
@@ -26,11 +30,6 @@ def test_pytorch_small(ds):
2630
dl = ds.pytorch(num_workers=2)
2731
return
2832

29-
if sys.version_info < (3, 8):
30-
with pytest.raises(NotImplementedError):
31-
dl = ds.pytorch(num_workers=2)
32-
return
33-
3433
dl = ds.pytorch(num_workers=2, batch_size=1)
3534

3635
for i, batch in enumerate(dl):
@@ -87,19 +86,11 @@ def test_pytorch_transform(ds):
8786
ds.create_tensor("image2")
8887
ds.image2.extend(np.array([i * np.ones((100, 100)) for i in range(256)]))
8988

90-
def to_tuple(sample):
91-
return sample["image"], sample["image2"]
92-
9389
if isinstance(get_base_storage(ds.storage), MemoryProvider):
9490
with pytest.raises(DatasetUnsupportedPytorch):
9591
dl = ds.pytorch(num_workers=2)
9692
return
9793

98-
if sys.version_info < (3, 8):
99-
with pytest.raises(NotImplementedError):
100-
dl = ds.pytorch(num_workers=2)
101-
return
102-
10394
dl = ds.pytorch(num_workers=2, transform=to_tuple, batch_size=1)
10495

10596
for i, batch in enumerate(dl):
@@ -127,11 +118,6 @@ def test_pytorch_with_compression(ds: Dataset):
127118
dl = ds.pytorch(num_workers=2)
128119
return
129120

130-
if sys.version_info < (3, 8):
131-
with pytest.raises(NotImplementedError):
132-
dl = ds.pytorch(num_workers=2)
133-
return
134-
135121
dl = ds.pytorch(num_workers=2, batch_size=1)
136122

137123
for batch in dl:
@@ -153,13 +139,13 @@ def test_pytorch_small_old(ds):
153139
if isinstance(get_base_storage(ds.storage), MemoryProvider):
154140
with pytest.raises(DatasetUnsupportedPytorch):
155141
dl = dataset_to_pytorch(
156-
ds, num_workers=0, batch_size=1, python_version_warning=False
142+
ds, num_workers=2, batch_size=1, python_version_warning=False
157143
)
158144
return
159145

160146
# .pytorch will automatically switch depending on version, this syntax is being used to ensure testing of old code on Python 3.8
161147
dl = dataset_to_pytorch(
162-
ds, num_workers=0, batch_size=1, python_version_warning=False
148+
ds, num_workers=2, batch_size=1, python_version_warning=False
163149
)
164150

165151
for i, batch in enumerate(dl):
@@ -173,11 +159,7 @@ def test_pytorch_small_old(ds):
173159

174160
@requires_torch
175161
@parametrize_all_dataset_storages
176-
@pytest.mark.xfail(
177-
sys.version_info < (3, 8),
178-
raises=NotImplementedError,
179-
reason="requires python3.8 or higher",
180-
)
162+
@pytest.mark.skip(reason="future")
181163
def test_custom_tensor_order(ds):
182164
with ds:
183165
tensors = ["a", "b", "c", "d"]
@@ -187,17 +169,12 @@ def test_custom_tensor_order(ds):
187169

188170
if isinstance(get_base_storage(ds.storage), MemoryProvider):
189171
with pytest.raises(DatasetUnsupportedPytorch):
190-
ptds = ds.pytorch(num_workers=2)
191-
return
192-
193-
if sys.version_info < (3, 8):
194-
with pytest.raises(NotImplementedError):
195172
dl = ds.pytorch(num_workers=2)
196173
return
197174

198175
dl_new = ds.pytorch(num_workers=2, tensors=["c", "d", "a"])
199176
dl_old = dataset_to_pytorch(
200-
ds, num_workers=0, tensors=["c", "d", "a"], python_version_warning=False
177+
ds, num_workers=2, tensors=["c", "d", "a"], python_version_warning=False
201178
)
202179
for dl in [dl_new, dl_old]:
203180
for i, batch in enumerate(dl):

0 commit comments

Comments
 (0)