Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.delta.kernel.annotation.Evolving;
import java.util.Objects;
import java.util.function.Predicate;

/**
* Represent {@code array} data type
Expand Down Expand Up @@ -85,6 +86,11 @@ public boolean isNested() {
return true;
}

@Override
public boolean existsRecursively(Predicate<DataType> predicate) {
return super.existsRecursively(predicate) || getElementType().existsRecursively(predicate);
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.delta.kernel.types;

import io.delta.kernel.annotation.Evolving;
import java.util.function.Predicate;

/**
* Base class for all data types.
Expand Down Expand Up @@ -61,6 +62,17 @@ public boolean isWriteCompatible(DataType dataType) {
*/
public abstract boolean isNested();

/**
* Returns {@code true} if the provided {@code predicate} matches this type or any of its nested
* child types.
*
* @param predicate the predicate to test this type (and recursively its children)
* @return the result of applying {@code predicate}
*/
public boolean existsRecursively(Predicate<DataType> predicate) {
return predicate != null && predicate.test(this);
}

@Override
public abstract int hashCode();

Expand Down
30 changes: 30 additions & 0 deletions kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
*/
package io.delta.kernel.types;

import static io.delta.kernel.internal.util.Preconditions.checkArgument;

import io.delta.kernel.annotation.Evolving;
import java.util.Objects;
import java.util.function.Predicate;

