Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/2159.perf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a `write_csr_csc_indices_with_min_possible_dtype` option to {attr}`anndata.settings` to enable downcasting of the `indices` of csr and csc matrices to a smaller dtype when writing. For example, if your csr matrix only has 30000 columns, then you can write out the `indices` of that matrix as `uint16` instead of `int64`. {user}`ilan-gold`
18 changes: 18 additions & 0 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,24 @@ def write_sparse_compressed(
for attr_name in ["data", "indices", "indptr"]:
attr = getattr(value, attr_name)
dtype = indptr_dtype if attr_name == "indptr" else attr.dtype
if (
attr_name == "indices"
and settings.write_csr_csc_indices_with_min_possible_dtype
):
# np.min_scalar_type can return things like np.ulonglong which zarr doesn't understand
# and I find this clearer as to what the result type is i.e., unsigned or signed.
# For example `np.iinfo(np.uint16).max + 1` could be either `uint32` or `int32`,
# and there's nothing in numpy's docs disallowing this output to change.
if (minor_axis_size := value.shape[value.format == "csr"]) <= np.iinfo(
np.uint8
).max:
dtype = np.dtype("uint8")
elif minor_axis_size <= np.iinfo(np.uint16).max:
dtype = np.dtype("uint16")
elif minor_axis_size <= np.iinfo(np.uint32).max:
dtype = np.dtype("uint32")
elif minor_axis_size <= np.iinfo(np.uint64).max:
dtype = np.dtype("uint64")
if isinstance(f, H5Group) or is_zarr_v2():
g.create_dataset(
attr_name, data=attr, shape=attr.shape, dtype=dtype, **dataset_kwargs
Expand Down
8 changes: 8 additions & 0 deletions src/anndata/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,14 @@ def validate_sparse_settings(val: Any, settings: SettingsManager) -> None:
get_from_env=check_and_get_bool,
)

settings.register(
"write_csr_csc_indices_with_min_possible_dtype",
default_value=False,
description="Write a csr or csc matrix with the minimum possible data type for `indices`, always unsigned integer.",
validate=validate_bool,
get_from_env=check_and_get_bool,
)

settings.register(
"auto_shard_zarr_v3",
default_value=False,
Expand Down
1 change: 1 addition & 0 deletions src/anndata/_settings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class _AnnDataSettingsManager(SettingsManager):
use_sparse_array_on_read: bool = False
min_rows_for_chunked_h5_copy: int = 1000
disallow_forward_slash_in_h5ad: bool = False
write_csr_csc_indices_with_min_possible_dtype: bool = False
auto_shard_zarr_v3: bool = False

settings: _AnnDataSettingsManager
107 changes: 107 additions & 0 deletions tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,113 @@ def test_write_indptr_dtype_override(store, sparse_format):
np.testing.assert_array_equal(store["X/indptr"][...], X.indptr)


@pytest.mark.parametrize(
("num_minor_axis", "expected_dtype"),
[
pytest.param(1, np.dtype("uint8"), id="one_col-expected_uint8_on_disk"),
pytest.param(
np.iinfo(np.uint8).max,
np.dtype("uint8"),
id="max_np.uint8-matching_dtype_on_disk",
),
pytest.param(
np.iinfo(np.int8).max,
np.dtype("uint8"),
id="max_np.int8-uint8_on_disk",
),
pytest.param(
np.iinfo(np.uint16).max,
np.dtype("uint16"),
id="max_np.uint16-matching_dtype_on_disk",
),
pytest.param(
np.iinfo(np.int16).max,
np.dtype("uint16"),
id="max_np.int16-uint16_on_disk",
),
pytest.param(
np.iinfo(np.uint32).max,
np.dtype("uint32"),
id="max_np.uint32-matching_dtype_on_disk",
),
pytest.param(
np.iinfo(np.int32).max,
np.dtype("uint32"),
id="max_np.int32-uint32_on_disk",
),
pytest.param(
np.iinfo(np.uint8).max + 1,
np.dtype("uint16"),
id="max_np.uint8_plus_one_cols-expected_uint16_on_disk",
),
pytest.param(
np.iinfo(np.uint16).max + 1,
np.dtype("uint32"),
id="max_np.uint16_plus_one_cols-expected_uint32_on_disk",
),
pytest.param(
np.iinfo(np.uint32).max + 1,
np.dtype("uint64"),
id="max_np.uint32_plus_one_cols-expected_uint64_on_disk",
),
pytest.param(
np.iinfo(np.int64).max + 1,
np.dtype("uint64"),
id="max_np.int64_plus_one_cols-expected_uint64_on_disk",
marks=pytest.mark.xfail(
reason="scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
),
),
pytest.param(
np.iinfo(np.uint64).max + 1,
np.dtype("uint64"),
id="max_np.uint64_plus_one_cols-expected_uint64_on_disk",
marks=pytest.mark.xfail(
reason="scipy sparse does not support bigger than max(int64) values in indices and there is no uint128."
),
),
],
)
@pytest.mark.parametrize("format", ["csr", "csc"])
def test_write_indices_min(
store: H5Group | ZarrGroup,
num_minor_axis: int,
expected_dtype: np.dtype,
format: Literal["csr", "csc"],
):
minor_axis_index = np.array([num_minor_axis - 1])
major_axis_index = np.array([10])
row_cols = (
(minor_axis_index, major_axis_index)
if format == "csc"
else (major_axis_index, minor_axis_index)
)
shape = (num_minor_axis, 20) if format == "csc" else (20, num_minor_axis)
X = getattr(sparse, f"{format}_array")(
(np.array([10]), row_cols),
shape=shape,
)
assert X.nnz == 1
with ad.settings.override(write_csr_csc_indices_with_min_possible_dtype=True):
write_elem(store, "X", X)

assert store["X/indices"].dtype == expected_dtype
with ad.settings.override(use_sparse_array_on_read=True):
result = read_elem(store["X"])
assert_equal(result.data, X.data)
assert_equal(result.indices, X.indices)
assert_equal(result.indptr, X.indptr)
assert X.format == result.format
assert result.shape == X.shape
# != comparison converts to csr, which allocates a lot of memory or errors out with:
# ValueError: array is too big; `arr.size * arr.dtype.itemsize` is larger than the maximum possible size.
# Because the old, very large, minor axis is now the major axis and so either it fails to create or the indptr is very big.
# The above tests should be enough to capture the desired equality checks so this is mostly for being extra sure.
# See https://github.com/scipy/scipy/issues/23826
if not (format == "csc" and num_minor_axis > np.iinfo(np.uint16).max + 1):
assert (result != X).nnz == 0


def test_io_spec_raw(store):
adata = gen_adata((3, 2), **GEN_ADATA_NO_XARRAY_ARGS)
adata.raw = adata.copy()
Expand Down