Skip to content

Commit 652281d

Browse files
Fix bool dist to values casting
1 parent 29d0b36 commit 652281d

File tree

1 file changed

+39
-8
lines changed

1 file changed

+39
-8
lines changed

libtiledbsoma/src/soma/managed_query.cc

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -786,22 +786,53 @@ template <>
786786
void ManagedQuery::_cast_dictionary_values<bool>(ArrowSchema* schema, ArrowArray* array) {
787787
// Boolean types require special handling due to bit vs uint8_t
788788
// representation in Arrow vs TileDB respectively
789+
std::unique_ptr<std::byte[]> data_buffer = std::make_unique_for_overwrite<std::byte[]>(array->length);
790+
std::span<uint8_t> data_view(reinterpret_cast<uint8_t*>(data_buffer.get()), array->length);
789791

790-
auto value_array = array->dictionary;
792+
auto extract_values = [&]<typename T>() {
793+
std::span<const T> indices(reinterpret_cast<const T*>(array->buffers[1]) + array->offset, array->length);
791794

792-
std::vector<int64_t> indexes = _get_index_vector(schema, array);
795+
for (int64_t i = 0; i < array->length; ++i) {
796+
data_view[i] = static_cast<uint8_t>(
797+
ArrowBitGet((const uint8_t*)array->dictionary->buffers[1], indices[i] + array->dictionary->offset));
798+
}
799+
};
793800

794-
std::unique_ptr<std::byte[]> values = std::make_unique_for_overwrite<std::byte[]>(array->length);
795-
std::span<uint8_t> values_view(reinterpret_cast<uint8_t*>(values.get()), array->length);
796-
for (int64_t i = 0; i < value_array->length; ++i) {
797-
values_view[i] = static_cast<uint8_t>(
798-
ArrowBitGet((const uint8_t*)value_array->buffers[1] + value_array->offset, indexes[i]));
801+
switch (ArrowAdapter::to_tiledb_format(schema->format)) {
802+
case TILEDB_INT8:
803+
extract_values.template operator()<int8_t>();
804+
break;
805+
case TILEDB_UINT8:
806+
extract_values.template operator()<uint8_t>();
807+
break;
808+
case TILEDB_INT16:
809+
extract_values.template operator()<int16_t>();
810+
break;
811+
case TILEDB_UINT16:
812+
extract_values.template operator()<uint16_t>();
813+
break;
814+
case TILEDB_INT32:
815+
extract_values.template operator()<int32_t>();
816+
break;
817+
case TILEDB_UINT32:
818+
extract_values.template operator()<uint32_t>();
819+
break;
820+
case TILEDB_INT64:
821+
extract_values.template operator()<int64_t>();
822+
break;
823+
case TILEDB_UINT64:
824+
extract_values.template operator()<uint64_t>();
825+
break;
826+
default:
827+
throw TileDBSOMAError(
828+
"Saw invalid index type when trying to promote indexes to "
829+
"values");
799830
}
800831

801832
setup_write_column(
802833
schema->name,
803834
array->length,
804-
std::move(values),
835+
std::move(data_buffer),
805836
(uint64_t*)nullptr,
806837
nullptr); // validities are set by index column
807838
}

0 commit comments

Comments
 (0)