Skip to content

Commit aff0c91

Browse files
[Datasets] Add from_torch (ray-project#29588)
Co-authored-by: Clark Zinzow <[email protected]>
1 parent e466a9e commit aff0c91

File tree

11 files changed

+83
-175
lines changed

11 files changed

+83
-175
lines changed
44.9 MB
Binary file not shown.

doc/requirements-doc.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ werkzeug
3131
wandb
3232
tensorflow; sys_platform != 'darwin' or platform_machine != 'arm64'
3333
tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'
34+
torch
35+
torchvision
3436
transformers
3537

3638
# Ray libraries

doc/source/data/api/input_output.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ Mars
137137
.. automethod:: ray.data.Dataset.to_mars
138138
:noindex:
139139

140+
Torch
141+
-----
142+
143+
.. autofunction:: ray.data.from_torch
144+
140145
HuggingFace
141146
------------
142147

@@ -193,9 +198,6 @@ Built-in Datasources
193198
.. autoclass:: ray.data.datasource.SimpleTensorFlowDatasource
194199
:members:
195200

196-
.. autoclass:: ray.data.datasource.SimpleTorchDatasource
197-
:members:
198-
199201
.. autoclass:: ray.data.datasource.TFRecordDatasource
200202
:members:
201203

doc/source/data/creating-datasets.rst

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,23 +501,20 @@ From Torch and TensorFlow
501501
.. tabbed:: PyTorch
502502

503503
If you already have a Torch dataset available, you can create a Ray Dataset using
504-
:py:class:`~ray.data.datasource.SimpleTorchDatasource`.
504+
:class:`~ray.data.from_torch`.
505505

506506
.. warning::
507-
:py:class:`~ray.data.datasource.SimpleTorchDatasource` doesn't support parallel
507+
:py:class:`~ray.data.datasource.from_torch` doesn't support parallel
508508
reads. You should only use this datasource for small datasets like MNIST or
509509
CIFAR.
510510

511511
.. code-block:: python
512512
513-
import ray.data
514-
from ray.data.datasource import SimpleTorchDatasource
513+
import ray
515514
import torchvision
516515
517-
dataset_factory = lambda: torchvision.datasets.MNIST("data", download=True)
518-
dataset = ray.data.read_datasource(
519-
SimpleTorchDatasource(), parallelism=1, dataset_factory=dataset_factory
520-
)
516+
dataset = torchvision.datasets.MNIST("data", download=True)
517+
dataset = ray.data.from_torch(dataset)
521518
dataset.take(1)
522519
# (<PIL.Image.Image image mode=L size=28x28 at 0x1142CCA60>, 5)
523520

doc/source/ray-air/examples/torch_image_example.ipynb

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -63,67 +63,49 @@
6363
"metadata": {},
6464
"outputs": [
6565
{
66-
"name": "stderr",
66+
"name": "stdout",
6767
"output_type": "stream",
6868
"text": [
69-
"2022-08-30 15:30:36,678\tINFO worker.py:1510 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n",
70-
"2022-08-30 15:30:37,791\tWARNING read_api.py:291 -- ⚠️ The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n",
71-
"\u001b[2m\u001b[36m(_get_read_tasks pid=3958)\u001b[0m 2022-08-30 15:30:37,789\tWARNING torch_datasource.py:55 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n"
69+
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz\n"
7270
]
7371
},
7472
{
75-
"name": "stdout",
73+
"name": "stderr",
7674
"output_type": "stream",
7775
"text": [
78-
"\u001b[2m\u001b[36m(_execute_read_task pid=3958)\u001b[0m Using downloaded and verified file: ./data/cifar-10-python.tar.gz\n",
79-
"\u001b[2m\u001b[36m(_execute_read_task pid=3958)\u001b[0m Extracting ./data/cifar-10-python.tar.gz to ./data\n"
76+
"100%|██████████| 170498071/170498071 [00:21<00:00, 7792736.24it/s]\n"
8077
]
8178
},
8279
{
83-
"name": "stderr",
80+
"name": "stdout",
8481
"output_type": "stream",
8582
"text": [
86-
"2022-08-30 15:30:44,508\tWARNING read_api.py:291 -- ⚠️ The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n",
87-
"\u001b[2m\u001b[36m(_get_read_tasks pid=3958)\u001b[0m 2022-08-30 15:30:44,507\tWARNING torch_datasource.py:55 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n"
83+
"Extracting data/cifar-10-python.tar.gz to data\n",
84+
"Files already downloaded and verified\n"
8885
]
8986
},
9087
{
91-
"name": "stdout",
88+
"name": "stderr",
9289
"output_type": "stream",
9390
"text": [
94-
"\u001b[2m\u001b[36m(_execute_read_task pid=3958)\u001b[0m Files already downloaded and verified\n"
91+
"2022-10-23 10:33:48,403\tINFO worker.py:1518 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n"
9592
]
9693
}
9794
],
9895
"source": [
9996
"import ray\n",
100-
"from ray.data.datasource import SimpleTorchDatasource\n",
10197
"import torchvision\n",
10298
"import torchvision.transforms as transforms\n",
10399
"\n",
104100
"transform = transforms.Compose(\n",
105101
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
106102
")\n",
107103
"\n",
104+
"train_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=True, transform=transform)\n",
105+
"test_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=False, transform=transform)\n",
108106
"\n",
109-
"def train_dataset_factory():\n",
110-
" return torchvision.datasets.CIFAR10(\n",
111-
" root=\"./data\", download=True, train=True, transform=transform\n",
112-
" )\n",
113-
"\n",
114-
"\n",
115-
"def test_dataset_factory():\n",
116-
" return torchvision.datasets.CIFAR10(\n",
117-
" root=\"./data\", download=True, train=False, transform=transform\n",
118-
" )\n",
119-
"\n",
120-
"\n",
121-
"train_dataset: ray.data.Dataset = ray.data.read_datasource(\n",
122-
" SimpleTorchDatasource(), dataset_factory=train_dataset_factory\n",
123-
")\n",
124-
"test_dataset: ray.data.Dataset = ray.data.read_datasource(\n",
125-
" SimpleTorchDatasource(), dataset_factory=test_dataset_factory\n",
126-
")"
107+
"train_dataset: ray.data.Dataset = ray.data.from_torch(train_dataset)\n",
108+
"test_dataset: ray.data.Dataset = ray.data.from_torch(test_dataset)"
127109
]
128110
},
129111
{
@@ -156,7 +138,7 @@
156138
"id": "a89b59e8",
157139
"metadata": {},
158140
"source": [
159-
"{py:class}`SimpleTorchDatasource <ray.data.datasource.SimpleTorchDatasource>` doesn't parallelize reads, so you shouldn't use it with larger datasets.\n",
141+
"{py:class}`from_torch <ray.data.from_torch>` doesn't parallelize reads, so you shouldn't use it with larger datasets.\n",
160142
"\n",
161143
"Next, let's represent our data using a dictionary of ndarrays instead of tuples. This lets us call {py:meth}`Dataset.iter_torch_batches <ray.data.Dataset.iter_torch_batches>` later in the tutorial."
162144
]
@@ -828,7 +810,7 @@
828810
],
829811
"metadata": {
830812
"kernelspec": {
831-
"display_name": "Python 3.9.12 ('.venv': venv)",
813+
"display_name": "Python 3.10.8 ('.venv': venv)",
832814
"language": "python",
833815
"name": "python3"
834816
},
@@ -842,11 +824,11 @@
842824
"name": "python",
843825
"nbconvert_exporter": "python",
844826
"pygments_lexer": "ipython3",
845-
"version": "3.9.12"
827+
"version": "3.10.8"
846828
},
847829
"vscode": {
848830
"interpreter": {
849-
"hash": "a658351b4133f922c5967ed6133cfc05c9f16c53a5161e5843ace3f528fccaf5"
831+
"hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909"
850832
}
851833
}
852834
},

0 commit comments

Comments
 (0)