Skip to content
Merged
60 changes: 60 additions & 0 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
from packaging.version import parse as parse_version
from pandas.core.dtypes.dtypes import BaseMaskedDtype

from ray._private.arrow_utils import get_pyarrow_version
from ray._private.ray_constants import env_integer
Expand Down Expand Up @@ -1049,3 +1051,61 @@ def _try_combine_chunks_safe(
new_chunks.append(pa.concat_arrays(cur_chunk_group))

return pa.chunked_array(new_chunks)


def convert_pandas_dtype_to_pyarrow(
dtype: Union[np.dtype, "pd.ArrowDtype", "pd.StringDtype", "BaseMaskedDtype"]
) -> "pyarrow.DataType":
"""Convert a pandas dtype to a PyArrow DataType.

Handles pandas extension dtypes (Int32, Int64, StringDtype, etc.),
ArrowDtype, and regular numpy dtypes.

Args:
dtype: A pandas dtype (numpy dtype, ArrowDtype, or extension dtype).

Returns:
The equivalent PyArrow DataType.
"""
import pandas as pd
from pandas.core.dtypes.dtypes import BaseMaskedDtype

from ray.data.extensions import TensorDtype

if isinstance(dtype, pd.ArrowDtype):
return dtype.pyarrow_dtype
elif isinstance(dtype, pd.StringDtype):
# StringDtype is not a BaseMaskedDtype, handle separately
return pyarrow.string()
elif isinstance(dtype, BaseMaskedDtype):
# Nullable integer types like Int32, Int64
dtype = dtype.numpy_dtype
elif isinstance(dtype, TensorDtype):
# Convert TensorDtype to Arrow tensor type.
# For variable-shaped tensors, use Ray's extension type.
# For fixed-shape tensors, use Arrow's native fixed_shape_tensor.
element_dtype = convert_pandas_dtype_to_pyarrow(dtype._dtype)

if dtype.is_variable_shaped:
# Use Ray's extension type for variable-shaped tensors
from ray.data.extensions import ArrowVariableShapedTensorType

# Extract ndim from shape tuple length
ndim = len(dtype._shape)
return ArrowVariableShapedTensorType(
dtype=element_dtype,
ndim=ndim,
)
else:
# Use Arrow's native fixed_shape_tensor for fixed-shape tensors
return pyarrow.fixed_shape_tensor(
element_type=element_dtype,
shape=dtype._shape,
)
elif hasattr(dtype, "kind") and dtype.kind == "O":
# Numpy object dtype - assume string
return pyarrow.string()
elif dtype is object:
# Python object type - assume string
return pyarrow.string()
return pyarrow.from_numpy_dtype(dtype)
7 changes: 5 additions & 2 deletions python/ray/data/_internal/datasource/bigquery_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import tempfile
import time
import uuid
from typing import Iterable, Optional
from typing import TYPE_CHECKING, Iterable, Optional

import pyarrow.parquet as pq

if TYPE_CHECKING:
import pyarrow as pa

import ray
from ray.data._internal.datasource import bigquery_datasource
from ray.data._internal.execution.interfaces import TaskContext
Expand Down Expand Up @@ -38,7 +41,7 @@ def __init__(
self.max_retry_cnt = max_retry_cnt
self.overwrite_table = overwrite_table

def on_write_start(self) -> None:
def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
from google.api_core import exceptions

if self.project_id is None or self.dataset is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)

import pyarrow
import pyarrow as pa
import pyarrow.types as pat

from ray.data._internal.execution.interfaces import TaskContext
Expand Down Expand Up @@ -300,7 +301,7 @@ def _get_existing_order_by(self, client) -> Optional[str]:
def supports_distributed_writes(self) -> bool:
return True

def on_write_start(self) -> None:
def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
client = None
try:
client = self._init_client()
Expand Down
4 changes: 2 additions & 2 deletions python/ray/data/_internal/datasource/iceberg_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
import pyarrow as pa
from pyiceberg.catalog import Catalog
from pyiceberg.manifest import DataFile


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -83,7 +83,7 @@ def _get_catalog(self) -> "Catalog":

return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs)