/**
* Data type representing a {@code map} type.
Expand All @@ -33,11 +36,13 @@ public class MapType extends DataType {
public static final String MAP_VALUE_NAME = "value";

public MapType(DataType keyType, DataType valueType, boolean valueContainsNull) {
validateKeyType(keyType);
this.keyField = new StructField(MAP_KEY_NAME, keyType, false);
this.valueField = new StructField(MAP_VALUE_NAME, valueType, valueContainsNull);
}

public MapType(StructField keyField, StructField valueField) {
validateKeyType(keyField.getDataType());
this.keyField = keyField;
this.valueField = valueField;
}
Expand Down Expand Up @@ -115,6 +120,13 @@ public boolean isNested() {
return true;
}

@Override
public boolean existsRecursively(Predicate<DataType> predicate) {
return super.existsRecursively(predicate)
|| getKeyType().existsRecursively(predicate)
|| getValueType().existsRecursively(predicate);
}

@Override
public int hashCode() {
return Objects.hash(keyField, valueField);
Expand All @@ -124,4 +136,22 @@ public int hashCode() {
public String toString() {
return String.format("map[%s, %s]", getKeyType(), getValueType());
}

/**
* Asserts whether the given {@code keyType} is valid for a map's key type. Disallows {@code
* StringType} with non-SPARK.UTF8_BINARY collation anywhere within the key type, including when
* nested inside complex types.
*/
private void validateKeyType(DataType keyType) {
checkArgument(
!keyType.existsRecursively(
dataType -> {
if (dataType instanceof StringType) {
StringType stringType = (StringType) dataType;
return !stringType.getCollationIdentifier().isSparkUTF8BinaryCollation();
}
return false;
}),
"Map key type cannot contain StringType with non-SPARK.UTF8_BINARY collation");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.delta.kernel.internal.types.DataTypeJsonSerDe;
import io.delta.kernel.internal.util.Tuple2;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -214,6 +215,19 @@ public boolean isNested() {
return true;
}

@Override
public boolean existsRecursively(Predicate<DataType> predicate) {
if (super.existsRecursively(predicate)) {
return true;
}
for (StructField field : fields) {
if (field.getDataType().existsRecursively(predicate)) {
return true;
}
}
return false;
}

@Override
public String toString() {
return String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ class DataTypeJsonSerDeSuite extends AnyFunSuite {
metadataJson = Some(
s"""{
|"$COLLATIONS_METADATA_KEY": {
| "tags.key": "SPARK.UTF8_LCASE",
| "tags.value": "ICU.UNICODE"
|},
|"delta.typeChanges": [
Expand All @@ -282,7 +281,7 @@ class DataTypeJsonSerDeSuite extends AnyFunSuite {
.add(
"tags",
new MapType(
new StructField("key", new StringType("SPARK.UTF8_LCASE"), false),
new StructField("key", new StringType("SPARK.UTF8_BINARY"), false),
new StructField("value", new StringType("ICU.UNICODE"), false)
.withTypeChanges(Seq(new TypeChange(BinaryType.BINARY, StringType.STRING)).asJava)),
true)
Expand Down Expand Up @@ -607,9 +606,7 @@ object DataTypeJsonSerDeSuite {
mapTypeJson("\"string\"", "\"string\"", true),
false,
metadataJson = Some(
s"""{"$COLLATIONS_METADATA_KEY"
| : {"b2.key" : "ICU.UNICODE_CI",
| "b2.value" : "SPARK.UTF8_LCASE"}}""".stripMargin)),
s"""{"$COLLATIONS_METADATA_KEY" : {"b2.value" : "SPARK.UTF8_LCASE"}}""")),
structFieldJson("b3", arrayTypeJson("\"string\"", false), true),
structFieldJson("b4", mapTypeJson("\"string\"", "\"string\"", false), false))),
true),
Expand All @@ -618,10 +615,7 @@ object DataTypeJsonSerDeSuite {
structTypeJson(Seq(
structFieldJson("b1", "\"string\"", false),
structFieldJson("b2", arrayTypeJson("\"integer\"", false), true))),
false,
metadataJson = Some(
s"""{"$COLLATIONS_METADATA_KEY"
| : {"b1" : "SPARK.UTF8_LCASE"}}""".stripMargin)))),
false))),
new StructType()
.add(
"a1",
Expand All @@ -635,7 +629,7 @@ object DataTypeJsonSerDeSuite {
.add(
"b2",
new MapType(
new StringType("ICU.UNICODE_CI"),
StringType.STRING,
new StringType("SPARK.UTF8_LCASE"),
true),
false)
Expand Down Expand Up @@ -681,10 +675,7 @@ object DataTypeJsonSerDeSuite {
| : {\"c2\" : \"ICU.UNICODE\"}}""".stripMargin)),
structFieldJson("c3", "\"string\"", true))),
true),
true,
metadataJson = Some(
s"""{"$COLLATIONS_METADATA_KEY"
| : {"b1.key.element.element" : "SPARK.UTF8_LCASE"}}""".stripMargin)),
true),
structFieldJson("b2", "\"long\"", true))),
true),
structFieldJson(
Expand All @@ -702,10 +693,7 @@ object DataTypeJsonSerDeSuite {
| : {"b1" : "SPARK.UTF8_LCASE"}}""".stripMargin)))),
false),
false),
true,
metadataJson = Some(
s"""{"$COLLATIONS_METADATA_KEY"
| : {"a3.element.key" : "ICU.UNICODE_CI"}}""".stripMargin)),
true),
structFieldJson(
"a4",
arrayTypeJson(
Expand All @@ -722,6 +710,7 @@ object DataTypeJsonSerDeSuite {
structFieldJson(
"a5",
mapTypeJson(
"\"string\"",
structTypeJson(Seq(
structFieldJson(
"b1",
Expand All @@ -730,7 +719,6 @@ object DataTypeJsonSerDeSuite {
metadataJson = Some(
s"""{"$COLLATIONS_METADATA_KEY"
| : {"b1" : "SPARK.UTF8_LCASE"}}""".stripMargin)))),
"\"string\"",
false),
false))),
new StructType()
Expand All @@ -743,7 +731,7 @@ object DataTypeJsonSerDeSuite {
new MapType(
new ArrayType(
new ArrayType(
new StringType("SPARK.UTF8_LCASE"),
StringType.STRING,
true),
true),
new StructType()
Expand All @@ -757,7 +745,7 @@ object DataTypeJsonSerDeSuite {
"a3",
new ArrayType(
new MapType(
new StringType("ICU.UNICODE_CI"),
StringType.STRING,
new StructType()
.add("b1", new StringType("SPARK.UTF8_LCASE"), false),
false),
Expand All @@ -773,9 +761,9 @@ object DataTypeJsonSerDeSuite {
.add(
"a5",
new MapType(
StringType.STRING,
new StructType()
.add("b1", new StringType("SPARK.UTF8_LCASE"), false),
StringType.STRING,
false),
false)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,33 @@ class DataTypeSuite extends AnyFunSuite {
assert(dt1.isWriteCompatible(dt2) == expected)
}
}

test("check MapType cannot be created with collated key") {
intercept[IllegalArgumentException] {
// invalid version for SPARK.UTF8_BINARY
new MapType(new StringType("SPARK.UTF8_BINARY.23"), StringType.STRING, true)
}
intercept[IllegalArgumentException] {
new MapType(utf8LcaseString, StringType.STRING, true)
}
intercept[IllegalArgumentException] {
new MapType(unicodeString, StringType.STRING, false)
}
intercept[IllegalArgumentException] {
new MapType(new ArrayType(unicodeString, true), StringType.STRING, true)
}
intercept[IllegalArgumentException] {
new MapType(
new StructType().add("c1", StringType.STRING).add("c1", utf8LcaseString),
StringType.STRING,
false)
}
intercept[IllegalArgumentException] {
new MapType(
new StructType().add("c1", StringType.STRING)
.add("c1", new ArrayType(unicodeString, false)),
StringType.STRING,
false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,13 @@ class StructFieldSuite extends AnyFunSuite {
new StructField(
"mapField",
new MapType(
new StructField("key", new StringType("ICU.DE_DE"), false),
new StructField("key", new StringType("SPARK.UTF8_BINARY"), false),
new StructField("value", LongType.LONG, true).withTypeChanges(
Seq(
new TypeChange(ShortType.SHORT, IntegerType.INTEGER),
new TypeChange(IntegerType.INTEGER, LongType.LONG)).asJava)),
false),
FieldMetadata.builder()
.putFieldMetadata(
COLLATIONS_METADATA_KEY,
FieldMetadata.builder()
.putString("mapField.key", "ICU.DE_DE")
.build())
.putFieldMetadataArray(
DELTA_TYPE_CHANGES_KEY,
Array(
Expand Down
Loading