From a54980796d62a8c004e19b9715d7b937ffde2a8c Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Wed, 2 Oct 2024 17:00:27 +0200 Subject: [PATCH] #787 - use shim'd joinWith - requires extra implicit --- .../main/scala/frameless/TypedDataset.scala | 51 +++++++++---------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index c08414cd..a4267fc5 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -4,27 +4,16 @@ import java.util import frameless.functions.CatalystExplodableCollection import frameless.ops._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{ Column, DataFrame, Dataset, SparkSession } -import org.apache.spark.sql.catalyst.expressions.{ - Attribute, - AttributeReference, - Literal -} -import org.apache.spark.sql.catalyst.plans.logical.{ Join, JoinHint } +import org.apache.spark.sql.{Column, DataFrame, Dataset, ShimUtils, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.ShimUtils.column import org.apache.spark.sql.types.StructType import shapeless._ import shapeless.labelled.FieldType -import shapeless.ops.hlist.{ - Diff, - IsHCons, - Mapper, - Prepend, - ToTraversable, - Tupler -} -import shapeless.ops.record.{ Keys, Modifier, Remover, Values } +import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler} +import shapeless.ops.record.{Keys, Modifier, Remover, Values} import scala.language.experimental.macros @@ -767,7 +756,8 @@ class TypedDataset[T] protected[frameless] ( e: TypedEncoder[(T, U)] ): TypedDataset[(T, U)] = new TypedDataset( - self.dataset.joinWith(other.dataset, column(Literal(true)), "cross") + ShimUtils.joinWith(dataset, other.dataset, column(Literal(true)), "cross")(TypedExpressionEncoder[(T, U)]) + //self.dataset.joinWith(other.dataset, column(Literal(true)), "cross") ) /** @@ -778,14 +768,17 @@ class TypedDataset[T] protected[frameless] ( other: TypedDataset[U] )(condition: TypedColumn[T with U, Boolean] )(implicit - e: TypedEncoder[(Option[T], Option[U])] + e: TypedEncoder[(Option[T], Option[U])], + to: TypedEncoder[(T, U)] ): TypedDataset[(Option[T], Option[U])] = new TypedDataset( - self.dataset + ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "full")(TypedExpressionEncoder[(T, U)]) + .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])]) + /*self.dataset .joinWith(other.dataset, condition.untyped, "full") .as[(Option[T], Option[U])]( TypedExpressionEncoder[(Option[T], Option[U])] - ) + )*/ ) /** @@ -820,12 +813,15 @@ class TypedDataset[T] protected[frameless] ( other: TypedDataset[U] )(condition: TypedColumn[T with U, Boolean] )(implicit - e: TypedEncoder[(T, Option[U])] + e: TypedEncoder[(T, Option[U])], + to: TypedEncoder[(T, U)] ): TypedDataset[(T, Option[U])] = new TypedDataset( - self.dataset - .joinWith(other.dataset, condition.untyped, "left_outer") + ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "left_outer")(TypedExpressionEncoder[(T, U)]) .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]) + /*self.dataset + .joinWith(other.dataset, condition.untyped, "left_outer") + .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])*/ ) /** @@ -864,12 +860,15 @@ class TypedDataset[T] protected[frameless] ( other: TypedDataset[U] )(condition: TypedColumn[T with U, Boolean] )(implicit - e: TypedEncoder[(Option[T], U)] + e: TypedEncoder[(Option[T], U)], + to: TypedEncoder[(T, U)] ): TypedDataset[(Option[T], U)] = new TypedDataset( - self.dataset - .joinWith(other.dataset, condition.untyped, "right_outer") + ShimUtils.joinWith( self.dataset, other.dataset, condition.untyped, "right_outer")(TypedExpressionEncoder[(T, U)]) .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)]) + /*self.dataset + .joinWith(other.dataset, condition.untyped, "right_outer") + .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)])*/ ) private def disambiguate(join: Join): Join = {