Skip to content

Commit

Permalink
typelevel#787 - ensure last/first are run on a single partition - 15.…
Browse files Browse the repository at this point in the history
…0 databricks doesn't process them on ordered dataset
  • Loading branch information
chris-twiner committed Apr 12, 2024
1 parent b6189b1 commit 25cc5c3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -824,13 +824,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.sorted

val aggrTyped = typedDS
.coalesce(1)
.orderBy(typedDS('a).asc)
.agg(atan(frameless.functions.aggregate.first(typedDS('a))))
.firstOption()
.run()
.get

val aggrSpark = cDS
.coalesce(1)
.orderBy("a")
.select(
sparkFunctions.atan(sparkFunctions.first("a")).as[Double]
Expand Down Expand Up @@ -883,6 +885,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.sorted

val aggrTyped = typedDS
.coalesce(1)
.orderBy(typedDS('a).asc, typedDS('b).asc)
.agg(
atan2(
Expand All @@ -895,6 +898,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.get

val aggrSpark = cDS
.coalesce(1)
.orderBy("a", "b")
.select(
sparkFunctions
Expand Down Expand Up @@ -948,13 +952,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.sorted

val aggrTyped = typedDS
.coalesce(1)
.orderBy(typedDS('a).asc)
.agg(atan2(lit, frameless.functions.aggregate.first(typedDS('a))))
.firstOption()
.run()
.get

val aggrSpark = cDS
.coalesce(1)
.orderBy("a")
.select(
sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double]
Expand Down Expand Up @@ -1006,13 +1012,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.sorted

val aggrTyped = typedDS
.coalesce(1)
.orderBy(typedDS('a).asc)
.agg(atan2(frameless.functions.aggregate.first(typedDS('a)), lit))
.firstOption()
.run()
.get

val aggrSpark = cDS
.coalesce(1)
.orderBy("a")
.select(
sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double]
Expand Down Expand Up @@ -2039,8 +2047,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(pairs) { values: List[X2[String, String]] =>
val ds = TypedDataset.create(values)
val td =
ds.agg(concat(first(ds('a)), first(ds('b)))).collect().run().toVector
ds.coalesce(1)
.agg(concat(first(ds('a)), first(ds('b))))
.collect()
.run()
.toVector
val spark = ds.dataset
.coalesce(1)
.select(
sparkFunctions.concat(
sparkFunctions.first($"a").as[String],
Expand Down Expand Up @@ -2092,11 +2105,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(pairs) { values: List[X2[String, String]] =>
val ds = TypedDataset.create(values)
val td = ds
.coalesce(1)
.agg(concatWs(",", first(ds('a)), first(ds('b)), last(ds('b))))
.collect()
.run()
.toVector
val spark = ds.dataset
.coalesce(1)
.select(
sparkFunctions.concat_ws(
",",
Expand Down
39 changes: 28 additions & 11 deletions dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,37 @@ class FramelessSyntaxTests extends TypedDatasetSuite {
// Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits
override val sparkDelay = null

def prop[A, B](data: Vector[X2[A, B]])(
implicit ev: TypedEncoder[X2[A, B]]
): Prop = {
def prop[A, B](
data: Vector[X2[A, B]]
)(implicit
ev: TypedEncoder[X2[A, B]]
): Prop = {
val dataset = TypedDataset.create(data).dataset
val dataframe = dataset.toDF()

val typedDataset = dataset.typed
val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]]

typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame.collect().run().toVector
typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame
.collect()
.run()
.toVector
}

test("dataset typed - toTyped") {
def prop[A, B](data: Vector[X2[A, B]])(
implicit ev: TypedEncoder[X2[A, B]]
): Prop = {
val dataset = session.createDataset(data)(TypedExpressionEncoder(ev)).typed
def prop[A, B](
data: Vector[X2[A, B]]
)(implicit
ev: TypedEncoder[X2[A, B]]
): Prop = {
val dataset =
session.createDataset(data)(TypedExpressionEncoder(ev)).typed
val dataframe = dataset.toDF()

dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector
dataset
.collect()
.run()
.toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector
}

check(forAll(prop[Int, String] _))
Expand All @@ -38,8 +49,14 @@ class FramelessSyntaxTests extends TypedDatasetSuite {
test("frameless typed column and aggregate") {
def prop[A: TypedEncoder](a: A, b: A): Prop = {
val d = TypedDataset.create((a, b) :: Nil)
(d.select(d('_1).untyped.typedColumn).collect().run ?= d.select(d('_1)).collect().run).&&(
d.agg(first(d('_1))).collect().run() ?= d.agg(first(d('_1)).untyped.typedAggregate).collect().run()
(d.coalesce(1).select(d('_1).untyped.typedColumn).collect().run ?= d
.select(d('_1))
.collect()
.run).&&(
d.coalesce(1).agg(first(d('_1))).collect().run() ?= d
.agg(first(d('_1)).untyped.typedAggregate)
.collect()
.run()
)
}

Expand Down

0 comments on commit 25cc5c3

Please sign in to comment.