Skip to content

Commit 9ac5391

Browse files
[python] Use arrow API to cast python tables before sending to C++ (#4359)
* (WIP) Safe-cast pyarrow tables on write Still needs the following: * fix schema names for GeometryDataFrame * test unsafe casting * Fix casting for geometry dataframe outlines * Update history * Switch from deprecated `field_by_name` to `field` * Update error message and remove unneeded type declaration * Take tests from PR #4311 Add test for dictionary casting from #4311 Co-authored-by: XanthosXanthopoulos <[email protected]> * Add xfail to uncovered bug * Remove test that is checking for unsafe cast * Fix syntax for xfail --------- Co-authored-by: XanthosXanthopoulos <[email protected]> (cherry picked from commit a1f6a68)
1 parent 73a4a1d commit 9ac5391

File tree

6 files changed

+124
-81
lines changed

6 files changed

+124
-81
lines changed

apis/python/HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ All notable changes to the Python TileDB-SOMA project will be documented in this
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
66

7+
## [2.3.0]
8+
9+
### Fixed
10+
11+
- \[[#4359](https://github.com/single-cell-data/TileDB-SOMA/pull/4359)\] Fix unsafe casting of data on write when the input data type in a PyArrow table or batch does not match the existing schema.
12+
713
## [Release 2.2.0]
814

915
### Changed

apis/python/src/tiledbsoma/_geometry_dataframe.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_revise_domain_for_extent,
3434
)
3535
from ._exception import DoesNotExistError, SOMAError, is_does_not_exist_error, map_exception_for_create
36+
from ._managed_query import ManagedQuery
3637
from ._read_iters import TableReadIter
3738
from ._spatial_dataframe import SpatialDataFrame
3839
from ._spatial_util import (
@@ -494,8 +495,6 @@ def write(
494495
Experimental.
495496
"""
496497
_util.check_type("values", values, (pa.Table,))
497-
498-
write_options: TileDBCreateOptions | TileDBWriteOptions
499498
if isinstance(platform_config, TileDBCreateOptions):
500499
raise ValueError(
501500
"As of TileDB-SOMA 1.13, the write method takes TileDBWriteOptions instead of TileDBCreateOptions",
@@ -535,13 +534,43 @@ def from_outlines(
535534
Returns: ``self``, to enable method chaining.
536535
537536
"""
538-
outline_transformer = clib.OutlineTransformer(coordinate_space_to_json(self._coord_space))
537+
_util.check_type("values", values, (pa.Table, pa.RecordBatch))
538+
if isinstance(platform_config, TileDBCreateOptions):
539+
raise ValueError(
540+
"As of TileDB-SOMA 1.13, the write method takes TileDBWriteOptions instead of TileDBCreateOptions",
541+
)
542+
write_options = TileDBWriteOptions.from_platform_config(platform_config)
543+
if not write_options.sort_coords:
544+
raise NotImplementedError("Support for writing outline geometries in global order is not yet implemented.")
539545

540-
for batch in values.to_batches():
541-
self.write(
542-
clib.TransformerPipeline(batch).transform(outline_transformer).asTable(),
543-
platform_config=platform_config,
546+
array_schema = self.schema
547+
for name in values.schema.names:
548+
if name not in array_schema.names:
549+
raise ValueError(
550+
f"Cannot write data. Field '{name}' in the input data is not a column in this {self._handle_type.__name__}."
551+
)
552+
batch_schema = pa.schema([
553+
values.schema.field(name) if name == "soma_geometry" else array_schema.field(name)
554+
for name in values.schema.names
555+
])
556+
557+
batches = values.to_batches()
558+
if not batches:
559+
return self
560+
561+
outline_transformer = clib.OutlineTransformer(coordinate_space_to_json(self._coord_space))
562+
for batch in batches:
563+
table = (
564+
clib.TransformerPipeline(batch.cast(batch_schema, safe=True)).transform(outline_transformer).asTable()
544565
)
566+
for subbatch in table.to_batches():
567+
mq = ManagedQuery(self)._handle
568+
mq.set_layout(clib.ResultOrder.unordered)
569+
mq.submit_batch(subbatch)
570+
mq.finalize()
571+
572+
if write_options.consolidate_and_vacuum:
573+
self._handle.consolidate_and_vacuum()
545574

546575
return self
547576

apis/python/src/tiledbsoma/_soma_array.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,20 +142,25 @@ def _write_table(self, values: pa.Table, sort_coords: bool) -> None:
142142
if not batches:
143143
return
144144

145-
layout = clib.ResultOrder.unordered if sort_coords else clib.ResultOrder.globalorder
146-
147-
if layout == clib.ResultOrder.unordered:
148-
# Finalize for each batch
145+
array_schema = self.schema
146+
for name in values.schema.names:
147+
if name not in array_schema.names:
148+
raise ValueError(
149+
f"Cannot write data. Field '{name}' in the input data is not a column in this {self._handle_type.__name__}."
150+
)
151+
batch_schema = pa.schema([array_schema.field(name) for name in values.schema.names])
152+
153+
if sort_coords:
154+
# Finalize each batch as it is written.
149155
for batch in batches:
150156
mq = ManagedQuery(self)._handle
151-
mq.set_layout(layout)
152-
mq.submit_batch(batch)
157+
mq.set_layout(clib.ResultOrder.unordered)
158+
mq.submit_batch(batch.cast(batch_schema, safe=True))
153159
mq.finalize()
154-
155-
else: # globalorder
156-
# Only finalize at the last batch
160+
else:
161+
# Single global order query - only finalize at the end.
157162
mq = ManagedQuery(self)._handle
158-
mq.set_layout(layout)
163+
mq.set_layout(clib.ResultOrder.globalorder)
159164
for batch in batches[:-1]:
160-
mq.submit_batch(batch)
165+
mq.submit_batch(batch.cast(batch_schema, safe=True))
161166
mq.submit_and_finalize_batch(batches[-1])

apis/python/tests/test_dataframe.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,6 @@ def test_dataframe_with_enumeration(tmp_path):
404404
with soma.DataFrame.create(tmp_path.as_posix(), schema=schema, domain=[[0, 5]]) as sdf:
405405
data = {}
406406
data["soma_joinid"] = [0, 1, 2, 3, 4]
407-
data["myint"] = ["a", "bb", "ccc", "bb", "a"]
408-
data["myfloat"] = ["cat", "dog", "cat", "cat", "cat"]
409-
with pytest.raises(soma.SOMAError):
410-
sdf.write(pa.Table.from_pydict(data))
411-
412407
data["myint"] = pd.Categorical(["a", "bb", "ccc", "bb", "a"])
413408
data["myfloat"] = pd.Categorical(["cat", "dog", "cat", "cat", "cat"])
414409
sdf.write(pa.Table.from_pydict(data))
@@ -4072,3 +4067,49 @@ def test_gow_mixed_idxes(tmp_path):
40724067
df = A.read().concat().to_pandas()
40734068

40744069
assert df.equals(expected_df)
4070+
4071+
4072+
def test_write_dictionary_to_non_enum_column(tmp_path):
4073+
written_df = pd.DataFrame(
4074+
{
4075+
"soma_joinid": pd.Series([0, 1, 2, 3, 4, 5], dtype=np.int64),
4076+
"str": pd.Series(["A", "B", "A", "B", "B", None], dtype="category"),
4077+
"byte": pd.Series([b"A", b"B", b"A", b"B", b"B", None], dtype="category"),
4078+
"bool": pd.Series([True, False, True, False, False, None], dtype="category"),
4079+
"int64": pd.Series([0, 1, 2, 0, 1, None], dtype="Int64").astype("category"),
4080+
"uint64": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt64").astype("category"),
4081+
"int32": pd.Series([0, 1, 2, 0, 1, None], dtype="Int32").astype("category"),
4082+
"uint32": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt32").astype("category"),
4083+
"int16": pd.Series([0, 1, 2, 0, 1, None], dtype="Int16").astype("category"),
4084+
"uint16": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt16").astype("category"),
4085+
"int8": pd.Series([0, 1, 2, 0, 1, None], dtype="Int8").astype("category"),
4086+
"uint8": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt8").astype("category"),
4087+
"float32": pd.Series([0, 1.1, 2.1, 0, 1.1, None], dtype="Float32").astype("category"),
4088+
"float64": pd.Series([0, 1.1, 2.1, 0, 1.1, None], dtype="Float64").astype("category"),
4089+
},
4090+
)
4091+
4092+
schema = pa.schema([
4093+
pa.field("soma_joinid", pa.int64()),
4094+
pa.field("str", pa.large_string(), nullable=True),
4095+
pa.field("byte", pa.large_binary(), nullable=True),
4096+
pa.field("bool", pa.bool_(), nullable=True),
4097+
pa.field("int64", pa.int64(), nullable=True),
4098+
pa.field("uint64", pa.uint64(), nullable=True),
4099+
pa.field("int32", pa.int32(), nullable=True),
4100+
pa.field("uint32", pa.uint32(), nullable=True),
4101+
pa.field("int16", pa.int16(), nullable=True),
4102+
pa.field("uint16", pa.uint16(), nullable=True),
4103+
pa.field("int8", pa.int8(), nullable=True),
4104+
pa.field("uint8", pa.uint8(), nullable=True),
4105+
pa.field("float32", pa.float32(), nullable=True),
4106+
pa.field("float64", pa.float64(), nullable=True),
4107+
])
4108+
4109+
with soma.DataFrame.create(str(tmp_path), schema=schema, domain=[[0, 9]]) as soma_dataframe:
4110+
tbl = pa.Table.from_pandas(written_df, preserve_index=False)
4111+
soma_dataframe.write(tbl)
4112+
4113+
with soma.open(str(tmp_path)) as soma_dataframe:
4114+
readback_tbl = soma_dataframe.read().concat()
4115+
assert tbl.to_pylist() == readback_tbl.to_pylist()

apis/python/tests/test_sparse_nd_array.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,42 +1818,6 @@ def test(path, tiledb_config):
18181818
gc.collect()
18191819

18201820

1821-
def test_sparse_nd_array_null(tmp_path):
1822-
uri = tmp_path.as_posix()
1823-
1824-
pydict = {
1825-
"soma_dim_0": pa.array([None, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
1826-
"soma_data": pa.array([None, 0, None, 1, 2, None, None, 3, 4, 5], type=pa.float64()),
1827-
}
1828-
table = pa.Table.from_pydict(pydict)
1829-
1830-
soma.SparseNDArray.create(uri, type=pa.int64(), shape=(10,))
1831-
1832-
# As of version 1.15.6 we were throwing in this case. However, we found
1833-
# a compatibility issue with pyarrow versions below 17. Thus this is
1834-
# now non-fatal.
1835-
# with soma.SparseNDArray.open(uri, "w") as A:
1836-
# with raises_no_typeguard(soma.SOMAError):
1837-
# # soma_joinid cannot be nullable
1838-
# A.write(table[:5])
1839-
# A.write(table[5:])
1840-
1841-
pydict["soma_dim_0"] = pa.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
1842-
table = pa.Table.from_pydict(pydict)
1843-
1844-
with soma.SparseNDArray.open(uri, "w") as A:
1845-
A.write(table[:5])
1846-
A.write(table[5:])
1847-
1848-
with soma.SparseNDArray.open(uri) as A:
1849-
pdf = A.read().tables().concat()
1850-
1851-
# soma_data is a non-nullable attribute. In ManagedQuery.set_array_data,
1852-
# any null values present in non-nullable attributes get casted to
1853-
# fill values. In the case for float64, the fill value is 0
1854-
np.testing.assert_array_equal(pdf["soma_data"], table["soma_data"].fill_null(0))
1855-
1856-
18571821
@pytest.mark.parametrize("ts", (None, 1))
18581822
def test_resize_with_time_travel_61254(tmp_path, ts):
18591823
uri = tmp_path.as_posix()

apis/python/tests/test_update_dataframes.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -307,55 +307,53 @@ def test_update_non_null_to_null(soma_tiledb_context, tmp_path, conftest_pbmc3k_
307307

308308

309309
@pytest.mark.medium_runner
310-
def test_enmr_add_drop_readd(soma_tiledb_context, tmp_path, conftest_pbmc3k_adata):
310+
@pytest.mark.xfail(reason="Bug reported in SOMA-792")
311+
def test_enmr_add_drop_read(soma_tiledb_context, tmp_path, conftest_pbmc3k_adata):
311312
uri = tmp_path.as_posix()
312313

313-
# Add
314+
# Create and check column.
314315
tiledbsoma.io.from_anndata(uri, conftest_pbmc3k_adata, measurement_name="RNA", context=soma_tiledb_context)
315-
316316
with tiledbsoma.Experiment.open(uri, "r") as exp:
317317
schema = exp.obs.schema
318318
assert "louvain" in schema.names
319319
field = schema.field("louvain")
320320
assert pa.types.is_dictionary(field.type)
321321

322-
# Drop
322+
# Create reference data.
323323
with tiledbsoma.Experiment.open(uri, "r") as exp:
324-
obs = exp.obs.read().concat().to_pandas()
325-
obs.drop(columns=["louvain"], inplace=True)
324+
obs_data = exp.obs.read().concat().to_pandas()
325+
obs_no_louvain = obs_data.drop(columns=["louvain"], inplace=False)
326+
obs_diff_type = obs_data.drop(columns=["louvain"], inplace=False)
327+
obs_diff_type["louvain"] = pd.Categorical(np.random.randint(1, 4, size=len(obs_data)))
326328

329+
# Drop data and check column.
327330
with tiledbsoma.Experiment.open(uri, "w") as exp:
328-
tiledbsoma.io.update_obs(exp, obs)
329-
331+
tiledbsoma.io.update_obs(exp, obs_no_louvain)
330332
with tiledbsoma.Experiment.open(uri, "r") as exp:
331333
schema = exp.obs.schema
332334
assert "louvain" not in schema.names
333335

334-
# Add column with same name and same type
336+
# Add column with same name and same type.
335337
with tiledbsoma.Experiment.open(uri, "w") as exp:
336-
# Most importantly, we're implicitly checking for no throw here.
337-
tiledbsoma.io.update_obs(exp, conftest_pbmc3k_adata.obs)
338-
338+
tiledbsoma.io.update_obs(exp, obs_data)
339339
with tiledbsoma.Experiment.open(uri, "r") as exp:
340340
schema = exp.obs.schema
341341
assert "louvain" in schema.names
342+
field = schema.field("louvain")
342343
assert pa.types.is_dictionary(field.type)
343344

344-
# Drop
345-
with tiledbsoma.Experiment.open(uri, "r") as exp:
346-
obs = exp.obs.read().concat().to_pandas()
347-
obs.drop(columns=["louvain"], inplace=True)
348-
345+
# Drop data and check column.
349346
with tiledbsoma.Experiment.open(uri, "w") as exp:
350-
tiledbsoma.io.update_obs(exp, obs)
347+
tiledbsoma.io.update_obs(exp, obs_no_louvain)
348+
with tiledbsoma.Experiment.open(uri, "r") as exp:
349+
schema = exp.obs.schema
350+
assert "louvain" not in schema.names
351351

352352
# Add column with same name but different categorical type (str to int)
353-
obs["louvain"] = pd.Categorical(np.random.randint(1, 4, size=len(obs)))
354353
with tiledbsoma.Experiment.open(uri, "w") as exp:
355-
# Most importantly, we're implicitly checking for no throw here.
356-
tiledbsoma.io.update_obs(exp, obs)
357-
354+
tiledbsoma.io.update_obs(exp, obs_diff_type)
358355
with tiledbsoma.Experiment.open(uri, "r") as exp:
359356
schema = exp.obs.schema
360357
assert "louvain" in schema.names
358+
field = schema.field("louvain")
361359
assert pa.types.is_dictionary(field.type)

0 commit comments

Comments
 (0)