def on_write_start(self) -> None:
def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
"""Prepare for the transaction"""
import pyiceberg
from pyiceberg.table import TableProperties
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/datasource/lance_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
def supports_distributed_writes(self) -> bool:
return True

def on_write_start(self):
def on_write_start(self, schema: Optional["pa.Schema"] = None) -> None:
_check_import(self, module="lance", package="pylance")

import lance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,20 @@
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)

if TYPE_CHECKING:
import pyarrow as pa

import ray
from ray.actor import ActorHandle
Expand Down Expand Up @@ -83,6 +96,7 @@ def __init__(
ray_remote_args: Optional[Dict[str, Any]] = None,
ray_actor_task_remote_args: Optional[Dict[str, Any]] = None,
target_max_block_size_override: Optional[int] = None,
on_start: Optional[Callable[[Optional["pa.Schema"]], None]] = None,
):
"""Create an ActorPoolMapOperator instance.

Expand Down Expand Up @@ -112,6 +126,8 @@ def __init__(
ray_actor_task_remote_args: Ray Core options passed to map actor tasks.
target_max_block_size_override: The target maximum number of bytes to
include in an output block.
on_start: Optional callback invoked with the schema from the first input
bundle before any tasks are submitted.
"""
super().__init__(
map_transformer,
Expand All @@ -125,6 +141,7 @@ def __init__(
map_task_kwargs,
ray_remote_args_fn,
ray_remote_args,
on_start,
)

self._min_rows_per_bundle = min_rows_per_bundle
Expand Down
71 changes: 71 additions & 0 deletions python/ray/data/_internal/execution/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Expand All @@ -18,6 +19,9 @@
Union,
)

if TYPE_CHECKING:
import pyarrow as pa

