Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions python/ray/data/_internal/iterator/stream_split_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
BLOCKED_CLIENT_WARN_TIMEOUT = 30


class _DatasetWrapper:
# A temporary workaround for https://github.com/ray-project/ray/issues/52549

def __init__(self, dataset: "Dataset") -> None:
self._dataset = dataset


class StreamSplitDataIterator(DataIterator):
"""Implements a collection of iterators over a shared data stream."""

Expand All @@ -52,7 +45,7 @@ def create(
scheduling_strategy=NodeAffinitySchedulingStrategy(
ray.get_runtime_context().get_node_id(), soft=False
),
).remote(_DatasetWrapper(base_dataset), n, locality_hints)
).remote(base_dataset, n, locality_hints)

return [
StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n)
Expand Down Expand Up @@ -135,11 +128,10 @@ class SplitCoordinator:

def __init__(
self,
dataset_wrapper: _DatasetWrapper,
dataset: "Dataset",
n: int,
locality_hints: Optional[List[NodeIdStr]],
):
dataset = dataset_wrapper._dataset

# Set current DataContext.
# This needs to be a deep copy so that updates to the base dataset's
Expand Down
3 changes: 1 addition & 2 deletions python/ray/train/v2/tests/test_data_resource_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from ray.data._internal.iterator.stream_split_iterator import (
SplitCoordinator,
_DatasetWrapper,
)
from ray.train.v2._internal.callbacks.datasets import DatasetsSetupCallback
from ray.train.v2._internal.execution.worker_group import (
Expand Down Expand Up @@ -120,7 +119,7 @@ def get_resources_when_updated(requester, prev_requests=None, timeout=3.0):
NUM_SPLITS = 1
dataset = ray.data.range(100)
coord = SplitCoordinator.options(name="test_split_coordinator").remote(
_DatasetWrapper(dataset), NUM_SPLITS, None
dataset, NUM_SPLITS, None
)
ray.get(coord.start_epoch.remote(0))

Expand Down