From 25cc5c33b4cbeff27658ce3014cb23009912e634 Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Fri, 12 Apr 2024 18:42:16 +0200 Subject: [PATCH] #787 - ensure last/first are run on a single partition - 15.0 databricks doesn't process them on ordered dataset --- .../NonAggregateFunctionsTests.scala | 17 +++++++- .../syntax/FramelessSyntaxTests.scala | 39 +++++++++++++------ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index 283a45eb4..ac79edce0 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -824,6 +824,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .sorted val aggrTyped = typedDS + .coalesce(1) .orderBy(typedDS('a).asc) .agg(atan(frameless.functions.aggregate.first(typedDS('a)))) .firstOption() @@ -831,6 +832,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .get val aggrSpark = cDS + .coalesce(1) .orderBy("a") .select( sparkFunctions.atan(sparkFunctions.first("a")).as[Double] @@ -883,6 +885,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .sorted val aggrTyped = typedDS + .coalesce(1) .orderBy(typedDS('a).asc, typedDS('b).asc) .agg( atan2( @@ -895,6 +898,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .get val aggrSpark = cDS + .coalesce(1) .orderBy("a", "b") .select( sparkFunctions @@ -948,6 +952,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .sorted val aggrTyped = typedDS + .coalesce(1) .orderBy(typedDS('a).asc) .agg(atan2(lit, frameless.functions.aggregate.first(typedDS('a)))) .firstOption() @@ -955,6 +960,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .get val aggrSpark = cDS + .coalesce(1) .orderBy("a") .select( sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double] @@ -1006,6 +1012,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .sorted val aggrTyped = typedDS + .coalesce(1) .orderBy(typedDS('a).asc) .agg(atan2(frameless.functions.aggregate.first(typedDS('a)), lit)) .firstOption() @@ -1013,6 +1020,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .get val aggrSpark = cDS + .coalesce(1) .orderBy("a") .select( sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double] @@ -2039,8 +2047,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) val td = - ds.agg(concat(first(ds('a)), first(ds('b)))).collect().run().toVector + ds.coalesce(1) + .agg(concat(first(ds('a)), first(ds('b)))) + .collect() + .run() + .toVector val spark = ds.dataset + .coalesce(1) .select( sparkFunctions.concat( sparkFunctions.first($"a").as[String], @@ -2092,11 +2105,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) val td = ds + .coalesce(1) .agg(concatWs(",", first(ds('a)), first(ds('b)), last(ds('b)))) .collect() .run() .toVector val spark = ds.dataset + .coalesce(1) .select( sparkFunctions.concat_ws( ",", diff --git a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala index 5108ed581..e1d0d52fc 100644 --- a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala +++ b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala @@ -9,26 +9,37 @@ class FramelessSyntaxTests extends TypedDatasetSuite { // Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits override val sparkDelay = null - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data).dataset val dataframe = dataset.toDF() val typedDataset = dataset.typed val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]] - typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame.collect().run().toVector + typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame + .collect() + .run() + .toVector } test("dataset typed - toTyped") { - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { - val dataset = session.createDataset(data)(TypedExpressionEncoder(ev)).typed + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { + val dataset = + session.createDataset(data)(TypedExpressionEncoder(ev)).typed val dataframe = dataset.toDF() - dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector + dataset + .collect() + .run() + .toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector } check(forAll(prop[Int, String] _)) @@ -38,8 +49,14 @@ class FramelessSyntaxTests extends TypedDatasetSuite { test("frameless typed column and aggregate") { def prop[A: TypedEncoder](a: A, b: A): Prop = { val d = TypedDataset.create((a, b) :: Nil) - (d.select(d('_1).untyped.typedColumn).collect().run ?= d.select(d('_1)).collect().run).&&( - d.agg(first(d('_1))).collect().run() ?= d.agg(first(d('_1)).untyped.typedAggregate).collect().run() + (d.coalesce(1).select(d('_1).untyped.typedColumn).collect().run ?= d + .select(d('_1)) + .collect() + .run).&&( + d.coalesce(1).agg(first(d('_1))).collect().run() ?= d + .agg(first(d('_1)).untyped.typedAggregate) + .collect() + .run() ) }