diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java index b2a8d2b8d7b..48d10de5580 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java @@ -18,6 +18,7 @@ import io.delta.kernel.annotation.Evolving; import java.util.Objects; +import java.util.function.Predicate; /** * Represent {@code array} data type @@ -85,6 +86,11 @@ public boolean isNested() { return true; } + @Override + public boolean existsRecursively(Predicate predicate) { + return super.existsRecursively(predicate) || getElementType().existsRecursively(predicate); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java index ec23877ea08..619dc7316d2 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java @@ -17,6 +17,7 @@ package io.delta.kernel.types; import io.delta.kernel.annotation.Evolving; +import java.util.function.Predicate; /** * Base class for all data types. @@ -61,6 +62,17 @@ public boolean isWriteCompatible(DataType dataType) { */ public abstract boolean isNested(); + /** + * Returns {@code true} if the provided {@code predicate} matches this type or any of its nested + * child types. + * + * @param predicate the predicate to test this type (and recursively its children) + * @return the result of applying {@code predicate} + */ + public boolean existsRecursively(Predicate predicate) { + return predicate != null && predicate.test(this); + } + @Override public abstract int hashCode(); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java index b46fe327009..28002de20dd 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java @@ -15,8 +15,11 @@ */ package io.delta.kernel.types; +import static io.delta.kernel.internal.util.Preconditions.checkArgument; + import io.delta.kernel.annotation.Evolving; import java.util.Objects; +import java.util.function.Predicate; /** * Data type representing a {@code map} type. @@ -33,11 +36,13 @@ public class MapType extends DataType { public static final String MAP_VALUE_NAME = "value"; public MapType(DataType keyType, DataType valueType, boolean valueContainsNull) { + validateKeyType(keyType); this.keyField = new StructField(MAP_KEY_NAME, keyType, false); this.valueField = new StructField(MAP_VALUE_NAME, valueType, valueContainsNull); } public MapType(StructField keyField, StructField valueField) { + validateKeyType(keyField.getDataType()); this.keyField = keyField; this.valueField = valueField; } @@ -115,6 +120,13 @@ public boolean isNested() { return true; } + @Override + public boolean existsRecursively(Predicate predicate) { + return super.existsRecursively(predicate) + || getKeyType().existsRecursively(predicate) + || getValueType().existsRecursively(predicate); + } + @Override public int hashCode() { return Objects.hash(keyField, valueField); @@ -124,4 +136,22 @@ public int hashCode() { public String toString() { return String.format("map[%s, %s]", getKeyType(), getValueType()); } + + /** + * Asserts whether the given {@code keyType} is valid for a map's key type. Disallows {@code + * StringType} with non-SPARK.UTF8_BINARY collation anywhere within the key type, including when + * nested inside complex types. + */ + private void validateKeyType(DataType keyType) { + checkArgument( + !keyType.existsRecursively( + dataType -> { + if (dataType instanceof StringType) { + StringType stringType = (StringType) dataType; + return !stringType.getCollationIdentifier().isSparkUTF8BinaryCollation(); + } + return false; + }), + "Map key type cannot contain StringType with non-SPARK.UTF8_BINARY collation"); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java index d68bfb5ec27..332cc53ca9d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java @@ -20,6 +20,7 @@ import io.delta.kernel.internal.types.DataTypeJsonSerDe; import io.delta.kernel.internal.util.Tuple2; import java.util.*; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -214,6 +215,19 @@ public boolean isNested() { return true; } + @Override + public boolean existsRecursively(Predicate predicate) { + if (super.existsRecursively(predicate)) { + return true; + } + for (StructField field : fields) { + if (field.getDataType().existsRecursively(predicate)) { + return true; + } + } + return false; + } + @Override public String toString() { return String.format( diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala index 1430b5f2954..3128039d5a3 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/types/DataTypeJsonSerDeSuite.scala @@ -266,7 +266,6 @@ class DataTypeJsonSerDeSuite extends AnyFunSuite { metadataJson = Some( s"""{ |"$COLLATIONS_METADATA_KEY": { - | "tags.key": "SPARK.UTF8_LCASE", | "tags.value": "ICU.UNICODE" |}, |"delta.typeChanges": [ @@ -282,7 +281,7 @@ class DataTypeJsonSerDeSuite extends AnyFunSuite { .add( "tags", new MapType( - new StructField("key", new StringType("SPARK.UTF8_LCASE"), false), + new StructField("key", new StringType("SPARK.UTF8_BINARY"), false), new StructField("value", new StringType("ICU.UNICODE"), false) .withTypeChanges(Seq(new TypeChange(BinaryType.BINARY, StringType.STRING)).asJava)), true) @@ -607,9 +606,7 @@ object DataTypeJsonSerDeSuite { mapTypeJson("\"string\"", "\"string\"", true), false, metadataJson = Some( - s"""{"$COLLATIONS_METADATA_KEY" - | : {"b2.key" : "ICU.UNICODE_CI", - | "b2.value" : "SPARK.UTF8_LCASE"}}""".stripMargin)), + s"""{"$COLLATIONS_METADATA_KEY" : {"b2.value" : "SPARK.UTF8_LCASE"}}""")), structFieldJson("b3", arrayTypeJson("\"string\"", false), true), structFieldJson("b4", mapTypeJson("\"string\"", "\"string\"", false), false))), true), @@ -618,10 +615,7 @@ object DataTypeJsonSerDeSuite { structTypeJson(Seq( structFieldJson("b1", "\"string\"", false), structFieldJson("b2", arrayTypeJson("\"integer\"", false), true))), - false, - metadataJson = Some( - s"""{"$COLLATIONS_METADATA_KEY" - | : {"b1" : "SPARK.UTF8_LCASE"}}""".stripMargin)))), + false))), new StructType() .add( "a1", @@ -635,7 +629,7 @@ object DataTypeJsonSerDeSuite { .add( "b2", new MapType( - new StringType("ICU.UNICODE_CI"), + StringType.STRING, new StringType("SPARK.UTF8_LCASE"), true), false) @@ -681,10 +675,7 @@ object DataTypeJsonSerDeSuite { | : {\"c2\" : \"ICU.UNICODE\"}}""".stripMargin)), structFieldJson("c3", "\"string\"", true))), true), - true, - metadataJson = Some( - s"""{"$COLLATIONS_METADATA_KEY" - | : {"b1.key.element.element" : "SPARK.UTF8_LCASE"}}""".stripMargin)), + true), structFieldJson("b2", "\"long\"", true))), true), structFieldJson( @@ -702,10 +693,7 @@ object DataTypeJsonSerDeSuite { | : {"b1" : "SPARK.UTF8_LCASE"}}""".stripMargin)))), false), false), - true, - metadataJson = Some( - s"""{"$COLLATIONS_METADATA_KEY" - | : {"a3.element.key" : "ICU.UNICODE_CI"}}""".stripMargin)), + true), structFieldJson( "a4", arrayTypeJson( @@ -722,6 +710,7 @@ object DataTypeJsonSerDeSuite { structFieldJson( "a5", mapTypeJson( + "\"string\"", structTypeJson(Seq( structFieldJson( "b1", @@ -730,7 +719,6 @@ object DataTypeJsonSerDeSuite { metadataJson = Some( s"""{"$COLLATIONS_METADATA_KEY" | : {"b1" : "SPARK.UTF8_LCASE"}}""".stripMargin)))), - "\"string\"", false), false))), new StructType() @@ -743,7 +731,7 @@ object DataTypeJsonSerDeSuite { new MapType( new ArrayType( new ArrayType( - new StringType("SPARK.UTF8_LCASE"), + StringType.STRING, true), true), new StructType() @@ -757,7 +745,7 @@ object DataTypeJsonSerDeSuite { "a3", new ArrayType( new MapType( - new StringType("ICU.UNICODE_CI"), + StringType.STRING, new StructType() .add("b1", new StringType("SPARK.UTF8_LCASE"), false), false), @@ -773,9 +761,9 @@ object DataTypeJsonSerDeSuite { .add( "a5", new MapType( + StringType.STRING, new StructType() .add("b1", new StringType("SPARK.UTF8_LCASE"), false), - StringType.STRING, false), false))) diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala index 839232f3090..de8ace32805 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala @@ -147,4 +147,33 @@ class DataTypeSuite extends AnyFunSuite { assert(dt1.isWriteCompatible(dt2) == expected) } } + + test("check MapType cannot be created with collated key") { + intercept[IllegalArgumentException] { + // invalid version for SPARK.UTF8_BINARY + new MapType(new StringType("SPARK.UTF8_BINARY.23"), StringType.STRING, true) + } + intercept[IllegalArgumentException] { + new MapType(utf8LcaseString, StringType.STRING, true) + } + intercept[IllegalArgumentException] { + new MapType(unicodeString, StringType.STRING, false) + } + intercept[IllegalArgumentException] { + new MapType(new ArrayType(unicodeString, true), StringType.STRING, true) + } + intercept[IllegalArgumentException] { + new MapType( + new StructType().add("c1", StringType.STRING).add("c1", utf8LcaseString), + StringType.STRING, + false) + } + intercept[IllegalArgumentException] { + new MapType( + new StructType().add("c1", StringType.STRING) + .add("c1", new ArrayType(unicodeString, false)), + StringType.STRING, + false) + } + } } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructFieldSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructFieldSuite.scala index 4385175cb66..7cf29e65f69 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructFieldSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/StructFieldSuite.scala @@ -266,18 +266,13 @@ class StructFieldSuite extends AnyFunSuite { new StructField( "mapField", new MapType( - new StructField("key", new StringType("ICU.DE_DE"), false), + new StructField("key", new StringType("SPARK.UTF8_BINARY"), false), new StructField("value", LongType.LONG, true).withTypeChanges( Seq( new TypeChange(ShortType.SHORT, IntegerType.INTEGER), new TypeChange(IntegerType.INTEGER, LongType.LONG)).asJava)), false), FieldMetadata.builder() - .putFieldMetadata( - COLLATIONS_METADATA_KEY, - FieldMetadata.builder() - .putString("mapField.key", "ICU.DE_DE") - .build()) .putFieldMetadataArray( DELTA_TYPE_CHANGES_KEY, Array(