Skip to content

Commit

Permalink
feat: cast list items to default before write with different item nam…
Browse files Browse the repository at this point in the history
…es (#1959)

# Description
Delta-rs always uses `item` as the list item name when writing lists. If
you read data which is for example written by Spark, the item name is
`element`, in the current implemantation it's not possible to write
RecordBatches with a different item name. This leads for example to the
problem that you cann't optimize tables which are written by Spark and
contain a List column.
In this MR I add condition which will intiate a cast if the list item
name of the record batch is different to the target schema one.
I have also tried to explain this behaviour in the tests, but
unfortunately creating the test data has become complicated (Happy to
get feedback)

This is my first MR in this project 

# Related Issue(s)

https://github.com/delta-io/delta-rs/blob/main/crates/deltalake-core/src/kernel/arrow/mod.rs#L58
https://github.com/delta-io/delta-rs/pull/684/files#r940790524
https://delta-users.slack.com/archives/C013LCAEB98/p1701885637615699

---------

Co-authored-by: Jonas Schmitz <[email protected]>
Co-authored-by: Ion Koutsouris <[email protected]>
Co-authored-by: Robert Pack <[email protected]>
  • Loading branch information
4 people authored Dec 20, 2023
1 parent 11ea2a5 commit bc9253c
Showing 1 changed file with 89 additions and 1 deletion.
90 changes: 89 additions & 1 deletion crates/deltalake-core/src/operations/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn cast_record_batch_columns(
.iter()
.map(|f| {
let col = batch.column_by_name(f.name()).unwrap();

if let (DataType::Struct(_), DataType::Struct(child_fields)) =
(col.data_type(), f.data_type())
{
Expand All @@ -28,7 +29,7 @@ fn cast_record_batch_columns(
child_columns.clone(),
None,
)) as ArrayRef)
} else if !col.data_type().equals_datatype(f.data_type()) {
} else if is_cast_required(col.data_type(), f.data_type()) {
cast_with_options(col, f.data_type(), cast_options)
} else {
Ok(col.clone())
Expand All @@ -37,6 +38,16 @@ fn cast_record_batch_columns(
.collect::<Result<Vec<_>, _>>()
}

fn is_cast_required(a: &DataType, b: &DataType) -> bool {
match (a, b) {
(DataType::List(a_item), DataType::List(b_item)) => {
// If list item name is not the default('item') the list must be casted
!a.equals_datatype(b) || a_item.name() != b_item.name()
}
(_, _) => !a.equals_datatype(b),
}
}

/// Cast recordbatch to a new target_schema, by casting each column array
pub fn cast_record_batch(
batch: &RecordBatch,
Expand All @@ -51,3 +62,80 @@ pub fn cast_record_batch(
let columns = cast_record_batch_columns(batch, target_schema.fields(), &cast_options)?;
Ok(RecordBatch::try_new(target_schema, columns)?)
}

#[cfg(test)]
mod tests {
use crate::operations::cast::{cast_record_batch, is_cast_required};
use arrow::array::ArrayData;
use arrow_array::{Array, ArrayRef, ListArray, RecordBatch};
use arrow_buffer::Buffer;
use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef};
use std::sync::Arc;

#[test]
fn test_cast_record_batch_with_list_non_default_item() {
let array = Arc::new(make_list_array()) as ArrayRef;
let source_schema = Schema::new(vec![Field::new(
"list_column",
array.data_type().clone(),
false,
)]);
let record_batch = RecordBatch::try_new(Arc::new(source_schema), vec![array]).unwrap();

let fields = Fields::from(vec![Field::new_list(
"list_column",
Field::new("item", DataType::Int8, false),
false,
)]);
let target_schema = Arc::new(Schema::new(fields)) as SchemaRef;

let result = cast_record_batch(&record_batch, target_schema, false);

let schema = result.unwrap().schema();
let field = schema.column_with_name("list_column").unwrap().1;
if let DataType::List(list_item) = field.data_type() {
assert_eq!(list_item.name(), "item");
} else {
panic!("Not a list");
}
}

fn make_list_array() -> ListArray {
let value_data = ArrayData::builder(DataType::Int32)
.len(8)
.add_buffer(Buffer::from_slice_ref([0, 1, 2, 3, 4, 5, 6, 7]))
.build()
.unwrap();

let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);

let list_data_type = DataType::List(Arc::new(Field::new("element", DataType::Int32, true)));
let list_data = ArrayData::builder(list_data_type)
.len(3)
.add_buffer(value_offsets)
.add_child_data(value_data)
.build()
.unwrap();
ListArray::from(list_data)
}

#[test]
fn test_is_cast_required_with_list() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(!is_cast_required(&field1, &field2));
}

#[test]
fn test_is_cast_required_with_list_non_default_item() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));

assert!(is_cast_required(&field1, &field2));
}
}

0 comments on commit bc9253c

Please sign in to comment.