import ray
from ray import ObjectRef
from ray._raylet import ObjectRefGenerator
Expand Down Expand Up @@ -130,6 +134,7 @@ def __init__(
map_task_kwargs: Optional[Dict[str, Any]],
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]],
ray_remote_args: Optional[Dict[str, Any]],
on_start: Optional[Callable[[Optional["pa.Schema"]], None]] = None,
):
# NOTE: This constructor should not be called directly; use MapOperator.create()
# instead.
Expand Down Expand Up @@ -169,13 +174,62 @@ def __init__(
# Callback functions that generate additional task kwargs
# for the map task.
self._map_task_kwargs_fns: List[Callable[[], Dict[str, Any]]] = []
# Callback for when first input bundle is ready (before task submission).
# Receives schema from the first bundle for deferred initialization
# (e.g., schema evolution for Iceberg writes via on_write_start).
self._on_start: Optional[Callable[[Optional["pa.Schema"]], None]] = on_start
self._start_called = False

def add_map_task_kwargs_fn(self, map_task_kwargs_fn: Callable[[], Dict[str, Any]]):
"""Add a callback function that generates additional kwargs for the map tasks.
In the map tasks, the kwargs can be accessible via `TaskContext.kwargs`.
"""
self._map_task_kwargs_fns.append(map_task_kwargs_fn)

def _notify_first_input(self, bundled_input: RefBundle) -> None:
"""Invoke on_start callback with schema if registered and not yet invoked.

Used for deferred initialization that needs schema from the first bundle
(e.g., schema evolution for Iceberg writes via on_write_start).
"""
if not self._start_called and self._on_start is not None:
schema = self._get_schema_from_bundle(bundled_input)
self._on_start(schema)
self._start_called = True

def _get_schema_from_bundle(self, bundle: RefBundle) -> Optional["pa.Schema"]:
"""Extract PyArrow schema from a RefBundle without fetching block data."""
import pyarrow as pa

from ray.data._internal.arrow_ops.transform_pyarrow import (
convert_pandas_dtype_to_pyarrow,
)
from ray.data._internal.pandas_block import PandasBlockSchema
from ray.data.dataset import Schema

if bundle.schema is None:
return None

schema = bundle.schema

# Unwrap Schema wrapper if present
if isinstance(schema, Schema):
schema = schema.base_schema

# Already a PyArrow schema - use directly
if isinstance(schema, pa.Schema):
return schema

# PandasBlockSchema - convert to PyArrow
if isinstance(schema, PandasBlockSchema):
fields = []
for name, dtype in zip(schema.names, schema.types):
pa_type = convert_pandas_dtype_to_pyarrow(dtype)
fields.append(pa.field(name, pa_type))
return pa.schema(fields)

return None

def get_map_task_kwargs(self) -> Dict[str, Any]:
"""Get the kwargs for the map task.
Subclasses should pass the returned kwargs to the map tasks.
Expand Down Expand Up @@ -245,6 +299,7 @@ def create(
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
per_block_limit: Optional[int] = None,
on_start: Optional[Callable[[Optional["pa.Schema"]], None]] = None,
) -> "MapOperator":
"""Create a MapOperator.

Expand Down Expand Up @@ -276,6 +331,10 @@ def create(
advanced, experimental feature.
ray_remote_args: Customize the :func:`ray.remote` args for this op's tasks.
per_block_limit: Maximum number of rows to process per block, for early termination.
on_start: Optional callback invoked with the schema from the first input
bundle before any tasks are submitted. Used for deferred initialization
that requires schema from actual data (e.g., schema evolution for
Iceberg writes).
"""
if (ref_bundler is not None and min_rows_per_bundle is not None) or (
min_rows_per_bundle is not None and ref_bundler is not None
Expand Down Expand Up @@ -311,6 +370,7 @@ def create(
map_task_kwargs=map_task_kwargs,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
on_start=on_start,
)
elif isinstance(compute_strategy, ActorPoolStrategy):
from ray.data._internal.execution.operators.actor_pool_map_operator import (
Expand All @@ -330,6 +390,7 @@ def create(
map_task_kwargs=map_task_kwargs,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
on_start=on_start,
)
else:
raise ValueError(f"Unsupported execution strategy {compute_strategy}")
Expand Down Expand Up @@ -415,6 +476,10 @@ def _add_input_inner(self, refs: RefBundle, input_index: int):
for bundle in input_refs:
self._metrics.on_input_dequeued(bundle)

# Invoke first-input callback before task submission (for deferred init).
# This is used by write operators to call on_write_start with schema.
self._notify_first_input(bundled_input)

# If the bundler has a full bundle, add it to the operator's task submission
# queue
self._add_bundled_input(bundled_input)
Expand Down Expand Up @@ -551,6 +616,12 @@ def all_inputs_done(self):
_,
bundled_input,
) = self._block_ref_bundler.get_next_bundle()

# Invoke first-input callback before task submission (for deferred init).
# This handles small datasets where bundles never met the threshold during
# normal processing and were deferred to all_inputs_done().
self._notify_first_input(bundled_input)

self._add_bundled_input(bundled_input)
super().all_inputs_done()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings
from typing import Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional

if TYPE_CHECKING:
import pyarrow as pa

from ray.data._internal.execution.interfaces import (
ExecutionResources,
Expand Down Expand Up @@ -34,6 +37,7 @@ def __init__(
map_task_kwargs: Optional[Dict[str, Any]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
on_start: Optional[Callable[[Optional["pa.Schema"]], None]] = None,
):
"""Create an TaskPoolMapOperator instance.

Expand All @@ -59,6 +63,8 @@ def __init__(
always override the args in ``ray_remote_args``. Note: this is an
advanced, experimental feature.
ray_remote_args: Customize the :func:`ray.remote` args for this op's tasks.
on_start: Optional callback invoked with the schema from the first input
bundle before any tasks are submitted.
"""
super().__init__(
map_transformer,
Expand All @@ -72,6 +78,7 @@ def __init__(
map_task_kwargs,
ray_remote_args_fn,
ray_remote_args,
on_start,
)

if max_concurrency is not None and max_concurrency <= 0:
Expand Down
Loading