Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ In particular, for pytorch-based models.

experimental.AnnCollection
experimental.AnnLoader
experimental.pytorch.batch_dict_converter
```

Out of core concatenation
Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes/2135.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Add `batch_converter` parameter and multiprocessing support to {class}`~anndata.experimental.pytorch.AnnLoader`.

- Added `batch_converter` parameter for batch-level post-processing of data batches
- Added {func}`~anndata.experimental.pytorch.batch_dict_converter` helper function for converting batches to tensor dictionaries
- Fixed multiprocessing support (`num_workers > 0`) by implementing pickling for `AnnCollectionView` objects
- Batch converters now work seamlessly with both single-threaded and multi-threaded data loading

{user}`ronamit`
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ filterwarnings_when_strict = [
"default:Consolidated metadata is:UserWarning",
"default:.*Structured:zarr.core.dtype.common.UnstableSpecificationWarning",
"default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning",
"default:'oneOf' deprecated - use 'one_of':DeprecationWarning",
"default:'parseString' deprecated - use 'parse_string':DeprecationWarning",
"default:'resetCache' deprecated - use 'reset_cache':DeprecationWarning",
"default:'enablePackrat' deprecated - use 'enable_packrat':DeprecationWarning",
]
python_files = "test_*.py"
testpaths = [
Expand Down
68 changes: 67 additions & 1 deletion src/anndata/experimental/multi_files/_anncollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,43 @@ def __init__(self, reference, convert, resolved_idx):
self._convert_X = None
self.convert = convert

# ------------------------------------------------------------------
# Pickling support (worker-safe)
# ------------------------------------------------------------------

def __getstate__(self):
"""Return minimal state for safe pickling across processes.

We only serialise lightweight metadata – the on-disk store is reopened
in the child worker to avoid passing an open HDF5 handle.
"""
return {
"reference_state": self.reference.__getstate__()
if hasattr(self.reference, "__getstate__")
else None,
"oidx": self.oidx,
"vidx": self.vidx,
"convert": self.convert,
}

def __setstate__(self, state):
from anndata.experimental.multi_files import (
AnnCollection, # local import to avoid circular
)

# Rebuild from saved reference_state (in-memory collections)
parent = AnnCollection.__new__(AnnCollection)
if state["reference_state"] is not None:
parent.__setstate__(state["reference_state"])
else:
msg = "Cannot restore AnnCollectionView without reference state"
raise ValueError(msg)

# Recreate the view via slicing to reuse internal helpers
view = parent[state["oidx"], state["vidx"]]
self.__dict__.update(view.__dict__)
self.convert = state["convert"]

def _lazy_init_attr(self, attr: str, *, set_vidx: bool = False):
if getattr(self, f"_{attr}_view") is not None:
return
Expand Down Expand Up @@ -779,7 +816,7 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915
ai_attr = getattr(a, attr)
a0_attr = getattr(adatas[0], attr)
new_keys = []
for key in keys:
for key in keys or []:
if key in ai_attr:
a0_ashape = a0_attr[key].shape
ai_ashape = ai_attr[key].shape
Expand All @@ -806,6 +843,35 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915

self.indices_strict = indices_strict

# ------------------------------------------------------------------
# Pickling support (worker-safe)
# ------------------------------------------------------------------

def __getstate__(self):
"""Return state for pickling. For in-memory collections, we serialize all data."""
return {
"adatas": self.adatas,
"join_obs": getattr(self, "_join_obs", "inner"),
"join_obsm": getattr(self, "_join_obsm", None),
"join_vars": getattr(self, "_join_vars", None),
"convert": self._convert,
"harmonize_dtypes": getattr(self, "_harmonize_dtypes", True),
"indices_strict": self.indices_strict,
}

def __setstate__(self, state):
"""Restore state from pickling."""
# Reconstruct AnnCollection from saved adatas and parameters
self.__init__(
state["adatas"],
join_obs=state["join_obs"],
join_obsm=state["join_obsm"],
join_vars=state["join_vars"],
convert=state["convert"],
harmonize_dtypes=state["harmonize_dtypes"],
indices_strict=state["indices_strict"],
)

def __getitem__(self, index: Index):
oidx, vidx = _normalize_indices(index, self.obs_names, self.var_names)
resolved_idx = self._resolve_idx(oidx, vidx)
Expand Down
18 changes: 17 additions & 1 deletion src/anndata/experimental/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from __future__ import annotations

from importlib.util import find_spec

# public re-exports
from ._annloader import AnnLoader

__all__ = ["AnnLoader"]
__all__: list[str] = ["AnnLoader"]

# Only import batch_dict_converter if torch is available
if find_spec("torch"):
from .converters import to_tensor_dict as batch_dict_converter

__all__ += ["batch_dict_converter"]
else:
# Provide a fallback that raises a helpful error
def batch_dict_converter(*args, **kwargs):
msg = "batch_dict_converter requires PyTorch. Install with: pip install torch"
raise ImportError(msg)

__all__ += ["batch_dict_converter"]
62 changes: 61 additions & 1 deletion src/anndata/experimental/pytorch/_annloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Sequence
from typing import TypeAlias, Union
from typing import Any, TypeAlias, Union

from scipy.sparse import spmatrix

Expand Down Expand Up @@ -118,6 +118,40 @@ def compose_convert(arr):
return new_convert


class _WorkerCollateWrapper:
"""Picklable wrapper for batch_converter to use in multiprocessing workers."""

def __init__(self, batch_converter):
self.batch_converter = batch_converter

def __call__(self, batch):
# First, create a batch from AnnCollectionView objects by concatenating them
if len(batch) == 0:
return batch

# Assume all samples are AnnCollectionView objects
first_sample = batch[0]
if not hasattr(first_sample, "reference"):
# Not an AnnCollectionView, fallback to default collate
from torch.utils.data._utils.collate import default_collate

return default_collate(batch)

# Create a batch view by concatenating the indices
reference = first_sample.reference
all_oidx = []
all_vidx = first_sample.vidx # Assume same variables for all samples

for sample in batch:
all_oidx.extend(sample.oidx)

# Create a new view with all the indices
batch_view = reference[all_oidx, all_vidx]

# Apply the batch converter to the combined view
return self.batch_converter(batch_view)


# AnnLoader has the same arguments as DataLoader, but uses BatchIndexSampler by default
class AnnLoader(DataLoader):
"""\
Expand All @@ -143,6 +177,9 @@ class AnnLoader(DataLoader):
use_cuda
Transfer pytorch tensors to the default cuda device after conversion.
Only works if `use_default_converter=True`
batch_converter
Optional callable to transform each batch after collation.
Works with both single-threaded and multi-threaded data loading.
**kwargs
Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also
arguments for `AnnCollection` initialization.
Expand All @@ -157,6 +194,7 @@ def __init__(
shuffle: bool = False,
use_default_converter: bool = True,
use_cuda: bool = False,
batch_converter: Callable[[Any], Any] | None = None,
**kwargs,
):
if isinstance(adatas, AnnData):
Expand Down Expand Up @@ -199,6 +237,20 @@ def __init__(
dataset.convert, _converter, dict(dataset.attrs_keys, X=[])
)

# Remove in case user passed via **kwargs (for forward-compat)
batch_converter = kwargs.pop("batch_converter", batch_converter)
self._batch_converter = batch_converter

# If workers >0 and user supplied a converter, apply it inside worker via custom collate
num_workers = kwargs.get("num_workers", 0)
if (
batch_converter is not None
and num_workers > 0
and "collate_fn" not in kwargs
):
kwargs["collate_fn"] = _WorkerCollateWrapper(batch_converter)
# Set batch_converter to None so main process doesn't apply it again
self._batch_converter = None
has_sampler = "sampler" in kwargs
has_batch_sampler = "batch_sampler" in kwargs

Expand Down Expand Up @@ -232,3 +284,11 @@ def __init__(
super().__init__(dataset, batch_size=None, sampler=sampler, **kwargs)
else:
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)

def __iter__(self): # type: ignore[override]
for batch in super().__iter__():
yield (
self._batch_converter(batch)
if self._batch_converter is not None
else batch
)
86 changes: 86 additions & 0 deletions src/anndata/experimental/pytorch/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Helper converters for AnnLoader batches.

This module provides convenience converters that can be passed to the
``batch_converter`` parameter of :pyclass:`~anndata.experimental.pytorch.AnnLoader`.
"""

from __future__ import annotations

from collections.abc import Mapping
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any

import pandas as pd

if find_spec("torch") or TYPE_CHECKING:
import torch
from torch import Tensor
else:
torch = None # type: ignore
Tensor = Any # type: ignore

__all__ = ["to_tensor_dict"]


def _to_tensor(arr) -> Tensor | Any:
"""Best-effort conversion of *arr* to ``torch.Tensor``.

Falls back to returning *arr* unchanged if torch or numpy is not available.
"""
if torch is None:
return arr

if isinstance(arr, torch.Tensor):
return arr
try:
import numpy as np
from scipy.sparse import issparse

if issparse(arr):
arr = arr.toarray()
if isinstance(arr, (np.ndarray, list)):
return torch.tensor(arr)
except ImportError: # pragma: no cover
pass
return arr


def to_tensor_dict(batch: Any) -> dict[str, Any]:
"""Convert an AnnLoader batch to a plain ``dict`` of tensors/arrays.

* ``X`` → ``"x"``
* each column in ``obs`` becomes a key in the output dict
* if *batch* is already a mapping it is returned as a *shallow copy*.
"""
# If user already returns dict-like we preserve it
if isinstance(batch, Mapping):
return dict(batch)

out: dict[str, Any] = {}

# AnnCollectionView has .X and .obs attributes
if hasattr(batch, "X"):
out["x"] = _to_tensor(batch.X)

if hasattr(batch, "obs") and batch.obs is not None:
obs_data = batch.obs
# Handle pandas DataFrame
if isinstance(obs_data, pd.DataFrame):
for col in obs_data.columns:
# ensure unique keys – users can post-process if needed
out[col] = _to_tensor(obs_data[col].to_numpy())
# Handle AnnCollection MapObsView (can be converted to dict directly)
elif hasattr(obs_data, "to_dict"):
obs_dict = obs_data.to_dict()
for key, value in obs_dict.items():
out[key] = _to_tensor(value)
# Handle generic dict-like objects
elif hasattr(obs_data, "keys") and callable(obs_data.keys):
try:
obs_dict = dict(obs_data)
for key, value in obs_dict.items():
out[key] = _to_tensor(value)
except (TypeError, AttributeError, ValueError):
pass # Skip if conversion fails

return out
Empty file added tests/pytorch/__init__.py
Empty file.
Loading
Loading