Skip to content

Commit

Permalink
Fix ascii/binary issues in TileDB->Arrow and TileDB->Pandas->Arrows t…
Browse files Browse the repository at this point in the history
…est cases
  • Loading branch information
johnkerl committed Oct 11, 2022
1 parent fb2c60a commit 46b4ae3
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 26 deletions.
2 changes: 2 additions & 0 deletions apis/python/src/tiledbsoma/soma_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ def write_from_pandas(
dfc = dataframe[column_name]
if len(dfc) > 0 and type(dfc[0]) == str:
column_types[column_name] = "ascii"
if len(dfc) > 0 and type(dfc[0]) == bytes:
column_types[column_name] = "bytes"

tiledb.from_pandas(
uri=self.uri,
Expand Down
20 changes: 13 additions & 7 deletions apis/python/src/tiledbsoma/util_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
of representing full type semantics, and correctly performing a
round trip conversion (eg, T == to_arrow(to_tiledb(T)))
Most primitive types are simple - eg, uint8. Of particular challenge
Most primitive types are simple -- e.g., uint8. Of particular challenge
are datetime/timestamps as TileDB has no distinction between a "datetime" and
a "timedelta". The best Arrow match is TimestampType, as long as that
TimestampType instance does NOT have a timezone set.
Because of our round-trip requirement, all other Arrow temporal types
are unsupported (even though they are just int64 under the covers).
We auto-promote Arrow's string and binary to large_string and large_binary,
respectively, as this is what TileDB stores -- a sequence of bytes preceded
by a 64-bit (not 32-bit) length int.
"""
ARROW_TO_TDB = {
# Dict of types unsupported by to_pandas_dtype, which require overrides.
Expand All @@ -25,8 +29,8 @@
#
pa.string(): "ascii", # XXX TODO: temporary work-around until UTF8 support is native. GH #338.
pa.large_string(): "ascii", # XXX TODO: temporary work-around until UTF8 support is native. GH #338.
pa.binary(): np.dtype("S"),
pa.large_binary(): np.dtype("S"),
pa.binary(): "bytes", # XXX TODO: temporary work-around until UTF8 support is native. GH #338.
pa.large_binary(): "bytes", # XXX TODO: temporary work-around until UTF8 support is native. GH #338.
pa.timestamp("s"): "datetime64[s]",
pa.timestamp("ms"): "datetime64[ms]",
pa.timestamp("us"): "datetime64[us]",
Expand Down Expand Up @@ -63,8 +67,9 @@ def tiledb_type_from_arrow_type(t: pa.DataType) -> Union[type, np.dtype, str]:
raise arrow_type
if arrow_type == "ascii":
return arrow_type
else:
return np.dtype(arrow_type)
if arrow_type == "bytes":
return arrow_type # np.int8()
return np.dtype(arrow_type)

if not pa.types.is_primitive(t):
raise TypeError(f"Type {str(t)} - unsupported type")
Expand All @@ -90,11 +95,12 @@ def get_arrow_type_from_tiledb_dtype(tiledb_dtype: Union[str, np.dtype]) -> pa.D
"""
TODO: COMMENT
"""
if tiledb_dtype == "bytes":
return pa.large_binary()
if tiledb_dtype == "ascii" or tiledb_dtype.name == "bytes":
# XXX TODO: temporary work-around until UTF8 support is native. GH #338.
return pa.large_string()
else:
return pa.from_numpy_dtype(tiledb_dtype)
return pa.from_numpy_dtype(tiledb_dtype)


def get_arrow_schema_from_tiledb_uri(
Expand Down
34 changes: 27 additions & 7 deletions apis/python/tests/test_soma_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,37 +174,57 @@ def test_SOMADataFrame_read_column_names(simple_soma_data_frame, ids, col_names)
schema, sdf, n_data = simple_soma_data_frame
assert sdf.exists()

def _check_tbl(tbl, col_names, ids):
def _check_tbl(tbl, col_names, ids, *, demote):
assert tbl.num_columns == (
len(schema.names) if col_names is None else len(col_names)
)
assert tbl.num_rows == (n_data if ids is None else len(ids))
assert tbl.schema == pa.schema(
[
schema.field(f)
for f in (col_names if col_names is not None else schema.names)
]
)

if demote:
assert tbl.schema == pa.schema(
[
pa.field(schema.field(f).name, pa.string())
if schema.field(f).type == pa.large_string()
else schema.field(f)
for f in (col_names if col_names is not None else schema.names)
]
)
else:
assert tbl.schema == pa.schema(
[
schema.field(f)
for f in (col_names if col_names is not None else schema.names)
]
)

# TileDB ASCII -> Arrow large_string
_check_tbl(
sdf.read_all(ids=ids, column_names=col_names),
col_names,
ids,
demote=False,
)

_check_tbl(
sdf.read_all(column_names=col_names),
col_names,
None,
demote=False,
)

# TileDB ASCII -> Pandas string -> Arrow string (not large_string)
_check_tbl(
pa.Table.from_pandas(
pd.concat(sdf.read_as_pandas(ids=ids, column_names=col_names))
),
col_names,
ids,
demote=True,
)

_check_tbl(
pa.Table.from_pandas(sdf.read_as_pandas_all(column_names=col_names)),
col_names,
None,
demote=True,
)
52 changes: 40 additions & 12 deletions apis/python/tests/test_type_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,20 @@
pa.timestamp("ms"),
pa.timestamp("us"),
pa.timestamp("ns"),
pa.string(),
# We use Arrow's large_string for ASCII and, ultimately, for Unicode as well
# https://github.com/single-cell-data/TileDB-SOMA/issues/99
# https://github.com/single-cell-data/TileDB-SOMA/pull/359
# https://github.com/single-cell-data/TileDB-SOMA/issues/274
pa.large_string(),
pa.binary(),
pa.large_binary(),
]

"""Arrow types we expect to auto-promote"""
PROMOTED_ARROW_TYPES = [
(pa.string(), pa.large_string()),
# XXX (pa.binary(), pa.large_binary()),
]


"""Arrow types we expect to fail"""
UNSUPPORTED_ARROW_TYPES = [
Expand All @@ -46,10 +54,13 @@
pa.duration("us"),
pa.duration("ns"),
pa.month_day_nano_interval(),
# We use Arrow's large_string for ASCII and, ultimately, for Unicode as well
# https://github.com/single-cell-data/TileDB-SOMA/issues/99
# https://github.com/single-cell-data/TileDB-SOMA/pull/359
# https://github.com/single-cell-data/TileDB-SOMA/issues/274
pa.string(),
pa.binary(),
pa.binary(10),
pa.large_binary(),
pa.large_string(),
pa.decimal128(1),
pa.decimal128(38),
pa.list_(pa.int8()),
Expand All @@ -61,20 +72,37 @@


@pytest.mark.parametrize("arrow_type", SUPPORTED_ARROW_TYPES)
def test_supported_types_supported(arrow_type):
def test_arrow_types_supported(arrow_type):
"""Verify round-trip conversion of types"""
if pa.types.is_binary(arrow_type):
pytest.xfail("Awaiting UTF-8 support - see issue #338")
# if pa.types.is_binary(arrow_type):
# pytest.xfail("Awaiting UTF-8 support - see issue #274")

tdb_dtype = tiledb_type_from_arrow_type(arrow_type)
assert isinstance(tdb_dtype, np.dtype) or tdb_dtype == "ascii"
rt_arrow_type = get_arrow_type_from_tiledb_dtype(tdb_dtype)
assert isinstance(rt_arrow_type, pa.DataType)
assert arrow_type == rt_arrow_type
assert (
isinstance(tdb_dtype, np.dtype) or tdb_dtype == "ascii" or tdb_dtype == "bytes"
)
arrow_rt_type = get_arrow_type_from_tiledb_dtype(tdb_dtype)
assert isinstance(arrow_rt_type, pa.DataType)
assert arrow_type == arrow_rt_type


@pytest.mark.parametrize("arrow_from_to_pair", PROMOTED_ARROW_TYPES)
def test_arrow_types_promoted(arrow_from_to_pair):
"""Verify round-trip conversion of types"""
arrow_from_type = arrow_from_to_pair[0]
arrow_to_type = arrow_from_to_pair[1]

tdb_dtype = tiledb_type_from_arrow_type(arrow_from_type)
assert (
isinstance(tdb_dtype, np.dtype) or tdb_dtype == "ascii" or tdb_dtype == "bytes"
)
arrow_rt_type = get_arrow_type_from_tiledb_dtype(tdb_dtype)
assert isinstance(arrow_rt_type, pa.DataType)
assert arrow_to_type == arrow_rt_type


@pytest.mark.parametrize("arrow_type", UNSUPPORTED_ARROW_TYPES)
def test_supported_types_unsupported(arrow_type):
def test_arrow_types_unsupported(arrow_type):
"""Verify correct error for unsupported types"""
with pytest.raises(TypeError):
tiledb_type_from_arrow_type(arrow_type, match=r".*unsupported type.*")

0 comments on commit 46b4ae3

Please sign in to comment.