Skip to content

Commit 14d71e2

Browse files
Correctly set the validity buffer when castin dictionary to values
1 parent 652281d commit 14d71e2

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

apis/python/tests/test_dataframe.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2784,6 +2784,52 @@ def test_extend_enumerations(tmp_path):
27842784
assert (readback_df[c] == written_df[c]).all()
27852785

27862786

2787+
def test_write_dictionary_to_non_enum_column(tmp_path):
2788+
written_df = pd.DataFrame(
2789+
{
2790+
"soma_joinid": pd.Series([0, 1, 2, 3, 4, 5], dtype=np.int64),
2791+
"str": pd.Series(["A", "B", "A", "B", "B", None], dtype="category"),
2792+
"byte": pd.Series([b"A", b"B", b"A", b"B", b"B", None], dtype="category"),
2793+
"bool": pd.Series([True, False, True, False, False, None], dtype="category"),
2794+
"int64": pd.Series([0, 1, 2, 0, 1, None], dtype="Int64").astype("category"),
2795+
"uint64": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt64").astype("category"),
2796+
"int32": pd.Series([0, 1, 2, 0, 1, None], dtype="Int32").astype("category"),
2797+
"uint32": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt32").astype("category"),
2798+
"int16": pd.Series([0, 1, 2, 0, 1, None], dtype="Int16").astype("category"),
2799+
"uint16": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt16").astype("category"),
2800+
"int8": pd.Series([0, 1, 2, 0, 1, None], dtype="Int8").astype("category"),
2801+
"uint8": pd.Series([0, 1, 2, 0, 1, None], dtype="UInt8").astype("category"),
2802+
"float32": pd.Series([0, 1.1, 2.1, 0, 1.1, None], dtype="Float32").astype("category"),
2803+
"float64": pd.Series([0, 1.1, 2.1, 0, 1.1, None], dtype="Float64").astype("category"),
2804+
},
2805+
)
2806+
2807+
schema = pa.schema([
2808+
pa.field("soma_joinid", pa.int64()),
2809+
pa.field("str", pa.large_string(), nullable=True),
2810+
pa.field("byte", pa.large_binary(), nullable=True),
2811+
pa.field("bool", pa.bool_(), nullable=True),
2812+
pa.field("int64", pa.int64(), nullable=True),
2813+
pa.field("uint64", pa.uint64(), nullable=True),
2814+
pa.field("int32", pa.int32(), nullable=True),
2815+
pa.field("uint32", pa.uint32(), nullable=True),
2816+
pa.field("int16", pa.int16(), nullable=True),
2817+
pa.field("uint16", pa.uint16(), nullable=True),
2818+
pa.field("int8", pa.int8(), nullable=True),
2819+
pa.field("uint8", pa.uint8(), nullable=True),
2820+
pa.field("float32", pa.float32(), nullable=True),
2821+
pa.field("float64", pa.float64(), nullable=True),
2822+
])
2823+
2824+
with soma.DataFrame.create(str(tmp_path), schema=schema, domain=[[0, 9]]) as soma_dataframe:
2825+
tbl = pa.Table.from_pandas(written_df, preserve_index=False)
2826+
soma_dataframe.write(tbl)
2827+
2828+
with soma.open(str(tmp_path)) as soma_dataframe:
2829+
readback_tbl = soma_dataframe.read().concat()
2830+
assert tbl.to_pylist() == readback_tbl.to_pylist()
2831+
2832+
27872833
def test_multiple_writes_with_str_enums(tmp_path):
27882834
uri = tmp_path.as_posix()
27892835

libtiledbsoma/src/soma/managed_query.cc

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -681,12 +681,12 @@ void ManagedQuery::_cast_dictionary_values(ArrowSchema* schema, ArrowArray* arra
681681
"values");
682682
}
683683

684-
setup_write_column(
685-
schema->name,
686-
array->length,
687-
std::move(data_buffer),
688-
(uint64_t*)nullptr,
689-
nullptr); // validities are set by index column
684+
std::unique_ptr<std::uint8_t[]> validity = nullptr;
685+
if (schema->flags & ARROW_FLAG_NULLABLE) {
686+
validity = _cast_validity_buffer_ptr(array);
687+
}
688+
689+
setup_write_column(schema->name, array->length, std::move(data_buffer), (uint64_t*)nullptr, std::move(validity));
690690
}
691691

692692
template <>
@@ -774,12 +774,13 @@ void ManagedQuery::_cast_dictionary_values<std::string>(ArrowSchema* schema, Arr
774774
std::tie(data_buffer, offset_buffer) = extract_values.operator()<uint32_t>(schema, array);
775775
}
776776

777+
std::unique_ptr<std::uint8_t[]> validity = nullptr;
778+
if (schema->flags & ARROW_FLAG_NULLABLE) {
779+
validity = _cast_validity_buffer_ptr(array);
780+
}
781+
777782
setup_write_column(
778-
schema->name,
779-
array->length,
780-
std::move(data_buffer),
781-
std::move(offset_buffer),
782-
nullptr); // validities are set by index column
783+
schema->name, array->length, std::move(data_buffer), std::move(offset_buffer), std::move(validity));
783784
}
784785

785786
template <>
@@ -829,12 +830,12 @@ void ManagedQuery::_cast_dictionary_values<bool>(ArrowSchema* schema, ArrowArray
829830
"values");
830831
}
831832

832-
setup_write_column(
833-
schema->name,
834-
array->length,
835-
std::move(data_buffer),
836-
(uint64_t*)nullptr,
837-
nullptr); // validities are set by index column
833+
std::unique_ptr<std::uint8_t[]> validity = nullptr;
834+
if (schema->flags & ARROW_FLAG_NULLABLE) {
835+
validity = _cast_validity_buffer_ptr(array);
836+
}
837+
838+
setup_write_column(schema->name, array->length, std::move(data_buffer), (uint64_t*)nullptr, std::move(validity));
838839
}
839840

840841
template <>

0 commit comments

Comments
 (0)