diff --git a/build.sbt b/build.sbt index 7ba262c62..ded5f1e50 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,5 @@ -val sparkVersion = "3.5.1" +val sparkVersion = + "3.5.1" // "4.0.0-SNAPSHOT" must have the apache_snaps configured val spark34Version = "3.4.2" val spark33Version = "3.3.4" val catsCoreVersion = "2.10.0" @@ -11,10 +12,32 @@ val scalacheck = "1.17.0" val scalacheckEffect = "1.0.4" val refinedVersion = "0.11.1" val nakedFSVersion = "0.1.0" +val shimVersion = "0.0.1-RC4" val Scala212 = "2.12.19" val Scala213 = "2.13.13" +resolvers in Global += Resolver.mavenLocal +resolvers in Global += MavenRepository( + "sonatype-s01-snapshots", + Resolver.SonatypeS01RepositoryRoot + "/snapshots" +) +resolvers in Global += MavenRepository( + "sonatype-s01-releases", + Resolver.SonatypeS01RepositoryRoot + "/releases" +) +resolvers in Global += MavenRepository( + "apache_snaps", + "https://repository.apache.org/content/repositories/snapshots" +) + +import scala.concurrent.duration.DurationInt +import lmcoursier.definitions.CachePolicy + +csrConfiguration := csrConfiguration.value + .withTtl(Some(1.minute)) + .withCachePolicies(Vector(CachePolicy.LocalOnly)) + ThisBuild / tlBaseVersion := "0.16" ThisBuild / crossScalaVersions := Seq(Scala213, Scala212) @@ -87,10 +110,10 @@ lazy val `cats-spark33` = project lazy val dataset = project .settings(name := "frameless-dataset") .settings( - Compile / unmanagedSourceDirectories += baseDirectory.value / "src" / "main" / "spark-3.4+" + Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "spark-3.3+" ) .settings( - Test / unmanagedSourceDirectories += baseDirectory.value / "src" / "test" / "spark-3.3+" + libraryDependencies += "com.sparkutils" %% "shim_runtime_3.5.0.oss_3.5" % shimVersion changing () // 4.0.0.oss_4.0 for 4 snapshot ) .settings(datasetSettings) .settings(sparkDependencies(sparkVersion)) @@ -100,10 +123,10 @@ lazy val `dataset-spark34` = project .settings(name := "frameless-dataset-spark34") .settings(sourceDirectory := (dataset / sourceDirectory).value) .settings( - Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3.4+" + Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+" ) .settings( - Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+" + libraryDependencies += "com.sparkutils" %% "shim_runtime_3.4.1.oss_3.4" % shimVersion changing () ) .settings(datasetSettings) .settings(sparkDependencies(spark34Version)) @@ -114,10 +137,10 @@ lazy val `dataset-spark33` = project .settings(name := "frameless-dataset-spark33") .settings(sourceDirectory := (dataset / sourceDirectory).value) .settings( - Compile / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "main" / "spark-3" + Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+" ) .settings( - Test / unmanagedSourceDirectories += (dataset / baseDirectory).value / "src" / "test" / "spark-3.3+" + libraryDependencies += "com.sparkutils" %% "shim_runtime_3.3.2.oss_3.3" % shimVersion changing () ) .settings(datasetSettings) .settings(sparkDependencies(spark33Version)) @@ -239,11 +262,29 @@ lazy val datasetSettings = imt("frameless.RecordEncoderFields.deriveRecordLast"), mc("frameless.functions.FramelessLit"), mc(f"frameless.functions.FramelessLit$$"), + mc("org.apache.spark.sql.FramelessInternals"), + mc(f"org.apache.spark.sql.FramelessInternals$$"), + mc("org.apache.spark.sql.FramelessInternals$DisambiguateLeft"), + mc("org.apache.spark.sql.FramelessInternals$DisambiguateLeft$"), + mc("org.apache.spark.sql.FramelessInternals$DisambiguateRight"), + mc("org.apache.spark.sql.FramelessInternals$DisambiguateRight$"), + mc("org.apache.spark.sql.reflection.package"), + mc("org.apache.spark.sql.reflection.package$"), + mc("org.apache.spark.sql.reflection.package$ScalaSubtypeLock$"), + mc("frameless.MapGroups"), + mc(f"frameless.MapGroups$$"), dmm("frameless.functions.package.litAggr"), - dmm("org.apache.spark.sql.FramelessInternals.column") + dmm("org.apache.spark.sql.FramelessInternals.column"), + dmm("frameless.TypedEncoder.collectionEncoder"), + dmm("frameless.TypedEncoder.setEncoder"), + dmm("frameless.functions.FramelessUdf.evalCode"), + dmm("frameless.functions.FramelessUdf.copy"), + dmm("frameless.functions.FramelessUdf.this"), + dmm("frameless.functions.FramelessUdf.apply"), + imt("frameless.functions.FramelessUdf.apply") ) }, - coverageExcludedPackages := "org.apache.spark.sql.reflection", + coverageExcludedPackages := "frameless.reflection", libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude ("org.apache.hadoop", "hadoop-commons") ) @@ -252,7 +293,18 @@ lazy val refinedSettings = libraryDependencies += "eu.timepit" %% "refined" % refinedVersion ) -lazy val mlSettings = framelessSettings ++ framelessTypedDatasetREPL +lazy val mlSettings = framelessSettings ++ framelessTypedDatasetREPL ++ Seq( + mimaBinaryIssueFilters ++= { + import com.typesafe.tools.mima.core._ + + val mc = ProblemFilters.exclude[MissingClassProblem](_) + + Seq( + mc("org.apache.spark.ml.FramelessInternals"), + mc(f"org.apache.spark.ml.FramelessInternals$$") + ) + } +) lazy val scalac212Options = Seq( "-Xlint:-missing-interpolator,-unused,_", @@ -324,7 +376,10 @@ lazy val framelessSettings = Seq( * [error] +- org.scoverage:scalac-scoverage-reporter_2.12:2.0.7 (depends on 2.1.0) * [error] +- org.scala-lang:scala-compiler:2.12.16 (depends on 1.0.6) */ - libraryDependencySchemes += "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always + libraryDependencySchemes += "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always, + // allow testing on different runtimes, but don't publish / run docs + Test / publishArtifact := true, + Test / packageDoc / publishArtifact := false ) ++ consoleSettings lazy val spark34Settings = Seq[Setting[_]]( diff --git a/cats/src/test/scala/frameless/cats/test.scala b/cats/src/test/scala/frameless/cats/test.scala index d75bc3bfd..faac43163 100644 --- a/cats/src/test/scala/frameless/cats/test.scala +++ b/cats/src/test/scala/frameless/cats/test.scala @@ -7,7 +7,7 @@ import _root_.cats.syntax.all._ import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext => SC} +import org.apache.spark.{ SparkConf, SparkContext => SC } import org.scalatest.compatible.Assertion import org.scalactic.anyvals.PosInt @@ -21,7 +21,11 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.propspec.AnyPropSpec trait SparkTests { - val appID: String = new java.util.Date().toString + math.floor(math.random() * 10E4).toLong.toString + + val appID: String = new java.util.Date().toString + math + .floor(math.random() * 10e4) + .toLong + .toString val conf: SparkConf = new SparkConf() .setMaster("local[*]") @@ -29,16 +33,27 @@ trait SparkTests { .set("spark.ui.enabled", "false") .set("spark.app.id", appID) - implicit def session: SparkSession = SparkSession.builder().config(conf).getOrCreate() + implicit def session: SparkSession = + SparkSession.builder().config(conf).getOrCreate() implicit def sc: SparkContext = session.sparkContext - implicit class seqToRdd[A: ClassTag](seq: Seq[A])(implicit sc: SC) { + implicit class seqToRdd[A: ClassTag]( + seq: Seq[A] + )(implicit + sc: SC) { def toRdd: RDD[A] = sc.makeRDD(seq) } } object Tests { - def innerPairwise(mx: Map[String, Int], my: Map[String, Int], check: (Any, Any) => Assertion)(implicit sc: SC): Assertion = { + + def innerPairwise( + mx: Map[String, Int], + my: Map[String, Int], + check: (Any, Any) => Assertion + )(implicit + sc: SC + ): Assertion = { import frameless.cats.implicits._ import frameless.cats.inner._ val xs = sc.parallelize(mx.toSeq) @@ -63,21 +78,31 @@ object Tests { } } -class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with SparkTests { +class Test + extends AnyPropSpec + with Matchers + with ScalaCheckPropertyChecks + with SparkTests { + implicit override val generatorDrivenConfig = PropertyCheckConfiguration(minSize = PosInt(10)) property("spark is working") { - sc.parallelize(Seq(1, 2, 3)).collect() shouldBe Array(1,2,3) + sc.parallelize(Seq(1, 2, 3)).collect() shouldBe Array(1, 2, 3) } property("inner pairwise monoid") { // Make sure we have non-empty map - forAll { (xh: (String, Int), mx: Map[String, Int], yh: (String, Int), my: Map[String, Int]) => - Tests.innerPairwise(mx + xh, my + yh, _ shouldBe _) + forAll { + (xh: (String, Int), + mx: Map[String, Int], + yh: (String, Int), + my: Map[String, Int] + ) => Tests.innerPairwise(mx + xh, my + yh, _ shouldBe _) } } + org.scalatestplus.scalacheck.Checkers property("rdd simple numeric commutative semigroup") { import frameless.cats.implicits._ @@ -110,7 +135,8 @@ class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with property("rdd tuple commutative semigroup example") { import frameless.cats.implicits._ forAll { seq: List[(Int, Int)] => - val expectedSum = if (seq.isEmpty) None else Some(Foldable[List].fold(seq)) + val expectedSum = + if (seq.isEmpty) None else Some(Foldable[List].fold(seq)) val rdd = seq.toRdd rdd.csum shouldBe expectedSum.getOrElse(0 -> 0) @@ -120,10 +146,22 @@ class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with property("pair rdd numeric commutative semigroup example") { import frameless.cats.implicits._ - val seq = Seq( ("a",2), ("b",3), ("d",6), ("b",2), ("d",1) ) + val seq = Seq(("a", 2), ("b", 3), ("d", 6), ("b", 2), ("d", 1)) val rdd = seq.toRdd - rdd.cminByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",2), ("d",1) ) - rdd.cmaxByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",3), ("d",6) ) - rdd.csumByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",5), ("d",7) ) + rdd.cminByKey.collect().toSeq should contain theSameElementsAs Seq( + ("a", 2), + ("b", 2), + ("d", 1) + ) + rdd.cmaxByKey.collect().toSeq should contain theSameElementsAs Seq( + ("a", 2), + ("b", 3), + ("d", 6) + ) + rdd.csumByKey.collect().toSeq should contain theSameElementsAs Seq( + ("a", 2), + ("b", 5), + ("d", 7) + ) } } diff --git a/dataset/src/main/scala/frameless/CollectionCaster.scala b/dataset/src/main/scala/frameless/CollectionCaster.scala new file mode 100644 index 000000000..bf329992e --- /dev/null +++ b/dataset/src/main/scala/frameless/CollectionCaster.scala @@ -0,0 +1,67 @@ +package frameless + +import frameless.TypedEncoder.CollectionConversion +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{ + CodegenContext, + CodegenFallback, + ExprCode +} +import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression } +import org.apache.spark.sql.types.{ DataType, ObjectType } + +case class CollectionCaster[F[_], C[_], Y]( + child: Expression, + conversion: CollectionConversion[F, C, Y]) + extends UnaryExpression + with CodegenFallback { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: F[Y] @unchecked => + conversion.convert(col) + case _ => o + } + } + + override def dataType: DataType = child.dataType +} + +case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression) + extends UnaryExpression { + + protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + // eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good + override def eval(input: InternalRow): Any = { + val o = child.eval(input).asInstanceOf[Object] + o match { + case col: Set[Y] @unchecked => + col.toSeq + case _ => o + } + } + + def toSeqOr[T](isSet: => T, or: => T): T = + child.dataType match { + case ObjectType(cls) + if classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + isSet + case t => or + } + + override def dataType: DataType = + toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType) + + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = + defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c")) + +} diff --git a/dataset/src/main/scala/frameless/FramelessInternals.scala b/dataset/src/main/scala/frameless/FramelessInternals.scala new file mode 100644 index 000000000..5a705cbf1 --- /dev/null +++ b/dataset/src/main/scala/frameless/FramelessInternals.scala @@ -0,0 +1,114 @@ +package frameless + +import com.sparkutils.shim.expressions.{ + Alias2 => Alias, + CreateStruct1 => CreateStruct +} +import org.apache.spark.sql.shim.{ utils => shimUtils } +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + NamedExpression, + NonSQLExpression +} +import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Project } +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types._ +import org.apache.spark.sql._ + +import scala.reflect.ClassTag + +object FramelessInternals { + + def objectTypeFor[A]( + implicit + classTag: ClassTag[A] + ): ObjectType = ObjectType(classTag.runtimeClass) + + def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = { + ds.toDF() + .queryExecution + .analyzed + .resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver) + .getOrElse { + throw org.apache.spark.sql.ShimUtils.analysisException(ds, colNames) + } + } + + def expr(column: Column): Expression = column.expr + + def logicalPlan(ds: Dataset[_]): LogicalPlan = shimUtils.logicalPlan(ds) + + def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution = + ShimUtils.executePlan(ds, plan) + + def joinPlan( + ds: Dataset[_], + plan: LogicalPlan, + leftPlan: LogicalPlan, + rightPlan: LogicalPlan + ): LogicalPlan = { + val joined = executePlan(ds, plan) + val leftOutput = joined.analyzed.output.take(leftPlan.output.length) + val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length) + + Project( + List( + Alias(CreateStruct(leftOutput), "_1")(), + Alias(CreateStruct(rightOutput), "_2")() + ), + joined.analyzed + ) + } + + def mkDataset[T]( + sqlContext: SQLContext, + plan: LogicalPlan, + encoder: Encoder[T] + ): Dataset[T] = + new Dataset(sqlContext, plan, encoder) + + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = + shimUtils.ofRows(sparkSession, logicalPlan) + + // because org.apache.spark.sql.types.UserDefinedType is private[spark] + type UserDefinedType[A >: Null] = + org.apache.spark.sql.types.UserDefinedType[A] + + // below only tested in SelfJoinTests.colLeft and colRight are equivalent to col outside of joins + // - via files (codegen) forces doGenCode eval. + /** Expression to tag columns from the left hand side of join expression. */ + case class DisambiguateLeft[T](tagged: Expression) + extends Expression + with NonSQLExpression { + def eval(input: InternalRow): Any = tagged.eval(input) + def nullable: Boolean = false + def children: Seq[Expression] = tagged :: Nil + def dataType: DataType = tagged.dataType + + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + tagged.genCode(ctx) + + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(newChildren.head) + } + + /** Expression to tag columns from the right hand side of join expression. */ + case class DisambiguateRight[T](tagged: Expression) + extends Expression + with NonSQLExpression { + def eval(input: InternalRow): Any = tagged.eval(input) + def nullable: Boolean = false + def children: Seq[Expression] = tagged :: Nil + def dataType: DataType = tagged.dataType + + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + tagged.genCode(ctx) + + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(newChildren.head) + } +} diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 7427d9de0..574ce4272 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -1,13 +1,18 @@ package frameless -import org.apache.spark.sql.FramelessInternals - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{ - Invoke, NewInstance, UnwrapOption, WrapOption +import com.sparkutils.shim.expressions.{ + CreateNamedStruct1 => CreateNamedStruct, + GetStructField3 => GetStructField, + UnwrapOption2 => UnwrapOption, + WrapOption2 => WrapOption +} +import com.sparkutils.shim.{ deriveUnitLiteral, ifIsNull } +import org.apache.spark.sql.catalyst.expressions.{ Expression, Literal } +import org.apache.spark.sql.shim.{ + Invoke5 => Invoke, + NewInstance4 => NewInstance } import org.apache.spark.sql.types._ - import shapeless._ import shapeless.labelled.FieldType import shapeless.ops.hlist.IsHCons @@ -16,10 +21,9 @@ import shapeless.ops.record.Keys import scala.reflect.ClassTag case class RecordEncoderField( - ordinal: Int, - name: String, - encoder: TypedEncoder[_] -) + ordinal: Int, + name: String, + encoder: TypedEncoder[_]) trait RecordEncoderFields[T <: HList] extends Serializable { def value: List[RecordEncoderField] @@ -30,33 +34,42 @@ trait RecordEncoderFields[T <: HList] extends Serializable { object RecordEncoderFields { - implicit def deriveRecordLast[K <: Symbol, H] - (implicit + implicit def deriveRecordLast[K <: Symbol, H]( + implicit key: Witness.Aux[K], head: RecordFieldEncoder[H] - ): RecordEncoderFields[FieldType[K, H] :: HNil] = new RecordEncoderFields[FieldType[K, H] :: HNil] { + ): RecordEncoderFields[FieldType[K, H] :: HNil] = + new RecordEncoderFields[FieldType[K, H] :: HNil] { def value: List[RecordEncoderField] = fieldEncoder[K, H] :: Nil } - implicit def deriveRecordCons[K <: Symbol, H, T <: HList] - (implicit + implicit def deriveRecordCons[K <: Symbol, H, T <: HList]( + implicit key: Witness.Aux[K], head: RecordFieldEncoder[H], tail: RecordEncoderFields[T] - ): RecordEncoderFields[FieldType[K, H] :: T] = new RecordEncoderFields[FieldType[K, H] :: T] { + ): RecordEncoderFields[FieldType[K, H] :: T] = + new RecordEncoderFields[FieldType[K, H] :: T] { + def value: List[RecordEncoderField] = - fieldEncoder[K, H] :: tail.value.map(x => x.copy(ordinal = x.ordinal + 1)) - } + fieldEncoder[K, H] :: tail.value.map(x => + x.copy(ordinal = x.ordinal + 1) + ) + } - private def fieldEncoder[K <: Symbol, H](implicit key: Witness.Aux[K], e: RecordFieldEncoder[H]): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder) + private def fieldEncoder[K <: Symbol, H]( + implicit + key: Witness.Aux[K], + e: RecordFieldEncoder[H] + ): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder) } /** - * Assists the generation of constructor call parameters from a labelled generic representation. - * As Unit typed fields were removed earlier, we need to put back unit literals in the appropriate positions. - * - * @tparam T labelled generic representation of type fields - */ + * Assists the generation of constructor call parameters from a labelled generic representation. + * As Unit typed fields were removed earlier, we need to put back unit literals in the appropriate positions. + * + * @tparam T labelled generic representation of type fields + */ trait NewInstanceExprs[T <: HList] extends Serializable { def from(exprs: List[Expression]): Seq[Expression] } @@ -67,32 +80,43 @@ object NewInstanceExprs { def from(exprs: List[Expression]): Seq[Expression] = Nil } - implicit def deriveUnit[K <: Symbol, T <: HList] - (implicit + implicit def deriveUnit[K <: Symbol, T <: HList]( + implicit tail: NewInstanceExprs[T] - ): NewInstanceExprs[FieldType[K, Unit] :: T] = new NewInstanceExprs[FieldType[K, Unit] :: T] { + ): NewInstanceExprs[FieldType[K, Unit] :: T] = + new NewInstanceExprs[FieldType[K, Unit] :: T] { + def from(exprs: List[Expression]): Seq[Expression] = - Literal.fromObject(()) +: tail.from(exprs) + deriveUnitLiteral +: tail.from(exprs) } - implicit def deriveNonUnit[K <: Symbol, V, T <: HList] - (implicit + implicit def deriveNonUnit[K <: Symbol, V, T <: HList]( + implicit notUnit: V =:!= Unit, tail: NewInstanceExprs[T] - ): NewInstanceExprs[FieldType[K, V] :: T] = new NewInstanceExprs[FieldType[K, V] :: T] { - def from(exprs: List[Expression]): Seq[Expression] = exprs.head +: tail.from(exprs.tail) + ): NewInstanceExprs[FieldType[K, V] :: T] = + new NewInstanceExprs[FieldType[K, V] :: T] { + + def from(exprs: List[Expression]): Seq[Expression] = + exprs.head +: tail.from(exprs.tail) } } /** - * Drops fields with Unit type from labelled generic representation of types. - * - * @tparam L labelled generic representation of type fields - */ -trait DropUnitValues[L <: HList] extends DepFn1[L] with Serializable { type Out <: HList } + * Drops fields with Unit type from labelled generic representation of types. + * + * @tparam L labelled generic representation of type fields + */ +trait DropUnitValues[L <: HList] extends DepFn1[L] with Serializable { + type Out <: HList +} object DropUnitValues { - def apply[L <: HList](implicit dropUnitValues: DropUnitValues[L]): Aux[L, dropUnitValues.Out] = dropUnitValues + + def apply[L <: HList]( + implicit + dropUnitValues: DropUnitValues[L] + ): Aux[L, dropUnitValues.Out] = dropUnitValues type Aux[L <: HList, Out0 <: HList] = DropUnitValues[L] { type Out = Out0 } @@ -101,93 +125,91 @@ object DropUnitValues { def apply(l: HNil): Out = HNil } - implicit def deriveUnit[K <: Symbol, T <: HList, OutT <: HList] - (implicit - dropUnitValues : DropUnitValues.Aux[T, OutT] - ): Aux[FieldType[K, Unit] :: T, OutT] = new DropUnitValues[FieldType[K, Unit] :: T] { + implicit def deriveUnit[K <: Symbol, T <: HList, OutT <: HList]( + implicit + dropUnitValues: DropUnitValues.Aux[T, OutT] + ): Aux[FieldType[K, Unit] :: T, OutT] = + new DropUnitValues[FieldType[K, Unit] :: T] { type Out = OutT - def apply(l : FieldType[K, Unit] :: T): Out = dropUnitValues(l.tail) + def apply(l: FieldType[K, Unit] :: T): Out = dropUnitValues(l.tail) } - implicit def deriveNonUnit[K <: Symbol, V, T <: HList, OutH, OutT <: HList] - (implicit + implicit def deriveNonUnit[K <: Symbol, V, T <: HList, OutH, OutT <: HList]( + implicit nonUnit: V =:!= Unit, - dropUnitValues : DropUnitValues.Aux[T, OutT] - ): Aux[FieldType[K, V] :: T, FieldType[K, V] :: OutT] = new DropUnitValues[FieldType[K, V] :: T] { + dropUnitValues: DropUnitValues.Aux[T, OutT] + ): Aux[FieldType[K, V] :: T, FieldType[K, V] :: OutT] = + new DropUnitValues[FieldType[K, V] :: T] { type Out = FieldType[K, V] :: OutT - def apply(l : FieldType[K, V] :: T): Out = l.head :: dropUnitValues(l.tail) + def apply(l: FieldType[K, V] :: T): Out = l.head :: dropUnitValues(l.tail) } } -class RecordEncoder[F, G <: HList, H <: HList] - (implicit +class RecordEncoder[F, G <: HList, H <: HList]( + implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], i2: IsHCons[H], fields: Lazy[RecordEncoderFields[H]], newInstanceExprs: Lazy[NewInstanceExprs[G]], - classTag: ClassTag[F] - ) extends TypedEncoder[F] { - def nullable: Boolean = false - - def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] - - def catalystRepr: DataType = { - val structFields = fields.value.value.map { field => - StructField( - name = field.name, - dataType = field.encoder.catalystRepr, - nullable = field.encoder.nullable, - metadata = Metadata.empty - ) - } - - StructType(structFields) + classTag: ClassTag[F]) + extends TypedEncoder[F] { + def nullable: Boolean = false + + def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] + + def catalystRepr: DataType = { + val structFields = fields.value.value.map { field => + StructField( + name = field.name, + dataType = field.encoder.catalystRepr, + nullable = field.encoder.nullable, + metadata = Metadata.empty + ) } - def toCatalyst(path: Expression): Expression = { - val nameExprs = fields.value.value.map { field => - Literal(field.name) - } - - val valueExprs = fields.value.value.map { field => - val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) - field.encoder.toCatalyst(fieldPath) - } - - // the way exprs are encoded in CreateNamedStruct - val exprs = nameExprs.zip(valueExprs).flatMap { - case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil - } + StructType(structFields) + } - val createExpr = CreateNamedStruct(exprs) - val nullExpr = Literal.create(null, createExpr.dataType) + def toCatalyst(path: Expression): Expression = { + val nameExprs = fields.value.value.map { field => Literal(field.name) } - If(IsNull(path), nullExpr, createExpr) + val valueExprs = fields.value.value.map { field => + val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) + field.encoder.toCatalyst(fieldPath) } - def fromCatalyst(path: Expression): Expression = { - val exprs = fields.value.value.map { field => - field.encoder.fromCatalyst( - GetStructField(path, field.ordinal, Some(field.name))) - } + // the way exprs are encoded in CreateNamedStruct + val exprs = nameExprs.zip(valueExprs).flatMap { + case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil + } - val newArgs = newInstanceExprs.value.from(exprs) - val newExpr = NewInstance( - classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) + val createExpr = CreateNamedStruct(exprs) - val nullExpr = Literal.create(null, jvmRepr) + ifIsNull(createExpr.dataType, path, createExpr) + } - If(IsNull(path), nullExpr, newExpr) + def fromCatalyst(path: Expression): Expression = { + val exprs = fields.value.value.map { field => + field.encoder.fromCatalyst( + GetStructField(path, field.ordinal, Some(field.name)) + ) } + + val newArgs = newInstanceExprs.value.from(exprs) + val newExpr = + NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) + + ifIsNull(jvmRepr, path, newExpr) + } } final class RecordFieldEncoder[T]( - val encoder: TypedEncoder[T], - private[frameless] val jvmRepr: DataType, - private[frameless] val fromCatalyst: Expression => Expression, - private[frameless] val toCatalyst: Expression => Expression -) extends Serializable + val encoder: TypedEncoder[T], + private[frameless] val jvmRepr: DataType, + private[frameless] val fromCatalyst: Expression => Expression, + private[frameless] val toCatalyst: Expression => Expression) + extends Serializable object RecordFieldEncoder extends RecordFieldEncoderLowPriority { @@ -198,8 +220,14 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam K the key type for the fields * @tparam V the inner value type */ - implicit def optionValueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] - (implicit + implicit def optionValueClass[ + F: IsValueClass, + G <: ::[_, HNil], + H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], + K <: Symbol, + V, + KS <: ::[_ <: Symbol, HNil] + ](implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], @@ -208,49 +236,49 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i5: TypedEncoder[V], i6: ClassTag[F] ): RecordFieldEncoder[Option[F]] = { - val fieldName = i4.head(i3()).name - val innerJvmRepr = ObjectType(i6.runtimeClass) + val fieldName = i4.head(i3()).name + val innerJvmRepr = ObjectType(i6.runtimeClass) - val catalyst: Expression => Expression = { path => - val value = UnwrapOption(innerJvmRepr, path) - val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) + val catalyst: Expression => Expression = { path => + val value = UnwrapOption(innerJvmRepr, path) + val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) - i5.toCatalyst(javaValue) - } + i5.toCatalyst(javaValue) + } - val fromCatalyst: Expression => Expression = { path => - val javaValue = i5.fromCatalyst(path) - val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) + val fromCatalyst: Expression => Expression = { path => + val javaValue = i5.fromCatalyst(path) + val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) - WrapOption(value, innerJvmRepr) - } + WrapOption(value, innerJvmRepr) + } - val jvmr = ObjectType(classOf[Option[F]]) + val jvmr = ObjectType(classOf[Option[F]]) - new RecordFieldEncoder[Option[F]]( - encoder = new TypedEncoder[Option[F]] { - val nullable = true + new RecordFieldEncoder[Option[F]]( + encoder = new TypedEncoder[Option[F]] { + val nullable = true - val jvmRepr = jvmr + val jvmRepr = jvmr - @inline def catalystRepr: DataType = i5.catalystRepr + @inline def catalystRepr: DataType = i5.catalystRepr - def fromCatalyst(path: Expression): Expression = { - val javaValue = i5.fromCatalyst(path) - val value = NewInstance( - i6.runtimeClass, Seq(javaValue), innerJvmRepr) + def fromCatalyst(path: Expression): Expression = { + val javaValue = i5.fromCatalyst(path) + val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) - WrapOption(value, innerJvmRepr) - } + WrapOption(value, innerJvmRepr) + } - def toCatalyst(path: Expression): Expression = catalyst(path) + def toCatalyst(path: Expression): Expression = catalyst(path) - override def toString: String = s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)" - }, - jvmRepr = jvmr, - fromCatalyst = fromCatalyst, - toCatalyst = catalyst - ) + override def toString: String = + s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)" + }, + jvmRepr = jvmr, + fromCatalyst = fromCatalyst, + toCatalyst = catalyst + ) } /** @@ -259,8 +287,14 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam H the single field of the value class (with guarantee it's not a `Unit` value) * @tparam V the inner value type */ - implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] - (implicit + implicit def valueClass[ + F: IsValueClass, + G <: ::[_, HNil], + H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], + K <: Symbol, + V, + KS <: ::[_ <: Symbol, HNil] + ](implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], @@ -269,40 +303,47 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i5: TypedEncoder[V], i6: ClassTag[F] ): RecordFieldEncoder[F] = { - val cls = i6.runtimeClass - val jvmr = i5.jvmRepr - val fieldName = i4.head(i3()).name - - new RecordFieldEncoder[F]( - encoder = new TypedEncoder[F] { - def nullable = i5.nullable - - def jvmRepr = jvmr - - def catalystRepr: DataType = i5.catalystRepr - - def fromCatalyst(path: Expression): Expression = - i5.fromCatalyst(path) - - @inline def toCatalyst(path: Expression): Expression = - i5.toCatalyst(path) - - override def toString: String = s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})" - }, - jvmRepr = FramelessInternals.objectTypeFor[F], - fromCatalyst = { expr: Expression => - NewInstance( - i6.runtimeClass, - i5.fromCatalyst(expr) :: Nil, - ObjectType(i6.runtimeClass)) - }, - toCatalyst = { expr: Expression => - i5.toCatalyst(Invoke(expr, fieldName, jvmr)) - } - ) + val cls = i6.runtimeClass + val jvmr = i5.jvmRepr + val fieldName = i4.head(i3()).name + + new RecordFieldEncoder[F]( + encoder = new TypedEncoder[F] { + def nullable = i5.nullable + + def jvmRepr = jvmr + + def catalystRepr: DataType = i5.catalystRepr + + def fromCatalyst(path: Expression): Expression = + i5.fromCatalyst(path) + + @inline def toCatalyst(path: Expression): Expression = + i5.toCatalyst(path) + + override def toString: String = + s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})" + }, + jvmRepr = FramelessInternals.objectTypeFor[F], + fromCatalyst = { expr: Expression => + NewInstance( + i6.runtimeClass, + i5.fromCatalyst(expr) :: Nil, + ObjectType(i6.runtimeClass) + ) + }, + toCatalyst = { expr: Expression => + i5.toCatalyst(Invoke(expr, fieldName, jvmr)) + } + ) } } private[frameless] sealed trait RecordFieldEncoderLowPriority { - implicit def apply[T](implicit e: TypedEncoder[T]): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst) + + implicit def apply[T]( + implicit + e: TypedEncoder[T] + ): RecordFieldEncoder[T] = + new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst) } diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 0bbaf6fed..5a31a8529 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -1,11 +1,14 @@ package frameless -import frameless.functions.{litAggr, lit => flit} +import frameless.functions.{ litAggr, lit => flit } import frameless.syntax._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + Literal +} // 787 - Spark 4 source code compat import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{Column, FramelessInternals} +import org.apache.spark.sql.Column import shapeless._ import shapeless.ops.record.Selector @@ -13,6 +16,15 @@ import shapeless.ops.record.Selector import scala.annotation.implicitNotFound import scala.reflect.ClassTag +import com.sparkutils.shim.expressions.{ + EqualNullSafe2 => EqualNullSafe, + EqualTo2 => EqualTo, + Not1 => Not, + IsNull1 => IsNull, + IsNotNull1 => IsNotNull, + Coalesce1 => Coalesce +} // 787 - Spark 4 source code compat + import scala.language.experimental.macros sealed trait UntypedExpression[T] { @@ -21,91 +33,121 @@ sealed trait UntypedExpression[T] { override def toString: String = expr.toString() } -/** Expression used in `select`-like constructions. - */ -sealed class TypedColumn[T, U](expr: Expression)( - implicit val uenc: TypedEncoder[U] -) extends AbstractTypedColumn[T, U](expr) { +/** + * Expression used in `select`-like constructions. + */ +sealed class TypedColumn[T, U]( + expr: Expression + )(implicit + val uenc: TypedEncoder[U]) + extends AbstractTypedColumn[T, U](expr) { type ThisType[A, B] = TypedColumn[A, B] - def this(column: Column)(implicit uencoder: TypedEncoder[U]) = + def this( + column: Column + )(implicit + uencoder: TypedEncoder[U] + ) = this(FramelessInternals.expr(column)) - override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn + override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = + c.typedColumn override def lit[U1: TypedEncoder](c: U1): TypedColumn[T, U1] = flit(c) } -/** Expression used in `agg`-like constructions. - */ -sealed class TypedAggregate[T, U](expr: Expression)( - implicit val uenc: TypedEncoder[U] -) extends AbstractTypedColumn[T, U](expr) { +/** + * Expression used in `agg`-like constructions. + */ +sealed class TypedAggregate[T, U]( + expr: Expression + )(implicit + val uenc: TypedEncoder[U]) + extends AbstractTypedColumn[T, U](expr) { type ThisType[A, B] = TypedAggregate[A, B] - def this(column: Column)(implicit uencoder: TypedEncoder[U]) = { + def this( + column: Column + )(implicit + uencoder: TypedEncoder[U] + ) = { this(FramelessInternals.expr(column)) } - override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] = c.typedAggregate + override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] = + c.typedAggregate override def lit[U1: TypedEncoder](c: U1): TypedAggregate[T, U1] = litAggr(c) } -/** Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or - * a [[frameless.TypedColumn]]. - * - * Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - * - * @tparam T phantom type representing the dataset on which this columns is - * selected. When `T = A with B` the selection is on either A or B. - * @tparam U type of column - */ -abstract class AbstractTypedColumn[T, U] - (val expr: Expression) - (implicit val uencoder: TypedEncoder[U]) +/** + * Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or + * a [[frameless.TypedColumn]]. + * + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + * + * @tparam T phantom type representing the dataset on which this columns is + * selected. When `T = A with B` the selection is on either A or B. + * @tparam U type of column + */ +abstract class AbstractTypedColumn[T, U]( + val expr: Expression + )(implicit + val uencoder: TypedEncoder[U]) extends UntypedExpression[T] { self => type ThisType[A, B] <: AbstractTypedColumn[A, B] - /** A helper class to make to simplify working with Optional fields. - * - * {{{ - * val x: TypedColumn[Option[Int]] = _ - * x.opt.map(_*2) // This only compiles if the type of x is Option[X] (in this example X is of type Int) - * }}} - * - * @note Known issue: map() will NOT work when the applied function is a udf(). - * It will compile and then throw a runtime error. - **/ + /** + * A helper class to make to simplify working with Optional fields. + * + * {{{ + * val x: TypedColumn[Option[Int]] = _ + * x.opt.map(_*2) // This only compiles if the type of x is Option[X] (in this example X is of type Int) + * }}} + * + * @note Known issue: map() will NOT work when the applied function is a udf(). + * It will compile and then throw a runtime error. + */ trait Mapper[X] { - def map[G, OutputType[_,_]](u: ThisType[T, X] => OutputType[T,G]) - (implicit - ev: OutputType[T,G] <:< AbstractTypedColumn[T, G] + + def map[G, OutputType[_, _]]( + u: ThisType[T, X] => OutputType[T, G] + )(implicit + ev: OutputType[T, G] <:< AbstractTypedColumn[T, G] ): OutputType[T, Option[G]] = { - u(self.asInstanceOf[ThisType[T, X]]).asInstanceOf[OutputType[T, Option[G]]] + u(self.asInstanceOf[ThisType[T, X]]) + .asInstanceOf[OutputType[T, Option[G]]] } } - /** Makes it easier to work with Optional columns. It returns an instance of `Mapper[X]` - * where `X` is type of the unwrapped Optional. E.g., in the case of `Option[Long]`, - * `X` is of type Long. - * - * {{{ - * val x: TypedColumn[Option[Int]] = _ - * x.opt.map(_*2) - * }}} - * */ - def opt[X](implicit x: U <:< Option[X]): Mapper[X] = new Mapper[X] {} + /** + * Makes it easier to work with Optional columns. It returns an instance of `Mapper[X]` + * where `X` is type of the unwrapped Optional. E.g., in the case of `Option[Long]`, + * `X` is of type Long. + * + * {{{ + * val x: TypedColumn[Option[Int]] = _ + * x.opt.map(_*2) + * }}} + */ + def opt[X]( + implicit + x: U <:< Option[X] + ): Mapper[X] = new Mapper[X] {} /** Fall back to an untyped Column */ def untyped: Column = new Column(expr) - private def equalsTo[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = typed { + private def equalsTo[TT, W]( + other: ThisType[TT, U] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed { if (uencoder.nullable) EqualNullSafe(self.expr, other.expr) else EqualTo(self.expr, other.expr) } @@ -120,773 +162,1125 @@ abstract class AbstractTypedColumn[T, U] /** Creates a typed column of either TypedColumn or TypedAggregate. */ def lit[U1: TypedEncoder](c: U1): ThisType[T, U1] - /** Equality test. - * {{{ - * df.filter( df.col('a) === 1 ) - * }}} - * - * apache/spark - */ + /** + * Equality test. + * {{{ + * df.filter( df.col('a) === 1 ) + * }}} + * + * apache/spark + */ def ===(u: U): ThisType[T, Boolean] = equalsTo(lit(u)) - /** Equality test. - * {{{ - * df.filter( df.col('a) === df.col('b) ) - * }}} - * - * apache/spark - */ - def ===[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Equality test. + * {{{ + * df.filter( df.col('a) === df.col('b) ) + * }}} + * + * apache/spark + */ + def ===[TT, W]( + other: ThisType[TT, U] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = equalsTo(other) - /** Inequality test. - * - * {{{ - * df.filter(df.col('a) =!= df.col('b)) - * }}} - * - * apache/spark - */ - def =!=[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Inequality test. + * + * {{{ + * df.filter(df.col('a) =!= df.col('b)) + * }}} + * + * apache/spark + */ + def =!=[TT, W]( + other: ThisType[TT, U] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(Not(equalsTo(other).expr)) - /** Inequality test. - * - * {{{ - * df.filter(df.col('a) =!= "a") - * }}} - * - * apache/spark - */ + /** + * Inequality test. + * + * {{{ + * df.filter(df.col('a) =!= "a") + * }}} + * + * apache/spark + */ def =!=(u: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(u)).expr)) - /** True if the current expression is an Option and it's None. - * - * apache/spark - */ - def isNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] = + /** + * True if the current expression is an Option and it's None. + * + * apache/spark + */ + def isNone( + implicit + i0: U <:< Option[_] + ): ThisType[T, Boolean] = typed(IsNull(expr)) - /** True if the current expression is an Option and it's not None. - * - * apache/spark - */ - def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] = + /** + * True if the current expression is an Option and it's not None. + * + * apache/spark + */ + def isNotNone( + implicit + i0: U <:< Option[_] + ): ThisType[T, Boolean] = typed(IsNotNull(expr)) - /** True if the current expression is a fractional number and is not NaN. - * - * apache/spark - */ - def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] = + /** + * True if the current expression is a fractional number and is not NaN. + * + * apache/spark + */ + def isNaN( + implicit + n: CatalystNaN[U] + ): ThisType[T, Boolean] = typed(self.untyped.isNaN) /** - * True if the value for this optional column `exists` as expected - * (see `Option.exists`). - * - * {{{ - * df.col('opt).isSome(_ === someOtherCol) - * }}} - */ - def isSome[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, false) + * True if the value for this optional column `exists` as expected + * (see `Option.exists`). + * + * {{{ + * df.col('opt).isSome(_ === someOtherCol) + * }}} + */ + def isSome[V]( + exists: ThisType[T, V] => ThisType[T, Boolean] + )(implicit + i0: U <:< Option[V] + ): ThisType[T, Boolean] = someOr[V](exists, false) /** - * True if the value for this optional column `exists` as expected, - * or is `None`. (see `Option.forall`). - * - * {{{ - * df.col('opt).isSomeOrNone(_ === someOtherCol) - * }}} - */ - def isSomeOrNone[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, true) - - private def someOr[V](exists: ThisType[T, V] => ThisType[T, Boolean], default: Boolean)(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = { + * True if the value for this optional column `exists` as expected, + * or is `None`. (see `Option.forall`). + * + * {{{ + * df.col('opt).isSomeOrNone(_ === someOtherCol) + * }}} + */ + def isSomeOrNone[V]( + exists: ThisType[T, V] => ThisType[T, Boolean] + )(implicit + i0: U <:< Option[V] + ): ThisType[T, Boolean] = someOr[V](exists, true) + + private def someOr[V]( + exists: ThisType[T, V] => ThisType[T, Boolean], + default: Boolean + )(implicit + i0: U <:< Option[V] + ): ThisType[T, Boolean] = { val defaultExpr = if (default) Literal.TrueLiteral else Literal.FalseLiteral typed(Coalesce(Seq(opt(i0).map(exists).expr, defaultExpr))) } - /** Convert an Optional column by providing a default value. - * - * {{{ - * df(df('opt).getOrElse(df('defaultValue))) - * }}} - */ - def getOrElse[TT, W, Out](default: ThisType[TT, Out])(implicit i0: U =:= Option[Out], i1: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Convert an Optional column by providing a default value. + * + * {{{ + * df(df('opt).getOrElse(df('defaultValue))) + * }}} + */ + def getOrElse[TT, W, Out]( + default: ThisType[TT, Out] + )(implicit + i0: U =:= Option[Out], + i1: With.Aux[T, TT, W] + ): ThisType[W, Out] = typed(Coalesce(Seq(expr, default.expr)))(default.uencoder) - /** Convert an Optional column by providing a default value. - * - * {{{ - * df( df('opt).getOrElse(defaultConstant) ) - * }}} - */ - def getOrElse[Out: TypedEncoder](default: Out)(implicit i0: U =:= Option[Out]): ThisType[T, Out] = + /** + * Convert an Optional column by providing a default value. + * + * {{{ + * df( df('opt).getOrElse(defaultConstant) ) + * }}} + */ + def getOrElse[Out: TypedEncoder]( + default: Out + )(implicit + i0: U =:= Option[Out] + ): ThisType[T, Out] = getOrElse(lit[Out](default)) - /** Sum of this expression and another expression. - * - * {{{ - * // The following selects the sum of a person's height and weight. - * people.select( people.col('height) plus people.col('weight) ) - * }}} - * - * apache/spark - */ - def plus[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Sum of this expression and another expression. + * + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people.col('height) plus people.col('weight) ) + * }}} + * + * apache/spark + */ + def plus[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.plus(other.untyped)) - /** Sum of this expression and another expression. - * {{{ - * // The following selects the sum of a person's height and weight. - * people.select( people.col('height) + people.col('weight) ) - * }}} - * - * apache/spark - */ - def +[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people.col('height) + people.col('weight) ) + * }}} + * + * apache/spark + */ + def +[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = plus(other) - /** Sum of this expression (column) with a constant. - * {{{ - * // The following selects the sum of a person's height and weight. - * people.select( people('height) + 2 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def +(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Sum of this expression (column) with a constant. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people('height) + 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def +( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.plus(u)) /** - * Inversion of boolean expression, i.e. NOT. - * {{{ - * // Select rows that are not active (isActive === false) - * df.filter( !df('isActive) ) - * }}} - * - * apache/spark - */ - def unary_!(implicit i0: U <:< Boolean): ThisType[T, Boolean] = + * Inversion of boolean expression, i.e. NOT. + * {{{ + * // Select rows that are not active (isActive === false) + * df.filter( !df('isActive) ) + * }}} + * + * apache/spark + */ + def unary_!( + implicit + i0: U <:< Boolean + ): ThisType[T, Boolean] = typed(!untyped) - /** Unary minus, i.e. negate the expression. - * {{{ - * // Select the amount column and negates all values. - * df.select( -df('amount) ) - * }}} - * - * apache/spark - */ - def unary_-(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * df.select( -df('amount) ) + * }}} + * + * apache/spark + */ + def unary_-( + implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(-self.untyped) - /** Subtraction. Subtract the other expression from this expression. - * {{{ - * // The following selects the difference between people's height and their weight. - * people.select( people.col('height) minus people.col('weight) ) - * }}} - * - * apache/spark - */ - def minus[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people.col('height) minus people.col('weight) ) + * }}} + * + * apache/spark + */ + def minus[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.minus(other.untyped)) - /** Subtraction. Subtract the other expression from this expression. - * {{{ - * // The following selects the difference between people's height and their weight. - * people.select( people.col('height) - people.col('weight) ) - * }}} - * - * apache/spark - */ - def -[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people.col('height) - people.col('weight) ) + * }}} + * + * apache/spark + */ + def -[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = minus(other) - /** Subtraction. Subtract the other expression from this expression. - * {{{ - * // The following selects the difference between people's height and their weight. - * people.select( people('height) - 1 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def -(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people('height) - 1 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def -( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.minus(u)) - /** Multiplication of this expression and another expression. - * {{{ - * // The following multiplies a person's height by their weight. - * people.select( people.col('height) multiply people.col('weight) ) - * }}} - * - * apache/spark - */ - def multiply[TT, W] - (other: ThisType[TT, U]) - (implicit + /** + * Multiplication of this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people.col('height) multiply people.col('weight) ) + * }}} + * + * apache/spark + */ + def multiply[TT, W]( + other: ThisType[TT, U] + )(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W], t: ClassTag[U] ): ThisType[W, U] = typed { - if (t.runtimeClass == BigDecimal(0).getClass) { - // That's apparently the only way to get sound multiplication. - // See https://issues.apache.org/jira/browse/SPARK-22036 - val dt = DecimalType(20, 14) - self.untyped.cast(dt).multiply(other.untyped.cast(dt)) - } else { - self.untyped.multiply(other.untyped) - } + if (t.runtimeClass == BigDecimal(0).getClass) { + // That's apparently the only way to get sound multiplication. + // See https://issues.apache.org/jira/browse/SPARK-22036 + val dt = DecimalType(20, 14) + self.untyped.cast(dt).multiply(other.untyped.cast(dt)) + } else { + self.untyped.multiply(other.untyped) } + } - /** Multiplication of this expression and another expression. - * {{{ - * // The following multiplies a person's height by their weight. - * people.select( people.col('height) * people.col('weight) ) - * }}} - * - * apache/spark - */ - def *[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W], t: ClassTag[U]): ThisType[W, U] = + /** + * Multiplication of this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people.col('height) * people.col('weight) ) + * }}} + * + * apache/spark + */ + def *[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W], + t: ClassTag[U] + ): ThisType[W, U] = multiply(other) - /** Multiplication of this expression a constant. - * {{{ - * // The following multiplies a person's height by their weight. - * people.select( people.col('height) * people.col('weight) ) - * }}} - * - * apache/spark - */ - def *(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Multiplication of this expression a constant. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people.col('height) * people.col('weight) ) + * }}} + * + * apache/spark + */ + def *( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.multiply(u)) - /** Modulo (a.k.a. remainder) expression. - * - * apache/spark - */ - def mod[Out: TypedEncoder, TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Modulo (a.k.a. remainder) expression. + * + * apache/spark + */ + def mod[Out: TypedEncoder, TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Out] = typed(self.untyped.mod(other.untyped)) - /** Modulo (a.k.a. remainder) expression. - * - * apache/spark - */ - def %[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Modulo (a.k.a. remainder) expression. + * + * apache/spark + */ + def %[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = mod(other) - /** Modulo (a.k.a. remainder) expression. - * - * apache/spark - */ - def %(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Modulo (a.k.a. remainder) expression. + * + * apache/spark + */ + def %( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.mod(u)) - /** Division this expression by another expression. - * {{{ - * // The following divides a person's height by their weight. - * people.select( people('height) / people('weight) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def divide[Out: TypedEncoder, TT, W](other: ThisType[TT, U])(implicit n: CatalystDivisible[U, Out], w: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people('height) / people('weight) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def divide[Out: TypedEncoder, TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystDivisible[U, Out], + w: With.Aux[T, TT, W] + ): ThisType[W, Out] = typed(self.untyped.divide(other.untyped)) - /** Division this expression by another expression. - * {{{ - * // The following divides a person's height by their weight. - * people.select( people('height) / people('weight) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def /[Out, TT, W](other: ThisType[TT, U])(implicit n: CatalystDivisible[U, Out], e: TypedEncoder[Out], w: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people('height) / people('weight) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def /[Out, TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystDivisible[U, Out], + e: TypedEncoder[Out], + w: With.Aux[T, TT, W] + ): ThisType[W, Out] = divide(other) - /** Division this expression by another expression. - * {{{ - * // The following divides a person's height by their weight. - * people.select( people('height) / 2 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def /(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, Double] = + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people('height) / 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def /( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, Double] = typed(self.untyped.divide(u)) - /** Returns a descending ordering used in sorting - * - * apache/spark - */ - def desc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] = + /** + * Returns a descending ordering used in sorting + * + * apache/spark + */ + def desc( + implicit + catalystOrdered: CatalystOrdered[U] + ): SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](untyped.desc) - /** Returns an ascending ordering used in sorting - * - * apache/spark - */ - def asc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] = + /** + * Returns an ascending ordering used in sorting + * + * apache/spark + */ + def asc( + implicit + catalystOrdered: CatalystOrdered[U] + ): SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](untyped.asc) - /** Bitwise AND this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseAND (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseAND(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise AND this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseAND (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseAND( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = typed(self.untyped.bitwiseAND(u)) - /** Bitwise AND this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseAND (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseAND[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise AND this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseAND (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseAND[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.bitwiseAND(other.untyped)) - /** Bitwise AND this expression and another expression (of same type). - * {{{ - * df.select(df.col('colA).cast[Int] & -1) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def &(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise AND this expression and another expression (of same type). + * {{{ + * df.select(df.col('colA).cast[Int] & -1) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def &( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = bitwiseAND(u) - /** Bitwise AND this expression and another expression. - * {{{ - * df.select(df.col('colA) & (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def &[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise AND this expression and another expression. + * {{{ + * df.select(df.col('colA) & (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def &[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = bitwiseAND(other) - /** Bitwise OR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseOR (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise OR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseOR (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseOR( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = typed(self.untyped.bitwiseOR(u)) - /** Bitwise OR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseOR (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def bitwiseOR[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise OR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseOR (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def bitwiseOR[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.bitwiseOR(other.untyped)) - /** Bitwise OR this expression and another expression (of same type). - * {{{ - * df.select(df.col('colA).cast[Long] | 1L) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def |(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise OR this expression and another expression (of same type). + * {{{ + * df.select(df.col('colA).cast[Long] | 1L) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def |( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = bitwiseOR(u) - /** Bitwise OR this expression and another expression. - * {{{ - * df.select(df.col('colA) | (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def |[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise OR this expression and another expression. + * {{{ + * df.select(df.col('colA) | (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def |[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = bitwiseOR(other) - /** Bitwise XOR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseXOR (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseXOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise XOR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseXOR (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseXOR( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = typed(self.untyped.bitwiseXOR(u)) - /** Bitwise XOR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseXOR (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def bitwiseXOR[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise XOR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseXOR (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def bitwiseXOR[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.bitwiseXOR(other.untyped)) - /** Bitwise XOR this expression and another expression (of same type). - * {{{ - * df.select(df.col('colA).cast[Long] ^ 1L) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def ^(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise XOR this expression and another expression (of same type). + * {{{ + * df.select(df.col('colA).cast[Long] ^ 1L) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def ^( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = bitwiseXOR(u) - /** Bitwise XOR this expression and another expression. - * {{{ - * df.select(df.col('colA) ^ (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def ^[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise XOR this expression and another expression. + * {{{ + * df.select(df.col('colA) ^ (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def ^[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = bitwiseXOR(other) - /** Casts the column to a different type. - * {{{ - * df.select(df('a).cast[Int]) - * }}} - */ - def cast[A: TypedEncoder](implicit c: CatalystCast[U, A]): ThisType[T, A] = + /** + * Casts the column to a different type. + * {{{ + * df.select(df('a).cast[Int]) + * }}} + */ + def cast[A: TypedEncoder]( + implicit + c: CatalystCast[U, A] + ): ThisType[T, A] = typed(self.untyped.cast(TypedEncoder[A].catalystRepr)) /** - * An expression that returns a substring - * {{{ - * df.select(df('a).substr(0, 5)) - * }}} - * - * @param startPos starting position - * @param len length of the substring - */ - def substr(startPos: Int, len: Int)(implicit ev: U =:= String): ThisType[T, String] = + * An expression that returns a substring + * {{{ + * df.select(df('a).substr(0, 5)) + * }}} + * + * @param startPos starting position + * @param len length of the substring + */ + def substr( + startPos: Int, + len: Int + )(implicit + ev: U =:= String + ): ThisType[T, String] = typed(self.untyped.substr(startPos, len)) /** - * An expression that returns a substring - * {{{ - * df.select(df('a).substr(df('b), df('c))) - * }}} - * - * @param startPos expression for the starting position - * @param len expression for the length of the substring - */ - def substr[TT1, TT2, W1, W2](startPos: ThisType[TT1, Int], len: ThisType[TT2, Int]) - (implicit - ev: U =:= String, - w1: With.Aux[T, TT1, W1], - w2: With.Aux[W1, TT2, W2]): ThisType[W2, String] = + * An expression that returns a substring + * {{{ + * df.select(df('a).substr(df('b), df('c))) + * }}} + * + * @param startPos expression for the starting position + * @param len expression for the length of the substring + */ + def substr[TT1, TT2, W1, W2]( + startPos: ThisType[TT1, Int], + len: ThisType[TT2, Int] + )(implicit + ev: U =:= String, + w1: With.Aux[T, TT1, W1], + w2: With.Aux[W1, TT2, W2] + ): ThisType[W2, String] = typed(self.untyped.substr(startPos.untyped, len.untyped)) - /** SQL like expression. Returns a boolean column based on a SQL LIKE match. - * {{{ - * val ds = TypedDataset.create(X2("foo", "bar") :: Nil) - * // true - * ds.select(ds('a).like("foo")) - * - * // Selected column has value "bar" - * ds.select(when(ds('a).like("f"), ds('a)).otherwise(ds('b)) - * }}} - * apache/spark - */ - def like(literal: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * SQL like expression. Returns a boolean column based on a SQL LIKE match. + * {{{ + * val ds = TypedDataset.create(X2("foo", "bar") :: Nil) + * // true + * ds.select(ds('a).like("foo")) + * + * // Selected column has value "bar" + * ds.select(when(ds('a).like("f"), ds('a)).otherwise(ds('b)) + * }}} + * apache/spark + */ + def like( + literal: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.like(literal)) - /** SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match. - * {{{ - * val ds = TypedDataset.create(X1("foo") :: Nil) - * // true - * ds.select(ds('a).rlike("foo")) - * - * // true - * ds.select(ds('a).rlike(".*)) - * }}} - * apache/spark - */ - def rlike(literal: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match. + * {{{ + * val ds = TypedDataset.create(X1("foo") :: Nil) + * // true + * ds.select(ds('a).rlike("foo")) + * + * // true + * ds.select(ds('a).rlike(".*)) + * }}} + * apache/spark + */ + def rlike( + literal: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.rlike(literal)) - /** String contains another string literal. - * {{{ - * df.filter ( df.col('a).contains("foo") ) - * }}} - * - * @param other a string that is being tested against. - * apache/spark - */ - def contains(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * String contains another string literal. + * {{{ + * df.filter ( df.col('a).contains("foo") ) + * }}} + * + * @param other a string that is being tested against. + * apache/spark + */ + def contains( + other: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.contains(other)) - /** String contains. - * {{{ - * df.filter ( df.col('a).contains(df.col('b) ) - * }}} - * - * @param other a column which values is used as a string that is being tested against. - * apache/spark - */ - def contains[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * String contains. + * {{{ + * df.filter ( df.col('a).contains(df.col('b) ) + * }}} + * + * @param other a column which values is used as a string that is being tested against. + * apache/spark + */ + def contains[TT, W]( + other: ThisType[TT, U] + )(implicit + ev: U =:= String, + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.contains(other.untyped)) - /** String starts with another string literal. - * {{{ - * df.filter ( df.col('a).startsWith("foo") - * }}} - * - * @param other a prefix that is being tested against. - * apache/spark - */ - def startsWith(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * String starts with another string literal. + * {{{ + * df.filter ( df.col('a).startsWith("foo") + * }}} + * + * @param other a prefix that is being tested against. + * apache/spark + */ + def startsWith( + other: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.startsWith(other)) - /** String starts with. - * {{{ - * df.filter ( df.col('a).startsWith(df.col('b)) - * }}} - * - * @param other a column which values is used as a prefix that is being tested against. - * apache/spark - */ - def startsWith[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * String starts with. + * {{{ + * df.filter ( df.col('a).startsWith(df.col('b)) + * }}} + * + * @param other a column which values is used as a prefix that is being tested against. + * apache/spark + */ + def startsWith[TT, W]( + other: ThisType[TT, U] + )(implicit + ev: U =:= String, + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.startsWith(other.untyped)) - /** String ends with another string literal. - * {{{ - * df.filter ( df.col('a).endsWith("foo") - * }}} - * - * @param other a suffix that is being tested against. - * apache/spark - */ - def endsWith(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * String ends with another string literal. + * {{{ + * df.filter ( df.col('a).endsWith("foo") + * }}} + * + * @param other a suffix that is being tested against. + * apache/spark + */ + def endsWith( + other: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.endsWith(other)) - /** String ends with. - * {{{ - * df.filter ( df.col('a).endsWith(df.col('b)) - * }}} - * - * @param other a column which values is used as a suffix that is being tested against. - * apache/spark - */ - def endsWith[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * String ends with. + * {{{ + * df.filter ( df.col('a).endsWith(df.col('b)) + * }}} + * + * @param other a column which values is used as a suffix that is being tested against. + * apache/spark + */ + def endsWith[TT, W]( + other: ThisType[TT, U] + )(implicit + ev: U =:= String, + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.endsWith(other.untyped)) - /** Boolean AND. - * {{{ - * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) ) - * }}} - */ - def and[TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean AND. + * {{{ + * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) ) + * }}} + */ + def and[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.and(other.untyped)) - /** Boolean AND. - * {{{ - * df.filter ( df.col('a) === 1 && df.col('b) > 5) - * }}} - */ - def && [TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean AND. + * {{{ + * df.filter ( df.col('a) === 1 && df.col('b) > 5) + * }}} + */ + def &&[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = and(other) - /** Boolean OR. - * {{{ - * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) ) - * }}} - */ - def or[TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean OR. + * {{{ + * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) ) + * }}} + */ + def or[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.or(other.untyped)) - /** Boolean OR. - * {{{ - * df.filter ( df.col('a) === 1 || df.col('b) > 5) - * }}} - */ - def || [TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean OR. + * {{{ + * df.filter ( df.col('a) === 1 || df.col('b) > 5) + * }}} + */ + def ||[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = or(other) - /** Less than. - * - * {{{ - * // The following selects people younger than the maxAge column. - * df.select(df('age) < df('maxAge) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def <[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Less than. + * + * {{{ + * // The following selects people younger than the maxAge column. + * df.select(df('age) < df('maxAge) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def <[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped < other.untyped) - /** Less than or equal to. - * - * {{{ - * // The following selects people younger or equal than the maxAge column. - * df.select(df('age) <= df('maxAge) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def <=[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Less than or equal to. + * + * {{{ + * // The following selects people younger or equal than the maxAge column. + * df.select(df('age) <= df('maxAge) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def <=[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped <= other.untyped) - /** Greater than. - * {{{ - * // The following selects people older than the maxAge column. - * df.select( df('age) > df('maxAge) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def >[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Greater than. + * {{{ + * // The following selects people older than the maxAge column. + * df.select( df('age) > df('maxAge) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def >[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped > other.untyped) - /** Greater than or equal. - * {{{ - * // The following selects people older or equal than the maxAge column. - * df.select( df('age) >= df('maxAge) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def >=[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Greater than or equal. + * {{{ + * // The following selects people older or equal than the maxAge column. + * df.select( df('age) >= df('maxAge) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def >=[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped >= other.untyped) - /** Less than. - * {{{ - * // The following selects people younger than 21. - * df.select( df('age) < 21 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def <(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Less than. + * {{{ + * // The following selects people younger than 21. + * df.select( df('age) < 21 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def <( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped < lit(u)(self.uencoder).untyped) - /** Less than or equal to. - * {{{ - * // The following selects people younger than 22. - * df.select( df('age) <= 2 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def <=(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Less than or equal to. + * {{{ + * // The following selects people younger than 22. + * df.select( df('age) <= 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def <=( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped <= lit(u)(self.uencoder).untyped) - /** Greater than. - * {{{ - * // The following selects people older than 21. - * df.select( df('age) > 21 ) - * }}} - * - * @param u another column of the same type - * apache/spark - */ - def >(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Greater than. + * {{{ + * // The following selects people older than 21. + * df.select( df('age) > 21 ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped > lit(u)(self.uencoder).untyped) - /** Greater than or equal. - * {{{ - * // The following selects people older than 20. - * df.select( df('age) >= 21 ) - * }}} - * - * @param u another column of the same type - * apache/spark - */ - def >=(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Greater than or equal. + * {{{ + * // The following selects people older than 20. + * df.select( df('age) >= 21 ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >=( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped >= lit(u)(self.uencoder).untyped) /** - * Returns true if the value of this column is contained in of the arguments. - * {{{ - * // The following selects people with age 15, 20, or 30. - * df.select( df('age).isin(15, 20, 30) ) - * }}} - * - * @param values are constants of the same type - * apache/spark - */ - def isin(values: U*)(implicit e: CatalystIsin[U]): ThisType[T, Boolean] = - typed(self.untyped.isin(values:_*)) - - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @param lowerBound a constant of the same type - * @param upperBound a constant of the same type - * apache/spark - */ - def between(lowerBound: U, upperBound: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = - typed(self.untyped.between(lit(lowerBound)(self.uencoder).untyped, lit(upperBound)(self.uencoder).untyped)) - - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @param lowerBound another column of the same type - * @param upperBound another column of the same type - * apache/spark - */ - def between[TT1, TT2, W1, W2](lowerBound: ThisType[TT1, U], upperBound: ThisType[TT2, U]) - (implicit + * Returns true if the value of this column is contained in of the arguments. + * {{{ + * // The following selects people with age 15, 20, or 30. + * df.select( df('age).isin(15, 20, 30) ) + * }}} + * + * @param values are constants of the same type + * apache/spark + */ + def isin( + values: U* + )(implicit + e: CatalystIsin[U] + ): ThisType[T, Boolean] = + typed(self.untyped.isin(values: _*)) + + /** + * True if the current column is between the lower bound and upper bound, inclusive. + * + * @param lowerBound a constant of the same type + * @param upperBound a constant of the same type + * apache/spark + */ + def between( + lowerBound: U, + upperBound: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = + typed( + self.untyped.between( + lit(lowerBound)(self.uencoder).untyped, + lit(upperBound)(self.uencoder).untyped + ) + ) + + /** + * True if the current column is between the lower bound and upper bound, inclusive. + * + * @param lowerBound another column of the same type + * @param upperBound another column of the same type + * apache/spark + */ + def between[TT1, TT2, W1, W2]( + lowerBound: ThisType[TT1, U], + upperBound: ThisType[TT2, U] + )(implicit i0: CatalystOrdered[U], w0: With.Aux[T, TT1, W1], w1: With.Aux[TT2, W1, W2] ): ThisType[W2, Boolean] = - typed(self.untyped.between(lowerBound.untyped, upperBound.untyped)) + typed(self.untyped.between(lowerBound.untyped, upperBound.untyped)) /** - * Returns a nested column matching the field `symbol`. - * - * @param symbol the field symbol - * @tparam V the type of the nested field - */ - def field[V](symbol: Witness.Lt[Symbol])(implicit + * Returns a nested column matching the field `symbol`. + * + * @param symbol the field symbol + * @tparam V the type of the nested field + */ + def field[V]( + symbol: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[U, symbol.T, V], i1: TypedEncoder[V] - ): ThisType[T, V] = + ): ThisType[T, V] = typed(self.untyped.getField(symbol.value.name)) } - -sealed class SortedTypedColumn[T, U](val expr: Expression)( - implicit - val uencoder: TypedEncoder[U] -) extends UntypedExpression[T] { - - def this(column: Column)(implicit e: TypedEncoder[U]) = { +sealed class SortedTypedColumn[T, U]( + val expr: Expression + )(implicit + val uencoder: TypedEncoder[U]) + extends UntypedExpression[T] { + + def this( + column: Column + )(implicit + e: TypedEncoder[U] + ) = { this(FramelessInternals.expr(column)) } @@ -894,16 +1288,24 @@ sealed class SortedTypedColumn[T, U](val expr: Expression)( } object SortedTypedColumn { - implicit def defaultAscending[T, U : CatalystOrdered](typedColumn: TypedColumn[T, U]): SortedTypedColumn[T, U] = + + implicit def defaultAscending[T, U: CatalystOrdered]( + typedColumn: TypedColumn[T, U] + ): SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](typedColumn.untyped.asc)(typedColumn.uencoder) - object defaultAscendingPoly extends Poly1 { - implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c)) - implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity) - } + object defaultAscendingPoly extends Poly1 { + + implicit def caseTypedColumn[T, U: CatalystOrdered] = + at[TypedColumn[T, U]](c => defaultAscending(c)) + + implicit def caseTypeSortedColumn[T, U] = + at[SortedTypedColumn[T, U]](identity) + } } object TypedColumn { + /** Evidence that type `T` has column `K` with type `V`. */ @implicitNotFound(msg = "No column ${K} of type ${V} in ${T}") trait Exists[T, K, V] @@ -912,37 +1314,46 @@ object TypedColumn { trait ExistsMany[T, K <: HList, V] object ExistsMany { - implicit def deriveCons[T, KH, KT <: HList, V0, V1] - (implicit + + implicit def deriveCons[T, KH, KT <: HList, V0, V1]( + implicit head: Exists[T, KH, V0], tail: ExistsMany[V0, KT, V1] ): ExistsMany[T, KH :: KT, V1] = - new ExistsMany[T, KH :: KT, V1] {} + new ExistsMany[T, KH :: KT, V1] {} - implicit def deriveHNil[T, K, V](implicit head: Exists[T, K, V]): ExistsMany[T, K :: HNil, V] = + implicit def deriveHNil[T, K, V]( + implicit + head: Exists[T, K, V] + ): ExistsMany[T, K :: HNil, V] = new ExistsMany[T, K :: HNil, V] {} } object Exists { - def apply[T, V](column: Witness)(implicit e: Exists[T, column.T, V]): Exists[T, column.T, V] = e - implicit def deriveRecord[T, H <: HList, K, V] - (implicit + def apply[T, V]( + column: Witness + )(implicit + e: Exists[T, column.T, V] + ): Exists[T, column.T, V] = e + + implicit def deriveRecord[T, H <: HList, K, V]( + implicit i0: LabelledGeneric.Aux[T, H], i1: Selector.Aux[H, K, V] ): Exists[T, K, V] = new Exists[T, K, V] {} } /** - * {{{ - * import frameless.TypedColumn - * - * case class Foo(id: Int, bar: String) - * - * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar } - * val colid = TypedColumn[Foo, Int](_.id) - * }}} - */ + * {{{ + * import frameless.TypedColumn + * + * case class Foo(id: Int, bar: String) + * + * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar } + * val colid = TypedColumn[Foo, Int](_.id) + * }}} + */ def apply[T, U](x: T => U): TypedColumn[T, U] = macro TypedColumnMacroImpl.applyImpl[T, U] diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index add2170b2..091272c8b 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -4,36 +4,52 @@ import java.util import frameless.functions.CatalystExplodableCollection import frameless.ops._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, FramelessInternals, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.{ Column, DataFrame, Dataset, SparkSession } +import org.apache.spark.sql.catalyst.expressions.{ + Attribute, + AttributeReference, + Literal +} +import org.apache.spark.sql.catalyst.plans.logical.{ Join, JoinHint } import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.types.StructType import shapeless._ import shapeless.labelled.FieldType -import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler} -import shapeless.ops.record.{Keys, Modifier, Remover, Values} +import shapeless.ops.hlist.{ + Diff, + IsHCons, + Mapper, + Prepend, + ToTraversable, + Tupler +} +import shapeless.ops.record.{ Keys, Modifier, Remover, Values } import scala.language.experimental.macros -/** [[TypedDataset]] is a safer interface for working with `Dataset`. - * - * NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you - * know what you are doing. - * - * Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val encoder: TypedEncoder[T]) +/** + * [[TypedDataset]] is a safer interface for working with `Dataset`. + * + * NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you + * know what you are doing. + * + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +class TypedDataset[T] protected[frameless] ( + val dataset: Dataset[T] + )(implicit + val encoder: TypedEncoder[T]) extends TypedDatasetForwarded[T] { self => private implicit val spark: SparkSession = dataset.sparkSession - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A](ca: TypedAggregate[T, A]): TypedDataset[A] = { implicit val ea = ca.uencoder val tuple1: TypedDataset[Tuple1[A]] = aggMany(ca) @@ -42,10 +58,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val TypedEncoder[A].catalystRepr match { case StructType(_) => // if column is struct, we use all its fields - val df = tuple1 - .dataset - .selectExpr("_1.*") - .as[A](TypedExpressionEncoder[A]) + val df = + tuple1.dataset.selectExpr("_1.*").as[A](TypedExpressionEncoder[A]) TypedDataset.create(df) case other => @@ -54,52 +68,59 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A, B]( - ca: TypedAggregate[T, A], - cb: TypedAggregate[T, B] - ): TypedDataset[(A, B)] = { + ca: TypedAggregate[T, A], + cb: TypedAggregate[T, B] + ): TypedDataset[(A, B)] = { implicit val (ea, eb) = (ca.uencoder, cb.uencoder) aggMany(ca, cb) } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A, B, C]( - ca: TypedAggregate[T, A], - cb: TypedAggregate[T, B], - cc: TypedAggregate[T, C] - ): TypedDataset[(A, B, C)] = { + ca: TypedAggregate[T, A], + cb: TypedAggregate[T, B], + cc: TypedAggregate[T, C] + ): TypedDataset[(A, B, C)] = { implicit val (ea, eb, ec) = (ca.uencoder, cb.uencoder, cc.uencoder) aggMany(ca, cb, cc) } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A, B, C, D]( - ca: TypedAggregate[T, A], - cb: TypedAggregate[T, B], - cc: TypedAggregate[T, C], - cd: TypedAggregate[T, D] - ): TypedDataset[(A, B, C, D)] = { - implicit val (ea, eb, ec, ed) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) + ca: TypedAggregate[T, A], + cb: TypedAggregate[T, B], + cc: TypedAggregate[T, C], + cd: TypedAggregate[T, D] + ): TypedDataset[(A, B, C, D)] = { + implicit val (ea, eb, ec, ed) = + (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) aggMany(ca, cb, cc, cd) } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ object aggMany extends ProductArgs { - def applyProduct[U <: HList, Out0 <: HList, Out](columns: U) - (implicit + + def applyProduct[U <: HList, Out0 <: HList, Out]( + columns: U + )(implicit i0: AggregateTypes.Aux[T, U, Out0], i1: ToTraversable.Aux[U, List, UntypedExpression[T]], i2: Tupler.Aux[Out0, Out], @@ -109,7 +130,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val val underlyingColumns = columns.toList[UntypedExpression[T]] val cols: Seq[Column] = for { (c, i) <- columns.toList[UntypedExpression[T]].zipWithIndex - } yield new Column(c.expr).as(s"_${i+1}") + } yield new Column(c.expr).as(s"_${i + 1}") // Workaround to SPARK-20346. One alternative is to allow the result to be Vector(null) for empty DataFrames. // Another one would be to return an Option. @@ -117,129 +138,163 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val for { (c, i) <- underlyingColumns.zipWithIndex if !c.uencoder.nullable - } yield s"_${i+1} is not null" - ).mkString(" or ") + } yield s"_${i + 1} is not null" + ).mkString(" or ") - val selected = dataset.toDF().agg(cols.head, cols.tail:_*).as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](if (filterStr.isEmpty) selected else selected.filter(filterStr)) + val selected = dataset + .toDF() + .agg(cols.head, cols.tail: _*) + .as[Out](TypedExpressionEncoder[Out]) + TypedDataset.create[Out]( + if (filterStr.isEmpty) selected else selected.filter(filterStr) + ) } } /** Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. */ - def as[U]()(implicit as: As[T, U]): TypedDataset[U] = { + def as[U]( + )(implicit + as: As[T, U] + ): TypedDataset[U] = { implicit val uencoder = as.encoder TypedDataset.create(dataset.as[U](TypedExpressionEncoder[U])) } - /** Returns a checkpointed version of this [[TypedDataset]]. Checkpointing can be used to truncate the - * logical plan of this Dataset, which is especially useful in iterative algorithms where the - * plan may grow exponentially. It will be saved to files inside the checkpoint - * directory set with `SparkContext#setCheckpointDir`. - * - * Differs from `Dataset#checkpoint` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def checkpoint[F[_]](eager: Boolean)(implicit F: SparkDelay[F]): F[TypedDataset[T]] = + /** + * Returns a checkpointed version of this [[TypedDataset]]. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. + * + * Differs from `Dataset#checkpoint` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def checkpoint[F[_]]( + eager: Boolean + )(implicit + F: SparkDelay[F] + ): F[TypedDataset[T]] = F.delay(TypedDataset.create[T](dataset.checkpoint(eager))) - /** Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. - * Unlike `as` the projection U may include a subset of the columns of T and the column names and types must agree. - * - * {{{ - * case class Foo(i: Int, j: String) - * case class Bar(j: String) - * - * val t: TypedDataset[Foo] = ... - * val b: TypedDataset[Bar] = t.project[Bar] - * - * case class BarErr(e: String) - * // The following does not compile because `Foo` doesn't have a field with name `e` - * val e: TypedDataset[BarErr] = t.project[BarErr] - * }}} - */ - def project[U](implicit projector: SmartProject[T,U]): TypedDataset[U] = projector.apply(this) - - /** Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] - * combined. - * - * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analogous to `UNION ALL` in SQL. - * - * Differs from `Dataset#union` by aligning fields if possible. - * It will not compile if `Datasets` have not compatible schema. - * - * Example: - * {{{ - * case class Foo(x: Int, y: Long) - * case class Bar(y: Long, x: Int) - * case class Faz(x: Int, y: Int, z: Int) - * - * foo: TypedDataset[Foo] = ... - * bar: TypedDataset[Bar] = ... - * faz: TypedDataset[Faz] = ... - * - * foo union bar: TypedDataset[Foo] - * foo union faz: TypedDataset[Foo] - * // won't compile, you need to reverse order, you can't project from less fields to more - * faz union foo - * - * }}} - * - * apache/spark - */ - def union[U: TypedEncoder](other: TypedDataset[U])(implicit projector: SmartProject[U, T]): TypedDataset[T] = + /** + * Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. + * Unlike `as` the projection U may include a subset of the columns of T and the column names and types must agree. + * + * {{{ + * case class Foo(i: Int, j: String) + * case class Bar(j: String) + * + * val t: TypedDataset[Foo] = ... + * val b: TypedDataset[Bar] = t.project[Bar] + * + * case class BarErr(e: String) + * // The following does not compile because `Foo` doesn't have a field with name `e` + * val e: TypedDataset[BarErr] = t.project[BarErr] + * }}} + */ + def project[U]( + implicit + projector: SmartProject[T, U] + ): TypedDataset[U] = projector.apply(this) + + /** + * Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. + * + * Differs from `Dataset#union` by aligning fields if possible. + * It will not compile if `Datasets` have not compatible schema. + * + * Example: + * {{{ + * case class Foo(x: Int, y: Long) + * case class Bar(y: Long, x: Int) + * case class Faz(x: Int, y: Int, z: Int) + * + * foo: TypedDataset[Foo] = ... + * bar: TypedDataset[Bar] = ... + * faz: TypedDataset[Faz] = ... + * + * foo union bar: TypedDataset[Foo] + * foo union faz: TypedDataset[Foo] + * // won't compile, you need to reverse order, you can't project from less fields to more + * faz union foo + * + * }}} + * + * apache/spark + */ + def union[U: TypedEncoder]( + other: TypedDataset[U] + )(implicit + projector: SmartProject[U, T] + ): TypedDataset[T] = TypedDataset.create(dataset.union(other.project[T].dataset)) - /** Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] - * combined. - * - * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analogous to `UNION ALL` in SQL. - * - * apache/spark - */ + /** + * Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. + * + * apache/spark + */ def union(other: TypedDataset[T]): TypedDataset[T] = { TypedDataset.create(dataset.union(other.dataset)) } - /** Returns the number of elements in the [[TypedDataset]]. - * - * Differs from `Dataset#count` by wrapping its result into an effect-suspending `F[_]`. - */ - def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] = + /** + * Returns the number of elements in the [[TypedDataset]]. + * + * Differs from `Dataset#count` by wrapping its result into an effect-suspending `F[_]`. + */ + def count[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Long] = F.delay(dataset.count()) - /** Returns `TypedColumn` of type `A` given its name (alias for `col`). - * - * {{{ - * tf('id) - * }}} - * - * It is statically checked that column with such name exists and has type `A`. - */ - def apply[A](column: Witness.Lt[Symbol]) - (implicit + /** + * Returns `TypedColumn` of type `A` given its name (alias for `col`). + * + * {{{ + * tf('id) + * }}} + * + * It is statically checked that column with such name exists and has type `A`. + */ + def apply[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = col(column) - /** Returns `TypedColumn` of type `A` given its name. - * - * {{{ - * tf.col('id) - * }}} - * - * It is statically checked that column with such name exists and has type `A`. - */ - def col[A](column: Witness.Lt[Symbol]) - (implicit + /** + * Returns `TypedColumn` of type `A` given its name. + * + * {{{ + * tf.col('id) + * }}} + * + * It is statically checked that column with such name exists and has type `A`. + */ + def col[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = - new TypedColumn[T, A](dataset(column.value.name).as[A](TypedExpressionEncoder[A])) + new TypedColumn[T, A]( + dataset(column.value.name).as[A](TypedExpressionEncoder[A]) + ) - /** Returns `TypedColumn` of type `A` given a lambda indicating the field. + /** + * Returns `TypedColumn` of type `A` given a lambda indicating the field. * * {{{ * td.col(_.id) @@ -250,12 +305,13 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val def col[A](x: Function1[T, A]): TypedColumn[T, A] = macro TypedColumnMacroImpl.applyImpl[T, A] - /** Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`. - * {{{ - * ts: TypedDataset[Foo] = ... - * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)] - * }}} - */ + /** + * Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`. + * {{{ + * ts: TypedDataset[Foo] = ... + * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)] + * }}} + */ def asCol: TypedColumn[T, T] = { val projectedColumn: Column = encoder.catalystRepr match { case StructType(_) => @@ -265,78 +321,98 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val case _ => dataset.col(dataset.columns.head) } - - new TypedColumn[T,T](projectedColumn) + + new TypedColumn[T, T](projectedColumn) } - /** References the entire `TypedDataset[T]` as a single column - * of type `TypedColumn[T,T]` so it can be used in a join operation. - * - * {{{ - * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) = - * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue) - * }}} - */ - def asJoinColValue(implicit i0: IsValueClass[T]): TypedColumn[T, T] = { + /** + * References the entire `TypedDataset[T]` as a single column + * of type `TypedColumn[T,T]` so it can be used in a join operation. + * + * {{{ + * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) = + * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue) + * }}} + */ + def asJoinColValue( + implicit + i0: IsValueClass[T] + ): TypedColumn[T, T] = { import _root_.frameless.syntax._ dataset.col("value").typedColumn } object colMany extends SingletonProductArgs { - def applyProduct[U <: HList, Out](columns: U) - (implicit + + def applyProduct[U <: HList, Out]( + columns: U + )(implicit i0: TypedColumn.ExistsMany[T, U, Out], i1: TypedEncoder[Out], i2: ToTraversable.Aux[U, List, Symbol] ): TypedColumn[T, Out] = { - val names = columns.toList[Symbol].map(_.name) - val colExpr = FramelessInternals.resolveExpr(dataset, names) - new TypedColumn[T, Out](colExpr) - } + val names = columns.toList[Symbol].map(_.name) + val colExpr = FramelessInternals.resolveExpr(dataset, names) + new TypedColumn[T, Out](colExpr) + } } - /** Right hand side disambiguation of `col` for join expressions. - * To be used when writting self-joins, noop in other circumstances. - * - * Note: In vanilla Spark, disambiguation in self-joins is acheaved using - * String based aliases, which is obviously unsafe. - */ - def colRight[A](column: Witness.Lt[Symbol]) - (implicit + /** + * Right hand side disambiguation of `col` for join expressions. + * To be used when writting self-joins, noop in other circumstances. + * + * Note: In vanilla Spark, disambiguation in self-joins is acheaved using + * String based aliases, which is obviously unsafe. + */ + def colRight[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = - new TypedColumn[T, A](FramelessInternals.DisambiguateRight(col(column).expr)) - - /** Left hand side disambiguation of `col` for join expressions. - * To be used when writting self-joins, noop in other circumstances. - * - * Note: In vanilla Spark, disambiguation in self-joins is acheaved using - * String based aliases, which is obviously unsafe. - */ - def colLeft[A](column: Witness.Lt[Symbol]) - (implicit + new TypedColumn[T, A]( + FramelessInternals.DisambiguateRight(col(column).expr) + ) + + /** + * Left hand side disambiguation of `col` for join expressions. + * To be used when writting self-joins, noop in other circumstances. + * + * Note: In vanilla Spark, disambiguation in self-joins is acheaved using + * String based aliases, which is obviously unsafe. + */ + def colLeft[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = - new TypedColumn[T, A](FramelessInternals.DisambiguateLeft(col(column).expr)) - - /** Returns a `Seq` that contains all the elements in this [[TypedDataset]]. - * - * Running this operation requires moving all the data into the application's driver process, and - * doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError. - * - * Differs from `Dataset#collect` by wrapping its result into an effect-suspending `F[_]`. - */ - def collect[F[_]]()(implicit F: SparkDelay[F]): F[Seq[T]] = + new TypedColumn[T, A](FramelessInternals.DisambiguateLeft(col(column).expr)) + + /** + * Returns a `Seq` that contains all the elements in this [[TypedDataset]]. + * + * Running this operation requires moving all the data into the application's driver process, and + * doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError. + * + * Differs from `Dataset#collect` by wrapping its result into an effect-suspending `F[_]`. + */ + def collect[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Seq[T]] = F.delay(dataset.collect().toSeq) - /** Optionally returns the first element in this [[TypedDataset]]. - * - * Differs from `Dataset#first` by wrapping its result into an `Option` and an effect-suspending `F[_]`. - */ - def firstOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] = + /** + * Optionally returns the first element in this [[TypedDataset]]. + * + * Differs from `Dataset#first` by wrapping its result into an `Option` and an effect-suspending `F[_]`. + */ + def firstOption[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Option[T]] = F.delay { try { Option(dataset.first()) @@ -345,354 +421,462 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } - /** Returns the first `num` elements of this [[TypedDataset]] as a `Seq`. - * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `num` can crash the driver process with OutOfMemoryError. - * - * Differs from `Dataset#take` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def take[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] = + /** + * Returns the first `num` elements of this [[TypedDataset]] as a `Seq`. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `num` can crash the driver process with OutOfMemoryError. + * + * Differs from `Dataset#take` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def take[F[_]]( + num: Int + )(implicit + F: SparkDelay[F] + ): F[Seq[T]] = F.delay(dataset.take(num).toSeq) - /** Return an iterator that contains all rows in this [[TypedDataset]]. - * - * The iterator will consume as much memory as the largest partition in this [[TypedDataset]]. - * - * NOTE: this results in multiple Spark jobs, and if the input [[TypedDataset]] is the result - * of a wide transformation (e.g. join with different partitioners), to avoid - * recomputing the input [[TypedDataset]] should be cached first. - * - * Differs from `Dataset#toLocalIterator()` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def toLocalIterator[F[_]]()(implicit F: SparkDelay[F]): F[util.Iterator[T]] = + /** + * Return an iterator that contains all rows in this [[TypedDataset]]. + * + * The iterator will consume as much memory as the largest partition in this [[TypedDataset]]. + * + * NOTE: this results in multiple Spark jobs, and if the input [[TypedDataset]] is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input [[TypedDataset]] should be cached first. + * + * Differs from `Dataset#toLocalIterator()` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def toLocalIterator[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[util.Iterator[T]] = F.delay(dataset.toLocalIterator()) - /** Alias for firstOption(). - */ - def headOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] = firstOption() + /** + * Alias for firstOption(). + */ + def headOption[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Option[T]] = firstOption() - /** Alias for take(). - */ - def head[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] = take(num) + /** + * Alias for take(). + */ + def head[F[_]]( + num: Int + )(implicit + F: SparkDelay[F] + ): F[Seq[T]] = take(num) // $COVERAGE-OFF$ - /** Alias for firstOption(). - */ - @deprecated("Method may throw exception. Use headOption or firstOption instead.", "0.5.0") + /** + * Alias for firstOption(). + */ + @deprecated( + "Method may throw exception. Use headOption or firstOption instead.", + "0.5.0" + ) def head: T = dataset.head() - /** Alias for firstOption(). - */ - @deprecated("Method may throw exception. Use headOption or firstOption instead.", "0.5.0") + /** + * Alias for firstOption(). + */ + @deprecated( + "Method may throw exception. Use headOption or firstOption instead.", + "0.5.0" + ) def first: T = dataset.head() // $COVERAGE-ONN$ - /** Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters - * will be truncated, and all cells will be aligned right. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows Number of rows to show - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right - * - * Differs from `Dataset#show` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def show[F[_]](numRows: Int = 20, truncate: Boolean = true)(implicit F: SparkDelay[F]): F[Unit] = + /** + * Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * Differs from `Dataset#show` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def show[F[_]]( + numRows: Int = 20, + truncate: Boolean = true + )(implicit + F: SparkDelay[F] + ): F[Unit] = F.delay(dataset.show(numRows, truncate)) - /** Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`. - * - * Differs from `TypedDatasetForward#filter` by taking a `TypedColumn[T, Boolean]` instead of a - * `T => Boolean`. Using a column expression instead of a regular function save one Spark → Scala - * deserialization which leads to better performance. - */ + /** + * Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`. + * + * Differs from `TypedDatasetForward#filter` by taking a `TypedColumn[T, Boolean]` instead of a + * `T => Boolean`. Using a column expression instead of a regular function save one Spark → Scala + * deserialization which leads to better performance. + */ def filter(column: TypedColumn[T, Boolean]): TypedDataset[T] = { - val filtered = dataset.toDF() - .filter(column.untyped) - .as[T](TypedExpressionEncoder[T]) + val filtered = + dataset.toDF().filter(column.untyped).as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](filtered) } - /** Runs `func` on each element of this [[TypedDataset]]. - * - * Differs from `Dataset#foreach` by wrapping its result into an effect-suspending `F[_]`. - */ - def foreach[F[_]](func: T => Unit)(implicit F: SparkDelay[F]): F[Unit] = + /** + * Runs `func` on each element of this [[TypedDataset]]. + * + * Differs from `Dataset#foreach` by wrapping its result into an effect-suspending `F[_]`. + */ + def foreach[F[_]]( + func: T => Unit + )(implicit + F: SparkDelay[F] + ): F[Unit] = F.delay(dataset.foreach(func)) - /** Runs `func` on each partition of this [[TypedDataset]]. - * - * Differs from `Dataset#foreachPartition` by wrapping its result into an effect-suspending `F[_]`. - */ - def foreachPartition[F[_]](func: Iterator[T] => Unit)(implicit F: SparkDelay[F]): F[Unit] = + /** + * Runs `func` on each partition of this [[TypedDataset]]. + * + * Differs from `Dataset#foreachPartition` by wrapping its result into an effect-suspending `F[_]`. + */ + def foreachPartition[F[_]]( + func: Iterator[T] => Unit + )(implicit + F: SparkDelay[F] + ): F[Unit] = F.delay(dataset.foreachPartition(func)) /** - * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified column, - * so we can run aggregation on it. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified column, + * so we can run aggregation on it. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def cube[K1]( - c1: TypedColumn[T, K1] - ): Cube1Ops[K1, T] = new Cube1Ops[K1, T](this, c1) - - /** - * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1] + ): Cube1Ops[K1, T] = new Cube1Ops[K1, T](this, c1) + + /** + * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def cube[K1, K2]( - c1: TypedColumn[T, K1], - c2: TypedColumn[T, K2] - ): Cube2Ops[K1, K2, T] = new Cube2Ops[K1, K2, T](this, c1, c2) - - /** - * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * {{{ - * case class MyClass(a: Int, b: Int, c: Int) - * val ds: TypedDataset[MyClass] - - * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = - * ds.cubeMany(ds('a), ds('b)).agg(count[MyClass]()) - * - * // original dataset: - * a b c - * 10 20 1 - * 15 25 2 - * - * // after aggregation: - * _1 _2 _3 - * 15 null 1 - * 15 25 1 - * null null 2 - * null 25 1 - * null 20 1 - * 10 null 1 - * 10 20 1 - * - * }}} - * - * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1], + c2: TypedColumn[T, K2] + ): Cube2Ops[K1, K2, T] = new Cube2Ops[K1, K2, T](this, c1, c2) + + /** + * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * {{{ + * case class MyClass(a: Int, b: Int, c: Int) + * val ds: TypedDataset[MyClass] + * + * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = + * ds.cubeMany(ds('a), ds('b)).agg(count[MyClass]()) + * + * // original dataset: + * a b c + * 10 20 1 + * 15 25 2 + * + * // after aggregation: + * _1 _2 _3 + * 15 null 1 + * 15 25 1 + * null null 2 + * null 25 1 + * null 20 1 + * 10 null 1 + * 10 20 1 + * + * }}} + * + * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ object cubeMany extends ProductArgs { - def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK) - (implicit + + def applyProduct[TK <: HList, K <: HList, KT]( + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: Tupler.Aux[K, KT], i2: ToTraversable.Aux[TK, List, UntypedExpression[T]] - ): CubeManyOps[T, TK, K, KT] = new CubeManyOps[T, TK, K, KT](self, groupedBy) + ): CubeManyOps[T, TK, K, KT] = + new CubeManyOps[T, TK, K, KT](self, groupedBy) } /** - * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * apache/spark - */ + * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * apache/spark + */ def groupBy[K1]( - c1: TypedColumn[T, K1] - ): GroupedBy1Ops[K1, T] = new GroupedBy1Ops[K1, T](this, c1) + c1: TypedColumn[T, K1] + ): GroupedBy1Ops[K1, T] = new GroupedBy1Ops[K1, T](this, c1) /** - * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * apache/spark - */ + * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * apache/spark + */ def groupBy[K1, K2]( - c1: TypedColumn[T, K1], - c2: TypedColumn[T, K2] - ): GroupedBy2Ops[K1, K2, T] = new GroupedBy2Ops[K1, K2, T](this, c1, c2) - - /** - * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * {{{ - * case class MyClass(a: Int, b: Int, c: Int) - * val ds: TypedDataset[MyClass] - * - * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = - * ds.groupByMany(ds('a), ds('b)).agg(count[MyClass]()) - * - * // original dataset: - * a b c - * 10 20 1 - * 15 25 2 - * - * // after aggregation: - * _1 _2 _3 - * 10 20 1 - * 15 25 1 - * - * }}} - * - * apache/spark - */ + c1: TypedColumn[T, K1], + c2: TypedColumn[T, K2] + ): GroupedBy2Ops[K1, K2, T] = new GroupedBy2Ops[K1, K2, T](this, c1, c2) + + /** + * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * {{{ + * case class MyClass(a: Int, b: Int, c: Int) + * val ds: TypedDataset[MyClass] + * + * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = + * ds.groupByMany(ds('a), ds('b)).agg(count[MyClass]()) + * + * // original dataset: + * a b c + * 10 20 1 + * 15 25 2 + * + * // after aggregation: + * _1 _2 _3 + * 10 20 1 + * 15 25 1 + * + * }}} + * + * apache/spark + */ object groupByMany extends ProductArgs { - def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK) - (implicit + + def applyProduct[TK <: HList, K <: HList, KT]( + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: Tupler.Aux[K, KT], i2: ToTraversable.Aux[TK, List, UntypedExpression[T]] - ): GroupedByManyOps[T, TK, K, KT] = new GroupedByManyOps[T, TK, K, KT](self, groupedBy) + ): GroupedByManyOps[T, TK, K, KT] = + new GroupedByManyOps[T, TK, K, KT](self, groupedBy) } /** - * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified column, - * so we can run aggregation on it. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified column, + * so we can run aggregation on it. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def rollup[K1]( - c1: TypedColumn[T, K1] - ): Rollup1Ops[K1, T] = new Rollup1Ops[K1, T](this, c1) - - /** - * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1] + ): Rollup1Ops[K1, T] = new Rollup1Ops[K1, T](this, c1) + + /** + * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def rollup[K1, K2]( - c1: TypedColumn[T, K1], - c2: TypedColumn[T, K2] - ): Rollup2Ops[K1, K2, T] = new Rollup2Ops[K1, K2, T](this, c1, c2) - - /** - * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * {{{ - * case class MyClass(a: Int, b: Int, c: Int) - * val ds: TypedDataset[MyClass] - * - * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = - * ds.rollupMany(ds('a), ds('b)).agg(count[MyClass]()) - * - * // original dataset: - * a b c - * 10 20 1 - * 15 25 2 - * - * // after aggregation: - * _1 _2 _3 - * 15 null 1 - * 15 25 1 - * null null 2 - * 10 null 1 - * 10 20 1 - * - * }}} - * - * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1], + c2: TypedColumn[T, K2] + ): Rollup2Ops[K1, K2, T] = new Rollup2Ops[K1, K2, T](this, c1, c2) + + /** + * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * {{{ + * case class MyClass(a: Int, b: Int, c: Int) + * val ds: TypedDataset[MyClass] + * + * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = + * ds.rollupMany(ds('a), ds('b)).agg(count[MyClass]()) + * + * // original dataset: + * a b c + * 10 20 1 + * 15 25 2 + * + * // after aggregation: + * _1 _2 _3 + * 15 null 1 + * 15 25 1 + * null null 2 + * 10 null 1 + * 10 20 1 + * + * }}} + * + * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ object rollupMany extends ProductArgs { - def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK) - (implicit + + def applyProduct[TK <: HList, K <: HList, KT]( + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: Tupler.Aux[K, KT], i2: ToTraversable.Aux[TK, List, UntypedExpression[T]] - ): RollupManyOps[T, TK, K, KT] = new RollupManyOps[T, TK, K, KT](self, groupedBy) + ): RollupManyOps[T, TK, K, KT] = + new RollupManyOps[T, TK, K, KT](self, groupedBy) } /** Computes the cartesian project of `this` `Dataset` with the `other` `Dataset` */ - def joinCross[U](other: TypedDataset[U]) - (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = - new TypedDataset(self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross")) - - /** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinFull[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] = - new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "full") - .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])])) - - /** Computes the inner join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinInner[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = { - import FramelessInternals._ - - val leftPlan = logicalPlan(dataset) - val rightPlan = logicalPlan(other.dataset) - val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE)) - val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan) - val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)]) - - TypedDataset.create[(T, U)](joinedDs) - } + def joinCross[U]( + other: TypedDataset[U] + )(implicit + e: TypedEncoder[(T, U)] + ): TypedDataset[(T, U)] = + new TypedDataset( + self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross") + ) - /** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinLeft[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] = - new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "left_outer") - .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])) - - /** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinLeftSemi[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] = - new TypedDataset(self.dataset.join(other.dataset, condition.untyped, "leftsemi") - .as[T](TypedExpressionEncoder(encoder))) - - /** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinLeftAnti[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] = - new TypedDataset(self.dataset.join(other.dataset, condition.untyped, "leftanti") - .as[T](TypedExpressionEncoder(encoder))) - - /** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinRight[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] = - new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "right_outer") - .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)])) + /** + * Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinFull[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(Option[T], Option[U])] + ): TypedDataset[(Option[T], Option[U])] = + new TypedDataset( + self.dataset + .joinWith(other.dataset, condition.untyped, "full") + .as[(Option[T], Option[U])]( + TypedExpressionEncoder[(Option[T], Option[U])] + ) + ) + + /** + * Computes the inner join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinInner[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(T, U)] + ): TypedDataset[(T, U)] = { + import FramelessInternals._ + + val leftPlan = logicalPlan(dataset) + val rightPlan = logicalPlan(other.dataset) + val join = disambiguate( + Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE) + ) + val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan) + val joinedDs = + mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)]) + + TypedDataset.create[(T, U)](joinedDs) + } + + /** + * Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinLeft[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(T, Option[U])] + ): TypedDataset[(T, Option[U])] = + new TypedDataset( + self.dataset + .joinWith(other.dataset, condition.untyped, "left_outer") + .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]) + ) + + /** + * Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinLeftSemi[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + ): TypedDataset[T] = + new TypedDataset( + self.dataset + .join(other.dataset, condition.untyped, "leftsemi") + .as[T](TypedExpressionEncoder(encoder)) + ) + + /** + * Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinLeftAnti[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + ): TypedDataset[T] = + new TypedDataset( + self.dataset + .join(other.dataset, condition.untyped, "leftanti") + .as[T](TypedExpressionEncoder(encoder)) + ) + + /** + * Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinRight[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(Option[T], U)] + ): TypedDataset[(Option[T], U)] = + new TypedDataset( + self.dataset + .joinWith(other.dataset, condition.untyped, "right_outer") + .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)]) + ) private def disambiguate(join: Join): Join = { - val plan = FramelessInternals.ofRows(dataset.sparkSession, join).queryExecution.analyzed.asInstanceOf[Join] + val plan = FramelessInternals + .ofRows(dataset.sparkSession, join) + .queryExecution + .analyzed + .asInstanceOf[Join] val disambiguated = plan.condition.map(_.transform { case FramelessInternals.DisambiguateLeft(tagged: AttributeReference) => val leftDs = FramelessInternals.ofRows(spark, plan.left) @@ -707,43 +891,81 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val plan.copy(condition = disambiguated) } - /** Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R]. - */ - def makeUDF[A: TypedEncoder, R: TypedEncoder](f: A => R): - TypedColumn[T, A] => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder](f: (A1, A2) => R): - (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2, A3) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, A5: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = functions.udf(f) - - /** Type-safe projection from type T to Tuple1[A] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R]. + */ + def makeUDF[A: TypedEncoder, R: TypedEncoder](f: A => R): TypedColumn[T, A] => TypedColumn[T, R] = + functions.udf(f) + + /** + * Takes a function from (A1, A2) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R]. + */ + def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder]( + f: (A1, A2) => R + ): (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = + functions.udf(f) + + /** + * Takes a function from (A1, A2, A3) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R]. + */ + def makeUDF[ + A1: TypedEncoder, + A2: TypedEncoder, + A3: TypedEncoder, + R: TypedEncoder + ](f: (A1, A2, A3) => R + ): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = + functions.udf(f) + + /** + * Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R]. + */ + def makeUDF[ + A1: TypedEncoder, + A2: TypedEncoder, + A3: TypedEncoder, + A4: TypedEncoder, + R: TypedEncoder + ](f: (A1, A2, A3, A4) => R + ): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3], + TypedColumn[T, A4] + ) => TypedColumn[T, R] = functions.udf(f) + + /** + * Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R]. + */ + def makeUDF[ + A1: TypedEncoder, + A2: TypedEncoder, + A3: TypedEncoder, + A4: TypedEncoder, + A5: TypedEncoder, + R: TypedEncoder + ](f: (A1, A2, A3, A4, A5) => R + ): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3], + TypedColumn[T, A4], + TypedColumn[T, A5] + ) => TypedColumn[T, R] = functions.udf(f) + + /** + * Type-safe projection from type T to Tuple1[A] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A]( - ca: TypedColumn[T, A] - ): TypedDataset[A] = { + ca: TypedColumn[T, A] + ): TypedDataset[A] = { implicit val ea = ca.uencoder val tuple1: TypedDataset[Tuple1[A]] = selectMany(ca) @@ -753,10 +975,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val TypedEncoder[A].catalystRepr match { case StructType(_) => // if column is struct, we use all its fields - val df = tuple1 - .dataset - .selectExpr("_1.*") - .as[A](TypedExpressionEncoder[A]) + val df = + tuple1.dataset.selectExpr("_1.*").as[A](TypedExpressionEncoder[A]) TypedDataset.create(df) case other => @@ -765,217 +985,288 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } - /** Type-safe projection from type T to Tuple2[A,B] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple2[A,B] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B] - ): TypedDataset[(A, B)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B] + ): TypedDataset[(A, B)] = { implicit val (ea, eb) = (ca.uencoder, cb.uencoder) selectMany(ca, cb) } - /** Type-safe projection from type T to Tuple3[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple3[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C] - ): TypedDataset[(A, B, C)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C] + ): TypedDataset[(A, B, C)] = { implicit val (ea, eb, ec) = (ca.uencoder, cb.uencoder, cc.uencoder) selectMany(ca, cb, cc) } - /** Type-safe projection from type T to Tuple4[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple4[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D] - ): TypedDataset[(A, B, C, D)] = { - implicit val (ea, eb, ec, ed) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D] + ): TypedDataset[(A, B, C, D)] = { + implicit val (ea, eb, ec, ed) = + (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) selectMany(ca, cb, cc, cd) } - /** Type-safe projection from type T to Tuple5[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple5[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E] - ): TypedDataset[(A, B, C, D, E)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E] + ): TypedDataset[(A, B, C, D, E)] = { implicit val (ea, eb, ec, ed, ee) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder) selectMany(ca, cb, cc, cd, ce) } - /** Type-safe projection from type T to Tuple6[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple6[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F] - ): TypedDataset[(A, B, C, D, E, F)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F] + ): TypedDataset[(A, B, C, D, E, F)] = { implicit val (ea, eb, ec, ed, ee, ef) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf) } - /** Type-safe projection from type T to Tuple7[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple7[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G] - ): TypedDataset[(A, B, C, D, E, F, G)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G] + ): TypedDataset[(A, B, C, D, E, F, G)] = { implicit val (ea, eb, ec, ed, ee, ef, eg) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg) } - /** Type-safe projection from type T to Tuple8[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple8[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G, H]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G], - ch: TypedColumn[T, H] - ): TypedDataset[(A, B, C, D, E, F, G, H)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G], + ch: TypedColumn[T, H] + ): TypedDataset[(A, B, C, D, E, F, G, H)] = { implicit val (ea, eb, ec, ed, ee, ef, eg, eh) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder, + ch.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg, ch) } - /** Type-safe projection from type T to Tuple9[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple9[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G, H, I]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G], - ch: TypedColumn[T, H], - ci: TypedColumn[T, I] - ): TypedDataset[(A, B, C, D, E, F, G, H, I)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G], + ch: TypedColumn[T, H], + ci: TypedColumn[T, I] + ): TypedDataset[(A, B, C, D, E, F, G, H, I)] = { implicit val (ea, eb, ec, ed, ee, ef, eg, eh, ei) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder, ci.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder, + ch.uencoder, + ci.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg, ch, ci) } - /** Type-safe projection from type T to Tuple10[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple10[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G, H, I, J]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G], - ch: TypedColumn[T, H], - ci: TypedColumn[T, I], - cj: TypedColumn[T, J] - ): TypedDataset[(A, B, C, D, E, F, G, H, I, J)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G], + ch: TypedColumn[T, H], + ci: TypedColumn[T, I], + cj: TypedColumn[T, J] + ): TypedDataset[(A, B, C, D, E, F, G, H, I, J)] = { implicit val (ea, eb, ec, ed, ee, ef, eg, eh, ei, ej) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder, ci.uencoder, cj.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder, + ch.uencoder, + ci.uencoder, + cj.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg, ch, ci, cj) } object selectMany extends ProductArgs { - def applyProduct[U <: HList, Out0 <: HList, Out](columns: U) - (implicit + + def applyProduct[U <: HList, Out0 <: HList, Out]( + columns: U + )(implicit i0: ColumnTypes.Aux[T, U, Out0], i1: ToTraversable.Aux[U, List, UntypedExpression[T]], i2: Tupler.Aux[Out0, Out], i3: TypedEncoder[Out] ): TypedDataset[Out] = { - val base = dataset.toDF() - .select(columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)):_*) - val selected = base.as[Out](TypedExpressionEncoder[Out]) + val base = dataset + .toDF() + .select( + columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)): _* + ) + val selected = base.as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](selected) - } + TypedDataset.create[Out](selected) + } } /** Sort each partition in the dataset using the columns selected. */ - def sortWithinPartitions[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] = + def sortWithinPartitions[A: CatalystOrdered]( + ca: SortedTypedColumn[T, A] + ): TypedDataset[T] = sortWithinPartitionsMany(ca) /** Sort each partition in the dataset using the columns selected. */ def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B] - ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb) + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B] + ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb) /** Sort each partition in the dataset using the columns selected. */ - def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B], - cc: SortedTypedColumn[T, C] - ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc) - - /** Sort each partition in the dataset by the given column expressions - * Default sort order is ascending. - * {{{ - * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc) - * }}} - */ + def sortWithinPartitions[ + A: CatalystOrdered, + B: CatalystOrdered, + C: CatalystOrdered + ](ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C] + ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc) + + /** + * Sort each partition in the dataset by the given column expressions + * Default sort order is ascending. + * {{{ + * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc) + * }}} + */ object sortWithinPartitionsMany extends ProductArgs { - def applyProduct[U <: HList, O <: HList](columns: U) - (implicit + + def applyProduct[U <: HList, O <: HList]( + columns: U + )(implicit i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] ): TypedDataset[T] = { - val sorted = dataset.toDF() - .sortWithinPartitions(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*) + val sorted = dataset + .toDF() + .sortWithinPartitions( + i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped): _* + ) .as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](sorted) @@ -988,268 +1279,309 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val /** Orders the TypedDataset using the columns selected. */ def orderBy[A: CatalystOrdered, B: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B] - ): TypedDataset[T] = orderByMany(ca, cb) - - /** Orders the TypedDataset using the columns selected. */ - def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B], - cc: SortedTypedColumn[T, C] - ): TypedDataset[T] = orderByMany(ca, cb, cc) - - /** Sort the dataset by any number of column expressions. - * Default sort order is ascending. - * {{{ - * d.orderByMany(d('a), d('b).desc, d('c).asc) - * }}} - */ + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B] + ): TypedDataset[T] = orderByMany(ca, cb) + + /** Orders the TypedDataset using the columns selected. */ + def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C] + ): TypedDataset[T] = orderByMany(ca, cb, cc) + + /** + * Sort the dataset by any number of column expressions. + * Default sort order is ascending. + * {{{ + * d.orderByMany(d('a), d('b).desc, d('c).asc) + * }}} + */ object orderByMany extends ProductArgs { - def applyProduct[U <: HList, O <: HList](columns: U) - (implicit + + def applyProduct[U <: HList, O <: HList]( + columns: U + )(implicit i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] ): TypedDataset[T] = { - val sorted = dataset.toDF() - .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*) + val sorted = dataset + .toDF() + .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped): _*) .as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](sorted) } } - /** Returns a new Dataset as a tuple with the specified - * column dropped. - * Does not allow for dropping from a single column TypedDataset - * - * {{{ - * val d: TypedDataset[Foo(a: String, b: Int...)] = ??? - * val result = TypedDataset[(Int, ...)] = d.drop('a) - * }}} - * @param column column to drop specified as a Symbol - * @param i0 LabelledGeneric derived for T - * @param i1 Remover derived for TRep and column - * @param i2 values of T with column removed - * @param i3 tupler of values - * @param i4 evidence of encoder of the tupled values - * @tparam Out Tupled return type - * @tparam TRep shapeless' record representation of T - * @tparam Removed record of T with column removed - * @tparam ValuesFromRemoved values of T with column removed as an HList - * @tparam V value type of column in T - * @return - */ - def dropTupled[Out, TRep <: HList, Removed <: HList, ValuesFromRemoved <: HList, V] - (column: Witness.Lt[Symbol]) - (implicit + /** + * Returns a new Dataset as a tuple with the specified + * column dropped. + * Does not allow for dropping from a single column TypedDataset + * + * {{{ + * val d: TypedDataset[Foo(a: String, b: Int...)] = ??? + * val result = TypedDataset[(Int, ...)] = d.drop('a) + * }}} + * @param column column to drop specified as a Symbol + * @param i0 LabelledGeneric derived for T + * @param i1 Remover derived for TRep and column + * @param i2 values of T with column removed + * @param i3 tupler of values + * @param i4 evidence of encoder of the tupled values + * @tparam Out Tupled return type + * @tparam TRep shapeless' record representation of T + * @tparam Removed record of T with column removed + * @tparam ValuesFromRemoved values of T with column removed as an HList + * @tparam V value type of column in T + * @return + */ + def dropTupled[ + Out, + TRep <: HList, + Removed <: HList, + ValuesFromRemoved <: HList, + V + ](column: Witness.Lt[Symbol] + )(implicit i0: LabelledGeneric.Aux[T, TRep], i1: Remover.Aux[TRep, column.T, (V, Removed)], i2: Values.Aux[Removed, ValuesFromRemoved], i3: Tupler.Aux[ValuesFromRemoved, Out], i4: TypedEncoder[Out] ): TypedDataset[Out] = { - val dropped = dataset - .toDF() - .drop(column.value.name) - .as[Out](TypedExpressionEncoder[Out]) + val dropped = dataset + .toDF() + .drop(column.value.name) + .as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](dropped) - } + TypedDataset.create[Out](dropped) + } /** - * Drops columns as necessary to return `U` - * - * @example - * {{{ - * case class X(i: Int, j: Int, k: Boolean) - * case class Y(i: Int, k: Boolean) - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.drop[Y] - * }}} - * - * @tparam U the output type - * - * @see [[frameless.TypedDataset#project]] - */ - def drop[U](implicit projector: SmartProject[T,U]): TypedDataset[U] = project[U] - - /** Prepends a new column to the Dataset. - * - * {{{ - * case class X(i: Int, j: Int) - * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) - * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10) - * }}} - */ - def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out] - (ca: TypedColumn[T, A]) - (implicit + * Drops columns as necessary to return `U` + * + * @example + * {{{ + * case class X(i: Int, j: Int, k: Boolean) + * case class Y(i: Int, k: Boolean) + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.drop[Y] + * }}} + * + * @tparam U the output type + * + * @see [[frameless.TypedDataset#project]] + */ + def drop[U]( + implicit + projector: SmartProject[T, U] + ): TypedDataset[U] = project[U] + + /** + * Prepends a new column to the Dataset. + * + * {{{ + * case class X(i: Int, j: Int) + * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10) + * }}} + */ + def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out]( + ca: TypedColumn[T, A] + )(implicit i0: Generic.Aux[T, H], i1: Prepend.Aux[H, A :: HNil, FH], i2: Tupler.Aux[FH, Out], i3: TypedEncoder[Out] ): TypedDataset[Out] = { - // Giving a random name to the new column (the proper name will be given by the Tuple-based encoder) - val selected = dataset.toDF().withColumn("I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI", ca.untyped) - .as[Out](TypedExpressionEncoder[Out]) + // Giving a random name to the new column (the proper name will be given by the Tuple-based encoder) + val selected = dataset + .toDF() + .withColumn("I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI", ca.untyped) + .as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](selected) + TypedDataset.create[Out](selected) } - /** Returns a new [[frameless.TypedDataset]] with the specified column updated with a new value - * {{{ - * case class X(i: Int, j: Int) - * val f: TypedDataset[X] = TypedDataset.create(X(1,10) :: Nil) - * val fNew: TypedDataset[X] = f.withColumn('j, f('i)) // results in X(1, 1) :: Nil - * }}} - * @param column column given as a symbol to replace - * @param replacement column to replace the value with - * @param i0 Evidence that a column with the correct type and name exists - */ + /** + * Returns a new [[frameless.TypedDataset]] with the specified column updated with a new value + * {{{ + * case class X(i: Int, j: Int) + * val f: TypedDataset[X] = TypedDataset.create(X(1,10) :: Nil) + * val fNew: TypedDataset[X] = f.withColumn('j, f('i)) // results in X(1, 1) :: Nil + * }}} + * @param column column given as a symbol to replace + * @param replacement column to replace the value with + * @param i0 Evidence that a column with the correct type and name exists + */ def withColumnReplaced[A]( - column: Witness.Lt[Symbol], - replacement: TypedColumn[T, A] - )(implicit - i0: TypedColumn.Exists[T, column.T, A] - ): TypedDataset[T] = { - val updated = dataset.toDF().withColumn(column.value.name, replacement.untyped) + column: Witness.Lt[Symbol], + replacement: TypedColumn[T, A] + )(implicit + i0: TypedColumn.Exists[T, column.T, A] + ): TypedDataset[T] = { + val updated = dataset + .toDF() + .withColumn(column.value.name, replacement.untyped) .as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](updated) } - /** Adds a column to a Dataset so long as the specified output type, `U`, has - * an extra column from `T` that has type `A`. - * - * @example - * {{{ - * case class X(i: Int, j: Int) - * case class Y(i: Int, j: Int, k: Boolean) - * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) - * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10) - * }}} - * @param ca The typed column to add - * @param i0 TypeEncoder for output type U - * @param i1 TypeEncoder for added column type A - * @param i2 the LabelledGeneric derived for T - * @param i3 the LabelledGeneric derived for U - * @param i4 proof no fields have been removed - * @param i5 diff from T to U - * @param i6 keys from newFields - * @param i7 the one and only new key - * @param i8 the one and only new field enforcing the type of A exists - * @param i9 the keys of U - * @param iA allows for traversing the keys of U - * @tparam U the output type - * @tparam A The added column type - * @tparam TRep shapeless' record representation of T - * @tparam URep shapeless' record representation of U - * @tparam UKeys the keys of U as an HList - * @tparam NewFields the added fields to T to get U - * @tparam NewKeys the keys of NewFields as an HList - * @tparam NewKey the first, and only, key in NewKey - * - * @see [[frameless.TypedDataset.WithColumnApply#apply]] - */ + /** + * Adds a column to a Dataset so long as the specified output type, `U`, has + * an extra column from `T` that has type `A`. + * + * @example + * {{{ + * case class X(i: Int, j: Int) + * case class Y(i: Int, j: Int, k: Boolean) + * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10) + * }}} + * @param ca The typed column to add + * @param i0 TypeEncoder for output type U + * @param i1 TypeEncoder for added column type A + * @param i2 the LabelledGeneric derived for T + * @param i3 the LabelledGeneric derived for U + * @param i4 proof no fields have been removed + * @param i5 diff from T to U + * @param i6 keys from newFields + * @param i7 the one and only new key + * @param i8 the one and only new field enforcing the type of A exists + * @param i9 the keys of U + * @param iA allows for traversing the keys of U + * @tparam U the output type + * @tparam A The added column type + * @tparam TRep shapeless' record representation of T + * @tparam URep shapeless' record representation of U + * @tparam UKeys the keys of U as an HList + * @tparam NewFields the added fields to T to get U + * @tparam NewKeys the keys of NewFields as an HList + * @tparam NewKey the first, and only, key in NewKey + * + * @see [[frameless.TypedDataset.WithColumnApply#apply]] + */ def withColumn[U] = new WithColumnApply[U] class WithColumnApply[U] { - def apply[A, TRep <: HList, URep <: HList, UKeys <: HList, NewFields <: HList, NewKeys <: HList, NewKey <: Symbol] - (ca: TypedColumn[T, A]) - (implicit - i0: TypedEncoder[U], - i1: TypedEncoder[A], - i2: LabelledGeneric.Aux[T, TRep], - i3: LabelledGeneric.Aux[U, URep], - i4: Diff.Aux[TRep, URep, HNil], - i5: Diff.Aux[URep, TRep, NewFields], - i6: Keys.Aux[NewFields, NewKeys], - i7: IsHCons.Aux[NewKeys, NewKey, HNil], - i8: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil], - i9: Keys.Aux[URep, UKeys], - iA: ToTraversable.Aux[UKeys, Seq, Symbol] - ): TypedDataset[U] = { + + def apply[ + A, + TRep <: HList, + URep <: HList, + UKeys <: HList, + NewFields <: HList, + NewKeys <: HList, + NewKey <: Symbol + ](ca: TypedColumn[T, A] + )(implicit + i0: TypedEncoder[U], + i1: TypedEncoder[A], + i2: LabelledGeneric.Aux[T, TRep], + i3: LabelledGeneric.Aux[U, URep], + i4: Diff.Aux[TRep, URep, HNil], + i5: Diff.Aux[URep, TRep, NewFields], + i6: Keys.Aux[NewFields, NewKeys], + i7: IsHCons.Aux[NewKeys, NewKey, HNil], + i8: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil], + i9: Keys.Aux[URep, UKeys], + iA: ToTraversable.Aux[UKeys, Seq, Symbol] + ): TypedDataset[U] = { val newColumnName = i7.head(i6()).name - val dfWithNewColumn = dataset - .toDF() - .withColumn(newColumnName, ca.untyped) + val dfWithNewColumn = dataset.toDF().withColumn(newColumnName, ca.untyped) val newColumns = i9.apply().to[Seq].map(_.name).map(dfWithNewColumn.col) - val selected = dfWithNewColumn - .select(newColumns: _*) - .as[U](TypedExpressionEncoder[U]) + val selected = + dfWithNewColumn.select(newColumns: _*).as[U](TypedExpressionEncoder[U]) TypedDataset.create[U](selected) } } /** - * Explodes a single column at a time. It only compiles if the type of column supports this operation. - * - * @example - * - * {{{ - * case class X(i: Int, j: Array[Int]) - * case class Y(i: Int, j: Int) - * - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.explode('j).as[Y] - * }}} - * @param column the column we wish to explode - */ - def explode[A, TRep <: HList, V[_], OutMod <: HList, OutModValues <: HList, Out] - (column: Witness.Lt[Symbol]) - (implicit - i0: TypedColumn.Exists[T, column.T, V[A]], - i1: TypedEncoder[A], - i2: CatalystExplodableCollection[V], - i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], - i5: Values.Aux[OutMod, OutModValues], - i6: Tupler.Aux[OutModValues, Out], - i7: TypedEncoder[Out] - ): TypedDataset[Out] = { - import org.apache.spark.sql.functions.{explode => sparkExplode} + * Explodes a single column at a time. It only compiles if the type of column supports this operation. + * + * @example + * + * {{{ + * case class X(i: Int, j: Array[Int]) + * case class Y(i: Int, j: Int) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.explode('j).as[Y] + * }}} + * @param column the column we wish to explode + */ + def explode[ + A, + TRep <: HList, + V[_], + OutMod <: HList, + OutModValues <: HList, + Out + ](column: Witness.Lt[Symbol] + )(implicit + i0: TypedColumn.Exists[T, column.T, V[A]], + i1: TypedEncoder[A], + i2: CatalystExplodableCollection[V], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { + import org.apache.spark.sql.functions.{ explode => sparkExplode } val df = dataset.toDF() val trans = - df - .withColumn(column.value.name, sparkExplode(df(column.value.name))) + df.withColumn(column.value.name, sparkExplode(df(column.value.name))) .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } /** - * Explodes a single column at a time. It only compiles if the type of column supports this operation. - * - * @example - * - * {{{ - * case class X(i: Int, j: Map[Int, Int]) - * case class Y(i: Int, j: (Int, Int)) - * - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y] - * }}} - * @param column the column we wish to explode - */ - def explodeMap[A, B, V[_, _], TRep <: HList, OutMod <: HList, OutModValues <: HList, Out] - (column: Witness.Lt[Symbol]) - (implicit - i0: TypedColumn.Exists[T, column.T, V[A, B]], - i1: TypedEncoder[A], - i2: TypedEncoder[B], - i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, V[A,B], (A, B), OutMod], - i5: Values.Aux[OutMod, OutModValues], - i6: Tupler.Aux[OutModValues, Out], - i7: TypedEncoder[Out] - ): TypedDataset[Out] = { - import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol} + * Explodes a single column at a time. It only compiles if the type of column supports this operation. + * + * @example + * + * {{{ + * case class X(i: Int, j: Map[Int, Int]) + * case class Y(i: Int, j: (Int, Int)) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y] + * }}} + * @param column the column we wish to explode + */ + def explodeMap[ + A, + B, + V[_, _], + TRep <: HList, + OutMod <: HList, + OutModValues <: HList, + Out + ](column: Witness.Lt[Symbol] + )(implicit + i0: TypedColumn.Exists[T, column.T, V[A, B]], + i1: TypedEncoder[A], + i2: TypedEncoder[B], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A, B], (A, B), OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { + import org.apache.spark.sql.functions.{ + explode => sparkExplode, + struct => sparkStruct, + col => sparkCol + } val df = dataset.toDF() // select all columns, all original columns and [key, value] columns appeared after the map explode @@ -1271,7 +1603,10 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val exploded // map explode explodes it into [key, value] columns // the only way to put it into a column is to create a struct - .withColumn(columnRenamed, sparkStruct(exploded("key"), exploded("value"))) + .withColumn( + columnRenamed, + sparkStruct(exploded("key"), exploded("value")) + ) // selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode .select(columns: _*) // rename columns back and form the result @@ -1281,72 +1616,81 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } /** - * Flattens a column of type Option[A]. Compiles only if the selected column is of type Option[A]. - * - * - * @example - * - * {{{ - * case class X(i: Int, j: Option[Int]) - * case class Y(i: Int, j: Int) - * - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.flattenOption('j).as[Y] - * }}} - * - * @param column the column we wish to flatten - */ - def flattenOption[A, TRep <: HList, V[_], OutMod <: HList, OutModValues <: HList, Out] - (column: Witness.Lt[Symbol]) - (implicit - i0: TypedColumn.Exists[T, column.T, V[A]], - i1: TypedEncoder[A], - i2: V[A] =:= Option[A], - i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], - i5: Values.Aux[OutMod, OutModValues], - i6: Tupler.Aux[OutModValues, Out], - i7: TypedEncoder[Out] - ): TypedDataset[Out] = { + * Flattens a column of type Option[A]. Compiles only if the selected column is of type Option[A]. + * + * @example + * + * {{{ + * case class X(i: Int, j: Option[Int]) + * case class Y(i: Int, j: Int) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.flattenOption('j).as[Y] + * }}} + * + * @param column the column we wish to flatten + */ + def flattenOption[ + A, + TRep <: HList, + V[_], + OutMod <: HList, + OutModValues <: HList, + Out + ](column: Witness.Lt[Symbol] + )(implicit + i0: TypedColumn.Exists[T, column.T, V[A]], + i1: TypedEncoder[A], + i2: V[A] =:= Option[A], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { val df = dataset.toDF() - val trans = df.filter(df(column.value.name).isNotNull). - as[Out](TypedExpressionEncoder[Out]) + val trans = df + .filter(df(column.value.name).isNotNull) + .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } } object TypedDataset { - def create[A](data: Seq[A]) - (implicit + + def create[A]( + data: Seq[A] + )(implicit encoder: TypedEncoder[A], sqlContext: SparkSession ): TypedDataset[A] = { - val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) + val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) - TypedDataset.create[A](dataset) - } + TypedDataset.create[A](dataset) + } - def create[A](data: RDD[A]) - (implicit + def create[A]( + data: RDD[A] + )(implicit encoder: TypedEncoder[A], sqlContext: SparkSession ): TypedDataset[A] = { - val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) + val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) - TypedDataset.create[A](dataset) - } + TypedDataset.create[A](dataset) + } def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = createUnsafe(dataset.toDF()) /** - * Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]]. - * Note that the names and types need to align! - * - * This is an unsafe operation: If the schemas do not align, - * the error will be captured at runtime (not during compilation). - */ + * Creates a [[frameless.TypedDataset]] from a Spark org.apache.spark.sql.DataFrame. + * Note that the names and types need to align! + * + * This is an unsafe operation: If the schemas do not align, + * the error will be captured at runtime (not during compilation). + */ def createUnsafe[A: TypedEncoder](df: DataFrame): TypedDataset[A] = { val e = TypedEncoder[A] val output: Seq[Attribute] = df.queryExecution.analyzed.output @@ -1358,7 +1702,8 @@ object TypedDataset { throw new IllegalStateException( s"Unsupported creation of TypedDataset with ${targetFields.size} column(s) " + s"from a DataFrame with ${output.size} columns. " + - "Try to `select()` the proper columns in the right order before calling `create()`.") + "Try to `select()` the proper columns in the right order before calling `create()`." + ) } // Adapt names if they are not the same (note: types still might not match) @@ -1368,7 +1713,7 @@ object TypedDataset { val canSelect = targetColNames.toSet.subsetOf(output.map(_.name).toSet) val reshaped = if (shouldReshape && canSelect) { - df.select(targetColNames.head, targetColNames.tail:_*) + df.select(targetColNames.head, targetColNames.tail: _*) } else if (shouldReshape) { df.toDF(targetColNames: _*) } else { @@ -1378,9 +1723,14 @@ object TypedDataset { new TypedDataset[A](reshaped.as[A](TypedExpressionEncoder[A])) } - /** Prefer `TypedDataset.create` over `TypedDataset.unsafeCreate` unless you - * know what you are doing. */ - @deprecated("Prefer TypedDataset.create over TypedDataset.unsafeCreate", "0.3.0") + /** + * Prefer `TypedDataset.create` over `TypedDataset.unsafeCreate` unless you + * know what you are doing. + */ + @deprecated( + "Prefer TypedDataset.create over TypedDataset.unsafeCreate", + "0.3.0" + ) def unsafeCreate[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = { new TypedDataset[A](dataset) } diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index b42b026ee..235060cf1 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -1,20 +1,17 @@ package frameless import java.math.BigInteger - import java.util.Date - -import java.time.{ Duration, Instant, Period, LocalDate } - +import java.time.{ Duration, Instant, LocalDate, Period } import java.sql.Timestamp - import scala.reflect.ClassTag +import FramelessInternals.UserDefinedType +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + UnsafeArrayData, + Literal +} -import org.apache.spark.sql.FramelessInternals -import org.apache.spark.sql.FramelessInternals.UserDefinedType -import org.apache.spark.sql.{ reflection => ScalaReflection } -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ ArrayBasedMapData, DateTimeUtils, @@ -22,9 +19,22 @@ import org.apache.spark.sql.catalyst.util.{ } import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import shapeless._ import shapeless.ops.hlist.IsHCons +import com.sparkutils.shim.expressions.{ + UnwrapOption2 => UnwrapOption, + WrapOption2 => WrapOption, + MapObjects5 => MapObjects, + ExternalMapToCatalyst7 => ExternalMapToCatalyst +} +import frameless.{ reflection => ScalaReflection } +import org.apache.spark.sql.shim.{ + StaticInvoke4 => StaticInvoke, + NewInstance4 => NewInstance, + Invoke5 => Invoke +} + +import scala.collection.immutable.{ ListSet, TreeSet } abstract class TypedEncoder[T]( implicit @@ -501,10 +511,76 @@ object TypedEncoder { override def toString: String = s"arrayEncoder($jvmRepr)" } - implicit def collectionEncoder[C[X] <: Seq[X], T]( + /** + * Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation + * + * This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported. + * + * @tparam C + */ + trait CollectionConversion[F[_], C[_], Y] extends Serializable { + def convert(c: F[Y]): C[Y] + } + + object CollectionConversion { + + implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] { + + override def convert(c: Seq[Y]): Seq[Y] = + c match { + // Stream is produced + case _: Stream[Y] @unchecked => c.toVector.toSeq + case _ => c + } + } + + implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] { + override def convert(c: Seq[Y]): Vector[Y] = c.toVector + } + + implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] { + override def convert(c: Seq[Y]): List[Y] = c.toList + } + + implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] { + override def convert(c: Set[Y]): Set[Y] = c + } + + implicit def setToTreeSet[Y]( + implicit + ordering: Ordering[Y] + ) = new CollectionConversion[Set, TreeSet, Y] { + + override def convert(c: Set[Y]): TreeSet[Y] = + TreeSet.newBuilder.++=(c).result() + } + + implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] { + + override def convert(c: Set[Y]): ListSet[Y] = + ListSet.newBuilder.++=(c).result() + } + } + + implicit def seqEncoder[C[X] <: Seq[X], T]( implicit i0: Lazy[RecordFieldEncoder[T]], - i1: ClassTag[C[T]] + i1: ClassTag[C[T]], + i2: CollectionConversion[Seq, C, T] + ) = collectionEncoder[Seq, C, T] + + implicit def setEncoder[C[X] <: Set[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[Set, C, T] + ) = collectionEncoder[Set, C, T] + + def collectionEncoder[O[_], C[X], T]( + implicit + i0: Lazy[RecordFieldEncoder[T]], + i1: ClassTag[C[T]], + i2: CollectionConversion[O, C, T] ): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { private lazy val encodeT = i0.value.encoder @@ -521,38 +597,31 @@ object TypedEncoder { if (ScalaReflection.isNativeType(enc.jvmRepr)) { NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr) } else { - MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) + // converts to Seq, both Set and Seq handling must convert to Seq first + MapObjects( + enc.toCatalyst, + SeqCaster(path), + enc.jvmRepr, + encodeT.nullable + ) } } def fromCatalyst(path: Expression): Expression = - MapObjects( - i0.value.fromCatalyst, - path, - encodeT.catalystRepr, - encodeT.nullable, - Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly - ) + CollectionCaster[O, C, T]( + MapObjects( + i0.value.fromCatalyst, + path, + encodeT.catalystRepr, + encodeT.nullable, + Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling + ), + implicitly[CollectionConversion[O, C, T]] + ) // This will convert Seq to the appropriate C[_] when eval'ing. override def toString: String = s"collectionEncoder($jvmRepr)" } - /** - * @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set. - * @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type. - * @tparam T the element type of the set. - * @return a `TypedEncoder` instance for `Set[T]`. - */ - implicit def setEncoder[T]( - implicit - i1: shapeless.Lazy[RecordFieldEncoder[T]], - i2: ClassTag[Set[T]] - ): TypedEncoder[Set[T]] = { - implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet) - - TypedEncoder.usingInjection - } - /** * @tparam A the key type * @tparam B the value type diff --git a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala index 5b78cd292..62d06a802 100644 --- a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala @@ -1,49 +1,32 @@ package frameless import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If} import org.apache.spark.sql.types.StructType object TypedExpressionEncoder { - /** In Spark, DataFrame has always schema of StructType - * - * DataFrames of primitive types become records - * with a single field called "value" set in ExpressionEncoder. - */ + /** + * In Spark, DataFrame has always schema of StructType + * + * DataFrames of primitive types become records + * with a single field called "value" set in ExpressionEncoder. + */ def targetStructType[A](encoder: TypedEncoder[A]): StructType = - encoder.catalystRepr match { - case x: StructType => - if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true))) - else x - - case dt => new StructType().add("value", dt, nullable = encoder.nullable) - } - - def apply[T](implicit encoder: TypedEncoder[T]): Encoder[T] = { - val in = BoundReference(0, encoder.jvmRepr, encoder.nullable) - - val (out, serializer) = encoder.toCatalyst(in) match { - case it @ If(_, _, _: CreateNamedStruct) => { - val out = GetColumnByOrdinal(0, encoder.catalystRepr) - - out -> it - } - - case other => { - val out = GetColumnByOrdinal(0, encoder.catalystRepr) - - out -> other - } - } - - new ExpressionEncoder[T]( - objSerializer = serializer, - objDeserializer = encoder.fromCatalyst(out), - clsTag = encoder.classTag + org.apache.spark.sql.ShimUtils + .targetStructType(encoder.catalystRepr, encoder.nullable) + + def apply[T]( + implicit + encoder: TypedEncoder[T] + ): Encoder[T] = { + import encoder._ + org.apache.spark.sql.ShimUtils.expressionEncoder[T]( + jvmRepr, + nullable, + toCatalyst, + catalystRepr, + fromCatalyst ) } -} +} diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala index e371ea048..ad137a4d6 100644 --- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala @@ -1,73 +1,91 @@ package frameless package functions -import org.apache.spark.sql.FramelessInternals.expr -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.{functions => sparkFunctions} +import FramelessInternals.expr +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.{ functions => sparkFunctions } import frameless.syntax._ -import scala.annotation.nowarn +import com.sparkutils.shim.expressions.{ + Coalesce1 => Coalesce, + functions => shimFunctions +} trait AggregateFunctions { - /** Aggregate function: returns the number of items in a group. - * - * apache/spark - */ + + /** + * Aggregate function: returns the number of items in a group. + * + * apache/spark + */ def count[T](): TypedAggregate[T, Long] = sparkFunctions.count(sparkFunctions.lit(1)).typedAggregate - /** Aggregate function: returns the number of items in a group for which the selected column is not null. - * - * apache/spark - */ + /** + * Aggregate function: returns the number of items in a group for which the selected column is not null. + * + * apache/spark + */ def count[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = sparkFunctions.count(column.untyped).typedAggregate - /** Aggregate function: returns the number of distinct items in a group. - * - * apache/spark - */ + /** + * Aggregate function: returns the number of distinct items in a group. + * + * apache/spark + */ def countDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = sparkFunctions.countDistinct(column.untyped).typedAggregate - /** Aggregate function: returns the approximate number of distinct items in a group. - */ + /** + * Aggregate function: returns the approximate number of distinct items in a group. + */ def approxCountDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = sparkFunctions.approx_count_distinct(column.untyped).typedAggregate - /** Aggregate function: returns the approximate number of distinct items in a group. - * - * @param rsd maximum estimation error allowed (default = 0.05) - * - * apache/spark - */ - def approxCountDistinct[T](column: TypedColumn[T, _], rsd: Double): TypedAggregate[T, Long] = + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @param rsd maximum estimation error allowed (default = 0.05) + * + * apache/spark + */ + def approxCountDistinct[T]( + column: TypedColumn[T, _], + rsd: Double + ): TypedAggregate[T, Long] = sparkFunctions.approx_count_distinct(column.untyped, rsd).typedAggregate - /** Aggregate function: returns a list of objects with duplicates. - * - * apache/spark - */ - def collectList[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = + /** + * Aggregate function: returns a list of objects with duplicates. + * + * apache/spark + */ + def collectList[T, A: TypedEncoder]( + column: TypedColumn[T, A] + ): TypedAggregate[T, Vector[A]] = sparkFunctions.collect_list(column.untyped).typedAggregate - /** Aggregate function: returns a set of objects with duplicate elements eliminated. - * - * apache/spark - */ + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * apache/spark + */ def collectSet[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = sparkFunctions.collect_set(column.untyped).typedAggregate - /** Aggregate function: returns the sum of all values in the given column. - * - * apache/spark - */ - def sum[A, T, Out](column: TypedColumn[T, A])( - implicit - summable: CatalystSummable[A, Out], - oencoder: TypedEncoder[Out], - aencoder: TypedEncoder[A] - ): TypedAggregate[T, Out] = { + /** + * Aggregate function: returns the sum of all values in the given column. + * + * apache/spark + */ + def sum[A, T, Out]( + column: TypedColumn[T, A] + )(implicit + summable: CatalystSummable[A, Out], + oencoder: TypedEncoder[Out], + aencoder: TypedEncoder[A] + ): TypedAggregate[T, Out] = { val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr) val sumExpr = expr(sparkFunctions.sum(column.untyped)) val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr)) @@ -75,204 +93,238 @@ trait AggregateFunctions { new TypedAggregate[T, Out](sumOrZero) } - /** Aggregate function: returns the sum of distinct values in the column. - * - * apache/spark - */ - @nowarn // supress sparkFunctions.sumDistinct call which is used to maintain Spark 3.1.x backwards compat - def sumDistinct[A, T, Out](column: TypedColumn[T, A])( - implicit - summable: CatalystSummable[A, Out], - oencoder: TypedEncoder[Out], - aencoder: TypedEncoder[A] - ): TypedAggregate[T, Out] = { + /** + * Aggregate function: returns the sum of distinct values in the column. + * + * apache/spark + */ + def sumDistinct[A, T, Out]( + column: TypedColumn[T, A] + )(implicit + summable: CatalystSummable[A, Out], + oencoder: TypedEncoder[Out], + aencoder: TypedEncoder[A] + ): TypedAggregate[T, Out] = { val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr) - val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped)) + val sumExpr = expr(shimFunctions.sumDistinct(column.untyped)) val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr)) new TypedAggregate[T, Out](sumOrZero) } - /** Aggregate function: returns the average of the values in a group. - * - * apache/spark - */ - def avg[A, T, Out](column: TypedColumn[T, A])( - implicit - averageable: CatalystAverageable[A, Out], - oencoder: TypedEncoder[Out] - ): TypedAggregate[T, Out] = { + /** + * Aggregate function: returns the average of the values in a group. + * + * apache/spark + */ + def avg[A, T, Out]( + column: TypedColumn[T, A] + )(implicit + averageable: CatalystAverageable[A, Out], + oencoder: TypedEncoder[Out] + ): TypedAggregate[T, Out] = { new TypedAggregate[T, Out](sparkFunctions.avg(column.untyped)) } - /** Aggregate function: returns the unbiased variance of the values in a group. - * - * @note In Spark variance always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#186]] - * - * apache/spark - */ - def variance[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @note In Spark variance always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#186]] + * + * apache/spark + */ + def variance[A: CatalystVariance, T]( + column: TypedColumn[T, A] + ): TypedAggregate[T, Double] = sparkFunctions.variance(column.untyped).typedAggregate - /** Aggregate function: returns the sample standard deviation. - * - * @note In Spark stddev always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#155]] - * - * apache/spark - */ + /** + * Aggregate function: returns the sample standard deviation. + * + * @note In Spark stddev always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#155]] + * + * apache/spark + */ def stddev[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = sparkFunctions.stddev(column.untyped).typedAggregate /** - * Aggregate function: returns the standard deviation of a column by population. - * - * @note In Spark stddev always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L143]] - * - * apache/spark - */ - def stddevPop[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the standard deviation of a column by population. + * + * @note In Spark stddev always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L143]] + * + * apache/spark + */ + def stddevPop[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.stddev_pop(column.cast[Double].untyped) ) } /** - * Aggregate function: returns the standard deviation of a column by sample. - * - * @note In Spark stddev always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L160]] - * - * apache/spark - */ - def stddevSamp[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double] ): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the standard deviation of a column by sample. + * + * @note In Spark stddev always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L160]] + * + * apache/spark + */ + def stddevSamp[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.stddev_samp(column.cast[Double].untyped) ) } - /** Aggregate function: returns the maximum value of the column in a group. - * - * apache/spark - */ + /** + * Aggregate function: returns the maximum value of the column in a group. + * + * apache/spark + */ def max[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder sparkFunctions.max(column.untyped).typedAggregate } - /** Aggregate function: returns the minimum value of the column in a group. - * - * apache/spark - */ + /** + * Aggregate function: returns the minimum value of the column in a group. + * + * apache/spark + */ def min[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder sparkFunctions.min(column.untyped).typedAggregate } - /** Aggregate function: returns the first value in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * apache/spark - */ + /** + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * apache/spark + */ def first[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { sparkFunctions.first(column.untyped).typedAggregate(column.uencoder) } /** - * Aggregate function: returns the last value in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * apache/spark - */ + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * apache/spark + */ def last[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder sparkFunctions.last(column.untyped).typedAggregate } /** - * Aggregate function: returns the Pearson Correlation Coefficient for two columns. - * - * @note In Spark corr always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala#L95]] - * - * apache/spark - */ - def corr[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B]) - (implicit + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @note In Spark corr always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala#L95]] + * + * apache/spark + */ + def corr[A, B, T]( + column1: TypedColumn[T, A], + column2: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Option[Double]] = { - new TypedAggregate[T, Option[Double]]( - sparkFunctions.corr(column1.cast[Double].untyped, column2.cast[Double].untyped) - ) - } + new TypedAggregate[T, Option[Double]]( + sparkFunctions + .corr(column1.cast[Double].untyped, column2.cast[Double].untyped) + ) + } /** - * Aggregate function: returns the covariance of two collumns. - * - * @note In Spark covar_pop always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L82]] - * - * apache/spark - */ - def covarPop[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B]) - (implicit + * Aggregate function: returns the covariance of two collumns. + * + * @note In Spark covar_pop always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L82]] + * + * apache/spark + */ + def covarPop[A, B, T]( + column1: TypedColumn[T, A], + column2: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Option[Double]] = { - new TypedAggregate[T, Option[Double]]( - sparkFunctions.covar_pop(column1.cast[Double].untyped, column2.cast[Double].untyped) - ) - } + new TypedAggregate[T, Option[Double]]( + sparkFunctions + .covar_pop(column1.cast[Double].untyped, column2.cast[Double].untyped) + ) + } /** - * Aggregate function: returns the covariance of two columns. - * - * @note In Spark covar_samp always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L93]] - * - * apache/spark - */ - def covarSamp[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B]) - (implicit + * Aggregate function: returns the covariance of two columns. + * + * @note In Spark covar_samp always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L93]] + * + * apache/spark + */ + def covarSamp[A, B, T]( + column1: TypedColumn[T, A], + column2: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Option[Double]] = { - new TypedAggregate[T, Option[Double]]( - sparkFunctions.covar_samp(column1.cast[Double].untyped, column2.cast[Double].untyped) - ) - } - + new TypedAggregate[T, Option[Double]]( + sparkFunctions + .covar_samp(column1.cast[Double].untyped, column2.cast[Double].untyped) + ) + } /** - * Aggregate function: returns the kurtosis of a column. - * - * @note In Spark kurtosis always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L220]] - * - * apache/spark - */ - def kurtosis[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the kurtosis of a column. + * + * @note In Spark kurtosis always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L220]] + * + * apache/spark + */ + def kurtosis[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.kurtosis(column.cast[Double].untyped) ) } /** - * Aggregate function: returns the skewness of a column. - * - * @note In Spark skewness always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L200]] - * - * apache/spark - */ - def skewness[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the skewness of a column. + * + * @note In Spark skewness always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L200]] + * + * apache/spark + */ + def skewness[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.skewness(column.cast[Double].untyped) ) diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 939bf5b8d..396b7ff43 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -1,537 +1,718 @@ package frameless package functions -import org.apache.spark.sql.{Column, functions => sparkFunctions} +import org.apache.spark.sql.{ Column, functions => sparkFunctions } + +import com.sparkutils.shim.expressions.{ functions => shimFunctions } -import scala.annotation.nowarn import scala.util.matching.Regex trait NonAggregateFunctions { - /** Non-Aggregate function: calculates the SHA-2 digest of a binary column and returns the value as a 40 character hex string - * - * apache/spark - */ - def sha2[T](column: AbstractTypedColumn[T, Array[Byte]], numBits: Int): column.ThisType[T, String] = + + /** + * Non-Aggregate function: calculates the SHA-2 digest of a binary column and returns the value as a 40 character hex string + * + * apache/spark + */ + def sha2[T]( + column: AbstractTypedColumn[T, Array[Byte]], + numBits: Int + ): column.ThisType[T, String] = column.typed(sparkFunctions.sha2(column.untyped, numBits)) - /** Non-Aggregate function: calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex string - * - * apache/spark - */ + /** + * Non-Aggregate function: calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex string + * + * apache/spark + */ def sha1[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] = column.typed(sparkFunctions.sha1(column.untyped)) - /** Non-Aggregate function: returns a cyclic redundancy check value of a binary column as long. - * - * apache/spark - */ + /** + * Non-Aggregate function: returns a cyclic redundancy check value of a binary column as long. + * + * apache/spark + */ def crc32[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, Long] = column.typed(sparkFunctions.crc32(column.untyped)) + /** - * Non-Aggregate function: returns the negated value of column. - * - * apache/spark - */ - def negate[A, B, T](column: AbstractTypedColumn[T,A])( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], - i1: TypedEncoder[B] - ): column.ThisType[T,B] = + * Non-Aggregate function: returns the negated value of column. + * + * apache/spark + */ + def negate[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.negate(column.untyped)) /** - * Non-Aggregate function: logical not. - * - * apache/spark - */ - def not[T](column: AbstractTypedColumn[T,Boolean]): column.ThisType[T,Boolean] = + * Non-Aggregate function: logical not. + * + * apache/spark + */ + def not[T](column: AbstractTypedColumn[T, Boolean]): column.ThisType[T, Boolean] = column.typed(sparkFunctions.not(column.untyped)) /** - * Non-Aggregate function: Convert a number in a string column from one base to another. - * - * apache/spark - */ - def conv[T](column: AbstractTypedColumn[T,String], fromBase: Int, toBase: Int): column.ThisType[T,String] = - column.typed(sparkFunctions.conv(column.untyped,fromBase,toBase)) + * Non-Aggregate function: Convert a number in a string column from one base to another. + * + * apache/spark + */ + def conv[T]( + column: AbstractTypedColumn[T, String], + fromBase: Int, + toBase: Int + ): column.ThisType[T, String] = + column.typed(sparkFunctions.conv(column.untyped, fromBase, toBase)) - /** Non-Aggregate function: Converts an angle measured in radians to an approximately equivalent angle measured in degrees. - * - * apache/spark - */ - def degrees[A,T](column: AbstractTypedColumn[T,A]): column.ThisType[T,Double] = + /** + * Non-Aggregate function: Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * apache/spark + */ + def degrees[A, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, Double] = column.typed(sparkFunctions.degrees(column.untyped)) - /** Non-Aggregate function: returns the ceiling of a numeric column - * - * apache/spark - */ - def ceil[A, B, T](column: AbstractTypedColumn[T, A]) - (implicit + /** + * Non-Aggregate function: returns the ceiling of a numeric column + * + * apache/spark + */ + def ceil[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystRound[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = + column.typed(sparkFunctions.ceil(column.untyped))(i1) + + /** + * Non-Aggregate function: returns the floor of a numeric column + * + * apache/spark + */ + def floor[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit i0: CatalystRound[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.ceil(column.untyped))(i1) - - /** Non-Aggregate function: returns the floor of a numeric column - * - * apache/spark - */ - def floor[A, B, T](column: AbstractTypedColumn[T, A]) - (implicit - i0: CatalystRound[A, B], - i1: TypedEncoder[B] - ): column.ThisType[T, B] = column.typed(sparkFunctions.floor(column.untyped))(i1) - /** Non-Aggregate function: unsigned shift the the given value numBits right. If given long, will return long else it will return an integer. - * - * apache/spark - */ - @nowarn // supress sparkFunctions.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat - def shiftRightUnsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) - (implicit + /** + * Non-Aggregate function: unsigned shift the the given value numBits right. If given long, will return long else it will return an integer. + * + * apache/spark + */ + def shiftRightUnsigned[A, B, T]( + column: AbstractTypedColumn[T, A], + numBits: Int + )(implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits)) - - /** Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer. - * - * apache/spark - */ - @nowarn // supress sparkFunctions.shiftReft call which is used to maintain Spark 3.1.x backwards compat - def shiftRight[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) - (implicit + column.typed(shimFunctions.shiftRightUnsigned(column.untyped, numBits)) + + /** + * Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer. + * + * apache/spark + */ + def shiftRight[A, B, T]( + column: AbstractTypedColumn[T, A], + numBits: Int + )(implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftRight(column.untyped, numBits)) - - /** Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer. - * - * apache/spark - */ - @nowarn // supress sparkFunctions.shiftLeft call which is used to maintain Spark 3.1.x backwards compat - def shiftLeft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) - (implicit + column.typed(shimFunctions.shiftRight(column.untyped, numBits)) + + /** + * Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer. + * + * apache/spark + */ + def shiftLeft[A, B, T]( + column: AbstractTypedColumn[T, A], + numBits: Int + )(implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftLeft(column.untyped, numBits)) - - /** Non-Aggregate function: returns the absolute value of a numeric column - * - * apache/spark - */ - def abs[A, B, T](column: AbstractTypedColumn[T, A]) - (implicit - i0: CatalystNumericWithJavaBigDecimal[A, B], - i1: TypedEncoder[B] + column.typed(shimFunctions.shiftLeft(column.untyped, numBits)) + + /** + * Non-Aggregate function: returns the absolute value of a numeric column + * + * apache/spark + */ + def abs[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.abs(column.untyped))(i1) - - /** Non-Aggregate function: Computes the cosine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def cos[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.cos(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the hyperbolic cosine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def cosh[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.cosh(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the signum of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def signum[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + column.typed(sparkFunctions.abs(column.untyped))(i1) + + /** + * Non-Aggregate function: Computes the cosine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def cos[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.cos(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the hyperbolic cosine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def cosh[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.cosh(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the signum of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def signum[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.signum(column.cast[Double].untyped)) - /** Non-Aggregate function: Computes the sine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def sin[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.sin(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the hyperbolic sine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def sinh[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.sinh(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the tangent of the given column. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def tan[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.tan(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the hyperbolic tangent of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def tanh[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.tanh(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns the acos of a numeric column - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def acos[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.acos(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns true if value is contained with in the array in the specified column - * - * apache/spark - */ - def arrayContains[C[_]: CatalystCollection, A, T](column: AbstractTypedColumn[T, C[A]], value: A): column.ThisType[T, Boolean] = + /** + * Non-Aggregate function: Computes the sine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def sin[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.sin(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the hyperbolic sine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def sinh[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.sinh(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the tangent of the given column. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def tan[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.tan(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the hyperbolic tangent of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def tanh[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.tanh(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns the acos of a numeric column + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def acos[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.acos(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns true if value is contained with in the array in the specified column + * + * apache/spark + */ + def arrayContains[C[_]: CatalystCollection, A, T]( + column: AbstractTypedColumn[T, C[A]], + value: A + ): column.ThisType[T, Boolean] = column.typed(sparkFunctions.array_contains(column.untyped, value)) - /** Non-Aggregate function: returns the atan of a numeric column - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def atan[A, T](column: AbstractTypedColumn[T,A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.atan(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns the asin of a numeric column - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def asin[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.asin(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def atan2[A, B, T](l: TypedColumn[T, A], r: TypedColumn[T, B]) - (implicit + /** + * Non-Aggregate function: returns the atan of a numeric column + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.atan(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns the asin of a numeric column + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def asin[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.asin(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan2[A, B, T]( + l: TypedColumn[T, A], + r: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedColumn[T, Double] = - r.typed(sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - - /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def atan2[A, B, T](l: TypedAggregate[T, A], r: TypedAggregate[T, B]) - (implicit + r.typed( + sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped) + ) + + /** + * Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan2[A, B, T]( + l: TypedAggregate[T, A], + r: TypedAggregate[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Double] = - r.typed(sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - - def atan2[B, T](l: Double, r: TypedColumn[T, B]) - (implicit i0: CatalystCast[B, Double]): TypedColumn[T, Double] = - atan2(r.lit(l), r) - - def atan2[A, T](l: TypedColumn[T, A], r: Double) - (implicit i0: CatalystCast[A, Double]): TypedColumn[T, Double] = - atan2(l, l.lit(r)) - - def atan2[B, T](l: Double, r: TypedAggregate[T, B]) - (implicit i0: CatalystCast[B, Double]): TypedAggregate[T, Double] = - atan2(r.lit(l), r) - - def atan2[A, T](l: TypedAggregate[T, A], r: Double) - (implicit i0: CatalystCast[A, Double]): TypedAggregate[T, Double] = - atan2(l, l.lit(r)) - - /** Non-Aggregate function: returns the square root value of a numeric column. - * - * apache/spark - */ - def sqrt[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + r.typed( + sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped) + ) + + def atan2[B, T]( + l: Double, + r: TypedColumn[T, B] + )(implicit + i0: CatalystCast[B, Double] + ): TypedColumn[T, Double] = + atan2(r.lit(l), r) + + def atan2[A, T]( + l: TypedColumn[T, A], + r: Double + )(implicit + i0: CatalystCast[A, Double] + ): TypedColumn[T, Double] = + atan2(l, l.lit(r)) + + def atan2[B, T]( + l: Double, + r: TypedAggregate[T, B] + )(implicit + i0: CatalystCast[B, Double] + ): TypedAggregate[T, Double] = + atan2(r.lit(l), r) + + def atan2[A, T]( + l: TypedAggregate[T, A], + r: Double + )(implicit + i0: CatalystCast[A, Double] + ): TypedAggregate[T, Double] = + atan2(l, l.lit(r)) + + /** + * Non-Aggregate function: returns the square root value of a numeric column. + * + * apache/spark + */ + def sqrt[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.sqrt(column.cast[Double].untyped)) - /** Non-Aggregate function: returns the cubic root value of a numeric column. - * - * apache/spark - */ - def cbrt[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + /** + * Non-Aggregate function: returns the cubic root value of a numeric column. + * + * apache/spark + */ + def cbrt[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.cbrt(column.cast[Double].untyped)) - /** Non-Aggregate function: returns the exponential value of a numeric column. - * - * apache/spark - */ - def exp[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + /** + * Non-Aggregate function: returns the exponential value of a numeric column. + * + * apache/spark + */ + def exp[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.exp(column.cast[Double].untyped)) - /** Non-Aggregate function: Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. - * - * apache/spark - */ - def round[A, B, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. + * + * apache/spark + */ + def round[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.round(column.untyped))(i1) - /** Non-Aggregate function: Round the value of `e` to `scale` decimal places with HALF_UP round mode - * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. - * - * apache/spark - */ - def round[A, B, T](column: AbstractTypedColumn[T, A], scale: Int)( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Round the value of `e` to `scale` decimal places with HALF_UP round mode + * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * + * apache/spark + */ + def round[A, B, T]( + column: AbstractTypedColumn[T, A], + scale: Int + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.round(column.untyped, scale))(i1) - /** Non-Aggregate function: Bankers Rounding - returns the rounded to 0 decimal places value with HALF_EVEN round mode - * of a numeric column. - * - * apache/spark - */ - def bround[A, B, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Bankers Rounding - returns the rounded to 0 decimal places value with HALF_EVEN round mode + * of a numeric column. + * + * apache/spark + */ + def bround[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.bround(column.untyped))(i1) - /** Non-Aggregate function: Bankers Rounding - returns the rounded to `scale` decimal places value with HALF_EVEN round mode - * of a numeric column. If `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. - * - * apache/spark - */ - def bround[A, B, T](column: AbstractTypedColumn[T, A], scale: Int)( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Bankers Rounding - returns the rounded to `scale` decimal places value with HALF_EVEN round mode + * of a numeric column. If `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * + * apache/spark + */ + def bround[A, B, T]( + column: AbstractTypedColumn[T, A], + scale: Int + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.bround(column.untyped, scale))(i1) /** - * Computes the natural logarithm of the given value. - * - * apache/spark - */ - def log[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the natural logarithm of the given value. + * + * apache/spark + */ + def log[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log(column.untyped)) /** - * Returns the first argument-base logarithm of the second argument. - * - * apache/spark - */ - def log[A, T](base: Double, column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the first argument-base logarithm of the second argument. + * + * apache/spark + */ + def log[A, T]( + base: Double, + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log(base, column.untyped)) /** - * Computes the logarithm of the given column in base 2. - * - * apache/spark - */ - def log2[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the logarithm of the given column in base 2. + * + * apache/spark + */ + def log2[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log2(column.untyped)) /** - * Computes the natural logarithm of the given value plus one. - * - * apache/spark - */ - def log1p[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the natural logarithm of the given value plus one. + * + * apache/spark + */ + def log1p[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log1p(column.untyped)) /** - * Computes the logarithm of the given column in base 10. - * - * apache/spark - */ - def log10[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the logarithm of the given column in base 10. + * + * apache/spark + */ + def log10[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log10(column.untyped)) - /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * apache/spark - */ - def hypot[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * apache/spark + */ + def hypot[A, T]( + column: AbstractTypedColumn[T, A], + column2: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.hypot(column.untyped, column2.untyped)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * apache/spark - */ - def hypot[A, T](column: AbstractTypedColumn[T, A], l: Double)( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * apache/spark + */ + def hypot[A, T]( + column: AbstractTypedColumn[T, A], + l: Double + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.hypot(column.untyped, l)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * apache/spark - */ - def hypot[A, T](l: Double, column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * apache/spark + */ + def hypot[A, T]( + l: Double, + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.hypot(l, column.untyped)) /** - * Returns the value of the first argument raised to the power of the second argument. - * - * apache/spark - */ - def pow[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the value of the first argument raised to the power of the second argument. + * + * apache/spark + */ + def pow[A, T]( + column: AbstractTypedColumn[T, A], + column2: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.pow(column.untyped, column2.untyped)) /** - * Returns the value of the first argument raised to the power of the second argument. - * - * apache/spark - */ - def pow[A, T](column: AbstractTypedColumn[T, A], l: Double)( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the value of the first argument raised to the power of the second argument. + * + * apache/spark + */ + def pow[A, T]( + column: AbstractTypedColumn[T, A], + l: Double + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.pow(column.untyped, l)) /** - * Returns the value of the first argument raised to the power of the second argument. - * - * apache/spark - */ - def pow[A, T](l: Double, column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the value of the first argument raised to the power of the second argument. + * + * apache/spark + */ + def pow[A, T]( + l: Double, + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.pow(l, column.untyped)) /** - * Returns the positive value of dividend mod divisor. - * - * apache/spark - */ - def pmod[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])( - implicit i0: TypedEncoder[A] - ): column.ThisType[T, A] = + * Returns the positive value of dividend mod divisor. + * + * apache/spark + */ + def pmod[A, T]( + column: AbstractTypedColumn[T, A], + column2: AbstractTypedColumn[T, A] + )(implicit + i0: TypedEncoder[A] + ): column.ThisType[T, A] = column.typed(sparkFunctions.pmod(column.untyped, column2.untyped)) - - /** Non-Aggregate function: Returns the string representation of the binary value of the given long - * column. For example, bin("12") returns "1100". - * - * apache/spark - */ + /** + * Non-Aggregate function: Returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * apache/spark + */ def bin[T](column: AbstractTypedColumn[T, Long]): column.ThisType[T, String] = column.typed(sparkFunctions.bin(column.untyped)) /** - * Calculates the MD5 digest of a binary column and returns the value - * as a 32 character hex string. - * - * apache/spark - */ - def md5[T, A](column: AbstractTypedColumn[T, A])(implicit i0: TypedEncoder[A]): column.ThisType[T, String] = + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. + * + * apache/spark + */ + def md5[T, A]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: TypedEncoder[A] + ): column.ThisType[T, String] = column.typed(sparkFunctions.md5(column.untyped)) /** - * Computes the factorial of the given value. - * - * apache/spark - */ - def factorial[T](column: AbstractTypedColumn[T, Long])(implicit i0: TypedEncoder[Long]): column.ThisType[T, Long] = + * Computes the factorial of the given value. + * + * apache/spark + */ + def factorial[T]( + column: AbstractTypedColumn[T, Long] + )(implicit + i0: TypedEncoder[Long] + ): column.ThisType[T, Long] = column.typed(sparkFunctions.factorial(column.untyped)) - /** Non-Aggregate function: Computes bitwise NOT. - * - * apache/spark - */ - @nowarn // supress sparkFunctions.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat - def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] = - column.typed(sparkFunctions.bitwiseNOT(column.untyped))(column.uencoder) - - /** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from - * a file - * - * apache/spark - */ + /** + * Non-Aggregate function: Computes bitwise NOT. + * + * apache/spark + */ + def bitwiseNOT[A: CatalystBitwise, T]( + column: AbstractTypedColumn[T, A] + ): column.ThisType[T, A] = + column.typed(shimFunctions.bitwiseNOT(column.untyped))(column.uencoder) + + /** + * Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from + * a file + * + * apache/spark + */ def inputFileName[T](): TypedColumn[T, String] = new TypedColumn[T, String](sparkFunctions.input_file_name()) - /** Non-Aggregate function: generates monotonically increasing id - * - * apache/spark - */ + /** + * Non-Aggregate function: generates monotonically increasing id + * + * apache/spark + */ def monotonicallyIncreasingId[T](): TypedColumn[T, Long] = { new TypedColumn[T, Long](sparkFunctions.monotonically_increasing_id()) } - /** Non-Aggregate function: Evaluates a list of conditions and returns one of multiple - * possible result expressions. If none match, otherwise is returned - * {{{ - * when(ds('boolField), ds('a)) - * .when(ds('otherBoolField), lit(123)) - * .otherwise(ds('b)) - * }}} - * apache/spark - */ - def when[T, A](condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] = + /** + * Non-Aggregate function: Evaluates a list of conditions and returns one of multiple + * possible result expressions. If none match, otherwise is returned + * {{{ + * when(ds('boolField), ds('a)) + * .when(ds('otherBoolField), lit(123)) + * .otherwise(ds('b)) + * }}} + * apache/spark + */ + def when[T, A]( + condition: AbstractTypedColumn[T, Boolean], + value: AbstractTypedColumn[T, A] + ): When[T, A] = new When[T, A](condition, value) class When[T, A] private (untypedC: Column) { - private[functions] def this(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]) = + private[functions] def this( + condition: AbstractTypedColumn[T, Boolean], + value: AbstractTypedColumn[T, A] + ) = this(sparkFunctions.when(condition.untyped, value.untyped)) - def when(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] = + def when( + condition: AbstractTypedColumn[T, Boolean], + value: AbstractTypedColumn[T, A] + ): When[T, A] = new When[T, A](untypedC.when(condition.untyped, value.untyped)) def otherwise(value: AbstractTypedColumn[T, A]): value.ThisType[T, A] = @@ -542,172 +723,219 @@ trait NonAggregateFunctions { // String functions ////////////////////////////////////////////////////////////////////////////////////////////// - - /** Non-Aggregate function: takes the first letter of a string column and returns the ascii int value in a new column - * - * apache/spark - */ + /** + * Non-Aggregate function: takes the first letter of a string column and returns the ascii int value in a new column + * + * apache/spark + */ def ascii[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Int] = column.typed(sparkFunctions.ascii(column.untyped)) - /** Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column. - * This is the reverse of unbase64. - * - * apache/spark - */ + /** + * Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. + * + * apache/spark + */ def base64[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] = column.typed(sparkFunctions.base64(column.untyped)) - /** Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column. - * This is the reverse of base64. - * - * apache/spark - */ + /** + * Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. + * + * apache/spark + */ def unbase64[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Array[Byte]] = column.typed(sparkFunctions.unbase64(column.untyped)) - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ def concat[T](columns: TypedColumn[T, String]*): TypedColumn[T, String] = new TypedColumn(sparkFunctions.concat(columns.map(_.untyped): _*)) - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ def concat[T](columns: TypedAggregate[T, String]*): TypedAggregate[T, String] = new TypedAggregate(sparkFunctions.concat(columns.map(_.untyped): _*)) - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column, - * using the given separator. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ - def concatWs[T](sep: String, columns: TypedAggregate[T, String]*): TypedAggregate[T, String] = - new TypedAggregate(sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*)) - - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column, - * using the given separator. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column, + * using the given separator. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ + def concatWs[T]( + sep: String, + columns: TypedAggregate[T, String]* + ): TypedAggregate[T, String] = + new TypedAggregate( + sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*) + ) + + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column, + * using the given separator. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ def concatWs[T](sep: String, columns: TypedColumn[T, String]*): TypedColumn[T, String] = new TypedColumn(sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*)) - /** Non-Aggregate function: Locates the position of the first occurrence of substring column - * in given string - * - * @note The position is not zero based, but 1 based index. Returns 0 if substr - * could not be found in str. - * - * apache/spark - */ - def instr[T](str: AbstractTypedColumn[T, String], substring: String): str.ThisType[T, Int] = + /** + * Non-Aggregate function: Locates the position of the first occurrence of substring column + * in given string + * + * @note The position is not zero based, but 1 based index. Returns 0 if substr + * could not be found in str. + * + * apache/spark + */ + def instr[T]( + str: AbstractTypedColumn[T, String], + substring: String + ): str.ThisType[T, Int] = str.typed(sparkFunctions.instr(str.untyped, substring)) - /** Non-Aggregate function: Computes the length of a given string. - * - * apache/spark - */ - //TODO: Also for binary + /** + * Non-Aggregate function: Computes the length of a given string. + * + * apache/spark + */ + // TODO: Also for binary def length[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Int] = str.typed(sparkFunctions.length(str.untyped)) - /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. - * - * apache/spark - */ - def levenshtein[T](l: TypedColumn[T, String], r: TypedColumn[T, String]): TypedColumn[T, Int] = + /** + * Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. + * + * apache/spark + */ + def levenshtein[T]( + l: TypedColumn[T, String], + r: TypedColumn[T, String] + ): TypedColumn[T, Int] = l.typed(sparkFunctions.levenshtein(l.untyped, r.untyped)) - /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. - * - * apache/spark - */ - def levenshtein[T](l: TypedAggregate[T, String], r: TypedAggregate[T, String]): TypedAggregate[T, Int] = + /** + * Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. + * + * apache/spark + */ + def levenshtein[T]( + l: TypedAggregate[T, String], + r: TypedAggregate[T, String] + ): TypedAggregate[T, Int] = l.typed(sparkFunctions.levenshtein(l.untyped, r.untyped)) - /** Non-Aggregate function: Converts a string column to lower case. - * - * apache/spark - */ + /** + * Non-Aggregate function: Converts a string column to lower case. + * + * apache/spark + */ def lower[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.lower(str.untyped)) - /** Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer - * than len, the return value is shortened to len characters. - * - * apache/spark - */ - def lpad[T](str: AbstractTypedColumn[T, String], - len: Int, - pad: String): str.ThisType[T, String] = + /** + * Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. + * + * apache/spark + */ + def lpad[T]( + str: AbstractTypedColumn[T, String], + len: Int, + pad: String + ): str.ThisType[T, String] = str.typed(sparkFunctions.lpad(str.untyped, len, pad)) - /** Non-Aggregate function: Trim the spaces from left end for the specified string value. - * - * apache/spark - */ + /** + * Non-Aggregate function: Trim the spaces from left end for the specified string value. + * + * apache/spark + */ def ltrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.ltrim(str.untyped)) - /** Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep. - * - * apache/spark - */ - def regexpReplace[T](str: AbstractTypedColumn[T, String], - pattern: Regex, - replacement: String): str.ThisType[T, String] = - str.typed(sparkFunctions.regexp_replace(str.untyped, pattern.regex, replacement)) - + /** + * Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep. + * + * apache/spark + */ + def regexpReplace[T]( + str: AbstractTypedColumn[T, String], + pattern: Regex, + replacement: String + ): str.ThisType[T, String] = + str.typed( + sparkFunctions.regexp_replace(str.untyped, pattern.regex, replacement) + ) - /** Non-Aggregate function: Reverses the string column and returns it as a new string column. - * - * apache/spark - */ + /** + * Non-Aggregate function: Reverses the string column and returns it as a new string column. + * + * apache/spark + */ def reverse[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.reverse(str.untyped)) - /** Non-Aggregate function: Right-pad the string column with pad to a length of len. - * If the string column is longer than len, the return value is shortened to len characters. - * - * apache/spark - */ - def rpad[T](str: AbstractTypedColumn[T, String], len: Int, pad: String): str.ThisType[T, String] = + /** + * Non-Aggregate function: Right-pad the string column with pad to a length of len. + * If the string column is longer than len, the return value is shortened to len characters. + * + * apache/spark + */ + def rpad[T]( + str: AbstractTypedColumn[T, String], + len: Int, + pad: String + ): str.ThisType[T, String] = str.typed(sparkFunctions.rpad(str.untyped, len, pad)) - /** Non-Aggregate function: Trim the spaces from right end for the specified string value. - * - * apache/spark - */ + /** + * Non-Aggregate function: Trim the spaces from right end for the specified string value. + * + * apache/spark + */ def rtrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.rtrim(str.untyped)) - /** Non-Aggregate function: Substring starts at `pos` and is of length `len` - * - * apache/spark - */ - //TODO: Also for byte array - def substring[T](str: AbstractTypedColumn[T, String], pos: Int, len: Int): str.ThisType[T, String] = + /** + * Non-Aggregate function: Substring starts at `pos` and is of length `len` + * + * apache/spark + */ + // TODO: Also for byte array + def substring[T]( + str: AbstractTypedColumn[T, String], + pos: Int, + len: Int + ): str.ThisType[T, String] = str.typed(sparkFunctions.substring(str.untyped, pos, len)) - /** Non-Aggregate function: Trim the spaces from both ends for the specified string column. - * - * apache/spark - */ + /** + * Non-Aggregate function: Trim the spaces from both ends for the specified string column. + * + * apache/spark + */ def trim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.trim(str.untyped)) - /** Non-Aggregate function: Converts a string column to upper case. - * - * apache/spark - */ + /** + * Non-Aggregate function: Converts a string column to upper case. + * + * apache/spark + */ def upper[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.upper(str.untyped)) @@ -715,93 +943,103 @@ trait NonAggregateFunctions { // DateTime functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string. - * - * Differs from `Column#year` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string. + * + * Differs from `Column#year` by wrapping it's result into an `Option`. + * + * apache/spark + */ def year[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.year(str.untyped)) - /** Non-Aggregate function: Extracts the quarter as an integer from a given date/timestamp/string. - * - * Differs from `Column#quarter` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the quarter as an integer from a given date/timestamp/string. + * + * Differs from `Column#quarter` by wrapping it's result into an `Option`. + * + * apache/spark + */ def quarter[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.quarter(str.untyped)) - /** Non-Aggregate function Extracts the month as an integer from a given date/timestamp/string. - * - * Differs from `Column#month` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function Extracts the month as an integer from a given date/timestamp/string. + * + * Differs from `Column#month` by wrapping it's result into an `Option`. + * + * apache/spark + */ def month[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.month(str.untyped)) - /** Non-Aggregate function: Extracts the day of the week as an integer from a given date/timestamp/string. - * - * Differs from `Column#dayofweek` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the day of the week as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofweek` by wrapping it's result into an `Option`. + * + * apache/spark + */ def dayofweek[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.dayofweek(str.untyped)) - /** Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string. - * - * Differs from `Column#dayofmonth` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofmonth` by wrapping it's result into an `Option`. + * + * apache/spark + */ def dayofmonth[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.dayofmonth(str.untyped)) - /** Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string. - * - * Differs from `Column#dayofyear` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofyear` by wrapping it's result into an `Option`. + * + * apache/spark + */ def dayofyear[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.dayofyear(str.untyped)) - /** Non-Aggregate function: Extracts the hours as an integer from a given date/timestamp/string. - * - * Differs from `Column#hour` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the hours as an integer from a given date/timestamp/string. + * + * Differs from `Column#hour` by wrapping it's result into an `Option`. + * + * apache/spark + */ def hour[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.hour(str.untyped)) - /** Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string. - * - * Differs from `Column#minute` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string. + * + * Differs from `Column#minute` by wrapping it's result into an `Option`. + * + * apache/spark + */ def minute[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.minute(str.untyped)) - /** Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string. - * - * Differs from `Column#second` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string. + * + * Differs from `Column#second` by wrapping it's result into an `Option`. + * + * apache/spark + */ def second[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.second(str.untyped)) - /** Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string. - * - * Differs from `Column#weekofyear` by wrapping it's result into an `Option`. - * - * apache/spark - */ + /** + * Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string. + * + * Differs from `Column#weekofyear` by wrapping it's result into an `Option`. + * + * apache/spark + */ def weekofyear[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.weekofyear(str.untyped)) } diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index 93ba7f118..c34e8561e 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -2,132 +2,179 @@ package frameless package functions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + LeafExpression, + NonSQLExpression +} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.DataType import shapeless.syntax.std.tuple._ -/** Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ +/** + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ trait Udf { - /** Defines a user-defined function of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A, R: TypedEncoder](f: A => R): - TypedColumn[T, A] => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A, R: TypedEncoder](f: A => R): TypedColumn[T, A] => TypedColumn[T, R] = { u => - val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R]) + val scalaUdf = FramelessUdf( + f, + List(u), + TypedEncoder[R], + s => f(s.head.asInstanceOf[A]) + ) new TypedColumn[T, R](scalaUdf) } - /** Defines a user-defined function of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, R: TypedEncoder](f: (A1,A2) => R): - (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, R: TypedEncoder](f: (A1, A2) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => f(s.head.asInstanceOf[A1], s(1).asInstanceOf[A2]) + ) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1,A2,A3) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1, A2, A3) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3] + ) + ) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1,A2,A3,A4) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3], + s(3).asInstanceOf[A4] + ) + ) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1,A2,A3,A4,A5) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf( + f, + us.toList[UntypedExpression[T]], + TypedEncoder[R], + s => + f( + s.head.asInstanceOf[A1], + s(1).asInstanceOf[A2], + s(2).asInstanceOf[A3], + s(3).asInstanceOf[A4], + s(4).asInstanceOf[A5] + ) + ) new TypedColumn[T, R](scalaUdf) - } + } } /** - * NB: Implementation detail, isn't intended to be directly used. - * - * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. - */ + * NB: Implementation detail, isn't intended to be directly used. + * + * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. + */ +// Possibly add UserDefinedExpression trait to stop the functions being registered and used as aggregates case class FramelessUdf[T, R]( - function: AnyRef, - encoders: Seq[TypedEncoder[_]], - children: Seq[Expression], - rencoder: TypedEncoder[R] -) extends Expression with NonSQLExpression { + function: AnyRef, + encoders: Seq[TypedEncoder[_]], + children: Seq[Expression], + rencoder: TypedEncoder[R], + evalFunction: Seq[Any] => Any) + extends Expression + with NonSQLExpression { override def nullable: Boolean = rencoder.nullable + override def toString: String = s"FramelessUdf(${children.mkString(", ")})" - lazy val evalCode = { - val ctx = new CodegenContext() - val eval = genCode(ctx) + lazy val typedEnc = + TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]] - val codeBody = s""" - public scala.Function1 generate(Object[] references) { - return new FramelessUdfEvalImpl(references); - } + lazy val isSerializedAsStructForTopLevel = + typedEnc.isSerializedAsStructForTopLevel - class FramelessUdfEvalImpl extends scala.runtime.AbstractFunction1 { - private final Object[] references; - ${ctx.declareMutableStates()} - ${ctx.declareAddedFunctions()} - - public FramelessUdfEvalImpl(Object[] references) { - this.references = references; - ${ctx.initMutableStates()} - } - - public java.lang.Object apply(java.lang.Object z) { - InternalRow ${ctx.INPUT_ROW} = (InternalRow) z; - ${eval.code} - return ${eval.isNull} ? ((Object)null) : ((Object)${eval.value}); - } - } - """ - - val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + def eval(input: InternalRow): Any = { + val jvmTypes = children.map(_.eval(input)) - val (clazz, _) = CodeGenerator.compile(code) - val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] + val returnJvm = evalFunction(jvmTypes).asInstanceOf[R] - codegen - } + val returnCatalyst = typedEnc.createSerializer().apply(returnJvm) + val retval = + if (returnCatalyst == null) + null + else if (isSerializedAsStructForTopLevel) + returnCatalyst + else + returnCatalyst.get(0, dataType) - def eval(input: InternalRow): Any = { - evalCode(input) + retval } def dataType: DataType = rencoder.catalystRepr @@ -139,29 +186,45 @@ case class FramelessUdf[T, R]( val framelessUdfClassName = classOf[FramelessUdf[_, _]].getName val funcClassName = s"scala.Function${children.size}" val funcExpressionIdx = ctx.references.size - 1 - val funcTerm = ctx.addMutableState(funcClassName, ctx.freshName("udf"), - v => s"$v = ($funcClassName)((($framelessUdfClassName)references" + - s"[$funcExpressionIdx]).function());") - - val (argsCode, funcArguments) = encoders.zip(children).map { - case (encoder, child) => - val eval = child.genCode(ctx) - val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) - val argTerm = ctx.freshName("arg") - val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + val funcTerm = ctx.addMutableState( + funcClassName, + ctx.freshName("udf"), + v => + s"$v = ($funcClassName)((($framelessUdfClassName)references" + + s"[$funcExpressionIdx]).function());" + ) - (convert, argTerm) - }.unzip + val (argsCode, funcArguments) = encoders + .zip(children) + .map { + case (encoder, child) => + val eval = child.genCode(ctx) + val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) + val argTerm = ctx.freshName("arg") + val convert = + s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + + (convert, argTerm) + } + .unzip val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr) - val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal")) - val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull")) + val internalTerm = + ctx.addMutableState(internalTpe, ctx.freshName("internal")) + val internalNullTerm = + ctx.addMutableState("boolean", ctx.freshName("internalNull")) // CTw - can't inject the term, may have to duplicate old code for parity - val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true) + val internalExpr = Spark2_4_LambdaVariable( + internalTerm, + internalNullTerm, + rencoder.jvmRepr, + true + ) val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx) - ev.copy(code = code""" + ev.copy( + code = code""" ${argsCode.mkString("\n")} $internalTerm = @@ -175,21 +238,28 @@ case class FramelessUdf[T, R]( ) } - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(children = newChildren) } case class Spark2_4_LambdaVariable( - value: String, - isNull: String, - dataType: DataType, - nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) + extends LeafExpression + with NonSQLExpression { - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = + InternalRow.getAccessor(dataType) // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { - assert(input.numFields == 1, - "The input row of interpreted LambdaVariable should have only 1 field.") + assert( + input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field." + ) if (nullable && input.isNullAt(0)) { null } else { @@ -197,7 +267,10 @@ case class Spark2_4_LambdaVariable( } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = { val isNullValue = if (nullable) { JavaCode.isNullVariable(isNull) } else { @@ -208,15 +281,18 @@ case class Spark2_4_LambdaVariable( } object FramelessUdf { + // Spark needs case class with `children` field to mutate it def apply[T, R]( - function: AnyRef, - cols: Seq[UntypedExpression[T]], - rencoder: TypedEncoder[R] - ): FramelessUdf[T, R] = FramelessUdf( + function: AnyRef, + cols: Seq[UntypedExpression[T]], + rencoder: TypedEncoder[R], + evalFunction: Seq[Any] => Any + ): FramelessUdf[T, R] = FramelessUdf( function = function, encoders = cols.map(_.uencoder).toList, children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList, - rencoder = rencoder + rencoder = rencoder, + evalFunction = evalFunction ) } diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index 1a57101e0..543925e00 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -1,13 +1,12 @@ package frameless +import frameless.{ reflection => ScalaReflection } import scala.reflect.ClassTag import shapeless._ import shapeless.labelled.FieldType import shapeless.ops.hlist.IsHCons import shapeless.ops.record.{ Keys, Values } - -import org.apache.spark.sql.{ reflection => ScalaReflection } import org.apache.spark.sql.catalyst.expressions.Literal package object functions extends Udf with UnaryFunctions { diff --git a/dataset/src/main/scala/frameless/ops/GroupByOps.scala b/dataset/src/main/scala/frameless/ops/GroupByOps.scala index 3feeaca59..d63870c9b 100644 --- a/dataset/src/main/scala/frameless/ops/GroupByOps.scala +++ b/dataset/src/main/scala/frameless/ops/GroupByOps.scala @@ -3,36 +3,51 @@ package ops import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.{Column, Dataset, FramelessInternals, RelationalGroupedDataset} +import org.apache.spark.sql.{ Column, Dataset, RelationalGroupedDataset } import shapeless._ -import shapeless.ops.hlist.{Length, Mapped, Prepend, ToList, ToTraversable, Tupler} +import shapeless.ops.hlist.{ + Length, + Mapped, + Prepend, + ToList, + ToTraversable, + Tupler +} +import com.sparkutils.shim.expressions.{ MapGroups4 => MapGroups } +import frameless.FramelessInternals -class GroupedByManyOps[T, TK <: HList, K <: HList, KT] - (self: TypedDataset[T], groupedBy: TK) - (implicit +class GroupedByManyOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i3: Tupler.Aux[K, KT] - ) extends AggregatingOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.groupBy(cols: _*)) { + i3: Tupler.Aux[K, KT]) + extends AggregatingOps[T, TK, K, KT]( + self, + groupedBy, + (dataset, cols) => dataset.groupBy(cols: _*) + ) { + object agg extends ProductArgs { - def applyProduct[TC <: HList, C <: HList, Out0 <: HList, Out1] - (columns: TC) - (implicit + + def applyProduct[TC <: HList, C <: HList, Out0 <: HList, Out1]( + columns: TC + )(implicit i3: AggregateTypes.Aux[T, TC, C], i4: Prepend.Aux[K, C, Out0], i5: Tupler.Aux[Out0, Out1], i6: TypedEncoder[Out1], i7: ToTraversable.Aux[TC, List, UntypedExpression[T]] ): TypedDataset[Out1] = { - aggregate[TC, Out1](columns) - } + aggregate[TC, Out1](columns) + } } } class GroupedBy1Ops[K1, V]( - self: TypedDataset[V], - g1: TypedColumn[V, K1] -) { + self: TypedDataset[V], + g1: TypedColumn[V, K1]) { private def underlying = new GroupedByManyOps(self, g1 :: HNil) private implicit def eg1 = g1.uencoder @@ -41,49 +56,75 @@ class GroupedBy1Ops[K1, V]( underlying.agg(c1) } - def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(K1, U1, U2)] = { + def agg[U1, U2]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2] + ): TypedDataset[(K1, U1, U2)] = { implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder underlying.agg(c1, c2) } - def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(K1, U1, U2, U3)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder + def agg[U1, U2, U3]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3] + ): TypedDataset[(K1, U1, U2, U3)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder underlying.agg(c1, c2, c3) } - def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(K1, U1, U2, U3, U4)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder + def agg[U1, U2, U3, U4]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4] + ): TypedDataset[(K1, U1, U2, U3, U4)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder underlying.agg(c1, c2, c3, c4) } - def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(K1, U1, U2, U3, U4, U5)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder + def agg[U1, U2, U3, U4, U5]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4], + c5: TypedAggregate[V, U5] + ): TypedDataset[(K1, U1, U2, U3, U4, U5)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; + implicit val e5 = c5.uencoder underlying.agg(c1, c2, c3, c4, c5) } - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { + def mapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => U): TypedDataset[U] = { underlying.deserialized.mapGroups(AggregatingOps.tuple1(f)) } - def flatMapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = { + def flatMapGroups[U: TypedEncoder]( + f: (K1, Iterator[V]) => TraversableOnce[U] + ): TypedDataset[U] = { underlying.deserialized.flatMapGroups(AggregatingOps.tuple1(f)) } } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): PivotNotValues[V, TypedColumn[V,K1] :: HNil, P] = + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[V, P] + ): PivotNotValues[V, TypedColumn[V, K1] :: HNil, P] = PivotNotValues(self, g1 :: HNil, pivotColumn) } - class GroupedBy2Ops[K1, K2, V]( - self: TypedDataset[V], - g1: TypedColumn[V, K1], - g2: TypedColumn[V, K2] -) { + self: TypedDataset[V], + g1: TypedColumn[V, K1], + g2: TypedColumn[V, K2]) { private def underlying = new GroupedByManyOps(self, g1 :: g2 :: HNil) private implicit def eg1 = g1.uencoder private implicit def eg2 = g2.uencoder @@ -93,57 +134,88 @@ class GroupedBy2Ops[K1, K2, V]( underlying.agg(c1) } - def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(K1, K2, U1, U2)] = { + def agg[U1, U2]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2] + ): TypedDataset[(K1, K2, U1, U2)] = { implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder underlying.agg(c1, c2) } - def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(K1, K2, U1, U2, U3)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder + def agg[U1, U2, U3]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3] + ): TypedDataset[(K1, K2, U1, U2, U3)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder underlying.agg(c1, c2, c3) } - def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(K1, K2, U1, U2, U3, U4)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder - underlying.agg(c1 , c2 , c3 , c4) + def agg[U1, U2, U3, U4]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4] + ): TypedDataset[(K1, K2, U1, U2, U3, U4)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder + underlying.agg(c1, c2, c3, c4) } - def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(K1, K2, U1, U2, U3, U4, U5)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder + def agg[U1, U2, U3, U4, U5]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4], + c5: TypedAggregate[V, U5] + ): TypedDataset[(K1, K2, U1, U2, U3, U4, U5)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; + implicit val e5 = c5.uencoder underlying.agg(c1, c2, c3, c4, c5) } - - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { - def mapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => U): TypedDataset[U] = { + + def mapGroups[U: TypedEncoder]( + f: ((K1, K2), Iterator[V]) => U + ): TypedDataset[U] = { underlying.deserialized.mapGroups(f) } - def flatMapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = { + def flatMapGroups[U: TypedEncoder]( + f: ((K1, K2), Iterator[V]) => TraversableOnce[U] + ): TypedDataset[U] = { underlying.deserialized.flatMapGroups(f) } } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): - PivotNotValues[V, TypedColumn[V,K1] :: TypedColumn[V, K2] :: HNil, P] = - PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn) + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[V, P] + ): PivotNotValues[V, TypedColumn[V, K1] :: TypedColumn[V, K2] :: HNil, P] = + PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn) } -private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] - (self: TypedDataset[T], groupedBy: TK, groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset) - (implicit +private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK, + groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i2: Tupler.Aux[K, KT] - ) { - def aggregate[TC <: HList, Out1](columns: TC) - (implicit - i7: TypedEncoder[Out1], - i8: ToTraversable.Aux[TC, List, UntypedExpression[T]] - ): TypedDataset[Out1] = { + i2: Tupler.Aux[K, KT]) { + + def aggregate[TC <: HList, Out1]( + columns: TC + )(implicit + i7: TypedEncoder[Out1], + i8: ToTraversable.Aux[TC, List, UntypedExpression[T]] + ): TypedDataset[Out1] = { def expr(c: UntypedExpression[T]): Column = new Column(c.expr) val groupByExprs = groupedBy.toList[UntypedExpression[T]].map(expr) @@ -159,25 +231,32 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] TypedDataset.create[Out1](aggregated) } - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { + def mapGroups[U: TypedEncoder]( - f: (KT, Iterator[T]) => U - )(implicit e: TypedEncoder[KT]): TypedDataset[U] = { + f: (KT, Iterator[T]) => U + )(implicit + e: TypedEncoder[KT] + ): TypedDataset[U] = { val func = (key: KT, it: Iterator[T]) => Iterator(f(key, it)) flatMapGroups(func) } def flatMapGroups[U: TypedEncoder]( - f: (KT, Iterator[T]) => TraversableOnce[U] - )(implicit e: TypedEncoder[KT]): TypedDataset[U] = { + f: (KT, Iterator[T]) => TraversableOnce[U] + )(implicit + e: TypedEncoder[KT] + ): TypedDataset[U] = { implicit val tendcoder = self.encoder val cols = groupedBy.toList[UntypedExpression[T]] val logicalPlan = FramelessInternals.logicalPlan(self.dataset) - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) + val withKeyColumns = + logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = FramelessInternals.executePlan(self.dataset, withKey) val keyAttributes = executed.analyzed.output.takeRight(cols.size) @@ -188,7 +267,11 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] keyAttributes, dataAttributes, executed.analyzed - )(TypedExpressionEncoder[KT], TypedExpressionEncoder[T], TypedExpressionEncoder[U]) + )( + TypedExpressionEncoder[KT], + TypedExpressionEncoder[T], + TypedExpressionEncoder[U] + ) val groupedAndFlatMapped = FramelessInternals.mkDataset( self.dataset.sqlContext, @@ -201,66 +284,95 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] } private def retainGroupColumns: Boolean = { - self.dataset.sqlContext.getConf("spark.sql.retainGroupColumns", "true").toBoolean + self.dataset.sqlContext + .getConf("spark.sql.retainGroupColumns", "true") + .toBoolean } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[T, P]): PivotNotValues[T, TK, P] = + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[T, P] + ): PivotNotValues[T, TK, P] = PivotNotValues(self, groupedBy, pivotColumn) } private[ops] object AggregatingOps { + /** Utility function to help Spark with serialization of closures */ def tuple1[K1, V, U](f: (K1, Iterator[V]) => U): (Tuple1[K1], Iterator[V]) => U = { (x: Tuple1[K1], it: Iterator[V]) => f(x._1, it) } } -/** Represents a typed Pivot operation. - */ +/** + * Represents a typed Pivot operation. + */ final case class Pivot[T, GroupedColumns <: HList, PivotType, Values <: HList]( - ds: TypedDataset[T], - groupedBy: GroupedColumns, - pivotedBy: TypedColumn[T, PivotType], - values: Values -) { + ds: TypedDataset[T], + groupedBy: GroupedColumns, + pivotedBy: TypedColumn[T, PivotType], + values: Values) { object agg extends ProductArgs { - def applyProduct[AggrColumns <: HList, AggrColumnTypes <: HList, GroupedColumnTypes <: HList, NumValues <: Nat, TypesForPivotedValues <: HList, TypesForPivotedValuesOpt <: HList, OutAsHList <: HList, Out] - (aggrColumns: AggrColumns) - (implicit + + def applyProduct[ + AggrColumns <: HList, + AggrColumnTypes <: HList, + GroupedColumnTypes <: HList, + NumValues <: Nat, + TypesForPivotedValues <: HList, + TypesForPivotedValuesOpt <: HList, + OutAsHList <: HList, + Out + ](aggrColumns: AggrColumns + )(implicit i0: AggregateTypes.Aux[T, AggrColumns, AggrColumnTypes], i1: ColumnTypes.Aux[T, GroupedColumns, GroupedColumnTypes], i2: Length.Aux[Values, NumValues], i3: Repeat.Aux[AggrColumnTypes, NumValues, TypesForPivotedValues], i4: Mapped.Aux[TypesForPivotedValues, Option, TypesForPivotedValuesOpt], - i5: Prepend.Aux[GroupedColumnTypes, TypesForPivotedValuesOpt, OutAsHList], + i5: Prepend.Aux[ + GroupedColumnTypes, + TypesForPivotedValuesOpt, + OutAsHList + ], i6: Tupler.Aux[OutAsHList, Out], i7: TypedEncoder[Out] ): TypedDataset[Out] = { - def mapAny[X](h: HList)(f: Any => X): List[X] = - h match { - case HNil => Nil - case x :: xs => f(x) :: mapAny(xs)(f) - } - - val aggCols: Seq[Column] = mapAny(aggrColumns)(x => new Column(x.asInstanceOf[TypedAggregate[_,_]].expr)) - val tmp = ds.dataset.toDF() - .groupBy(mapAny(groupedBy)(_.asInstanceOf[TypedColumn[_, _]].untyped): _*) - .pivot(pivotedBy.untyped.toString, mapAny(values)(identity)) - .agg(aggCols.head, aggCols.tail:_*) - .as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create(tmp) - } + def mapAny[X](h: HList)(f: Any => X): List[X] = + h match { + case HNil => Nil + case x :: xs => f(x) :: mapAny(xs)(f) + } + + val aggCols: Seq[Column] = mapAny(aggrColumns)(x => + new Column(x.asInstanceOf[TypedAggregate[_, _]].expr) + ) + val tmp = ds.dataset + .toDF() + .groupBy( + mapAny(groupedBy)(_.asInstanceOf[TypedColumn[_, _]].untyped): _* + ) + .pivot(pivotedBy.untyped.toString, mapAny(values)(identity)) + .agg(aggCols.head, aggCols.tail: _*) + .as[Out](TypedExpressionEncoder[Out]) + TypedDataset.create(tmp) + } } } final case class PivotNotValues[T, GroupedColumns <: HList, PivotType]( - ds: TypedDataset[T], - groupedBy: GroupedColumns, - pivotedBy: TypedColumn[T, PivotType] -) extends ProductArgs { - - def onProduct[Values <: HList](values: Values)( - implicit validValues: ToList[Values, PivotType] // validValues: FilterNot.Aux[Values, PivotType, HNil] // did not work - ): Pivot[T, GroupedColumns, PivotType, Values] = Pivot(ds, groupedBy, pivotedBy, values) + ds: TypedDataset[T], + groupedBy: GroupedColumns, + pivotedBy: TypedColumn[T, PivotType]) + extends ProductArgs { + + def onProduct[Values <: HList]( + values: Values + )(implicit + validValues: ToList[ + Values, + PivotType + ] // validValues: FilterNot.Aux[Values, PivotType, HNil] // did not work + ): Pivot[T, GroupedColumns, PivotType, Values] = + Pivot(ds, groupedBy, pivotedBy, values) } diff --git a/dataset/src/main/scala/org/apache/spark/sql/reflection/package.scala b/dataset/src/main/scala/frameless/reflection/package.scala similarity index 53% rename from dataset/src/main/scala/org/apache/spark/sql/reflection/package.scala rename to dataset/src/main/scala/frameless/reflection/package.scala index 07090a8db..5a38baa71 100644 --- a/dataset/src/main/scala/org/apache/spark/sql/reflection/package.scala +++ b/dataset/src/main/scala/frameless/reflection/package.scala @@ -1,26 +1,6 @@ -package org.apache.spark.sql +package frameless -import org.apache.spark.sql.catalyst.ScalaReflection.{ - cleanUpReflectionObjects, - getClassFromType, - localTypeOf -} -import org.apache.spark.sql.types.{ - BinaryType, - BooleanType, - ByteType, - CalendarIntervalType, - DataType, - Decimal, - DecimalType, - DoubleType, - FloatType, - IntegerType, - LongType, - NullType, - ObjectType, - ShortType -} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval /** @@ -45,6 +25,61 @@ package object reflection { import universe._ + // Since we are creating a runtime mirror using the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + def mirror: universe.Mirror = { + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + } + + /** + * Any codes calling `scala.reflect.api.Types.TypeApi.<:<` should be wrapped by this method to + * clean up the Scala reflection garbage automatically. Otherwise, it will leak some objects to + * `scala.reflect.runtime.JavaUniverse.undoLog`. + * + * @see https://github.com/scala/bug/issues/8302 + */ + def cleanUpReflectionObjects[T](func: => T): T = { + universe.asInstanceOf[scala.reflect.runtime.JavaUniverse].undoLog.undo(func) + } + + /** + * Return the Scala Type for `T` in the current classloader mirror. + * + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 + */ + def localTypeOf[T: TypeTag]: `Type` = { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.dealias + } + + /* + * Retrieves the runtime class corresponding to the provided type. + */ + def getClassFromType(tpe: Type): Class[_] = + mirror.runtimeClass(erasure(tpe).dealias.typeSymbol.asClass) + + private def erasure(tpe: Type): Type = { + // For user-defined AnyVal classes, we should not erasure it. Otherwise, it will + // resolve to underlying type which wrapped by this class, e.g erasure + // `case class Foo(i: Int) extends AnyVal` will return type `Int` instead of `Foo`. + // But, for other types, we do need to erasure it. For example, we need to erasure + // `scala.Any` to `java.lang.Object` in order to load it from Java ClassLoader. + // Please see SPARK-17368 & SPARK-31190 for more details. + if ( + isSubtype(tpe, localTypeOf[AnyVal]) && !tpe.toString.startsWith("scala") + ) { + tpe + } else { + tpe.erasure + } + } + /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including @@ -62,7 +97,7 @@ package object reflection { * * See https://github.com/scala/bug/issues/10766 */ - private[sql] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { + private def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { ScalaSubtypeLock.synchronized { tpe1 <:< tpe2 } diff --git a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala deleted file mode 100644 index 5459230d4..000000000 --- a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala +++ /dev/null @@ -1,73 +0,0 @@ -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct} -import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ObjectType -import scala.reflect.ClassTag - -object FramelessInternals { - def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass) - - def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = { - ds.toDF().queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""") - } - } - - def expr(column: Column): Expression = column.expr - - def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan - - def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution = - ds.sparkSession.sessionState.executePlan(plan) - - def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = { - val joined = executePlan(ds, plan) - val leftOutput = joined.analyzed.output.take(leftPlan.output.length) - val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length) - - Project(List( - Alias(CreateStruct(leftOutput), "_1")(), - Alias(CreateStruct(rightOutput), "_2")() - ), joined.analyzed) - } - - def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] = - new Dataset(sqlContext, plan, encoder) - - def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = - Dataset.ofRows(sparkSession, logicalPlan) - - // because org.apache.spark.sql.types.UserDefinedType is private[spark] - type UserDefinedType[A >: Null] = org.apache.spark.sql.types.UserDefinedType[A] - - // below only tested in SelfJoinTests.colLeft and colRight are equivalent to col outside of joins - // - via files (codegen) forces doGenCode eval. - /** Expression to tag columns from the left hand side of join expression. */ - case class DisambiguateLeft[T](tagged: Expression) extends Expression with NonSQLExpression { - def eval(input: InternalRow): Any = tagged.eval(input) - def nullable: Boolean = false - def children: Seq[Expression] = tagged :: Nil - def dataType: DataType = tagged.dataType - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx) - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) - } - - /** Expression to tag columns from the right hand side of join expression. */ - case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression { - def eval(input: InternalRow): Any = tagged.eval(input) - def nullable: Boolean = false - def children: Seq[Expression] = tagged :: Nil - def dataType: DataType = tagged.dataType - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx) - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) - } -} diff --git a/dataset/src/main/spark-3.4+/frameless/MapGroups.scala b/dataset/src/main/spark-3.4+/frameless/MapGroups.scala deleted file mode 100644 index 6856acba4..000000000 --- a/dataset/src/main/spark-3.4+/frameless/MapGroups.scala +++ /dev/null @@ -1,21 +0,0 @@ -package frameless - -import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups} - -object MapGroups { - def apply[K: Encoder, T: Encoder, U: Encoder]( - func: (K, Iterator[T]) => TraversableOnce[U], - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - child: LogicalPlan - ): LogicalPlan = - SMapGroups( - func, - groupingAttributes, - dataAttributes, - Seq(), // #698 - no order given - child - ) -} diff --git a/dataset/src/main/spark-3/frameless/MapGroups.scala b/dataset/src/main/spark-3/frameless/MapGroups.scala deleted file mode 100644 index 3fd27f333..000000000 --- a/dataset/src/main/spark-3/frameless/MapGroups.scala +++ /dev/null @@ -1,14 +0,0 @@ -package frameless - -import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups} - -object MapGroups { - def apply[K: Encoder, T: Encoder, U: Encoder]( - func: (K, Iterator[T]) => TraversableOnce[U], - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - child: LogicalPlan - ): LogicalPlan = SMapGroups(func, groupingAttributes, dataAttributes, child) -} diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala index 4d9b5547d..6bd2f88d6 100644 --- a/dataset/src/test/scala/frameless/CreateTests.scala +++ b/dataset/src/test/scala/frameless/CreateTests.scala @@ -1,6 +1,6 @@ package frameless -import org.scalacheck.{Arbitrary, Prop} +import org.scalacheck.{ Arbitrary, Prop } import org.scalacheck.Prop._ import scala.reflect.ClassTag @@ -13,29 +13,40 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("creation using X4 derived DataFrames") { def prop[ - A: TypedEncoder, - B: TypedEncoder, - C: TypedEncoder, - D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = { + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]] + ): Prop = { val ds = TypedDataset.create(data) - TypedDataset.createUnsafe[X4[A, B, C, D]](ds.toDF()).collect().run() ?= data + TypedDataset + .createUnsafe[X4[A, B, C, D]](ds.toDF()) + .collect() + .run() ?= data } check(forAll(prop[Int, Char, X2[Option[Country], Country], Int] _)) check(forAll(prop[X2[Int, Int], Int, Boolean, Vector[Food]] _)) check(forAll(prop[String, Food, X3[Food, Country, Boolean], Int] _)) check(forAll(prop[String, Food, X3U[Food, Country, Boolean], Int] _)) - check(forAll(prop[ - Option[Vector[Food]], - Vector[Vector[X2[Vector[(Person, X1[Char])], Country]]], - X3[Food, Country, String], - Vector[(Food, Country)]] _)) + check( + forAll( + prop[Option[Vector[Food]], Vector[ + Vector[X2[Vector[(Person, X1[Char])], Country]] + ], X3[Food, Country, String], Vector[(Food, Country)]] _ + ) + ) } test("array fields") { def prop[T: Arbitrary: TypedEncoder: ClassTag] = forAll { - (d1: Array[T], d2: Array[Option[T]], d3: Array[X1[T]], d4: Array[X1[Option[T]]], - d5: X1[Array[T]]) => + (d1: Array[T], + d2: Array[Option[T]], + d3: Array[X1[T]], + d4: Array[X1[Option[T]]], + d5: X1[Array[T]] + ) => TypedDataset.create(Seq(d1)).collect().run().head.sameElements(d1) && TypedDataset.create(Seq(d2)).collect().run().head.sameElements(d2) && TypedDataset.create(Seq(d3)).collect().run().head.sameElements(d3) && @@ -55,13 +66,17 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("vector fields") { def prop[T: Arbitrary: TypedEncoder] = forAll { - (d1: Vector[T], d2: Vector[Option[T]], d3: Vector[X1[T]], d4: Vector[X1[Option[T]]], - d5: X1[Vector[T]]) => - (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && - (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && - (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && - (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && - (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) + (d1: Vector[T], + d2: Vector[Option[T]], + d3: Vector[X1[T]], + d4: Vector[X1[Option[T]]], + d5: X1[Vector[T]] + ) => + (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && + (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && + (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && + (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && + (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) } check(prop[Boolean]) @@ -77,9 +92,13 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("list fields") { def prop[T: Arbitrary: TypedEncoder] = forAll { - (d1: List[T], d2: List[Option[T]], d3: List[X1[T]], d4: List[X1[Option[T]]], - d5: X1[List[T]]) => - (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && + (d1: List[T], + d2: List[Option[T]], + d3: List[X1[T]], + d4: List[X1[Option[T]]], + d5: X1[List[T]] + ) => + (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && @@ -98,16 +117,23 @@ class CreateTests extends TypedDatasetSuite with Matchers { } test("Map fields (scala.Predef.Map / scala.collection.immutable.Map)") { - def prop[A: Arbitrary: NotCatalystNullable: TypedEncoder, B: Arbitrary: NotCatalystNullable: TypedEncoder] = forAll { - (d1: Map[A, B], d2: Map[B, A], d3: Map[A, Option[B]], - d4: Map[A, X1[B]], d5: Map[X1[A], B], d6: Map[X1[A], X1[B]]) => - - (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && - (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && - (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && - (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && - (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) && - (TypedDataset.create(Seq(d6)).collect().run().head ?= d6) + def prop[ + A: Arbitrary: NotCatalystNullable: TypedEncoder, + B: Arbitrary: NotCatalystNullable: TypedEncoder + ] = forAll { + (d1: Map[A, B], + d2: Map[B, A], + d3: Map[A, Option[B]], + d4: Map[A, X1[B]], + d5: Map[X1[A], B], + d6: Map[X1[A], X1[B]] + ) => + (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && + (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && + (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && + (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && + (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) && + (TypedDataset.create(Seq(d6)).collect().run().head ?= d6) } check(prop[String, String]) @@ -123,14 +149,17 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("maps with Option keys should not resolve the TypedEncoder") { val data: Seq[Map[Option[Int], Int]] = Seq(Map(Some(5) -> 5)) - illTyped("TypedDataset.create(data)", ".*could not find implicit value for parameter encoder.*") + illTyped( + "TypedDataset.create(data)", + ".*could not find implicit value for parameter encoder.*" + ) } test("not aligned columns should throw an exception") { - val v = Vector(X2(1,2)) + val v = Vector(X2(1, 2)) val df = TypedDataset.create(v).dataset.toDF() - a [IllegalStateException] should be thrownBy { + a[IllegalStateException] should be thrownBy { TypedDataset.createUnsafe[X1[Int]](df).show().run() } } @@ -138,15 +167,25 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("dataset with different column order") { // e.g. when loading data from partitioned dataset // the partition columns get appended to the end of the underlying relation - def prop[A: Arbitrary: TypedEncoder, B: Arbitrary: TypedEncoder] = forAll { - (a1: A, b1: B) => { - val ds = TypedDataset.create( - Vector((b1, a1)) - ).dataset.toDF("b", "a").as[X2[A, B]](TypedExpressionEncoder[X2[A, B]]) - TypedDataset.create(ds).collect().run().head ?= X2(a1, b1) - + def prop[A: Arbitrary: TypedEncoder, B: Arbitrary: TypedEncoder] = + forAll { (a1: A, b1: B) => + { + // this code separates the different column order from the 'as'ing, requires createUnsafe to actually work. + // using create directly falsely assumes no type checking will take place on the actually incorrect encoders (DBR 14.3 does this) + val df = TypedDataset + .create( + Vector((b1, a1)) + ) + .dataset + .toDF("b", "a") + TypedDataset + .createUnsafe(df)(TypedEncoder[X2[A, B]]) + .collect() + .run() + .head ?= X2(a1, b1) + + } } - } check(prop[X1[Double], X1[X1[SQLDate]]]) check(prop[String, Int]) } diff --git a/dataset/src/test/scala/frameless/EncoderTests.scala b/dataset/src/test/scala/frameless/EncoderTests.scala index 4ebf5d93f..ab1f35811 100644 --- a/dataset/src/test/scala/frameless/EncoderTests.scala +++ b/dataset/src/test/scala/frameless/EncoderTests.scala @@ -1,7 +1,6 @@ package frameless -import scala.collection.immutable.Set - +import scala.collection.immutable.{ ListSet, Set, TreeSet } import org.scalatest.matchers.should.Matchers object EncoderTests { @@ -10,6 +9,8 @@ object EncoderTests { case class InstantRow(i: java.time.Instant) case class DurationRow(d: java.time.Duration) case class PeriodRow(p: java.time.Period) + + case class ContainerOf[CC[X] <: Iterable[X]](a: CC[X1[Int]]) } class EncoderTests extends TypedDatasetSuite with Matchers { @@ -32,4 +33,55 @@ class EncoderTests extends TypedDatasetSuite with Matchers { test("It should encode java.time.Period") { implicitly[TypedEncoder[PeriodRow]] } + + def performCollection[C[X] <: Iterable[X]]( + toType: Seq[X1[Int]] => C[X1[Int]] + )(implicit + ce: TypedEncoder[C[X1[Int]]] + ): (Unit, Unit) = evalCodeGens { + + implicit val cte = TypedExpressionEncoder[C[X1[Int]]] + implicit val e = implicitly[TypedEncoder[ContainerOf[C]]] + implicit val te = TypedExpressionEncoder[ContainerOf[C]] + implicit val xe = implicitly[TypedEncoder[X1[ContainerOf[C]]]] + implicit val xte = TypedExpressionEncoder[X1[ContainerOf[C]]] + val v = toType((1 to 20).map(X1(_))) + val ds = { + sqlContext.createDataset(Seq(X1[ContainerOf[C]](ContainerOf[C](v)))) + } + ds.head.a.a shouldBe v + () + } + + test("It should serde a Seq of Objects") { + performCollection[Seq](_) + } + + test("It should serde a Set of Objects") { + performCollection[Set](_) + } + + test("It should serde a Vector of Objects") { + performCollection[Vector](_.toVector) + } + + test("It should serde a TreeSet of Objects") { + // only needed for 2.12 + implicit val ordering = new Ordering[X1[Int]] { + val intordering = implicitly[Ordering[Int]] + + override def compare(x: X1[Int], y: X1[Int]): Int = + intordering.compare(x.a, y.a) + } + + performCollection[TreeSet](TreeSet.newBuilder.++=(_).result()) + } + + test("It should serde a List of Objects") { + performCollection[List](_.toList) + } + + test("It should serde a ListSet of Objects") { + performCollection[ListSet](ListSet.newBuilder.++=(_).result()) + } } diff --git a/dataset/src/test/scala/frameless/OrderByTests.scala b/dataset/src/test/scala/frameless/OrderByTests.scala index 98bd7442d..fd659bdc4 100644 --- a/dataset/src/test/scala/frameless/OrderByTests.scala +++ b/dataset/src/test/scala/frameless/OrderByTests.scala @@ -7,19 +7,25 @@ import org.apache.spark.sql.Column import org.scalatest.matchers.should.Matchers class OrderByTests extends TypedDatasetSuite with Matchers { - def sortings[A : CatalystOrdered, T]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = Seq( - (_.desc, _.desc), - (_.asc, _.asc), - (t => t, t => t) //default ascending - ) + + def sortings[A: CatalystOrdered, T]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = + Seq( + (_.desc, _.desc), + (_.asc, _.asc), + (t => t, t => t) // default ascending + ) test("single column non nullable orderBy") { - def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered](data: Vector[X1[A]]): Prop = { val ds = TypedDataset.create(data) - sortings[A, X1[A]].map { case (typ, untyp) => - ds.dataset.orderBy(untyp(ds.dataset.col("a"))).collect().toVector.?=( - ds.orderBy(typ(ds('a))).collect().run().toVector) + sortings[A, X1[A]].map { + case (typ, untyp) => + ds.dataset + .orderBy(untyp(ds.dataset.col("a"))) + .collect() + .toVector + .?=(ds.orderBy(typ(ds('a))).collect().run().toVector) }.reduce(_ && _) } @@ -36,12 +42,16 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("single column non nullable partition sorting") { - def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered](data: Vector[X1[A]]): Prop = { val ds = TypedDataset.create(data) - sortings[A, X1[A]].map { case (typ, untyp) => - ds.dataset.sortWithinPartitions(untyp(ds.dataset.col("a"))).collect().toVector.?=( - ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector) + sortings[A, X1[A]].map { + case (typ, untyp) => + ds.dataset + .sortWithinPartitions(untyp(ds.dataset.col("a"))) + .collect() + .toVector + .?=(ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector) }.reduce(_ && _) } @@ -58,15 +68,34 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("two columns non nullable orderBy") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X2[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) => - val vanillaSpark = ds.dataset.orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector - vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&( - vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b))).collect().run().toVector - ) - }.reduce(_ && _) + sortings[A, X2[A, B]].reverse + .zip(sortings[B, X2[A, B]]) + .map { + case ((typA, untypA), (typB, untypB)) => + val vanillaSpark = ds.dataset + .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))) + .collect() + .toVector + vanillaSpark + .?=( + ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector + ) + .&&( + vanillaSpark ?= ds + .orderByMany(typA(ds('a)), typB(ds('b))) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -75,15 +104,40 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("two columns non nullable partition sorting") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X2[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) => - val vanillaSpark = ds.dataset.sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector - vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&( - vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))).collect().run().toVector - ) - }.reduce(_ && _) + sortings[A, X2[A, B]].reverse + .zip(sortings[B, X2[A, B]]) + .map { + case ((typA, untypA), (typB, untypB)) => + val vanillaSpark = ds.dataset + .sortWithinPartitions( + untypA(ds.dataset.col("a")), + untypB(ds.dataset.col("b")) + ) + .collect() + .toVector + vanillaSpark + .?=( + ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))) + .collect() + .run() + .toVector + ) + .&&( + vanillaSpark ?= ds + .sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -92,21 +146,43 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("three columns non nullable orderBy") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X3[A, B, A]] + ): Prop = { val ds = TypedDataset.create(data) sortings[A, X3[A, B, A]].reverse .zip(sortings[B, X3[A, B, A]]) .zip(sortings[A, X3[A, B, A]]) - .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => - val vanillaSpark = ds.dataset - .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c"))) - .collect().toVector - - vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&( - vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector - ) - }.reduce(_ && _) + .map { + case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => + val vanillaSpark = ds.dataset + .orderBy( + untypA(ds.dataset.col("a")), + untypB(ds.dataset.col("b")), + untypA2(ds.dataset.col("c")) + ) + .collect() + .toVector + + vanillaSpark + .?=( + ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))) + .collect() + .run() + .toVector + ) + .&&( + vanillaSpark ?= ds + .orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -115,21 +191,50 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("three columns non nullable partition sorting") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X3[A, B, A]] + ): Prop = { val ds = TypedDataset.create(data) sortings[A, X3[A, B, A]].reverse .zip(sortings[B, X3[A, B, A]]) .zip(sortings[A, X3[A, B, A]]) - .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => - val vanillaSpark = ds.dataset - .sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c"))) - .collect().toVector - - vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&( - vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector - ) - }.reduce(_ && _) + .map { + case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => + val vanillaSpark = ds.dataset + .sortWithinPartitions( + untypA(ds.dataset.col("a")), + untypB(ds.dataset.col("b")), + untypA2(ds.dataset.col("c")) + ) + .collect() + .toVector + + vanillaSpark + .?=( + ds.sortWithinPartitions( + typA(ds('a)), + typB(ds('b)), + typA2(ds('c)) + ).collect() + .run() + .toVector + ) + .&&( + vanillaSpark ?= ds + .sortWithinPartitionsMany( + typA(ds('a)), + typB(ds('b)), + typA2(ds('c)) + ) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -138,13 +243,28 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("sort support for mixed default and explicit ordering") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A, B]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X2[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - ds.dataset.orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=( - ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) && - ds.dataset.sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=( - ds.sortWithinPartitionsMany(ds('a), ds('b).desc).collect().run().toVector) + ds.dataset + .orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc) + .collect() + .toVector + .?=(ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) && + ds.dataset + .sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc) + .collect() + .toVector + .?=( + ds.sortWithinPartitionsMany(ds('a), ds('b).desc) + .collect() + .run() + .toVector + ) } check(forAll(prop[SQLDate, Long] _)) @@ -159,50 +279,75 @@ class OrderByTests extends TypedDatasetSuite with Matchers { illTyped("""d.sortWithinPartitions(d('b).desc)""") } - test("derives a CatalystOrdered for case classes when all fields are comparable") { + test( + "derives a CatalystOrdered for case classes when all fields are comparable" + ) { type T[A, B] = X3[Int, Boolean, X2[A, B]] def prop[ - A: TypedEncoder : CatalystOrdered, - B: TypedEncoder : CatalystOrdered - ](data: Vector[T[A, B]]): Prop = { + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[T[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[X2[A, B], T[A, B]].map { case (typX2, untypX2) => - val vanilla = ds.dataset.orderBy(untypX2(ds.dataset.col("c"))).collect().toVector - val frameless = ds.orderBy(typX2(ds('c))).collect().run.toVector - vanilla ?= frameless + sortings[X2[A, B], T[A, B]].map { + case (typX2, untypX2) => + val vanilla = ds.dataset + .orderBy(untypX2(ds.dataset.col("c"))) + .collect() + .toVector + .map(_.c) + val frameless = + ds.orderBy(typX2(ds('c))).collect().run.toVector.map(_.c) + vanilla ?= frameless }.reduce(_ && _) } check(forAll(prop[Int, Long] _)) check(forAll(prop[(String, SQLDate), Float] _)) // Check that nested case classes are properly derived too - check(forAll(prop[X2[Boolean, Float], X4[SQLTimestamp, Double, Short, Byte]] _)) + check( + forAll(prop[X2[Boolean, Float], X4[SQLTimestamp, Double, Short, Byte]] _) + ) } test("derives a CatalystOrdered for tuples when all fields are comparable") { type T[A, B] = X2[Int, (A, B)] def prop[ - A: TypedEncoder : CatalystOrdered, - B: TypedEncoder : CatalystOrdered - ](data: Vector[T[A, B]]): Prop = { + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[T[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[(A, B), T[A, B]].map { case (typX2, untypX2) => - val vanilla = ds.dataset.orderBy(untypX2(ds.dataset.col("b"))).collect().toVector - val frameless = ds.orderBy(typX2(ds('b))).collect().run.toVector - vanilla ?= frameless + sortings[(A, B), T[A, B]].map { + case (typX2, untypX2) => + val vanilla = ds.dataset + .orderBy(untypX2(ds.dataset.col("b"))) + .collect() + .toVector + .map(_.b) + val frameless = + ds.orderBy(typX2(ds('b))).collect().run.toVector.map(_.b) + vanilla ?= frameless }.reduce(_ && _) } check(forAll(prop[Int, Long] _)) check(forAll(prop[(String, SQLDate), Float] _)) - check(forAll(prop[X2[Boolean, Float], X1[(SQLTimestamp, Double, Short, Byte)]] _)) + check( + forAll( + prop[X2[Boolean, Float], X1[(SQLTimestamp, Double, Short, Byte)]] _ + ) + ) } test("fails to compile when one of the field isn't comparable") { type T = X2[Int, X2[Int, Map[String, String]]] val d = TypedDataset.create(X2(1, X2(2, Map("not" -> "comparable"))) :: Nil) - illTyped("d.orderBy(d('b).desc)", """Cannot compare columns of type frameless.X2\[Int,scala.collection.immutable.Map\[String,String]].""") + illTyped( + "d.orderBy(d('b).desc)", + """Cannot compare columns of type frameless.X2\[Int,scala.collection.immutable.Map\[String,String]].""" + ) } } diff --git a/dataset/src/test/scala/frameless/SchemaTests.scala b/dataset/src/test/scala/frameless/SchemaTests.scala index 92fd33057..1fe9a3274 100644 --- a/dataset/src/test/scala/frameless/SchemaTests.scala +++ b/dataset/src/test/scala/frameless/SchemaTests.scala @@ -2,7 +2,7 @@ package frameless import frameless.functions.aggregate._ import frameless.functions._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ Metadata, StructType } import org.scalacheck.Prop import org.scalacheck.Prop._ import org.scalatest.matchers.should.Matchers @@ -10,10 +10,17 @@ import org.scalatest.matchers.should.Matchers class SchemaTests extends TypedDatasetSuite with Matchers { def structToNonNullable(struct: StructType): StructType = { - StructType(struct.fields.map( f => f.copy(nullable = false))) + StructType( + struct.fields.map(f => + f.copy(nullable = false, metadata = Metadata.empty) + ) + ) // Spark 4 puts metadata in _2 in schema test } - def prop[A](dataset: TypedDataset[A], ignoreNullable: Boolean = false): Prop = { + def prop[A]( + dataset: TypedDataset[A], + ignoreNullable: Boolean = false + ): Prop = { val schema = dataset.dataset.schema Prop.all( @@ -24,7 +31,9 @@ class SchemaTests extends TypedDatasetSuite with Matchers { if (!ignoreNullable) TypedExpressionEncoder.targetStructType(dataset.encoder) ?= schema else - structToNonNullable(TypedExpressionEncoder.targetStructType(dataset.encoder)) ?= structToNonNullable(schema) + structToNonNullable( + TypedExpressionEncoder.targetStructType(dataset.encoder) + ) ?= structToNonNullable(schema) ) } diff --git a/dataset/src/test/scala/frameless/SelfJoinTests.scala b/dataset/src/test/scala/frameless/SelfJoinTests.scala index cede7be2a..4495654d4 100644 --- a/dataset/src/test/scala/frameless/SelfJoinTests.scala +++ b/dataset/src/test/scala/frameless/SelfJoinTests.scala @@ -2,13 +2,18 @@ package frameless import org.scalacheck.Prop import org.scalacheck.Prop._ -import org.apache.spark.sql.{SparkSession, functions => sparkFunctions} +import org.apache.spark.sql.{ SparkSession, functions => sparkFunctions } class SelfJoinTests extends TypedDatasetSuite { + // Without crossJoin.enabled=true Spark doesn't like trivial join conditions: // [error] Join condition is missing or trivial. // [error] Use the CROSS JOIN syntax to allow cartesian products between these relations. - def allowTrivialJoin[T](body: => T)(implicit session: SparkSession): T = { + def allowTrivialJoin[T]( + body: => T + )(implicit + session: SparkSession + ): T = { val crossJoin = "spark.sql.crossJoin.enabled" val oldSetting = session.conf.get(crossJoin) session.conf.set(crossJoin, "true") @@ -17,7 +22,11 @@ class SelfJoinTests extends TypedDatasetSuite { result } - def allowAmbiguousJoin[T](body: => T)(implicit session: SparkSession): T = { + def allowAmbiguousJoin[T]( + body: => T + )(implicit + session: SparkSession + ): T = { val crossJoin = "spark.sql.analyzer.failAmbiguousSelfJoin" val oldSetting = session.conf.get(crossJoin) session.conf.set(crossJoin, "false") @@ -27,22 +36,26 @@ class SelfJoinTests extends TypedDatasetSuite { } test("self join with colLeft/colRight disambiguation") { - def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering - ](dx: List[X2[A, B]], d: X2[A, B]): Prop = allowAmbiguousJoin { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + dx: List[X2[A, B]], + d: X2[A, B] + ): Prop = allowAmbiguousJoin { val data = d :: dx val ds = TypedDataset.create(data) // This is the way to write unambiguous self-join in vanilla, see https://goo.gl/XnkSUD val df1 = ds.dataset.as("df1") val df2 = ds.dataset.as("df2") - val vanilla = df1.join(df2, - sparkFunctions.col("df1.a") === sparkFunctions.col("df2.a")).count() + val vanilla = df1 + .join(df2, sparkFunctions.col("df1.a") === sparkFunctions.col("df2.a")) + .count() - val typed = ds.joinInner(ds)( - ds.colLeft('a) === ds.colRight('a) - ).count().run() + val typed = ds + .joinInner(ds)( + ds.colLeft('a) === ds.colRight('a) + ) + .count() + .run() vanilla ?= typed } @@ -51,47 +64,60 @@ class SelfJoinTests extends TypedDatasetSuite { } test("trivial self join") { - def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering - ](dx: List[X2[A, B]], d: X2[A, B]): Prop = - allowTrivialJoin { allowAmbiguousJoin { - - val data = d :: dx - val ds = TypedDataset.create(data) - val untyped = ds.dataset - // Interestingly, even with aliasing it seems that it's impossible to - // obtain a trivial join condition of shape df1.a == df1.a, Spark we - // always interpret that as df1.a == df2.a. For the purpose of this - // test we fall-back to lit(true) instead. - // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a") - val trivial = sparkFunctions.lit(true) - val vanilla = untyped.as("df1").join(untyped.as("df2"), trivial).count() - - val typed = ds.joinInner(ds)(ds.colLeft('a) === ds.colLeft('a)).count().run - vanilla ?= typed - } } + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + dx: List[X2[A, B]], + d: X2[A, B] + ): Prop = + allowTrivialJoin { + allowAmbiguousJoin { + + val data = d :: dx + val ds = TypedDataset.create(data) + val untyped = ds.dataset + // Interestingly, even with aliasing it seems that it's impossible to + // obtain a trivial join condition of shape df1.a == df1.a, Spark we + // always interpret that as df1.a == df2.a. For the purpose of this + // test we fall-back to lit(true) instead. + // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a") + val trivial = sparkFunctions.lit(true) + val vanilla = + untyped.as("df1").join(untyped.as("df2"), trivial).count() + + val typed = + ds.joinInner(ds)(ds.colLeft('a) === ds.colLeft('a)).count().run + vanilla ?= typed + } + } check(prop[Int, Int] _) } test("self join with unambiguous expression") { def prop[ - A : TypedEncoder : CatalystNumeric : Ordering, - B : TypedEncoder : Ordering - ](data: List[X3[A, A, B]]): Prop = allowAmbiguousJoin { + A: TypedEncoder: CatalystNumeric: Ordering, + B: TypedEncoder: Ordering + ](data: List[X3[A, A, B]] + ): Prop = allowAmbiguousJoin { val ds = TypedDataset.create(data) val df1 = ds.dataset.alias("df1") val df2 = ds.dataset.alias("df2") - val vanilla = df1.join(df2, - (sparkFunctions.col("df1.a") + sparkFunctions.col("df1.b")) === - (sparkFunctions.col("df2.a") + sparkFunctions.col("df2.b"))).count() - - val typed = ds.joinInner(ds)( - (ds.colLeft('a) + ds.colLeft('b)) === (ds.colRight('a) + ds.colRight('b)) - ).count().run() + val vanilla = df1 + .join( + df2, + (sparkFunctions.col("df1.a") + sparkFunctions.col("df1.b")) === + (sparkFunctions.col("df2.a") + sparkFunctions.col("df2.b")) + ) + .count() + + val typed = ds + .joinInner(ds)( + (ds.colLeft('a) + ds.colLeft('b)) === (ds.colRight('a) + ds + .colRight('b)) + ) + .count() + .run() vanilla ?= typed } @@ -99,41 +125,57 @@ class SelfJoinTests extends TypedDatasetSuite { check(prop[Int, Int] _) } - test("Do you want ambiguous self join? This is how you get ambiguous self join.") { + test( + "Do you want ambiguous self join? This is how you get ambiguous self join." + ) { def prop[ - A : TypedEncoder : CatalystNumeric : Ordering, - B : TypedEncoder : Ordering - ](data: List[X3[A, A, B]]): Prop = - allowTrivialJoin { allowAmbiguousJoin { - val ds = TypedDataset.create(data) - - // The point I'm making here is that it "behaves just like Spark". I - // don't know (or really care about how) how Spark disambiguates that - // internally... - val vanilla = ds.dataset.join(ds.dataset, - (ds.dataset("a") + ds.dataset("b")) === - (ds.dataset("a") + ds.dataset("b"))).count() - - val typed = ds.joinInner(ds)( - (ds.col('a) + ds.col('b)) === (ds.col('a) + ds.col('b)) - ).count().run() - - vanilla ?= typed - } } - - check(prop[Int, Int] _) - } + A: TypedEncoder: CatalystNumeric: Ordering, + B: TypedEncoder: Ordering + ](data: List[X3[A, A, B]] + ): Prop = + allowTrivialJoin { + allowAmbiguousJoin { + val ds = TypedDataset.create(data) + + // The point I'm making here is that it "behaves just like Spark". I + // don't know (or really care about how) how Spark disambiguates that + // internally... + val vanilla = ds.dataset + .join( + ds.dataset, + (ds.dataset("a") + ds.dataset("b")) === + (ds.dataset("a") + ds.dataset("b")) + ) + .count() + + val typed = ds + .joinInner(ds)( + (ds.col('a) + ds.col('b)) === (ds.col('a) + ds.col('b)) + ) + .count() + .run() + + vanilla ?= typed + } + } + + check(prop[Int, Int] _) + } test("colLeft and colRight are equivalent to col outside of joins") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - ex4: TypedEncoder[X4[A, B, C, D]] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + ex4: TypedEncoder[X4[A, B, C, D]] + ): Prop = { val dataset = TypedDataset.create(data) - val selectedCol = dataset.select(dataset.col [A]('a)).collect().run().toVector - val selectedColLeft = dataset.select(dataset.colLeft [A]('a)).collect().run().toVector - val selectedColRight = dataset.select(dataset.colRight[A]('a)).collect().run().toVector + val selectedCol = + dataset.select(dataset.col[A]('a)).collect().run().toVector + val selectedColLeft = + dataset.select(dataset.colLeft[A]('a)).collect().run().toVector + val selectedColRight = + dataset.select(dataset.colRight[A]('a)).collect().run().toVector (selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight) } @@ -145,16 +187,26 @@ class SelfJoinTests extends TypedDatasetSuite { } test("colLeft and colRight are equivalent to col outside of joins - via files (codegen)") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - ex4: TypedEncoder[X4[A, B, C, D]] - ): Prop = { - TypedDataset.create(data).write.mode("overwrite").parquet("./target/testData") - val dataset = TypedDataset.createUnsafe[X4[A, B, C, D]](session.read.parquet("./target/testData")) - val selectedCol = dataset.select(dataset.col [A]('a)).collect().run().toVector - val selectedColLeft = dataset.select(dataset.colLeft [A]('a)).collect().run().toVector - val selectedColRight = dataset.select(dataset.colRight[A]('a)).collect().run().toVector + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + ex4: TypedEncoder[X4[A, B, C, D]] + ): Prop = { + TypedDataset + .create(data) + .write + .mode("overwrite") + .parquet(s"$TEST_OUTPUT_DIR/testData_selfjoins") + val dataset = TypedDataset.createUnsafe[X4[A, B, C, D]]( + session.read.parquet(s"$TEST_OUTPUT_DIR/testData_selfjoins") + ) + val selectedCol = + dataset.select(dataset.col[A]('a)).collect().run().toVector + val selectedColLeft = + dataset.select(dataset.colLeft[A]('a)).collect().run().toVector + val selectedColRight = + dataset.select(dataset.colRight[A]('a)).collect().run().toVector (selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight) } diff --git a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala index 8a4697835..ba16c12e2 100644 --- a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala +++ b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala @@ -2,28 +2,35 @@ package frameless import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem import org.apache.hadoop.fs.local.StreamingFS -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLContext, SparkSession} +import org.apache.spark.{ SparkConf, SparkContext } +import org.apache.spark.sql.{ SQLContext, SparkSession } import org.scalactic.anyvals.PosZInt import org.scalatest.BeforeAndAfterAll import org.scalatestplus.scalacheck.Checkers import org.scalacheck.Prop import org.scalacheck.Prop._ -import scala.util.{Properties, Try} +import scala.util.{ Properties, Try } import org.scalatest.funsuite.AnyFunSuite trait SparkTesting { self: BeforeAndAfterAll => - val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString + val appID: String = new java.util.Date().toString + math + .floor(math.random * 10e4) + .toLong + .toString /** * Allows bare naked to be used instead of winutils for testing / dev */ def registerFS(sparkConf: SparkConf): SparkConf = { if (System.getProperty("os.name").startsWith("Windows")) - sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName). - set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName) + sparkConf + .set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName) + .set( + "spark.hadoop.fs.AbstractFileSystem.file.impl", + classOf[StreamingFS].getName + ) else sparkConf } @@ -33,6 +40,10 @@ trait SparkTesting { self: BeforeAndAfterAll => .setAppName("test") .set("spark.ui.enabled", "false") .set("spark.app.id", appID) + .set( + "spark.sql.ansi.enabled", + "false" + ) // 43 tests fail on overflow / casting issues private var s: SparkSession = _ @@ -40,9 +51,9 @@ trait SparkTesting { self: BeforeAndAfterAll => implicit def sc: SparkContext = session.sparkContext implicit def sqlContext: SQLContext = session.sqlContext - def registerOptimizations(sqlContext: SQLContext): Unit = { } + def registerOptimizations(sqlContext: SQLContext): Unit = {} - def addSparkConfigProperties(config: SparkConf): Unit = { } + def addSparkConfigProperties(config: SparkConf): Unit = {} override def beforeAll(): Unit = { assert(s == null) @@ -51,7 +62,7 @@ trait SparkTesting { self: BeforeAndAfterAll => registerOptimizations(sqlContext) } - override def afterAll(): Unit = { + override def afterAll(): Unit = if (shouldCloseSession) { if (s != null) { s.stop() s = null @@ -59,11 +70,16 @@ trait SparkTesting { self: BeforeAndAfterAll => } } +class TypedDatasetSuite + extends AnyFunSuite + with Checkers + with BeforeAndAfterAll + with SparkTesting { -class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll with SparkTesting { // Limit size of generated collections and number of checks to avoid OutOfMemoryError implicit override val generatorDrivenConfig: PropertyCheckConfiguration = { - def getPosZInt(name: String, default: PosZInt) = Properties.envOrNone(s"FRAMELESS_GEN_${name}") + def getPosZInt(name: String, default: PosZInt) = Properties + .envOrNone(s"FRAMELESS_GEN_${name}") .flatMap(s => Try(s.toInt).toOption) .flatMap(PosZInt.from) .getOrElse(default) @@ -75,17 +91,24 @@ class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll implicit val sparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob - def approximatelyEqual[A](a: A, b: A)(implicit numeric: Numeric[A]): Prop = { + def approximatelyEqual[A]( + a: A, + b: A + )(implicit + numeric: Numeric[A] + ): Prop = { val da = numeric.toDouble(a) val db = numeric.toDouble(b) - val epsilon = 1E-6 + val epsilon = 1e-6 // Spark has a weird behaviour concerning expressions that should return Inf // Most of the time they return NaN instead, for instance stddev of Seq(-7.827553978923477E227, -5.009124275715786E153) - if((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved + if ((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved else if ( (da - db).abs < epsilon || - (da - db).abs < da.abs / 100) - proved - else falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon." + (da - db).abs < da.abs / 100 + ) + proved + else + falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon." } } diff --git a/dataset/src/test/scala/frameless/UdtEncodedClass.scala b/dataset/src/test/scala/frameless/UdtEncodedClass.scala index 4e5c2c6d9..b98f74a11 100644 --- a/dataset/src/test/scala/frameless/UdtEncodedClass.scala +++ b/dataset/src/test/scala/frameless/UdtEncodedClass.scala @@ -1,14 +1,19 @@ package frameless import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{ + GenericInternalRow, + UnsafeArrayData +} import org.apache.spark.sql.types._ -import org.apache.spark.sql.FramelessInternals.UserDefinedType +import FramelessInternals.UserDefinedType @SQLUserDefinedType(udt = classOf[UdtEncodedClassUdt]) class UdtEncodedClass(val a: Int, val b: Array[Double]) { + override def equals(other: Any): Boolean = other match { - case that: UdtEncodedClass => a == that.a && java.util.Arrays.equals(b, that.b) + case that: UdtEncodedClass => + a == that.a && java.util.Arrays.equals(b, that.b) case _ => false } @@ -25,11 +30,18 @@ object UdtEncodedClass { } class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] { + def sqlType: DataType = { - StructType(Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", ArrayType(DoubleType, containsNull = false), nullable = false) - )) + StructType( + Seq( + StructField("a", IntegerType, nullable = false), + StructField( + "b", + ArrayType(DoubleType, containsNull = false), + nullable = false + ) + ) + ) } def serialize(obj: UdtEncodedClass): InternalRow = { @@ -40,7 +52,8 @@ class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] { } def deserialize(datum: Any): UdtEncodedClass = datum match { - case row: InternalRow => new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray()) + case row: InternalRow => + new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray()) } def userClass: Class[UdtEncodedClass] = classOf[UdtEncodedClass] diff --git a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala index 201d93c63..7580c2a04 100644 --- a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala @@ -1,31 +1,37 @@ package frameless package functions -import frameless.{TypedAggregate, TypedColumn} +import frameless.{ TypedAggregate, TypedColumn } import frameless.functions.aggregate._ -import org.apache.spark.sql.{Column, Encoder} -import org.scalacheck.{Gen, Prop} +import org.apache.spark.sql.{ Column, Encoder } +import org.scalacheck.{ Gen, Prop } import org.scalacheck.Prop._ import org.scalatest.exceptions.GeneratorDrivenPropertyCheckFailedException class AggregateFunctionsTests extends TypedDatasetSuite { - def sparkSchema[A: TypedEncoder, U](f: TypedColumn[X1[A], A] => TypedAggregate[X1[A], U]): Prop = { + + def sparkSchema[A: TypedEncoder, U]( + f: TypedColumn[X1[A], A] => TypedAggregate[X1[A], U] + ): Prop = { val df = TypedDataset.create[X1[A]](Nil) val col = f(df.col('a)) val sumDf = df.agg(col) - TypedExpressionEncoder.targetStructType(sumDf.encoder) ?= sumDf.dataset.schema + TypedExpressionEncoder.targetStructType( + sumDf.encoder + ) ?= sumDf.dataset.schema } test("sum") { case class Sum4Tests[A, B](sum: Seq[A] => B) - def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])( - implicit - summable: CatalystSummable[A, Out], - summer: Sum4Tests[A, Out] - ): Prop = { + def prop[A: TypedEncoder, Out: TypedEncoder: Numeric]( + xs: List[A] + )(implicit + summable: CatalystSummable[A, Out], + summer: Sum4Tests[A, Out] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) @@ -33,7 +39,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { datasetSum match { case x :: Nil => approximatelyEqual(summer.sum(xs), x) - case other => falsified + case other => falsified } } @@ -61,27 +67,31 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("sumDistinct") { case class Sum4Tests[A, B](sum: Seq[A] => B) - def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])( - implicit - summable: CatalystSummable[A, Out], - summer: Sum4Tests[A, Out] - ): Prop = { + def prop[A: TypedEncoder, Out: TypedEncoder: Numeric]( + xs: List[A] + )(implicit + summable: CatalystSummable[A, Out], + summer: Sum4Tests[A, Out] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) - val datasetSum: List[Out] = dataset.agg(sumDistinct(A)).collect().run().toList + val datasetSum: List[Out] = + dataset.agg(sumDistinct(A)).collect().run().toList datasetSum match { case x :: Nil => approximatelyEqual(summer.sum(xs), x) - case other => falsified + case other => falsified } } // Replicate Spark's behaviour : Ints and Shorts are cast to Long // https://github.com/apache/spark/blob/7eb2ca8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L37 implicit def summerLong = Sum4Tests[Long, Long](_.toSet.sum) - implicit def summerInt = Sum4Tests[Int, Long]( x => x.toSet.map((_:Int).toLong).sum) - implicit def summerShort = Sum4Tests[Short, Long](x => x.toSet.map((_:Short).toLong).sum) + implicit def summerInt = + Sum4Tests[Int, Long](x => x.toSet.map((_: Int).toLong).sum) + implicit def summerShort = + Sum4Tests[Short, Long](x => x.toSet.map((_: Short).toLong).sum) check(forAll(prop[Long, Long] _)) check(forAll(prop[Int, Long] _)) @@ -95,33 +105,41 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("avg") { case class Averager4Tests[A, B](avg: Seq[A] => B) - def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])( - implicit - averageable: CatalystAverageable[A, Out], - averager: Averager4Tests[A, Out] - ): Prop = { + def prop[A: TypedEncoder, Out: TypedEncoder: Numeric]( + xs: List[A] + )(implicit + averageable: CatalystAverageable[A, Out], + averager: Averager4Tests[A, Out] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) val datasetAvg: Vector[Out] = dataset.agg(avg(A)).collect().run().toVector if (datasetAvg.size > 2) falsified - else xs match { - case Nil => datasetAvg ?= Vector() - case _ :: _ => datasetAvg.headOption match { - case Some(x) => approximatelyEqual(averager.avg(xs), x) - case None => falsified + else + xs match { + case Nil => datasetAvg ?= Vector() + case _ :: _ => + datasetAvg.headOption match { + case Some(x) => approximatelyEqual(averager.avg(xs), x) + case None => falsified + } } - } } // Replicate Spark's behaviour : If the datatype isn't BigDecimal cast type to Double // https://github.com/apache/spark/blob/7eb2ca8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L50 - implicit def averageDecimal = Averager4Tests[BigDecimal, BigDecimal](as => as.sum/as.size) - implicit def averageDouble = Averager4Tests[Double, Double](as => as.sum/as.size) - implicit def averageLong = Averager4Tests[Long, Double](as => as.map(_.toDouble).sum/as.size) - implicit def averageInt = Averager4Tests[Int, Double](as => as.map(_.toDouble).sum/as.size) - implicit def averageShort = Averager4Tests[Short, Double](as => as.map(_.toDouble).sum/as.size) + implicit def averageDecimal = + Averager4Tests[BigDecimal, BigDecimal](as => as.sum / as.size) + implicit def averageDouble = + Averager4Tests[Double, Double](as => as.sum / as.size) + implicit def averageLong = + Averager4Tests[Long, Double](as => as.map(_.toDouble).sum / as.size) + implicit def averageInt = + Averager4Tests[Int, Double](as => as.map(_.toDouble).sum / as.size) + implicit def averageShort = + Averager4Tests[Short, Double](as => as.map(_.toDouble).sum / as.size) /* under 3.4 an oddity was detected: Falsified after 2 successful property evaluations. @@ -141,20 +159,27 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("stddev and variance") { - def prop[A: TypedEncoder : CatalystVariance : Numeric](xs: List[A]): Prop = { + def prop[A: TypedEncoder: CatalystVariance: Numeric](xs: List[A]): Prop = { val numeric = implicitly[Numeric[A]] val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) - val datasetStdOpt = dataset.agg(stddev(A)).collect().run().toVector.headOption - val datasetVarOpt = dataset.agg(variance(A)).collect().run().toVector.headOption + val datasetStdOpt = + dataset.agg(stddev(A)).collect().run().toVector.headOption + val datasetVarOpt = + dataset.agg(variance(A)).collect().run().toVector.headOption - val std = sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleStdev() - val `var` = sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleVariance() + val std = + sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleStdev() + val `var` = + sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleVariance() (datasetStdOpt, datasetVarOpt) match { case (Some(datasetStd), Some(datasetVar)) => - approximatelyEqual(datasetStd, std) && approximatelyEqual(datasetVar, `var`) + approximatelyEqual(datasetStd, std) && approximatelyEqual( + datasetVar, + `var` + ) case _ => proved } } @@ -167,9 +192,17 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("litAggr") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](xs: List[A], b: B, c: C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + xs: List[A], + b: B, + c: C + ): Prop = { val dataset = TypedDataset.create(xs) - val (r1, rb, rc, rcount) = dataset.agg(count().lit(1), litAggr(b), litAggr(c), count()).collect().run().head + val (r1, rb, rc, rcount) = dataset + .agg(count().lit(1), litAggr(b), litAggr(c), count()) + .collect() + .run() + .head (rcount ?= xs.size.toLong) && (r1 ?= 1) && (rb ?= b) && (rc ?= c) } @@ -203,7 +236,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("max") { - def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered]( + xs: List[A] + )(implicit + o: Ordering[A] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) val datasetMax = dataset.agg(max(A)).collect().run().toList @@ -225,14 +262,18 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val A = dataset.col[Long]('a) val datasetMax = dataset.agg(max(A) * 2).collect().run().headOption - datasetMax ?= (if(xs.isEmpty) None else Some(xs.max * 2)) + datasetMax ?= (if (xs.isEmpty) None else Some(xs.max * 2)) } check(forAll(prop _)) } test("min") { - def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered]( + xs: List[A] + )(implicit + o: Ordering[A] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) @@ -250,13 +291,21 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("first") { - def prop[A: TypedEncoder](xs: List[A]): Prop = { + def prop[A: TypedEncoder: Ordering: CatalystOrdered](xs: List[A]): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) - - val datasetFirst = dataset.agg(first(A)).collect().run().toList - - datasetFirst ?= xs.headOption.toList + // servers do not return the same order told to + val sxs = xs.sorted + + val datasetFirst = dataset + .coalesce(1) + .orderBy(A: SortedTypedColumn[X1[A], A]) + .agg(first(A)) + .collect() + .run() + .toList + + datasetFirst ?= sxs.headOption.toList } check(forAll(prop[BigDecimal] _)) @@ -269,13 +318,21 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("last") { - def prop[A: TypedEncoder](xs: List[A]): Prop = { + def prop[A: TypedEncoder: Ordering: CatalystOrdered](xs: List[A]): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) - - val datasetLast = dataset.agg(last(A)).collect().run().toList - - datasetLast ?= xs.lastOption.toList + // servers do not return the same order told to + val sxs = xs.sorted + + val datasetLast = dataset + .coalesce(1) + .orderBy(A: SortedTypedColumn[X1[A], A]) + .agg(last(A)) + .collect() + .run() + .toList + + datasetLast ?= sxs.lastOption.toList } check(forAll(prop[BigDecimal] _)) @@ -301,8 +358,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check { forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] => val tds = TypedDataset.create(xs) - val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run() - tdsRes.toMap ?= xs.groupBy(_._1).mapValues(_.map(_._2).distinct.size.toLong).toSeq.toMap + val tdsRes: Seq[(Int, Long)] = + tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run() + tdsRes.toMap ?= xs + .groupBy(_._1) + .mapValues(_.map(_._2).distinct.size.toLong) + .toSeq + .toMap } } } @@ -310,7 +372,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("approxCountDistinct") { // Simple version of #approximatelyEqual() // Default maximum estimation error of HyperLogLog in Spark is 5% - def approxEqual(actual: Long, estimated: Long, allowedDeviationPercentile: Double = 0.05): Boolean = { + def approxEqual( + actual: Long, + estimated: Long, + allowedDeviationPercentile: Double = 0.05 + ): Boolean = { val delta: Long = Math.abs(actual - estimated) delta / actual.toDouble < allowedDeviationPercentile * 2 } @@ -319,7 +385,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite { forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] => val tds = TypedDataset.create(xs) val tdsRes: Seq[(Int, Long, Long)] = - tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))).collect().run() + tds + .groupBy(tds('_1)) + .agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))) + .collect() + .run() tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2) } } } @@ -329,18 +399,28 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val tds = TypedDataset.create(xs) val allowedError = 0.1 // 10% val tdsRes: Seq[(Int, Long, Long)] = - tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2), allowedError)).collect().run() + tds + .groupBy(tds('_1)) + .agg( + countDistinct(tds('_2)), + approxCountDistinct(tds('_2), allowedError) + ) + .collect() + .run() tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2, allowedError) } } } } test("collectList") { - def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = { + def prop[A: TypedEncoder: Ordering](xs: List[X2[A, A]]): Prop = { val tds = TypedDataset.create(xs) - val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run() + val tdsRes: Seq[(A, Vector[A])] = + tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run() - tdsRes.toMap.map { case (k, v) => k -> v.sorted } ?= xs.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).toVector.sorted } + tdsRes.toMap.map { case (k, v) => k -> v.sorted } ?= xs.groupBy(_.a).map { + case (k, v) => k -> v.map(_.b).toVector.sorted + } } check(forAll(prop[Long] _)) @@ -350,11 +430,14 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("collectSet") { - def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = { + def prop[A: TypedEncoder: Ordering](xs: List[X2[A, A]]): Prop = { val tds = TypedDataset.create(xs) - val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run() + val tdsRes: Seq[(A, Vector[A])] = + tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run() - tdsRes.toMap.map { case (k, v) => k -> v.toSet } ?= xs.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).toSet } + tdsRes.toMap.map { case (k, v) => k -> v.toSet } ?= xs.groupBy(_.a).map { + case (k, v) => k -> v.map(_.b).toSet + } } check(forAll(prop[Long] _)) @@ -379,92 +462,113 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[BigDecimal] _)) } - - def bivariatePropTemplate[A: TypedEncoder, B: TypedEncoder] - ( - xs: List[X3[Int, A, B]] - ) - ( - framelessFun: (TypedColumn[X3[Int, A, B], A], TypedColumn[X3[Int, A, B], B]) => TypedAggregate[X3[Int, A, B], Option[Double]], - sparkFun: (Column, Column) => Column - ) - ( - implicit - encEv: Encoder[(Int, A, B)], - encEv2: Encoder[(Int,Option[Double])], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = { + def bivariatePropTemplate[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(framelessFun: ( + TypedColumn[X3[Int, A, B], A], + TypedColumn[X3[Int, A, B], B] + ) => TypedAggregate[X3[Int, A, B], Option[Double]], + sparkFun: (Column, Column) => Column, + fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[ + BigDecimal + ], Option[BigDecimal]] = identity + )(implicit + encEv: Encoder[(Int, A, B)], + encEv2: Encoder[(Int, Option[Double])], + encEv3: Encoder[(Int, Option[BigDecimal])], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = { val tds = TypedDataset.create(xs) // Typed implementation of bivar stats function - val tdBivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b), tds('c))).deserialized.map(kv => - (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)) - ).collect().run() + val tdBivar = tds + .groupBy(tds('a)) + .agg(framelessFun(tds('b), tds('c))) + .deserialized + .map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))) + .collect() + .run() val cDF = session.createDataset(xs.map(x => (x.a, x.b, x.c))) // Comparison implementation of bivar stats functions val compBivar = cDF .groupBy(cDF("_1")) .agg(sparkFun(cDF("_2"), cDF("_3"))) - .map( - row => { - val grp = row.getInt(0) - (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) - } - ) + .map(row => { + val grp = row.getInt(0) + (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) + }) // Should be the same - tdBivar.toMap ?= compBivar.collect().toMap + // tdBivar.toMap ?= compBivar.collect().toMap + DoubleBehaviourUtils.compareMaps( + tdBivar.toMap, + compBivar.collect().toMap, + fudger + ) } - def univariatePropTemplate[A: TypedEncoder] - ( - xs: List[X2[Int, A]] - ) - ( - framelessFun: (TypedColumn[X2[Int, A], A]) => TypedAggregate[X2[Int, A], Option[Double]], - sparkFun: (Column) => Column - ) - ( - implicit - encEv: Encoder[(Int, A)], - encEv2: Encoder[(Int,Option[Double])], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = { + def univariatePropTemplate[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(framelessFun: (TypedColumn[X2[Int, A], A]) => TypedAggregate[ + X2[Int, A], + Option[Double] + ], + sparkFun: (Column) => Column, + fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[ + BigDecimal + ], Option[BigDecimal]] = identity + )(implicit + encEv: Encoder[(Int, A)], + encEv2: Encoder[(Int, Option[Double])], + encEv3: Encoder[(Int, Option[BigDecimal])], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = { val tds = TypedDataset.create(xs) - //typed implementation of univariate stats function - val tdUnivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b))).deserialized.map(kv => - (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)) - ).collect().run() + // typed implementation of univariate stats function + val tdUnivar = tds + .groupBy(tds('a)) + .agg(framelessFun(tds('b))) + .deserialized + .map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))) + .collect() + .run() val cDF = session.createDataset(xs.map(x => (x.a, x.b))) // Comparison implementation of bivar stats functions val compUnivar = cDF .groupBy(cDF("_1")) .agg(sparkFun(cDF("_2"))) - .map( - row => { - val grp = row.getInt(0) - (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) - } - ) + .map(row => { + val grp = row.getInt(0) + (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) + }) // Should be the same - tdUnivar.toMap ?= compUnivar.collect().toMap + // tdUnivar.toMap ?= compUnivar.collect().toMap + DoubleBehaviourUtils.compareMaps( + tdUnivar.toMap, + compUnivar.collect().toMap, + fudger + ) } test("corr") { val spark = session import spark.implicits._ - def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])( - implicit - encEv: Encoder[(Int, A, B)], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = bivariatePropTemplate(xs)(corr[A,B,X3[Int, A, B]],org.apache.spark.sql.functions.corr) + def prop[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(implicit + encEv: Encoder[(Int, A, B)], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = bivariatePropTemplate(xs)( + corr[A, B, X3[Int, A, B]], + org.apache.spark.sql.functions.corr + ) check(forAll(prop[Double, Double] _)) check(forAll(prop[Double, Int] _)) @@ -477,14 +581,16 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])( - implicit - encEv: Encoder[(Int, A, B)], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = bivariatePropTemplate(xs)( + def prop[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(implicit + encEv: Encoder[(Int, A, B)], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = bivariatePropTemplate(xs)( covarPop[A, B, X3[Int, A, B]], - org.apache.spark.sql.functions.covar_pop + org.apache.spark.sql.functions.covar_pop, + fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("100")) ) check(forAll(prop[Double, Double] _)) @@ -498,14 +604,16 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])( - implicit - encEv: Encoder[(Int, A, B)], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = bivariatePropTemplate(xs)( + def prop[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(implicit + encEv: Encoder[(Int, A, B)], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = bivariatePropTemplate(xs)( covarSamp[A, B, X3[Int, A, B]], - org.apache.spark.sql.functions.covar_samp + org.apache.spark.sql.functions.covar_samp, + fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("10")) ) check(forAll(prop[Double, Double] _)) @@ -519,13 +627,15 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( kurtosis[A, X2[Int, A]], - org.apache.spark.sql.functions.kurtosis + org.apache.spark.sql.functions.kurtosis, + fudger = DoubleBehaviourUtils.tolerance(_, BigDecimal("0.1")) ) check(forAll(prop[Double] _)) @@ -539,11 +649,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( skewness[A, X2[Int, A]], org.apache.spark.sql.functions.skewness ) @@ -559,11 +670,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( stddevPop[A, X2[Int, A]], org.apache.spark.sql.functions.stddev_pop ) @@ -579,11 +691,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( stddevSamp[A, X2[Int, A]], org.apache.spark.sql.functions.stddev_samp ) diff --git a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala index f3a8be581..7c5599e0a 100644 --- a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala +++ b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala @@ -1,20 +1,171 @@ package frameless package functions +import org.scalacheck.Prop +import org.scalacheck.Prop.AnyOperators +import org.scalacheck.util.Pretty +import shapeless.{ Lens, OpticDefns } + /** - * Some statistical functions in Spark can result in Double, Double.NaN or Null. - * This tends to break ?= of the property based testing. Use the nanNullHandler function - * here to alleviate this by mapping this NaN and Null to None. This will result in - * functioning comparison again. - */ + * Some statistical functions in Spark can result in Double, Double.NaN or Null. + * This tends to break ?= of the property based testing. Use the nanNullHandler function + * here to alleviate this by mapping this NaN and Null to None. This will result in + * functioning comparison again. + * + * Values are truncated to allow a chance of mitigating serialization issues + */ object DoubleBehaviourUtils { + + val dp5 = BigDecimal(0.00001) + // Mapping with this function is needed because spark uses Double.NaN for some semantics in the // correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN - private val nanHandler: Double => Option[Double] = value => if (!value.equals(Double.NaN)) Option(value) else None + private val nanHandler: Double => Option[Double] = value => + if (!value.equals(Double.NaN)) Option(value) else None + // Making sure that null => None and does not result in 0.0d because of row.getAs[Double]'s use of .asInstanceOf - val nanNullHandler: Any => Option[Double] = { + val nanNullHandler: Any => Option[BigDecimal] = { case null => None - case d: Double => nanHandler(d) + case d: Double => + nanHandler(d).map(truncate) case _ => ??? } + + /** ensure different serializations are 'comparable' */ + def truncate(d: Double): BigDecimal = + if (d == Double.NegativeInfinity || d == Double.PositiveInfinity) + BigDecimal("1000000.000000") * (if (d == Double.PositiveInfinity) 1 + else -1) + else + BigDecimal(d).setScale( + 6, + if (d > 0) + BigDecimal.RoundingMode.FLOOR + else + BigDecimal.RoundingMode.CEILING + ) + + import shapeless._ + + def tolerantCompareVectors[K, CC[X] <: Seq[X]]( + v1: CC[K], + v2: CC[K], + of: BigDecimal + )(fudgers: Seq[OpticDefns.RootLens[K] => Lens[K, Option[BigDecimal]]] + ): Prop = compareVectors(v1, v2)(fudgers.map(f => (f, tolerance(_, of)))) + + def compareVectors[K, CC[X] <: Seq[X]]( + v1: CC[K], + v2: CC[K] + )(fudgers: Seq[ + (OpticDefns.RootLens[K] => Lens[K, Option[BigDecimal]], + Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[ + BigDecimal + ], Option[BigDecimal]] + ) + ] + ): Prop = + if (v1.size != v2.size) + Prop.falsified :| { + "Expected Seq of size " + v1.size + " but got " + v2.size + } + else { + val together = v1.zip(v2) + val m = + together.map { p => + fudgers.foldLeft(p) { (curr, nf) => + val theLens = nf._1(lens[K]) + val p = (theLens.get(curr._1), theLens.get(curr._2)) + val (nl, nr) = nf._2(p) + (theLens.set(curr._1)(nl), theLens.set(curr._2)(nr)) + } + }.toMap + + m.keys.toVector ?= m.values.toVector + } + + def compareMaps[K]( + m1: Map[K, Option[BigDecimal]], + m2: Map[K, Option[BigDecimal]], + fudger: Tuple2[Option[BigDecimal], Option[BigDecimal]] => Tuple2[Option[ + BigDecimal + ], Option[BigDecimal]] + ): Prop = { + def compareKey(k: K): Prop = { + val m1v = m1.get(k) + val m2v = m2.get(k) + if (!m2v.isDefined) + Prop.falsified :| { + val expKey = Pretty.pretty[K](k, Pretty.Params(0)) + "Expected key of " + expKey + " in right side map" + } + else { + val (v1, v2) = fudger((m1v.get, m2v.get)) + if (v1 == v2) + Prop.proved + else + Prop.falsified :| { + val expKey = Pretty.pretty[K](k, Pretty.Params(0)) + val leftVal = + Pretty.pretty[Option[BigDecimal]](v1, Pretty.Params(0)) + val rightVal = + Pretty.pretty[Option[BigDecimal]](v2, Pretty.Params(0)) + "For key of " + expKey + " expected " + leftVal + " got " + rightVal + } + } + } + + if (m1.size != m2.size) + Prop.falsified :| { + "Expected map of size " + m1.size + " but got " + m2.size + } + else + m1.keys.foldLeft(Prop.passed) { (curr, elem) => curr && compareKey(elem) } + } + + /** running covar_pop and kurtosis multiple times is giving slightly different results */ + def tolerance( + p: Tuple2[Option[BigDecimal], Option[BigDecimal]], + of: BigDecimal + ): Tuple2[Option[BigDecimal], Option[BigDecimal]] = { + val comb = p._1.flatMap(a => p._2.map(b => (a, b))) + if (comb.isEmpty) + p + else { + val (l, r) = comb.get + if ((l.max(r) - l.min(r)).abs < of) + // tolerate it + (Some(l), Some(l)) + else + p + } + } + + import shapeless._ + + def tl[X]( + lensf: OpticDefns.RootLens[X] => Lens[X, Option[BigDecimal]], + of: BigDecimal + ): (X, X) => (X, X) = + (l: X, r: X) => { + val theLens = lensf(lens[X]) + val (nl, rl) = tolerance((theLens.get(l), theLens.get(r)), of) + (theLens.set(l)(nl), theLens.set(r)(rl)) + } + +} + +/** drop in conversion for doubles to handle serialization on cluster */ +trait ToDecimal[A] { + def truncate(a: A): Option[BigDecimal] + +} + +object ToDecimal { + + implicit val doubleToDecimal: ToDecimal[Double] = new ToDecimal[Double] { + + override def truncate(a: Double): Option[BigDecimal] = + DoubleBehaviourUtils.nanNullHandler(a) + } } diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index 470d58e5f..ac79edce0 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -7,40 +7,48 @@ import java.nio.charset.StandardCharsets import frameless.functions.nonAggregate._ import org.apache.commons.io.FileUtils -import org.apache.spark.sql.{Column, Encoder, SaveMode, functions => sparkFunctions} +import org.apache.spark.sql.{ + Column, + Encoder, + SaveMode, + functions => sparkFunctions +} import org.scalacheck.Prop._ -import org.scalacheck.{Arbitrary, Gen, Prop} +import org.scalacheck.{ Arbitrary, Gen, Prop } import scala.annotation.nowarn class NonAggregateFunctionsTests extends TypedDatasetSuite { - val testTempFiles = "target/testoutput" + val testTempFiles = s"$TEST_OUTPUT_DIR/naFtestoutput" object NonNegativeGenerators { + val doubleGen = for { - s <- Gen.chooseNum(1, Int.MaxValue) - e <- Gen.chooseNum(1, Int.MaxValue) + s <- Gen.chooseNum(1, Int.MaxValue) + e <- Gen.chooseNum(1, Int.MaxValue) res: Double = s.toDouble / e.toDouble } yield res - val intGen: Gen[Int] = Gen.chooseNum(1, Int.MaxValue) + val intGen: Gen[Int] = Gen.chooseNum(1, Int.MaxValue) val shortGen: Gen[Short] = Gen.chooseNum(1, Short.MaxValue) - val longGen: Gen[Long] = Gen.chooseNum(1, Long.MaxValue) - val byteGen: Gen[Byte] = Gen.chooseNum(1, Byte.MaxValue) + val longGen: Gen[Long] = Gen.chooseNum(1, Long.MaxValue) + val byteGen: Gen[Byte] = Gen.chooseNum(1, Byte.MaxValue) } object NonNegativeArbitraryNumericValues { import NonNegativeGenerators._ - implicit val arbInt: Arbitrary[Int] = Arbitrary(intGen) - implicit val arbDouble: Arbitrary[Double] = Arbitrary(doubleGen) - implicit val arbLong: Arbitrary[Long] = Arbitrary(longGen) - implicit val arbShort: Arbitrary[Short] = Arbitrary(shortGen) - implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen) + implicit val arbInt: Arbitrary[Int] = Arbitrary(intGen) + implicit val arbDouble: Arbitrary[Double] = Arbitrary(doubleGen) + implicit val arbLong: Arbitrary[Long] = Arbitrary(longGen) + implicit val arbShort: Arbitrary[Short] = Arbitrary(shortGen) + implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen) } private val base64Encoder = Base64.getEncoder + private def base64X1String(x1: X1[String]): X1[String] = { - def base64(str: String): String = base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8)) + def base64(str: String): String = + base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8)) x1.copy(a = base64(x1.a)) } @@ -53,9 +61,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](values: List[X1[A]])( - implicit encX1:Encoder[X1[A]], - catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]], + catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.negate(cDS("a"))) @@ -65,11 +76,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(negate(col)) - .collect() - .run() - .toList + val res = typedDS.select(negate(col)).collect().run().toList res ?= resCompare } @@ -77,7 +84,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Byte, Byte] _)) check(forAll(prop[Short, Short] _)) check(forAll(prop[Int, Int] _)) - check(forAll(prop[Long, Long] _)) + check(forAll(prop[Long, Long] _)) check(forAll(prop[BigDecimal, java.math.BigDecimal] _)) } @@ -85,7 +92,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Boolean]], fromBase: Int, toBase: Int)(implicit encX1:Encoder[X1[Boolean]]) = { + def prop( + values: List[X1[Boolean]], + fromBase: Int, + toBase: Int + )(implicit + encX1: Encoder[X1[Boolean]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS @@ -96,11 +109,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(not(col)) - .collect() - .run() - .toList + val res = typedDS.select(not(col)).collect().run().toList res ?= resCompare } @@ -112,7 +121,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[String]], fromBase: Int, toBase: Int)(implicit encX1:Encoder[X1[String]]) = { + def prop( + values: List[X1[String]], + fromBase: Int, + toBase: Int + )(implicit + encX1: Encoder[X1[String]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS @@ -123,11 +138,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(conv(col, fromBase, toBase)) - .collect() - .run() - .toList + val res = + typedDS.select(conv(col, fromBase, toBase)).collect().run().toList res ?= resCompare } @@ -139,7 +151,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.degrees(cDS("a"))) @@ -149,11 +165,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(degrees(col)) - .collect() - .run() - .toList + val res = typedDS.select(degrees(col)).collect().run().toList res ?= resCompare } @@ -161,12 +173,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Byte] _)) check(forAll(prop[Short] _)) check(forAll(prop[Int] _)) - check(forAll(prop[Long] _)) + check(forAll(prop[Long] _)) check(forAll(prop[BigDecimal] _)) } - def propBitShift[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]]) - (typedCol: TypedColumn[X1[A], B], sparkFunc: (Column,Int) => Column, numBits: Int): Prop = { + def propBitShift[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + typedDS: TypedDataset[X1[A]] + )(typedCol: TypedColumn[X1[A], B], + sparkFunc: (Column, Int) => Column, + numBits: Int + ): Prop = { val spark = session import spark.implicits._ @@ -176,11 +192,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .toList - val res = typedDS - .select(typedCol) - .collect() - .run() - .toList + val res = typedDS.select(typedCol).collect().run().toList res ?= resCompare } @@ -190,11 +202,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ @nowarn // supress sparkFunctions.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]], numBits: Int) - (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]], + numBits: Int + )(implicit + catalystBitShift: CatalystBitShift[A, B], + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftRightUnsigned, numBits) + propBitShift(typedDS)( + shiftRightUnsigned(typedDS('a), numBits), + sparkFunctions.shiftRightUnsigned, + numBits + ) } check(forAll(prop[Byte, Int] _)) @@ -209,11 +229,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ @nowarn // supress sparkFunctions.shiftRight call which is used to maintain Spark 3.1.x backwards compat - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]], numBits: Int) - (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]], + numBits: Int + )(implicit + catalystBitShift: CatalystBitShift[A, B], + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftRight, numBits) + propBitShift(typedDS)( + shiftRight(typedDS('a), numBits), + sparkFunctions.shiftRight, + numBits + ) } check(forAll(prop[Byte, Int] _)) @@ -228,11 +256,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ @nowarn // supress sparkFunctions.shiftLeft call which is used to maintain Spark 3.1.x backwards compat - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]], numBits: Int) - (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]], + numBits: Int + )(implicit + catalystBitShift: CatalystBitShift[A, B], + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftLeft, numBits) + propBitShift(typedDS)( + shiftLeft(typedDS('a), numBits), + sparkFunctions.shiftLeft, + numBits + ) } check(forAll(prop[Byte, Int] _)) @@ -246,27 +282,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]])( - implicit catalystAbsolute: CatalystRound[A, B], encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystRound[A, B], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.ceil(cDS("a"))) .map(_.getAs[B](0)) .collect() - .toList.map{ - case bigDecimal : java.math.BigDecimal => bigDecimal.setScale(0) - case other => other - }.asInstanceOf[List[B]] - + .toList + .map { + case bigDecimal: java.math.BigDecimal => bigDecimal.setScale(0) + case other => other + } + .asInstanceOf[List[B]] val typedDS = TypedDataset.create(values) - val res = typedDS - .select(ceil(typedDS('a))) - .collect() - .run() - .toList + val res = typedDS.select(ceil(typedDS('a))).collect().run().toList res ?= resCompare } @@ -282,20 +317,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { Seq(224, 256, 384, 512).map { numBits => val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.sha2(cDS("a"), numBits)) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(sha2(typedDS('a), numBits)) .collect() - .run() .toList + + val typedDS = TypedDataset.create(values) + val res = + typedDS.select(sha2(typedDS('a), numBits)).collect().run().toList res ?= resCompare }.reduce(_ && _) } @@ -307,20 +344,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.sha1(cDS("a"))) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(sha1(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(sha1(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -331,7 +369,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.crc32(cDS("a"))) @@ -340,11 +382,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .toList val typedDS = TypedDataset.create(values) - val res = typedDS - .select(crc32(typedDS('a))) - .collect() - .run() - .toList + val res = typedDS.select(crc32(typedDS('a))).collect().run().toList res ?= resCompare } @@ -356,27 +394,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]])( - implicit catalystAbsolute: CatalystRound[A, B], encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystRound[A, B], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.floor(cDS("a"))) .map(_.getAs[B](0)) .collect() - .toList.map{ - case bigDecimal : java.math.BigDecimal => bigDecimal.setScale(0) - case other => other - }.asInstanceOf[List[B]] - + .toList + .map { + case bigDecimal: java.math.BigDecimal => bigDecimal.setScale(0) + case other => other + } + .asInstanceOf[List[B]] val typedDS = TypedDataset.create(values) - val res = typedDS - .select(floor(typedDS('a))) - .collect() - .run() - .toList + val res = typedDS.select(floor(typedDS('a))).collect().run().toList res ?= resCompare } @@ -387,35 +424,35 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[BigDecimal, java.math.BigDecimal] _)) } - test("abs big decimal") { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder] - (values: List[X1[A]]) - ( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B], - encX1:Encoder[X1[A]] - )= { - val cDS = session.createDataset(values) - val resCompare = cDS - .select(sparkFunctions.abs(cDS("a"))) - .map(_.getAs[B](0)) - .collect().toList + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B], + encX1: Encoder[X1[A]] + ) = { + val cDS = session.createDataset(values) + val resCompare = cDS + .select(sparkFunctions.abs(cDS("a"))) + .map(_.getAs[B](0)) + .collect() + .toList - val typedDS = TypedDataset.create(values) - val col = typedDS('a) - val res = typedDS - .select( - abs(col) - ) - .collect() - .run() - .toList + val typedDS = TypedDataset.create(values) + val col = typedDS('a) + val res = typedDS + .select( + abs(col) + ) + .collect() + .run() + .toList - res ?= resCompare - } + res ?= resCompare + } check(forAll(prop[BigDecimal, java.math.BigDecimal] _)) } @@ -424,26 +461,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder] - (values: List[X1[A]]) - ( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, A], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.abs(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(abs(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(abs(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -453,36 +486,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - def propTrigonometric[A: CatalystNumeric: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]]) - (typedCol: TypedColumn[X1[A], Double], sparkFunc: Column => Column): Prop = { - val spark = session - import spark.implicits._ + def propTrigonometric[A: CatalystNumeric: TypedEncoder: Encoder]( + typedDS: TypedDataset[X1[A]] + )(typedCol: TypedColumn[X1[A], Double], + sparkFunc: Column => Column + ): Prop = { + val spark = session + import spark.implicits._ - val resCompare = typedDS.dataset - .select(sparkFunc($"a")) - .map(_.getAs[Double](0)) - .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + val resCompare = typedDS.dataset + .select(sparkFunc($"a")) + .map(_.getAs[Double](0)) + .map(DoubleBehaviourUtils.nanNullHandler) + .collect() + .toList - val res = typedDS - .select(typedCol) - .deserialized - .map(DoubleBehaviourUtils.nanNullHandler) - .collect() - .run() - .toList + val res = typedDS + .select(typedCol) + .deserialized + .map(DoubleBehaviourUtils.nanNullHandler) + .collect() + .run() + .toList - res ?= resCompare + res ?= resCompare } test("cos") { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(cos(typedDS('a)), sparkFunctions.cos) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(cos(typedDS('a)), sparkFunctions.cos) } check(forAll(prop[Int] _)) @@ -497,10 +537,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(cosh(typedDS('a)), sparkFunctions.cosh) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(cosh(typedDS('a)), sparkFunctions.cosh) } check(forAll(prop[Int] _)) @@ -515,10 +558,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(acos(typedDS('a)), sparkFunctions.acos) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(acos(typedDS('a)), sparkFunctions.acos) } check(forAll(prop[Int] _)) @@ -529,16 +575,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - - test("signum") { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(signum(typedDS('a)), sparkFunctions.signum) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(signum(typedDS('a)), sparkFunctions.signum) } check(forAll(prop[Int] _)) @@ -553,10 +600,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(sin(typedDS('a)), sparkFunctions.sin) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(sin(typedDS('a)), sparkFunctions.sin) } check(forAll(prop[Int] _)) @@ -571,10 +621,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(sinh(typedDS('a)), sparkFunctions.sinh) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(sinh(typedDS('a)), sparkFunctions.sinh) } check(forAll(prop[Int] _)) @@ -589,10 +642,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(asin(typedDS('a)), sparkFunctions.asin) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(asin(typedDS('a)), sparkFunctions.asin) } check(forAll(prop[Int] _)) @@ -607,10 +663,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(tan(typedDS('a)), sparkFunctions.tan) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(tan(typedDS('a)), sparkFunctions.tan) } check(forAll(prop[Int] _)) @@ -625,10 +684,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(tanh(typedDS('a)), sparkFunctions.tanh) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(tanh(typedDS('a)), sparkFunctions.tanh) } check(forAll(prop[Int] _)) @@ -639,48 +701,46 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - /* - * Currently not all Collection types play nice with the Encoders. - * This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved. - * - * [[https://issues.apache.org/jira/browse/SPARK-18891]] - * [[https://issues.apache.org/jira/browse/SPARK-21204]] - */ - test("arrayContains"){ + /* + * Currently not all Collection types play nice with the Encoders. + * This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved. + * + * [[https://issues.apache.org/jira/browse/SPARK-18891]] + * [[https://issues.apache.org/jira/browse/SPARK-21204]] + */ + test("arrayContains") { val spark = session import spark.implicits._ val listLength = 10 val idxs = Stream.continually(Range(0, listLength)).flatten.toIterator - abstract class Nth[A, C[A]:CatalystCollection] { + abstract class Nth[A, C[A]: CatalystCollection] { - def nth(c:C[A], idx:Int):A + def nth(c: C[A], idx: Int): A } - implicit def deriveListNth[A] : Nth[A, List] = new Nth[A, List] { + implicit def deriveListNth[A]: Nth[A, List] = new Nth[A, List] { override def nth(c: List[A], idx: Int): A = c(idx) } - implicit def deriveSeqNth[A] : Nth[A, Seq] = new Nth[A, Seq] { + implicit def deriveSeqNth[A]: Nth[A, Seq] = new Nth[A, Seq] { override def nth(c: Seq[A], idx: Int): A = c(idx) } - implicit def deriveVectorNth[A] : Nth[A, Vector] = new Nth[A, Vector] { + implicit def deriveVectorNth[A]: Nth[A, Vector] = new Nth[A, Vector] { override def nth(c: Vector[A], idx: Int): A = c(idx) } - implicit def deriveArrayNth[A] : Nth[A, Array] = new Nth[A, Array] { + implicit def deriveArrayNth[A]: Nth[A, Array] = new Nth[A, Array] { override def nth(c: Array[A], idx: Int): A = c(idx) } - - def prop[C[_] : CatalystCollection] - ( + def prop[C[_]: CatalystCollection]( values: C[Int], - shouldBeIn:Boolean) - ( - implicit nth:Nth[Int, C], + shouldBeIn: Boolean + )(implicit + nth: Nth[Int, C], encEv: Encoder[C[Int]], tEncEv: TypedEncoder[C[Int]] ) = { @@ -691,7 +751,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = cDS .select(sparkFunctions.array_contains(cDS("value"), contained)) .map(_.getAs[Boolean](0)) - .collect().toList + .collect() + .toList val typedDS = TypedDataset.create(List(X1(values))) val res = typedDS @@ -705,10 +766,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)), - Gen.oneOf(true,false) - ) - (prop[List]) + Gen.listOfN(listLength, Gen.choose(0, 100)), + Gen.oneOf(true, false) + )(prop[List]) ) /*check( Looks like there is no Typed Encoder for Seq type yet @@ -721,18 +781,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)).map(_.toVector), - Gen.oneOf(true,false) - ) - (prop[Vector]) + Gen.listOfN(listLength, Gen.choose(0, 100)).map(_.toVector), + Gen.oneOf(true, false) + )(prop[Vector]) ) check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)).map(_.toArray), - Gen.oneOf(true,false) - ) - (prop[Array]) + Gen.listOfN(listLength, Gen.choose(0, 100)).map(_.toArray), + Gen.oneOf(true, false) + )(prop[Array]) ) } @@ -740,14 +798,20 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder] - (na: A, values: List[X1[A]])(implicit encX1: Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered]( + na: A, + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(X1(na) :: values) val resCompare = cDS .select(sparkFunctions.atan(cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList + .sorted val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -757,16 +821,28 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .run() .toList + .sorted - val aggrTyped = typedDS.agg(atan( - frameless.functions.aggregate.first(typedDS('a))) - ).firstOption().run().get + val aggrTyped = typedDS + .coalesce(1) + .orderBy(typedDS('a).asc) + .agg(atan(frameless.functions.aggregate.first(typedDS('a)))) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan(sparkFunctions.first("a")).as[Double] - ).first() + val aggrSpark = cDS + .coalesce(1) + .orderBy("a") + .select( + sparkFunctions.atan(sparkFunctions.first("a")).as[Double] + ) + .first() - (res ?= resCompare).&&(aggrTyped ?= aggrSpark) + (res ?= resCompare).&&( + DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils + .nanNullHandler(aggrSpark) + ) } check(forAll(prop[Int] _)) @@ -781,16 +857,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder, - B: CatalystNumeric : TypedEncoder : Encoder](na: X2[A, B], values: List[X2[A, B]]) - (implicit encEv: Encoder[X2[A,B]]) = { + def prop[ + A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered, + B: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered + ](na: X2[A, B], + values: List[X2[A, B]] + )(implicit + encEv: Encoder[X2[A, B]] + ) = { val cDS = session.createDataset(na +: values) val resCompare = cDS .select(sparkFunctions.atan2(cDS("a"), cDS("b"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList - + .collect() + .toList + .sorted val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -800,20 +882,37 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .run() .toList + .sorted + + val aggrTyped = typedDS + .coalesce(1) + .orderBy(typedDS('a).asc, typedDS('b).asc) + .agg( + atan2( + frameless.functions.aggregate.first(typedDS('a)), + frameless.functions.aggregate.first(typedDS('b)) + ) + ) + .firstOption() + .run() + .get - val aggrTyped = typedDS.agg(atan2( - frameless.functions.aggregate.first(typedDS('a)), - frameless.functions.aggregate.first(typedDS('b))) - ).firstOption().run().get - - val aggrSpark = cDS.select( - sparkFunctions.atan2(sparkFunctions.first("a"),sparkFunctions.first("b")).as[Double] - ).first() + val aggrSpark = cDS + .coalesce(1) + .orderBy("a", "b") + .select( + sparkFunctions + .atan2(sparkFunctions.first("a"), sparkFunctions.first("b")) + .as[Double] + ) + .first() - (res ?= resCompare).&&(aggrTyped ?= aggrSpark) + (res ?= resCompare).&&( + DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils + .nanNullHandler(aggrSpark) + ) } - check(forAll(prop[Int, Long] _)) check(forAll(prop[Long, Int] _)) check(forAll(prop[Short, Byte] _)) @@ -826,15 +925,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder] - (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered]( + na: X1[A], + value: List[X1[A]], + lit: Double + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(na +: value) val resCompare = cDS .select(sparkFunctions.atan2(lit, cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList - + .collect() + .toList + .sorted val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -844,17 +949,28 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .run() .toList + .sorted - val aggrTyped = typedDS.agg(atan2( - lit, - frameless.functions.aggregate.first(typedDS('a))) - ).firstOption().run().get + val aggrTyped = typedDS + .coalesce(1) + .orderBy(typedDS('a).asc) + .agg(atan2(lit, frameless.functions.aggregate.first(typedDS('a)))) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double] - ).first() + val aggrSpark = cDS + .coalesce(1) + .orderBy("a") + .select( + sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double] + ) + .first() - (res ?= resCompare).&&(aggrTyped ?= aggrSpark) + (res ?= resCompare).&&( + DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils + .nanNullHandler(aggrSpark) + ) } check(forAll(prop[Int] _)) @@ -869,15 +985,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder] - (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder: CatalystOrdered]( + na: X1[A], + value: List[X1[A]], + lit: Double + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(na +: value) val resCompare = cDS .select(sparkFunctions.atan2(cDS("a"), lit)) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList - + .collect() + .toList + .sorted val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -887,20 +1009,30 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .run() .toList + .sorted - val aggrTyped = typedDS.agg(atan2( - frameless.functions.aggregate.first(typedDS('a)), - lit) - ).firstOption().run().get + val aggrTyped = typedDS + .coalesce(1) + .orderBy(typedDS('a).asc) + .agg(atan2(frameless.functions.aggregate.first(typedDS('a)), lit)) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double] - ).first() + val aggrSpark = cDS + .coalesce(1) + .orderBy("a") + .select( + sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double] + ) + .first() - (res ?= resCompare).&&(aggrTyped ?= aggrSpark) + (res ?= resCompare).&&( + DoubleBehaviourUtils.nanNullHandler(aggrTyped) ?= DoubleBehaviourUtils + .nanNullHandler(aggrSpark) + ) } - check(forAll(prop[Int] _)) check(forAll(prop[Long] _)) check(forAll(prop[Short] _)) @@ -909,9 +1041,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - def mathProp[A: CatalystNumeric: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]])( - typedCol: TypedColumn[X1[A], Double], sparkFunc: Column => Column - ): Prop = { + def mathProp[A: CatalystNumeric: TypedEncoder: Encoder]( + typedDS: TypedDataset[X1[A]] + )(typedCol: TypedColumn[X1[A], Double], + sparkFunc: Column => Column + ): Prop = { val spark = session import spark.implicits._ @@ -919,7 +1053,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunc($"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(typedCol) @@ -936,7 +1071,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(sqrt(typedDS('a)), sparkFunctions.sqrt) } @@ -953,7 +1092,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(cbrt(typedDS('a)), sparkFunctions.cbrt) } @@ -970,7 +1113,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(exp(typedDS('a)), sparkFunctions.exp) } @@ -987,7 +1134,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]]): Prop = { + def prop[A: TypedEncoder: Encoder](values: List[X1[A]]): Prop = { val spark = session import spark.implicits._ @@ -996,14 +1143,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = typedDS.dataset .select(sparkFunctions.md5($"a")) .map(_.getAs[String](0)) - .collect().toList - - val res = typedDS - .select(md5(typedDS('a))) .collect() - .run() .toList + val res = typedDS.select(md5(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1022,14 +1166,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = typedDS.dataset .select(sparkFunctions.factorial($"a")) .map(_.getAs[Long](0)) - .collect().toList - - val res = typedDS - .select(factorial(typedDS('a))) .collect() - .run() .toList + val res = typedDS.select(factorial(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1040,24 +1181,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(round(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(round(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1071,25 +1213,27 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"))) .map(_.getAs[java.math.BigDecimal](0)) .collect() - .toList.map(_.setScale(0)) + .toList + .map(_.setScale(0)) val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(round(col)) - .collect() - .run() - .toList + val res = typedDS.select(round(col)).collect().run().toList res ?= resCompare } @@ -1101,24 +1245,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"), 1)) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(round(typedDS('a), 1)) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(round(typedDS('a), 1)).collect().run().toList + res ?= resCompare } @@ -1132,25 +1277,27 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"), 0)) .map(_.getAs[java.math.BigDecimal](0)) .collect() - .toList.map(_.setScale(0)) + .toList + .map(_.setScale(0)) val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(round(col, 0)) - .collect() - .run() - .toList + val res = typedDS.select(round(col, 0)).collect().run().toList res ?= resCompare } @@ -1162,24 +1309,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bround(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bround(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bround(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1187,31 +1335,33 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Long] _)) check(forAll(prop[Short] _)) check(forAll(prop[Double] _)) - } + } test("bround big decimal") { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bround(cDS("a"))) .map(_.getAs[java.math.BigDecimal](0)) .collect() - .toList.map(_.setScale(0)) + .toList + .map(_.setScale(0)) val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(bround(col)) - .collect() - .run() - .toList + val res = typedDS.select(bround(col)).collect().run().toList res ?= resCompare } @@ -1219,63 +1369,66 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[BigDecimal] _)) } - test("bround with scale") { - val spark = session - import spark.implicits._ + test("bround with scale") { + val spark = session + import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], encX1: Encoder[X1[A]] ) = { - val cDS = session.createDataset(values) - val resCompare = cDS - .select(sparkFunctions.bround(cDS("a"), 1)) - .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bround(typedDS('a), 1)) - .collect() - .run() - .toList + val cDS = session.createDataset(values) + val resCompare = cDS + .select(sparkFunctions.bround(cDS("a"), 1)) + .map(_.getAs[A](0)) + .collect() + .toList - res ?= resCompare - } + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bround(typedDS('a), 1)).collect().run().toList - check(forAll(prop[Int] _)) - check(forAll(prop[Long] _)) - check(forAll(prop[Short] _)) - check(forAll(prop[Double] _)) + res ?= resCompare } - test("bround big decimal with scale") { - val spark = session - import spark.implicits._ + check(forAll(prop[Int] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Double] _)) + } - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] + test("bround big decimal with scale") { + val spark = session + import spark.implicits._ + + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] ) = { - val cDS = session.createDataset(values) - - val resCompare = cDS - .select(sparkFunctions.bround(cDS("a"), 0)) - .map(_.getAs[java.math.BigDecimal](0)) - .collect() - .toList.map(_.setScale(0)) - - val typedDS = TypedDataset.create(values) - val col = typedDS('a) - val res = typedDS - .select(bround(col, 0)) - .collect() - .run() - .toList - - res ?= resCompare - } + val cDS = session.createDataset(values) + + val resCompare = cDS + .select(sparkFunctions.bround(cDS("a"), 0)) + .map(_.getAs[java.math.BigDecimal](0)) + .collect() + .toList + .map(_.setScale(0)) + + val typedDS = TypedDataset.create(values) + val col = typedDS('a) + val res = typedDS.select(bround(col, 0)).collect().run().toList + + res ?= resCompare + } check(forAll(prop[BigDecimal] _)) } @@ -1285,10 +1438,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X1[A]], - base: Double - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]], + base: Double + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1297,7 +1450,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.log(base, $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(log(base, typedDS('a))) @@ -1322,7 +1476,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log(typedDS('a)), sparkFunctions.log) } @@ -1339,7 +1497,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log2(typedDS('a)), sparkFunctions.log2) } @@ -1356,7 +1518,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log1p(typedDS('a)), sparkFunctions.log1p) } @@ -1373,7 +1539,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log10(typedDS('a)), sparkFunctions.log10) } @@ -1389,20 +1559,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values:List[X1[Array[Byte]]])(implicit encX1:Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.base64(cDS("a"))) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(base64(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(base64(typedDS('a))).collect().run().toList + val backAndForth = typedDS .select(base64(unbase64(base64(typedDS('a))))) .collect() @@ -1419,10 +1590,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X1[A]], - base: Double - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]], + base: Double + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1431,7 +1602,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.hypot(base, $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res2 = typedDS .select(hypot(typedDS('a), base)) @@ -1463,9 +1635,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X2[A, A]] - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X2[A, A]] + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1474,7 +1646,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.hypot($"b", $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(hypot(typedDS('b), typedDS('a))) @@ -1498,10 +1671,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X1[A]], - base: Double - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]], + base: Double + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1510,7 +1683,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.pow(base, $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(pow(base, typedDS('a))) @@ -1524,7 +1698,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.pow($"a", base)) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res2 = typedDS .select(pow(typedDS('a), base)) @@ -1534,7 +1709,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - (res ?= resCompare) && (res2 ?= resCompare2) + (res ?= resCompare) && (res2 ?= resCompare2) } check(forAll(prop[Int] _)) @@ -1548,9 +1723,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X2[A, A]] - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X2[A, A]] + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1559,7 +1734,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.pow($"b", $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(pow(typedDS('b), typedDS('a))) @@ -1584,9 +1760,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X2[A, A]] - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X2[A, A]] + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1594,14 +1770,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = typedDS.dataset .select(sparkFunctions.pmod($"b", $"a")) .map(_.getAs[A](0)) - .collect().toList - - val res = typedDS - .select(pmod(typedDS('b), typedDS('a))) .collect() - .run() .toList + val res = + typedDS.select(pmod(typedDS('b), typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1616,71 +1790,73 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[String]])(implicit encX1: Encoder[X1[String]]) = { + def prop( + values: List[X1[String]] + )(implicit + encX1: Encoder[X1[String]] + ) = { val valuesBase64 = values.map(base64X1String) val cDS = session.createDataset(valuesBase64) val resCompare = cDS .select(sparkFunctions.unbase64(cDS("a"))) .map(_.getAs[Array[Byte]](0)) - .collect().toList - - val typedDS = TypedDataset.create(valuesBase64) - val res = typedDS - .select(unbase64(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(valuesBase64) + val res = typedDS.select(unbase64(typedDS('a))).collect().run().toList + res.map(_.toList) ?= resCompare.map(_.toList) } check(forAll(prop _)) } - test("bin"){ + test("bin") { val spark = session import spark.implicits._ - def prop(values:List[X1[Long]])(implicit encX1:Encoder[X1[Long]]) = { + def prop( + values: List[X1[Long]] + )(implicit + encX1: Encoder[X1[Long]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bin(cDS("a"))) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bin(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bin(typedDS('a))).collect().run().toList + res ?= resCompare } check(forAll(prop _)) } - test("bitwiseNOT"){ + test("bitwiseNOT") { val spark = session import spark.implicits._ @nowarn // supress sparkFunctions.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat - def prop[A: CatalystBitwise : TypedEncoder : Encoder] - (values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystBitwise: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bitwiseNOT(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bitwiseNOT(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bitwiseNOT(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1694,11 +1870,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder]( - toFile1: List[X1[A]], - toFile2: List[X1[A]], - inMem: List[X1[A]] - )(implicit x2Gen: Encoder[X2[A, String]], x3Gen: Encoder[X3[A, String, String]]) = { + def prop[A: TypedEncoder]( + toFile1: List[X1[A]], + toFile2: List[X1[A]], + inMem: List[X1[A]] + )(implicit + x2Gen: Encoder[X2[A, String]], + x3Gen: Encoder[X3[A, String, String]] + ) = { val file1Path = testTempFiles + "/file1" val file2Path = testTempFiles + "/file2" @@ -1719,7 +1898,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val unioned = ds1.union(ds2).union(ds3) - val withFileName = unioned.withColumn[X3[A, String, String]](inputFileName[X2[A, String]]()) + val withFileName = unioned + .withColumn[X3[A, String, String]](inputFileName[X2[A, String]]()) .collect() .run() .toVector @@ -1727,10 +1907,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val grouped = withFileName.groupBy(_.b).mapValues(_.map(_.c).toSet) grouped.foldLeft(passed) { (p, g) => - p && secure { g._1 match { - case "" => g._2.head == "" //Empty string if didn't come from file - case f => g._2.forall(_.contains(f)) - }}} + p && secure { + g._1 match { + case "" => g._2.head == "" // Empty string if didn't come from file + case f => g._2.forall(_.contains(f)) + } + } + } } check(forAll(prop[String] _)) @@ -1740,17 +1923,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder](xs: List[X1[A]])(implicit x2en: Encoder[X2[A, Long]]) = { + def prop[A: TypedEncoder]( + xs: List[X1[A]] + )(implicit + x2en: Encoder[X2[A, Long]] + ) = { val ds = TypedDataset.create(xs) - val result = ds.withColumn[X2[A, Long]](monotonicallyIncreasingId()) + val result = ds + .withColumn[X2[A, Long]](monotonicallyIncreasingId()) .collect() .run() .toVector val ids = result.map(_.b) (ids.toSet.size ?= ids.length) && - (ids.sorted ?= ids) + (ids.sorted ?= ids) } check(forAll(prop[String] _)) @@ -1760,13 +1948,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder : Encoder] - (condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = { - val ds = TypedDataset.create(X5(condition1, condition2, value1, value2, otherwise) :: Nil) + def prop[A: TypedEncoder: Encoder]( + condition1: Boolean, + condition2: Boolean, + value1: A, + value2: A, + otherwise: A + ) = { + val ds = TypedDataset.create( + X5(condition1, condition2, value1, value2, otherwise) :: Nil + ) - val untypedWhen = ds.toDF() + val untypedWhen = ds + .toDF() .select( - sparkFunctions.when(sparkFunctions.col("a"), sparkFunctions.col("c")) + sparkFunctions + .when(sparkFunctions.col("a"), sparkFunctions.col("c")) .when(sparkFunctions.col("b"), sparkFunctions.col("d")) .otherwise(sparkFunctions.col("e")) ) @@ -1776,9 +1973,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedWhen = ds .select( - when(ds('a), ds('c)) - .when(ds('b), ds('d)) - .otherwise(ds('e)) + when(ds('a), ds('c)).when(ds('b), ds('d)).otherwise(ds('e)) ) .collect() .run() @@ -1800,17 +1995,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.ascii($"a")) .map(_.getAs[Int](0)) .collect() .toVector - val typed = ds - .select(ascii(ds('a))) - .collect() - .run() - .toVector + val typed = ds.select(ascii(ds('a))).collect().run().toVector typed ?= sparkResult }) @@ -1828,19 +2020,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.concat($"a", $"b")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(concat(ds('a), ds('b))) - .collect() - .run() - .toVector + val typed = ds.select(concat(ds('a), ds('b))).collect().run().toVector - (typed ?= sparkResult).&&(typed ?= values.map(x => s"${x.a}${x.b}").toVector) + (typed ?= sparkResult).&&( + typed ?= values.map(x => s"${x.a}${x.b}").toVector + ) }) } @@ -1855,10 +2046,23 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val td = ds.agg(concat(first(ds('a)),first(ds('b)))).collect().run().toVector - val spark = ds.dataset.select(sparkFunctions.concat( - sparkFunctions.first($"a").as[String], - sparkFunctions.first($"b").as[String])).as[String].collect().toVector + val td = + ds.coalesce(1) + .agg(concat(first(ds('a)), first(ds('b)))) + .collect() + .run() + .toVector + val spark = ds.dataset + .coalesce(1) + .select( + sparkFunctions.concat( + sparkFunctions.first($"a").as[String], + sparkFunctions.first($"b").as[String] + ) + ) + .as[String] + .collect() + .toVector td ?= spark }) } @@ -1875,17 +2079,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.concat_ws(",", $"a", $"b")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(concatWs(",", ds('a), ds('b))) - .collect() - .run() - .toVector + val typed = + ds.select(concatWs(",", ds('a), ds('b))).collect().run().toVector typed ?= sparkResult }) @@ -1902,11 +2104,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val td = ds.agg(concatWs(",",first(ds('a)),first(ds('b)), last(ds('b)))).collect().run().toVector - val spark = ds.dataset.select(sparkFunctions.concat_ws(",", - sparkFunctions.first($"a").as[String], - sparkFunctions.first($"b").as[String], - sparkFunctions.last($"b").as[String])).as[String].collect().toVector + val td = ds + .coalesce(1) + .agg(concatWs(",", first(ds('a)), first(ds('b)), last(ds('b)))) + .collect() + .run() + .toVector + val spark = ds.dataset + .coalesce(1) + .select( + sparkFunctions.concat_ws( + ",", + sparkFunctions.first($"a").as[String], + sparkFunctions.first($"b").as[String], + sparkFunctions.last($"b").as[String] + ) + ) + .as[String] + .collect() + .toVector td ?= spark }) } @@ -1917,17 +2133,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(Gen.nonEmptyListOf(Gen.alphaStr)) { values: List[String] => val ds = TypedDataset.create(values.map(x => X1(x + values.head))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.instr($"a", values.head)) .map(_.getAs[Int](0)) .collect() .toVector - val typed = ds - .select(instr(ds('a), values.head)) - .collect() - .run() - .toVector + val typed = ds.select(instr(ds('a), values.head)).collect().run().toVector typed ?= sparkResult }) @@ -1939,17 +2152,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.length($"a")) .map(_.getAs[Int](0)) .collect() .toVector - val typed = ds - .select(length(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(length(ds[String]('a))).collect().run().toVector (typed ?= sparkResult).&&(values.map(_.a.length).toVector ?= typed) }) @@ -1961,26 +2171,49 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { (na: X1[String], values: List[X1[String]]) => val ds = TypedDataset.create(na +: values) - val sparkResult = ds.toDF() - .select(sparkFunctions.levenshtein($"a", sparkFunctions.concat($"a",sparkFunctions.lit("Hello")))) + val sparkResult = ds + .toDF() + .select( + sparkFunctions.levenshtein( + $"a", + sparkFunctions.concat($"a", sparkFunctions.lit("Hello")) + ) + ) .map(_.getAs[Int](0)) .collect() .toVector + .sorted val typed = ds - .select(levenshtein(ds('a), concat(ds('a),lit("Hello")))) + .select(levenshtein(ds('a), concat(ds('a), lit("Hello")))) .collect() .run() .toVector + .sorted val cDS = ds.dataset - val aggrTyped = ds.agg( - levenshtein(frameless.functions.aggregate.first(ds('a)), litAggr("Hello")) - ).firstOption().run().get + val aggrTyped = ds + .coalesce(1) + .orderBy(ds('a).asc) + .agg( + levenshtein( + frameless.functions.aggregate.first(ds('a)), + litAggr("Hello") + ) + ) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello")).as[Int] - ).first() + val aggrSpark = cDS + .coalesce(1) + .orderBy("a") + .select( + sparkFunctions + .levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello")) + .as[Int] + ) + .first() (typed ?= sparkResult).&&(aggrTyped ?= aggrSpark) }) @@ -1992,7 +2225,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { (values: List[X1[String]], n: Int) => val ds = TypedDataset.create(values.map(x => X1(s"$n${x.a}-$n$n"))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.regexp_replace($"a", "\\d+", "n")) .map(_.getAs[String](0)) .collect() @@ -2014,17 +2248,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.reverse($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(reverse(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(reverse(ds[String]('a))).collect().run().toVector (typed ?= sparkResult).&&(values.map(_.a.reverse).toVector ?= typed) }) @@ -2036,17 +2267,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.rpad($"a", 5, "hello")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(rpad(ds[String]('a), 5, "hello")) - .collect() - .run() - .toVector + val typed = + ds.select(rpad(ds[String]('a), 5, "hello")).collect().run().toVector typed ?= sparkResult }) @@ -2058,17 +2287,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.lpad($"a", 5, "hello")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(lpad(ds[String]('a), 5, "hello")) - .collect() - .run() - .toVector + val typed = + ds.select(lpad(ds[String]('a), 5, "hello")).collect().run().toVector typed ?= sparkResult }) @@ -2080,17 +2307,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.rtrim($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(rtrim(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(rtrim(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2102,17 +2326,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.ltrim($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(ltrim(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(ltrim(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2124,17 +2345,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.substring($"a", 5, 3)) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(substring(ds[String]('a), 5, 3)) - .collect() - .run() - .toVector + val typed = + ds.select(substring(ds[String]('a), 5, 3)).collect().run().toVector typed ?= sparkResult }) @@ -2146,17 +2365,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.trim($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(trim(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(trim(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2168,17 +2384,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] => val ds = TypedDataset.create(values.map(X1(_))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.upper($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(upper(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(upper(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2190,27 +2403,29 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] => val ds = TypedDataset.create(values.map(X1(_))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.lower($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(lower(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(lower(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) } test("Empty vararg tests") { - def prop[A : TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { + def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { val ds = TypedDataset.create(data) - val frameless = ds.select(ds('a), concat(), ds('b), concatWs(":")).collect().run().toVector - val framelessAggr = ds.agg(concat(), concatWs("x"), litAggr(2)).collect().run().toVector + val frameless = ds + .select(ds('a), concat(), ds('b), concatWs(":")) + .collect() + .run() + .toVector + val framelessAggr = + ds.agg(concat(), concatWs("x"), litAggr(2)).collect().run().toVector val scala = data.map(x => (x.a, "", x.b, "")) val scalaAggr = Vector(("", "", 2)) (frameless ?= scala).&&(framelessAggr ?= scalaAggr) @@ -2220,8 +2435,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Option[Boolean], Long] _)) } - def dateTimeStringProp(typedDS: TypedDataset[X1[String]]) - (typedCol: TypedColumn[X1[String], Option[Int]], sparkFunc: Column => Column): Prop = { + def dateTimeStringProp( + typedDS: TypedDataset[X1[String]] + )(typedCol: TypedColumn[X1[String], Option[Int]], + sparkFunc: Column => Column + ): Prop = { val spark = session import spark.implicits._ @@ -2231,11 +2449,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .toList - val typed = typedDS - .select(typedCol) - .collect() - .run() - .toList + val typed = typedDS.select(typedCol).collect().run().toList typed ?= sparkResult } @@ -2244,10 +2458,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { - val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(year(ds[String]('a)), sparkFunctions.year) - } + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { + val ds = TypedDataset.create(data) + dateTimeStringProp(ds)(year(ds[String]('a)), sparkFunctions.year) + } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) check(forAll(prop _)) @@ -2257,7 +2475,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(quarter(ds[String]('a)), sparkFunctions.quarter) } @@ -2270,7 +2492,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(month(ds[String]('a)), sparkFunctions.month) } @@ -2283,9 +2509,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(dayofweek(ds[String]('a)), sparkFunctions.dayofweek) + dateTimeStringProp(ds)( + dayofweek(ds[String]('a)), + sparkFunctions.dayofweek + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) @@ -2296,9 +2529,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(dayofmonth(ds[String]('a)), sparkFunctions.dayofmonth) + dateTimeStringProp(ds)( + dayofmonth(ds[String]('a)), + sparkFunctions.dayofmonth + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) @@ -2309,9 +2549,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(dayofyear(ds[String]('a)), sparkFunctions.dayofyear) + dateTimeStringProp(ds)( + dayofyear(ds[String]('a)), + sparkFunctions.dayofyear + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) @@ -2322,7 +2569,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(hour(ds[String]('a)), sparkFunctions.hour) } @@ -2335,7 +2586,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(minute(ds[String]('a)), sparkFunctions.minute) } @@ -2348,7 +2603,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(second(ds[String]('a)), sparkFunctions.second) } @@ -2361,9 +2620,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(weekofyear(ds[String]('a)), sparkFunctions.weekofyear) + dateTimeStringProp(ds)( + weekofyear(ds[String]('a)), + sparkFunctions.weekofyear + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) diff --git a/dataset/src/test/scala/frameless/functions/UdfTests.scala b/dataset/src/test/scala/frameless/functions/UdfTests.scala index 10e65180f..af452cba4 100644 --- a/dataset/src/test/scala/frameless/functions/UdfTests.scala +++ b/dataset/src/test/scala/frameless/functions/UdfTests.scala @@ -4,182 +4,257 @@ package functions import org.scalacheck.Prop import org.scalacheck.Prop._ +import scala.collection.immutable.{ ListSet, TreeSet } + class UdfTests extends TypedDatasetSuite { test("one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X1[A]], f1: A => B): Prop = { - val dataset: TypedDataset[X1[A]] = TypedDataset.create(data) - val u1 = udf[X1[A], A, B](f1) - val u2 = dataset.makeUDF(f1) - val A = dataset.col[A]('a) - - // filter forces whole codegen - val codegen = dataset.deserialized.filter((_:X1[A]) => true).select(u1(A)).collect().run().toVector - - // otherwise it uses local relation - val local = dataset.select(u2(A)).collect().run().toVector - - val d = data.map(x => f1(x.a)) - - (codegen ?= d) && (local ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder]( + data: Vector[X1[A]], + f1: A => B + ): Prop = { + val dataset: TypedDataset[X1[A]] = TypedDataset.create(data) + val u1 = udf[X1[A], A, B](f1) + val u2 = dataset.makeUDF(f1) + val A = dataset.col[A]('a) + + // filter forces whole codegen + val codegen = dataset.deserialized + .filter((_: X1[A]) => true) + .select(u1(A)) + .collect() + .run() + .toVector + + // otherwise it uses local relation + val local = dataset.select(u2(A)).collect().run().toVector + + val d = data.map(x => f1(x.a)) + + (codegen ?= d) && (local ?= d) + } + + check(forAll(prop[Int, Int] _)) + check(forAll(prop[String, String] _)) + check(forAll(prop[Option[Int], Option[Int]] _)) + check(forAll(prop[X1[Int], X1[Int]] _)) + check(forAll(prop[X1[Option[Int]], X1[Option[Int]]] _)) + + // TODO doesn't work for the same reason as `collect` + // check(forAll(prop[X1[Option[X1[Int]]], X1[Option[X1[Option[Int]]]]] _)) + + // Vector/List isn't supported by MapObjects, not all collections are equal see #804 + check(forAll(prop[Option[Seq[String]], Option[Seq[String]]] _)) + check(forAll(prop[Option[List[String]], Option[List[String]]] _)) + check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _)) + + // ListSet/TreeSet weren't supported before #804 + check(forAll(prop[Option[Set[String]], Option[Set[String]]] _)) + check(forAll(prop[Option[ListSet[String]], Option[ListSet[String]]] _)) + check(forAll(prop[Option[TreeSet[String]], Option[TreeSet[String]]] _)) + + def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = + prop(Vector(X1(a)), f) + + check( + forAll( + prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _ + ) + ) + check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _)) } - - check(forAll(prop[Int, Int] _)) - check(forAll(prop[String, String] _)) - check(forAll(prop[Option[Int], Option[Int]] _)) - check(forAll(prop[X1[Int], X1[Int]] _)) - check(forAll(prop[X1[Option[Int]], X1[Option[Int]]] _)) - - // TODO doesn't work for the same reason as `collect` - // check(forAll(prop[X1[Option[X1[Int]]], X1[Option[X1[Option[Int]]]]] _)) - - check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _)) - - def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = prop(Vector(X1(a)), f) - - check(forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _)) - check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _)) } test("multiple one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: A => A, f2: B => B, f3: C => C): Prop = { - val dataset = TypedDataset.create(data) - val u11 = udf[X3[A, B, C], A, A](f1) - val u21 = udf[X3[A, B, C], B, B](f2) - val u31 = udf[X3[A, B, C], C, C](f3) - val u12 = dataset.makeUDF(f1) - val u22 = dataset.makeUDF(f2) - val u32 = dataset.makeUDF(f3) - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector - val dataset22 = dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector - val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c))) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: A => A, + f2: B => B, + f3: C => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u11 = udf[X3[A, B, C], A, A](f1) + val u21 = udf[X3[A, B, C], B, B](f2) + val u31 = udf[X3[A, B, C], C, C](f3) + val u12 = dataset.makeUDF(f1) + val u22 = dataset.makeUDF(f2) + val u32 = dataset.makeUDF(f3) + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = + dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector + val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c))) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) + check(forAll(prop[X3[Int, String, Boolean], Int, Int] _)) + check(forAll(prop[X3U[Int, String, Boolean], Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) - check(forAll(prop[X3[Int, String, Boolean], Int, Int] _)) - check(forAll(prop[X3U[Int, String, Boolean], Int, Int] _)) } test("two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X3[A, B, C], A, B, C](f1) - val u2 = dataset.makeUDF(f1) - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - - val dataset21 = dataset.select(u1(A, B)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B)).collect().run().toVector - val d = data.map(x => f1(x.a, x.b)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X3[A, B, C], A, B, C](f1) + val u2 = dataset.makeUDF(f1) + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + + val dataset21 = dataset.select(u1(A, B)).collect().run().toVector + val dataset22 = dataset.select(u2(A, B)).collect().run().toVector + val d = data.map(x => f1(x.a, x.b)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("multiple two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C, f2: (B, C) => A): Prop = { - val dataset = TypedDataset.create(data) - val u11 = udf[X3[A, B, C], A, B, C](f1) - val u12 = dataset.makeUDF(f1) - val u21 = udf[X3[A, B, C], B, C, A](f2) - val u22 = dataset.makeUDF(f2) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u11(A, B), u21(B, C)).collect().run().toVector - val dataset22 = dataset.select(u12(A, B), u22(B, C)).collect().run().toVector - val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c))) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C, + f2: (B, C) => A + ): Prop = { + val dataset = TypedDataset.create(data) + val u11 = udf[X3[A, B, C], A, B, C](f1) + val u12 = dataset.makeUDF(f1) + val u21 = udf[X3[A, B, C], B, C, A](f2) + val u22 = dataset.makeUDF(f2) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = + dataset.select(u11(A, B), u21(B, C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A, B), u22(B, C)).collect().run().toVector + val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c))) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("three argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f: (A, B, C) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X3[A, B, C], A, B, C, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - - val dataset21 = dataset.select(u1(A, B, C)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + forceInterpreted { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f: (A, B, C) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X3[A, B, C], A, B, C, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + + val dataset21 = dataset.select(u1(A, B, C)).collect().run().toVector + val dataset22 = dataset.select(u2(A, B, C)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int] _)) + check(forAll(prop[String, Int, Int] _)) + } } - - check(forAll(prop[Int, Int, Int] _)) - check(forAll(prop[String, Int, Int] _)) } test("four argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder] - (data: Vector[X4[A, B, C, D]], f: (A, B, C, D) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - val D = dataset.col[D]('d) - - val dataset21 = dataset.select(u1(A, B, C, D)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C, D)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c, x.d)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + forceInterpreted { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]], + f: (A, B, C, D) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + val D = dataset.col[D]('d) + + val dataset21 = + dataset.select(u1(A, B, C, D)).collect().run().toVector + val dataset22 = + dataset.select(u2(A, B, C, D)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c, x.d)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int, Int] _)) + check(forAll(prop[String, Int, Int, String] _)) + check(forAll(prop[String, String, String, String] _)) + check(forAll(prop[String, Long, String, String] _)) + check(forAll(prop[String, Boolean, Boolean, String] _)) + } } - - check(forAll(prop[Int, Int, Int, Int] _)) - check(forAll(prop[String, Int, Int, String] _)) - check(forAll(prop[String, String, String, String] _)) - check(forAll(prop[String, Long, String, String] _)) - check(forAll(prop[String, Boolean, Boolean, String] _)) } test("five argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder, E: TypedEncoder] - (data: Vector[X5[A, B, C, D, E]], f: (A, B, C, D, E) => C): Prop = { - val dataset = TypedDataset.create(data) - val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) - val u2 = dataset.makeUDF(f) - - val A = dataset.col[A]('a) - val B = dataset.col[B]('b) - val C = dataset.col[C]('c) - val D = dataset.col[D]('d) - val E = dataset.col[E]('e) - - val dataset21 = dataset.select(u1(A, B, C, D, E)).collect().run().toVector - val dataset22 = dataset.select(u2(A, B, C, D, E)).collect().run().toVector - val d = data.map(x => f(x.a, x.b, x.c, x.d, x.e)) - - (dataset21 ?= d) && (dataset22 ?= d) + evalCodeGens { + forceInterpreted { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder, + E: TypedEncoder + ](data: Vector[X5[A, B, C, D, E]], + f: (A, B, C, D, E) => C + ): Prop = { + val dataset = TypedDataset.create(data) + val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) + val u2 = dataset.makeUDF(f) + + val A = dataset.col[A]('a) + val B = dataset.col[B]('b) + val C = dataset.col[C]('c) + val D = dataset.col[D]('d) + val E = dataset.col[E]('e) + + val dataset21 = + dataset.select(u1(A, B, C, D, E)).collect().run().toVector + val dataset22 = + dataset.select(u2(A, B, C, D, E)).collect().run().toVector + val d = data.map(x => f(x.a, x.b, x.c, x.d, x.e)) + + (dataset21 ?= d) && (dataset22 ?= d) + } + + check(forAll(prop[Int, Int, Int, Int, Int] _)) + } } - - check(forAll(prop[Int, Int, Int, Int, Int] _)) } } diff --git a/dataset/src/test/scala/frameless/ops/CubeTests.scala b/dataset/src/test/scala/frameless/ops/CubeTests.scala index 7a06822b9..5fa61a14e 100644 --- a/dataset/src/test/scala/frameless/ops/CubeTests.scala +++ b/dataset/src/test/scala/frameless/ops/CubeTests.scala @@ -1,6 +1,8 @@ package frameless package ops +import frameless.functions.DoubleBehaviourUtils.{ dp5, tolerantCompareVectors } +import frameless.functions.ToDecimal import frameless.functions.aggregate._ import org.scalacheck.Prop import org.scalacheck.Prop._ @@ -8,14 +10,28 @@ import org.scalacheck.Prop._ class CubeTests extends TypedDatasetSuite { test("cube('a).agg(count())") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.cube(A).agg(count()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.cube("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = dataset + .cube(A) + .agg(count()) + .collect() + .run() + .toVector + .sortBy(t => (t._2, t._1)) + val expected = dataset.dataset + .cube("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(t => (t._2, t._1)) received ?= expected } @@ -24,15 +40,34 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a, 'b).agg(count())") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + Out: TypedEncoder: Numeric: Ordering + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.cube(A, B).agg(count()).collect().run().toVector.sortBy(_._3) - val expected = dataset.dataset.cube("a", "b").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))).sortBy(_._3) + val received = dataset + .cube(A, B) + .agg(count()) + .collect() + .run() + .toVector + .sortBy(t => (t._3, t._2, t._1)) + val expected = dataset.dataset + .cube("a", "b") + .count() + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2)) + ) + .sortBy(t => (t._3, t._2, t._1)) received ?= expected } @@ -41,15 +76,32 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).agg(sum('b)") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.cube(A).agg(sum(B)).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.cube("a").sum("b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))).sortBy(_._2) + val received = dataset + .cube(A) + .agg(sum(B)) + .collect() + .run() + .toVector + .sortBy(t => (t._2, t._1)) + val expected = dataset.dataset + .cube("a") + .sum("b") + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))) + .sortBy(t => (t._2, t._1)) received ?= expected } @@ -58,15 +110,22 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).mapGroups('a, sum('b))") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder : Numeric] - (data: List[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric]( + data: List[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.cube(A) - .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } - .collect().run().toVector.sortBy(_._1) - val expected = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) + val received = dataset + .cube(A) + .deserialized + .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } + .collect() + .run() + .toVector + .sortBy(_._1) + val expected = + data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) received ?= expected } @@ -76,61 +135,137 @@ class CubeTests extends TypedDatasetSuite { test("cube('a).agg(sum('b), sum('c)) to cube('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - C: TypedEncoder, - OutB: TypedEncoder : Numeric, - OutC: TypedEncoder : Numeric - ](data: List[X3[A, B, C]])( - implicit - summableB: CatalystSummable[B, OutB], - summableC: CatalystSummable[C, OutC] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + C: TypedEncoder, + OutB: TypedEncoder: Numeric, + OutC: TypedEncoder: Numeric: ToDecimal + ](data: List[X3[A, B, C]] + )(implicit + summableB: CatalystSummable[B, OutB], + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) + val toDecOpt = implicitly[ToDecimal[OutC]].truncate _ + val framelessSumBC = dataset .cube(A) .agg(sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBC = dataset.dataset.cube("a").sum("b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2))) - .sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3))) + .sortBy(t => (t._1, t._2, t._3)) + + val sparkSumBC = dataset.dataset + .cube("a") + .sum("b", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)) + ) + ) + .sortBy(t => (t._1, t._2, t._3)) val framelessSumBCB = dataset .cube(A) .agg(sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBCB = dataset.dataset.cube("a").sum("b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3))) - .sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3))) + .sortBy(t => (t._1, t._2, t._3)) + + val sparkSumBCB = dataset.dataset + .cube("a") + .sum("b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)), + row.getAs[OutB](3) + ) + ) + .sortBy(t => (t._1, t._2, t._3)) val framelessSumBCBC = dataset .cube(A) .agg(sum(B), sum(C), sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBCBC = dataset.dataset.cube("a").sum("b", "c", "b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4))) - .sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3), _5 = toDecOpt(row._5))) + .sortBy(t => (t._1, t._2, t._3)) + + val sparkSumBCBC = dataset.dataset + .cube("a") + .sum("b", "c", "b", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)), + row.getAs[OutB](3), + toDecOpt(row.getAs[OutC](4)) + ) + ) + .sortBy(t => (t._1, t._2, t._3)) val framelessSumBCBCB = dataset .cube(A) .agg(sum(B), sum(C), sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBCBCB = dataset.dataset.cube("a").sum("b", "c", "b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4), row.getAs[OutB](5))) - .sortBy(_._1) - - (framelessSumBC ?= sparkSumBC) - .&&(framelessSumBCB ?= sparkSumBCB) - .&&(framelessSumBCBC ?= sparkSumBCBC) - .&&(framelessSumBCBCB ?= sparkSumBCBCB) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3), _5 = toDecOpt(row._5))) + .sortBy(t => (t._1, t._2, t._3)) + + val sparkSumBCBCB = dataset.dataset + .cube("a") + .sum("b", "c", "b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)), + row.getAs[OutB](3), + toDecOpt(row.getAs[OutC](4)), + row.getAs[OutB](5) + ) + ) + .sortBy(t => (t._1, t._2, t._3)) + + (tolerantCompareVectors(framelessSumBC, sparkSumBC, dp5)(Seq(l => l._3))) + .&&( + tolerantCompareVectors(framelessSumBCB, sparkSumBCB, dp5)( + Seq(l => l._3) + ) + ) + .&&( + tolerantCompareVectors(framelessSumBCBC, sparkSumBCBC, dp5)( + Seq(l => l._3, l => l._5) + ) + ) + .&&( + tolerantCompareVectors(framelessSumBCBCB, sparkSumBCBCB, dp5)( + Seq(l => l._3, l => l._5) + ) + ) } check(forAll(prop[String, Long, Double, Long, Double] _)) @@ -138,34 +273,52 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).agg(sum('c), sum('d))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - D: TypedEncoder, - OutC: TypedEncoder : Numeric, - OutD: TypedEncoder : Numeric - ](data: List[X4[A, B, C, D]])( - implicit - summableC: CatalystSummable[C, OutC], - summableD: CatalystSummable[D, OutD] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + D: TypedEncoder, + OutC: TypedEncoder: Numeric, + OutD: TypedEncoder: Numeric: ToDecimal + ](data: List[X4[A, B, C, D]] + )(implicit + summableC: CatalystSummable[C, OutC], + summableD: CatalystSummable[D, OutD] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) val D = dataset.col[D]('d) + val toDecOpt = implicitly[ToDecimal[OutD]].truncate _ + val framelessSumByAB = dataset .cube(A, B) .agg(sum(C), sum(D)) - .collect().run().toVector.sortBy(x => (x._1, x._2)) + .collect() + .run() + .toVector + .map(row => row.copy(_4 = toDecOpt(row._4))) + .sortBy(x => (x._1, x._2)) val sparkSumByAB = dataset.dataset - .cube("a", "b").sum("c", "d").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutD](3))) + .cube("a", "b") + .sum("c", "d") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + toDecOpt(row.getAs[OutD](3)) + ) + ) .sortBy(x => (x._1, x._2)) - framelessSumByAB ?= sparkSumByAB + tolerantCompareVectors(framelessSumByAB, sparkSumByAB, dp5)( + Seq(l => l._4) + ) } check(forAll(prop[Byte, Int, Long, Double, Long, Double] _)) @@ -173,11 +326,17 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).agg(sum('c)) to cube('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - OutC: TypedEncoder: Numeric - ](data: List[X3[A, B, C]])(implicit summableC: CatalystSummable[C, OutC]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + OutC: TypedEncoder: Numeric: ToDecimal + ](data: List[X3[A, B, C]] + )(implicit + summableC: CatalystSummable[C, OutC] + ): Prop = { + + val toDecOpt = implicitly[ToDecimal[OutC]].truncate _ + val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -186,63 +345,162 @@ class CubeTests extends TypedDatasetSuite { val framelessSumC = dataset .cube(A, B) .agg(sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3))) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumC = dataset.dataset - .cube("a", "b").sum("c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2))) - .sortBy(_._2) + .cube("a", "b") + .sum("c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCC = dataset .cube(A, B) .agg(sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3), _4 = toDecOpt(row._4))) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCC = dataset.dataset - .cube("a", "b").sum("c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3))) - .sortBy(_._2) + .cube("a", "b") + .sum("c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCCC = dataset .cube(A, B) .agg(sum(C), sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => + row.copy( + _3 = toDecOpt(row._3), + _4 = toDecOpt(row._4), + _5 = toDecOpt(row._5) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCCC = dataset.dataset - .cube("a", "b").sum("c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4))) - .sortBy(_._2) + .cube("a", "b") + .sum("c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)), + toDecOpt(row.getAs[OutC](4)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCCCC = dataset .cube(A, B) .agg(sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => + row.copy( + _3 = toDecOpt(row._3), + _4 = toDecOpt(row._4), + _5 = toDecOpt(row._5), + _6 = toDecOpt(row._6) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCCCC = dataset.dataset - .cube("a", "b").sum("c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5))) - .sortBy(_._2) + .cube("a", "b") + .sum("c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)), + toDecOpt(row.getAs[OutC](4)), + toDecOpt(row.getAs[OutC](5)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCCCCC = dataset .cube(A, B) .agg(sum(C), sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => + row.copy( + _3 = toDecOpt(row._3), + _4 = toDecOpt(row._4), + _5 = toDecOpt(row._5), + _6 = toDecOpt(row._6), + _7 = toDecOpt(row._7) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCCCCC = dataset.dataset - .cube("a", "b").sum("c", "c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5), row.getAs[OutC](6))) - .sortBy(_._2) - - (framelessSumC ?= sparkSumC) && - (framelessSumCC ?= sparkSumCC) && - (framelessSumCCC ?= sparkSumCCC) && - (framelessSumCCCC ?= sparkSumCCCC) && - (framelessSumCCCCC ?= sparkSumCCCCC) + .cube("a", "b") + .sum("c", "c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)), + toDecOpt(row.getAs[OutC](4)), + toDecOpt(row.getAs[OutC](5)), + toDecOpt(row.getAs[OutC](6)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) + + (tolerantCompareVectors(framelessSumC, sparkSumC, dp5)(Seq(l => l._3))) && + (tolerantCompareVectors(framelessSumCC, sparkSumCC, dp5)( + Seq(l => l._3, l => l._4) + )) && + (tolerantCompareVectors(framelessSumCCC, sparkSumCCC, dp5)( + Seq(l => l._3, l => l._4, l => l._5) + )) && + (tolerantCompareVectors(framelessSumCCCC, sparkSumCCCC, dp5)( + Seq(l => l._3, l => l._4, l => l._5, l => l._6) + )) && + (tolerantCompareVectors(framelessSumCCCCC, sparkSumCCCCC, dp5)( + Seq(l => l._3, l => l._4, l => l._5, l => l._6, l => l._7) + )) } check(forAll(prop[String, Long, Double, Double] _)) @@ -250,22 +508,30 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).mapGroups('a, 'b, sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Numeric - ](data: List[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val framelessSumByAB = dataset .cube(A, B) - .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } - .collect().run().toVector.sortBy(x => (x._1, x._2)) + .deserialized + .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } + .collect() + .run() + .toVector + .sortBy(x => (x._1, x._2)) - val sumByAB = data.groupBy(x => (x.a, x.b)) + val sumByAB = data + .groupBy(x => (x.a, x.b)) .mapValues { xs => xs.map(_.c).sum } - .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2)) + .toVector + .map { case ((a, b), c) => (a, b, c) } + .sortBy(x => (x._1, x._2)) framelessSumByAB ?= sumByAB } @@ -274,17 +540,19 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).mapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder: Ordering, - B: TypedEncoder: Ordering, - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .cube(A) - .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted)) - .collect().run().toMap + .deserialized + .mapGroups((a, xs) => (a, xs.toVector.sorted)) + .collect() + .run() + .toMap val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted } @@ -297,21 +565,23 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).flatMapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .cube(A) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(_.a).toSeq + .groupBy(_.a) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -325,22 +595,26 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Ordering - ](data: Vector[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](data: Vector[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val cA = dataset.col[A]('a) val cB = dataset.col[B]('b) val datasetGrouped = dataset .cube(cA, cB) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(t => (t.a, t.b)).toSeq + .groupBy(t => (t.a, t.b)) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -353,18 +627,32 @@ class CubeTests extends TypedDatasetSuite { } test("cubeMany('a).agg(sum('b))") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.cubeMany(A).agg(count[X1[A]]()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.cube("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = dataset + .cubeMany(A) + .agg(count[X1[A]]()) + .collect() + .run() + .toVector + .sortBy(_.swap) + val expected = dataset.dataset + .cube("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_.swap) received ?= expected } check(forAll(prop[Int, Long] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/ops/PivotTest.scala b/dataset/src/test/scala/frameless/ops/PivotTest.scala index dd9bf5e61..251a41c50 100644 --- a/dataset/src/test/scala/frameless/ops/PivotTest.scala +++ b/dataset/src/test/scala/frameless/ops/PivotTest.scala @@ -2,12 +2,13 @@ package frameless package ops import frameless.functions.aggregate._ -import org.apache.spark.sql.{functions => sparkFunctions} +import org.apache.spark.sql.{ functions => sparkFunctions } import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Prop._ -import org.scalacheck.{Gen, Prop} +import org.scalacheck.{ Gen, Prop } class PivotTest extends TypedDatasetSuite { + def withCustomGenX4: Gen[Vector[X4[String, String, Int, Boolean]]] = { val kvPairGen: Gen[X4[String, String, Int, Boolean]] = for { a <- Gen.oneOf(Seq("1", "2", "3", "4")) @@ -22,77 +23,113 @@ class PivotTest extends TypedDatasetSuite { test("X4[Boolean, String, Int, Boolean] pivot on String") { def prop(data: Vector[X4[String, String, Int, Boolean]]): Prop = { val d = TypedDataset.create(data) - val frameless = d.groupBy(d('a)). - pivot(d('b)).on("a", "b", "c"). - agg(sum(d('c)), first(d('d))).collect().run().toVector + val frameless = d + .coalesce(1) + .orderBy(d('a).asc, d('d).asc) + .groupBy(d('a)) + .pivot(d('b)) + .on("a", "b", "c") + .agg(sum(d('c)), first(d('d))) + .collect() + .run() + .toVector - val spark = d.dataset.groupBy("a") + val spark = d.dataset + .coalesce(1) + .orderBy("a", "d") + .groupBy("a") .pivot("b", Seq("a", "b", "c")) - .agg(sparkFunctions.sum("c"), sparkFunctions.first("d")).collect().toVector + .agg(sparkFunctions.sum("c"), sparkFunctions.first("d")) + .collect() + .toVector - (frameless.map(_._1) ?= spark.map(x => x.getAs[String](0))).&&( - frameless.map(_._2) ?= spark.map(x => Option(x.getAs[Long](1)))).&&( - frameless.map(_._3) ?= spark.map(x => Option(x.getAs[Boolean](2)))).&&( - frameless.map(_._4) ?= spark.map(x => Option(x.getAs[Long](3)))).&&( - frameless.map(_._5) ?= spark.map(x => Option(x.getAs[Boolean](4)))).&&( - frameless.map(_._6) ?= spark.map(x => Option(x.getAs[Long](5)))).&&( - frameless.map(_._7) ?= spark.map(x => Option(x.getAs[Boolean](6)))) + (frameless.map(_._1) ?= spark.map(x => x.getAs[String](0))) + .&&(frameless.map(_._2) ?= spark.map(x => Option(x.getAs[Long](1)))) + .&&(frameless.map(_._3) ?= spark.map(x => Option(x.getAs[Boolean](2)))) + .&&(frameless.map(_._4) ?= spark.map(x => Option(x.getAs[Long](3)))) + .&&(frameless.map(_._5) ?= spark.map(x => Option(x.getAs[Boolean](4)))) + .&&(frameless.map(_._6) ?= spark.map(x => Option(x.getAs[Long](5)))) + .&&(frameless.map(_._7) ?= spark.map(x => Option(x.getAs[Boolean](6)))) } check(forAll(withCustomGenX4)(prop)) } test("Pivot on Boolean") { - val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) + val x: Seq[X3[String, Boolean, Boolean]] = + Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.groupByMany(d('a)). - pivot(d('c)).on(true, false). - agg(count[X3[String, Boolean, Boolean]]()). - collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false + d.groupByMany(d('a)) + .pivot(d('c)) + .on(true, false) + .agg(count[X3[String, Boolean, Boolean]]()) + .collect() + .run() + .toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false } test("Pivot with groupBy on two columns, pivot on Long") { - val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) + val x: Seq[X3[String, String, Long]] = + Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) - d.groupBy(d('a), d('b)). - pivot(d('c)).on(1L, 20L). - agg(count[X3[String, String, Long]]()). - collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) + d.groupBy(d('a), d('b)) + .pivot(d('c)) + .on(1L, 20L) + .agg(count[X3[String, String, Long]]()) + .collect() + .run() + .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } test("Pivot with cube on two columns, pivot on Long") { - val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) + val x: Seq[X3[String, String, Long]] = + Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) d.cube(d('a), d('b)) - .pivot(d('c)).on(1L, 20L) + .pivot(d('c)) + .on(1L, 20L) .agg(count[X3[String, String, Long]]()) - .collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) + .collect() + .run() + .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } test("Pivot with cube on Boolean") { - val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) + val x: Seq[X3[String, Boolean, Boolean]] = + Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.cube(d('a)). - pivot(d('c)).on(true, false). - agg(count[X3[String, Boolean, Boolean]]()). - collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) + d.cube(d('a)) + .pivot(d('c)) + .on(true, false) + .agg(count[X3[String, Boolean, Boolean]]()) + .collect() + .run() + .toVector ?= Vector(("a", Some(2L), Some(1L))) } test("Pivot with rollup on two columns, pivot on Long") { - val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) + val x: Seq[X3[String, String, Long]] = + Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) d.rollup(d('a), d('b)) - .pivot(d('c)).on(1L, 20L) + .pivot(d('c)) + .on(1L, 20L) .agg(count[X3[String, String, Long]]()) - .collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) + .collect() + .run() + .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } test("Pivot with rollup on Boolean") { - val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) + val x: Seq[X3[String, Boolean, Boolean]] = + Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.rollupMany(d('a)). - pivot(d('c)).on(true, false). - agg(count[X3[String, Boolean, Boolean]]()). - collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) + d.rollupMany(d('a)) + .pivot(d('c)) + .on(true, false) + .agg(count[X3[String, Boolean, Boolean]]()) + .collect() + .run() + .toVector ?= Vector(("a", Some(2L), Some(1L))) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/ops/RollupTests.scala b/dataset/src/test/scala/frameless/ops/RollupTests.scala index da73ef8d0..20cd4f405 100644 --- a/dataset/src/test/scala/frameless/ops/RollupTests.scala +++ b/dataset/src/test/scala/frameless/ops/RollupTests.scala @@ -1,6 +1,8 @@ package frameless package ops +import frameless.functions.DoubleBehaviourUtils.{ dp5, tolerantCompareVectors } +import frameless.functions.ToDecimal import frameless.functions.aggregate._ import org.scalacheck.Prop import org.scalacheck.Prop._ @@ -8,14 +10,23 @@ import org.scalacheck.Prop._ class RollupTests extends TypedDatasetSuite { test("rollup('a).agg(count())") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.rollup(A).agg(count()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.rollup("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = + dataset.rollup(A).agg(count()).collect().run().toVector.sortBy(_.swap) + val expected = dataset.dataset + .rollup("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_.swap) received ?= expected } @@ -24,15 +35,34 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a, 'b).agg(count())") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.rollup(A, B).agg(count()).collect().run().toVector.sortBy(_._3) - val expected = dataset.dataset.rollup("a", "b").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))).sortBy(_._3) + val received = dataset + .rollup(A, B) + .agg(count()) + .collect() + .run() + .toVector + .sortBy(t => (t._3, t._2, t._1)) + val expected = dataset.dataset + .rollup("a", "b") + .count() + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2)) + ) + .sortBy(t => (t._3, t._2, t._1)) received ?= expected } @@ -41,15 +71,27 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).agg(sum('b)") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.rollup(A).agg(sum(B)).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.rollup("a").sum("b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))).sortBy(_._2) + val received = + dataset.rollup(A).agg(sum(B)).collect().run().toVector.sortBy(_.swap) + val expected = dataset.dataset + .rollup("a") + .sum("b") + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))) + .sortBy(_.swap) received ?= expected } @@ -58,15 +100,22 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).mapGroups('a, sum('b))") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder : Numeric] - (data: List[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric]( + data: List[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.rollup(A) - .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } - .collect().run().toVector.sortBy(_._1) - val expected = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) + val received = dataset + .rollup(A) + .deserialized + .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } + .collect() + .run() + .toVector + .sortBy(identity) + val expected = + data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(identity) received ?= expected } @@ -76,61 +125,138 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a).agg(sum('b), sum('c)) to rollup('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - C: TypedEncoder, - OutB: TypedEncoder : Numeric, - OutC: TypedEncoder : Numeric - ](data: List[X3[A, B, C]])( - implicit - summableB: CatalystSummable[B, OutB], - summableC: CatalystSummable[C, OutC] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + C: TypedEncoder, + OutB: TypedEncoder: Numeric, + OutC: TypedEncoder: Numeric: ToDecimal + ](data: List[X3[A, B, C]] + )(implicit + summableB: CatalystSummable[B, OutB], + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) + val toDecOpt = implicitly[ToDecimal[OutC]].truncate _ + val framelessSumBC = dataset .rollup(A) .agg(sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBC = dataset.dataset.rollup("a").sum("b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2))) - .sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3))) + .sortBy(identity) + + val sparkSumBC = dataset.dataset + .rollup("a") + .sum("b", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)) + ) + ) + .sortBy(identity) val framelessSumBCB = dataset .rollup(A) .agg(sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBCB = dataset.dataset.rollup("a").sum("b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3))) - .sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3))) + .sortBy(identity) + + val sparkSumBCB = dataset.dataset + .rollup("a") + .sum("b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)), + row.getAs[OutB](3) + ) + ) + .sortBy(identity) val framelessSumBCBC = dataset .rollup(A) .agg(sum(B), sum(C), sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) - - val sparkSumBCBC = dataset.dataset.rollup("a").sum("b", "c", "b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4))) - .sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3), _5 = toDecOpt(row._5))) + .sortBy(identity) + + val sparkSumBCBC = dataset.dataset + .rollup("a") + .sum("b", "c", "b", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)), + row.getAs[OutB](3), + toDecOpt(row.getAs[OutC](4)) + ) + ) + .sortBy(identity) val framelessSumBCBCB = dataset .rollup(A) .agg(sum(B), sum(C), sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3), _5 = toDecOpt(row._5))) + .sortBy(identity) + + val sparkSumBCBCB = dataset.dataset + .rollup("a") + .sum("b", "c", "b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + toDecOpt(row.getAs[OutC](2)), + row.getAs[OutB](3), + toDecOpt(row.getAs[OutC](4)), + row.getAs[OutB](5) + ) + ) + .sortBy(identity) + + (tolerantCompareVectors(framelessSumBC, sparkSumBC, dp5)(Seq(l => l._3))) + .&&( + tolerantCompareVectors(framelessSumBCB, sparkSumBCB, dp5)( + Seq(l => l._3) + ) + ) + .&&( + tolerantCompareVectors(framelessSumBCBC, sparkSumBCBC, dp5)( + Seq(l => l._3, l => l._5) + ) + ) + .&&( + tolerantCompareVectors(framelessSumBCBCB, sparkSumBCBCB, dp5)( + Seq(l => l._3, l => l._5) + ) + ) - val sparkSumBCBCB = dataset.dataset.rollup("a").sum("b", "c", "b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4), row.getAs[OutB](5))) - .sortBy(_._1) - - (framelessSumBC ?= sparkSumBC) - .&&(framelessSumBCB ?= sparkSumBCB) - .&&(framelessSumBCBC ?= sparkSumBCBC) - .&&(framelessSumBCBCB ?= sparkSumBCBCB) } check(forAll(prop[String, Long, Double, Long, Double] _)) @@ -138,34 +264,52 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).agg(sum('c), sum('d))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - D: TypedEncoder, - OutC: TypedEncoder : Numeric, - OutD: TypedEncoder : Numeric - ](data: List[X4[A, B, C, D]])( - implicit - summableC: CatalystSummable[C, OutC], - summableD: CatalystSummable[D, OutD] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + D: TypedEncoder, + OutC: TypedEncoder: Numeric, + OutD: TypedEncoder: Numeric: ToDecimal + ](data: List[X4[A, B, C, D]] + )(implicit + summableC: CatalystSummable[C, OutC], + summableD: CatalystSummable[D, OutD] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) val D = dataset.col[D]('d) + val toDecOpt = implicitly[ToDecimal[OutD]].truncate _ + val framelessSumByAB = dataset .rollup(A, B) .agg(sum(C), sum(D)) - .collect().run().toVector.sortBy(_._2) + .collect() + .run() + .toVector + .map(row => row.copy(_4 = toDecOpt(row._4))) + .sortBy(t => (t._2, t._1, t._3, t._4)) val sparkSumByAB = dataset.dataset - .rollup("a", "b").sum("c", "d").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutD](3))) - .sortBy(_._2) - - framelessSumByAB ?= sparkSumByAB + .rollup("a", "b") + .sum("c", "d") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + toDecOpt(row.getAs[OutD](3)) + ) + ) + .sortBy(t => (t._2, t._1, t._3, t._4)) + + tolerantCompareVectors(framelessSumByAB, sparkSumByAB, dp5)( + Seq(l => l._4) + ) } check(forAll(prop[Byte, Int, Long, Double, Long, Double] _)) @@ -173,76 +317,180 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).agg(sum('c)) to rollup('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - OutC: TypedEncoder: Numeric - ](data: List[X3[A, B, C]])(implicit summableC: CatalystSummable[C, OutC]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + OutC: TypedEncoder: Numeric: ToDecimal + ](data: List[X3[A, B, C]] + )(implicit + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) + val toDecOpt = implicitly[ToDecimal[OutC]].truncate _ + val framelessSumC = dataset .rollup(A, B) .agg(sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3))) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumC = dataset.dataset - .rollup("a", "b").sum("c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2))) - .sortBy(_._2) + .rollup("a", "b") + .sum("c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCC = dataset .rollup(A, B) .agg(sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => row.copy(_3 = toDecOpt(row._3), _4 = toDecOpt(row._4))) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCC = dataset.dataset - .rollup("a", "b").sum("c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3))) - .sortBy(_._2) + .rollup("a", "b") + .sum("c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCCC = dataset .rollup(A, B) .agg(sum(C), sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => + row.copy( + _3 = toDecOpt(row._3), + _4 = toDecOpt(row._4), + _5 = toDecOpt(row._5) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCCC = dataset.dataset - .rollup("a", "b").sum("c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4))) - .sortBy(_._2) + .rollup("a", "b") + .sum("c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)), + toDecOpt(row.getAs[OutC](4)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCCCC = dataset .rollup(A, B) .agg(sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => + row.copy( + _3 = toDecOpt(row._3), + _4 = toDecOpt(row._4), + _5 = toDecOpt(row._5), + _6 = toDecOpt(row._6) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCCCC = dataset.dataset - .rollup("a", "b").sum("c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5))) - .sortBy(_._2) + .rollup("a", "b") + .sum("c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)), + toDecOpt(row.getAs[OutC](4)), + toDecOpt(row.getAs[OutC](5)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val framelessSumCCCCC = dataset .rollup(A, B) .agg(sum(C), sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector - .sortBy(_._2) + .collect() + .run() + .toVector + .map(row => + row.copy( + _3 = toDecOpt(row._3), + _4 = toDecOpt(row._4), + _5 = toDecOpt(row._5), + _6 = toDecOpt(row._6), + _7 = toDecOpt(row._7) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) val sparkSumCCCCC = dataset.dataset - .rollup("a", "b").sum("c", "c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5), row.getAs[OutC](6))) - .sortBy(_._2) - - (framelessSumC ?= sparkSumC) && - (framelessSumCC ?= sparkSumCC) && - (framelessSumCCC ?= sparkSumCCC) && - (framelessSumCCCC ?= sparkSumCCCC) && - (framelessSumCCCCC ?= sparkSumCCCCC) + .rollup("a", "b") + .sum("c", "c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + toDecOpt(row.getAs[OutC](2)), + toDecOpt(row.getAs[OutC](3)), + toDecOpt(row.getAs[OutC](4)), + toDecOpt(row.getAs[OutC](5)), + toDecOpt(row.getAs[OutC](6)) + ) + ) + .sortBy(t => (t._2, t._1, t._3)) + + (tolerantCompareVectors(framelessSumC, sparkSumC, dp5)(Seq(l => l._3))) && + (tolerantCompareVectors(framelessSumCC, sparkSumCC, dp5)( + Seq(l => l._3, l => l._4) + )) && + (tolerantCompareVectors(framelessSumCCC, sparkSumCCC, dp5)( + Seq(l => l._3, l => l._4, l => l._5) + )) && + (tolerantCompareVectors(framelessSumCCCC, sparkSumCCCC, dp5)( + Seq(l => l._3, l => l._4, l => l._5, l => l._6) + )) && + (tolerantCompareVectors(framelessSumCCCCC, sparkSumCCCCC, dp5)( + Seq(l => l._3, l => l._4, l => l._5, l => l._6, l => l._7) + )) } check(forAll(prop[String, Long, Double, Double] _)) @@ -250,22 +498,30 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).mapGroups('a, 'b, sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Numeric - ](data: List[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val framelessSumByAB = dataset .rollup(A, B) - .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } - .collect().run().toVector.sortBy(x => (x._1, x._2)) - - val sumByAB = data.groupBy(x => (x.a, x.b)) + .deserialized + .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } + .collect() + .run() + .toVector + .sortBy(identity) + + val sumByAB = data + .groupBy(x => (x.a, x.b)) .mapValues { xs => xs.map(_.c).sum } - .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2)) + .toVector + .map { case ((a, b), c) => (a, b, c) } + .sortBy(identity) framelessSumByAB ?= sumByAB } @@ -274,17 +530,19 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).mapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder: Ordering, - B: TypedEncoder: Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .rollup(A) - .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted)) - .collect().run().toMap + .deserialized + .mapGroups((a, xs) => (a, xs.toVector.sorted)) + .collect() + .run() + .toMap val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted } @@ -297,21 +555,23 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).flatMapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .rollup(A) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(_.a).toSeq + .groupBy(_.a) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -325,22 +585,26 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Ordering - ](data: Vector[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](data: Vector[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val cA = dataset.col[A]('a) val cB = dataset.col[B]('b) val datasetGrouped = dataset .rollup(cA, cB) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(t => (t.a, t.b)).toSeq + .groupBy(t => (t.a, t.b)) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -353,18 +617,32 @@ class RollupTests extends TypedDatasetSuite { } test("rollupMany('a).agg(sum('b))") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.rollupMany(A).agg(count[X1[A]]()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.rollup("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = dataset + .rollupMany(A) + .agg(count[X1[A]]()) + .collect() + .run() + .toVector + .sortBy(_.swap) + val expected = dataset.dataset + .rollup("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_.swap) received ?= expected } check(forAll(prop[Int, Long] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index 82ff375c9..8085582a2 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -1,9 +1,13 @@ import java.time.format.DateTimeFormatter -import java.time.{LocalDateTime => JavaLocalDateTime} +import java.time.{ LocalDateTime => JavaLocalDateTime } +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.scalacheck.{ Arbitrary, Cogen, Gen } -import org.scalacheck.{Arbitrary, Gen} +import scala.collection.immutable.{ ListSet, TreeSet } package object frameless { + /** Fixed decimal point to avoid precision problems specific to Spark */ implicit val arbBigDecimal: Arbitrary[BigDecimal] = Arbitrary { for { @@ -30,11 +34,62 @@ package object frameless { } // see issue with scalacheck non serializable Vector: https://github.com/rickynils/scalacheck/issues/315 - implicit def arbVector[A](implicit A: Arbitrary[A]): Arbitrary[Vector[A]] = + implicit def arbVector[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[Vector[A]] = Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector)) def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary + implicit def arbSeq[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[scala.collection.Seq[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector.toSeq)) + + def seqGen[A: Arbitrary]: Gen[scala.collection.Seq[A]] = arbSeq[A].arbitrary + + implicit def arbList[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[List[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(_.toList)) + + def listGen[A: Arbitrary]: Gen[List[A]] = arbList[A].arbitrary + + implicit def arbSet[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[Set[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(Set.newBuilder.++=(_).result())) + + def setGen[A: Arbitrary]: Gen[Set[A]] = arbSet[A].arbitrary + + implicit def cogenListSet[A: Cogen: Ordering]: Cogen[ListSet[A]] = + Cogen.it(_.toVector.sorted.iterator) + + implicit def arbListSet[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[ListSet[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(ListSet.newBuilder.++=(_).result())) + + def listSetGen[A: Arbitrary]: Gen[ListSet[A]] = arbListSet[A].arbitrary + + implicit def cogenTreeSet[A: Cogen: Ordering]: Cogen[TreeSet[A]] = + Cogen.it(_.toVector.sorted.iterator) + + implicit def arbTreeSet[A]( + implicit + A: Arbitrary[A], + o: Ordering[A] + ): Arbitrary[TreeSet[A]] = + Arbitrary(Gen.listOf(A.arbitrary).map(TreeSet.newBuilder.++=(_).result())) + + def treeSetGen[A: Arbitrary: Ordering]: Gen[TreeSet[A]] = + arbTreeSet[A].arbitrary + implicit val arbUdtEncodedClass: Arbitrary[UdtEncodedClass] = Arbitrary { for { int <- Arbitrary.arbitrary[Int] @@ -42,7 +97,8 @@ package object frameless { } yield new UdtEncodedClass(int, doubles.toArray) } - val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm") + val dateTimeFormatter: DateTimeFormatter = + DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm") implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary { for { @@ -61,7 +117,27 @@ package object frameless { localDate <- listOfDates } yield localDate.format(dateTimeFormatter) - val TEST_OUTPUT_DIR = "target/test-output" + private var outputDir: String = _ + + /** allow test usage on non-build environments */ + def setOutputDir(path: String): Unit = { + outputDir = path + } + + lazy val TEST_OUTPUT_DIR = + if (outputDir ne null) + outputDir + else + "target/test-output" + + private var shouldClose = true + + /** allow test usage on non-build environments */ + def setShouldCloseSession(shouldClose: Boolean): Unit = { + this.shouldClose = shouldClose + } + + lazy val shouldCloseSession = shouldClose /** * Will dive down causes until either the cause is true or there are no more causes @@ -72,11 +148,10 @@ package object frameless { def anyCauseHas(t: Throwable, f: Throwable => Boolean): Boolean = if (f(t)) true + else if (t.getCause ne null) + anyCauseHas(t.getCause, f) else - if (t.getCause ne null) - anyCauseHas(t.getCause, f) - else - false + false /** * Runs up to maxRuns and outputs the number of failures (times thrown) @@ -85,11 +160,11 @@ package object frameless { * @tparam T * @return the last passing thunk, or null */ - def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T ={ + def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T = { var i = 0 var r = null.asInstanceOf[T] var passed = 0 - while(i < maxRuns){ + while (i < maxRuns) { i += 1 try { r = thunk @@ -98,29 +173,36 @@ package object frameless { println(s"run $i successful") } } catch { - case t: Throwable => System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}") + case t: Throwable => + System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}") } } if (passed != maxRuns) { - System.err.println(s"had ${maxRuns - passed} failures out of $maxRuns runs") + System.err.println( + s"had ${maxRuns - passed} failures out of $maxRuns runs" + ) } r } - /** + /** * Runs a given thunk up to maxRuns times, restarting the thunk if tolerantOf the thrown Throwable is true * @param tolerantOf * @param maxRuns default of 20 * @param thunk * @return either a successful run result or the last error will be thrown */ - def tolerantRun[T](tolerantOf: Throwable => Boolean, maxRuns: Int = 20)(thunk: => T): T ={ + def tolerantRun[T]( + tolerantOf: Throwable => Boolean, + maxRuns: Int = 20 + )(thunk: => T + ): T = { var passed = false var i = 0 var res: T = null.asInstanceOf[T] var thrown: Throwable = null - while((i < maxRuns) && !passed) { + while ((i < maxRuns) && !passed) { try { i += 1 res = thunk @@ -139,4 +221,58 @@ package object frameless { } res } + + // from Quality, which is from Spark test versions + + // if this blows then debug on CodeGenerator 1294, 1299 and grab code.body + def forceCodeGen[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + def forceInterpreted[T](f: => T): T = { + val codegenMode = CodegenObjectFactoryMode.NO_CODEGEN.toString + + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + f + } + } + + /** + * runs the same test with both eval and codegen, then does the same again using resolveWith + * + * @param f + * @tparam T + * @return + */ + def evalCodeGens[T](f: => T): (T, T) = + (forceInterpreted(f), forceCodeGen(f)) + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf[T](pairs: (String, String)*)(f: => T): T = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => conf.setConfString(k, v) } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + } diff --git a/dataset/src/test/scala/frameless/sql/package.scala b/dataset/src/test/scala/frameless/sql/package.scala index fcb45b03d..1da73bd35 100644 --- a/dataset/src/test/scala/frameless/sql/package.scala +++ b/dataset/src/test/scala/frameless/sql/package.scala @@ -1,16 +1,18 @@ package frameless import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.{And, Or} +import com.sparkutils.shim.expressions.{ And2 => And, Or2 => Or } package object sql { + implicit class ExpressionOps(val self: Expression) extends AnyVal { + def toList: List[Expression] = { def rec(expr: Expression, acc: List[Expression]): List[Expression] = { expr match { case And(left, right) => rec(left, rec(right, acc)) - case Or(left, right) => rec(left, rec(right, acc)) - case e => e +: acc + case Or(left, right) => rec(left, rec(right, acc)) + case e => e +: acc } } diff --git a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala index 5108ed581..e1d0d52fc 100644 --- a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala +++ b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala @@ -9,26 +9,37 @@ class FramelessSyntaxTests extends TypedDatasetSuite { // Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits override val sparkDelay = null - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data).dataset val dataframe = dataset.toDF() val typedDataset = dataset.typed val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]] - typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame.collect().run().toVector + typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame + .collect() + .run() + .toVector } test("dataset typed - toTyped") { - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { - val dataset = session.createDataset(data)(TypedExpressionEncoder(ev)).typed + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { + val dataset = + session.createDataset(data)(TypedExpressionEncoder(ev)).typed val dataframe = dataset.toDF() - dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector + dataset + .collect() + .run() + .toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector } check(forAll(prop[Int, String] _)) @@ -38,8 +49,14 @@ class FramelessSyntaxTests extends TypedDatasetSuite { test("frameless typed column and aggregate") { def prop[A: TypedEncoder](a: A, b: A): Prop = { val d = TypedDataset.create((a, b) :: Nil) - (d.select(d('_1).untyped.typedColumn).collect().run ?= d.select(d('_1)).collect().run).&&( - d.agg(first(d('_1))).collect().run() ?= d.agg(first(d('_1)).untyped.typedAggregate).collect().run() + (d.coalesce(1).select(d('_1).untyped.typedColumn).collect().run ?= d + .select(d('_1)) + .collect() + .run).&&( + d.coalesce(1).agg(first(d('_1))).collect().run() ?= d + .agg(first(d('_1)).untyped.typedAggregate) + .collect() + .run() ) } diff --git a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala index 36a443fb5..c7107dd7a 100644 --- a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala +++ b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala @@ -2,19 +2,23 @@ package frameless.sql.rules import frameless._ import frameless.functions.Lit -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant} -import org.apache.spark.sql.sources.{EqualTo, GreaterThanOrEqual, IsNotNull} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ + microsToInstant, + instantToMicros +} +import org.apache.spark.sql.sources.{ EqualTo, GreaterThanOrEqual, IsNotNull } import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import java.time.Instant class FramelessLitPushDownTests extends SQLRulesSuite { - private val now: Long = currentTimestamp() + private val now: Long = instantToMicros(Instant.now()) test("java.sql.Timestamp push-down") { val expected = java.sql.Timestamp.from(microsToInstant(now)) val expectedStructure = X1(SQLTimestamp(now)) - val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) + val expectedPushDownFilters = + List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) predicatePushDownTest[SQLTimestamp]( expectedStructure, @@ -27,7 +31,8 @@ class FramelessLitPushDownTests extends SQLRulesSuite { test("java.time.Instant push-down") { val expected = java.sql.Timestamp.from(microsToInstant(now)) val expectedStructure = X1(microsToInstant(now)) - val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) + val expectedPushDownFilters = + List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) predicatePushDownTest[Instant]( expectedStructure, @@ -40,7 +45,10 @@ class FramelessLitPushDownTests extends SQLRulesSuite { test("struct push-down") { type Payload = X4[Int, Int, Int, Int] val expectedStructure = X1(X4(1, 2, 3, 4)) - val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema) + val expected = new GenericRowWithSchema( + Array(1, 2, 3, 4), + TypedExpressionEncoder[Payload].schema + ) val expectedPushDownFilters = List(IsNotNull("a"), EqualTo("a", expected)) predicatePushDownTest[Payload]( diff --git a/ml/src/main/scala/frameless/ml/package.scala b/ml/src/main/scala/frameless/ml/package.scala index d1c306158..a2ef8ae62 100644 --- a/ml/src/main/scala/frameless/ml/package.scala +++ b/ml/src/main/scala/frameless/ml/package.scala @@ -1,13 +1,15 @@ package frameless -import org.apache.spark.sql.FramelessInternals.UserDefinedType -import org.apache.spark.ml.FramelessInternals -import org.apache.spark.ml.linalg.{Matrix, Vector} +import FramelessInternals.UserDefinedType +import org.apache.spark.sql.shim.{ mlUtils => MLFramelessInternals } +import org.apache.spark.ml.linalg.{ Matrix, Vector } package object ml { - implicit val mlVectorUdt: UserDefinedType[Vector] = FramelessInternals.vectorUdt + implicit val mlVectorUdt: UserDefinedType[Vector] = + MLFramelessInternals.vectorUdt - implicit val mlMatrixUdt: UserDefinedType[Matrix] = FramelessInternals.matrixUdt + implicit val mlMatrixUdt: UserDefinedType[Matrix] = + MLFramelessInternals.matrixUdt } diff --git a/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala b/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala deleted file mode 100644 index bec43cd11..000000000 --- a/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala +++ /dev/null @@ -1,13 +0,0 @@ -package org.apache.spark.ml - -import org.apache.spark.ml.linalg.{MatrixUDT, VectorUDT} - -object FramelessInternals { - - // because org.apache.spark.ml.linalg.VectorUDT is private[spark] - val vectorUdt = new VectorUDT - - // because org.apache.spark.ml.linalg.MatrixUDT is private[spark] - val matrixUdt = new MatrixUDT - -} diff --git a/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala b/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala index dba59454c..7cf56baab 100644 --- a/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala +++ b/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala @@ -2,26 +2,33 @@ package frameless.refined import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.objects.{ - Invoke, NewInstance, UnwrapOption, WrapOption -} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ +import com.sparkutils.shim.expressions.{ + UnwrapOption2 => UnwrapOption, + WrapOption2 => WrapOption +} +import org.apache.spark.sql.shim.{ + Invoke5 => Invoke, + NewInstance4 => NewInstance +} + import eu.timepit.refined.api.RefType import frameless.{ TypedEncoder, RecordFieldEncoder } private[refined] trait RefinedFieldEncoders { + /** * @tparam T the refined type (e.g. `String`) */ implicit def optionRefined[F[_, _], T, R]( - implicit + implicit i0: RefType[F], i1: TypedEncoder[T], - i2: ClassTag[F[T, R]], - ): RecordFieldEncoder[Option[F[T, R]]] = + i2: ClassTag[F[T, R]] + ): RecordFieldEncoder[Option[F[T, R]]] = RecordFieldEncoder[Option[F[T, R]]](new TypedEncoder[Option[F[T, R]]] { def nullable = true @@ -54,11 +61,11 @@ private[refined] trait RefinedFieldEncoders { * @tparam T the refined type (e.g. `String`) */ implicit def refined[F[_, _], T, R]( - implicit + implicit i0: RefType[F], i1: TypedEncoder[T], - i2: ClassTag[F[T, R]], - ): RecordFieldEncoder[F[T, R]] = + i2: ClassTag[F[T, R]] + ): RecordFieldEncoder[F[T, R]] = RecordFieldEncoder[F[T, R]](new TypedEncoder[F[T, R]] { def nullable = i1.nullable @@ -76,4 +83,3 @@ private[refined] trait RefinedFieldEncoders { override def toString = s"refined[${i2.runtimeClass.getName}]" }) } -