diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 5a1b9101..42f3b8ca 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -143,20 +143,19 @@ object DropUnitValues { } } -class RecordEncoder[F, G <: HList, H <: HList]( +abstract class RecordEncoder[F, G <: HList, H <: HList]( implicit - i0: LabelledGeneric.Aux[F, G], - i1: DropUnitValues.Aux[G, H], - i2: IsHCons[H], - fields: Lazy[RecordEncoderFields[H]], - newInstanceExprs: Lazy[NewInstanceExprs[G]], + stage1: RecordEncoderStage1[G, H], classTag: ClassTag[F]) extends TypedEncoder[F] { + + import stage1._ + def nullable: Boolean = false - def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] + lazy val jvmRepr: DataType = FramelessInternals.objectTypeFor[F] - def catalystRepr: DataType = { + lazy val catalystRepr: DataType = { val structFields = fields.value.value.map { field => StructField( name = field.name, @@ -169,39 +168,99 @@ class RecordEncoder[F, G <: HList, H <: HList]( StructType(structFields) } - def toCatalyst(path: Expression): Expression = { - val nameExprs = fields.value.value.map { field => Literal(field.name) } +} - val valueExprs = fields.value.value.map { field => - val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) - field.encoder.toCatalyst(fieldPath) - } +object RecordEncoder { + + case class ForGeneric[F, G <: HList, H <: HList]( + )(implicit + stage1: RecordEncoderStage1[G, H], + classTag: ClassTag[F]) + extends RecordEncoder[F, G, H] { + + import stage1._ + + def toCatalyst(path: Expression): Expression = { + + val valueExprs = fields.value.value.map { field => + val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) + field.encoder.toCatalyst(fieldPath) + } + + val createExpr = stage1.cellsToCatalyst(valueExprs) - // the way exprs are encoded in CreateNamedStruct - val exprs = nameExprs.zip(valueExprs).flatMap { - case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil + val nullExpr = Literal.create(null, createExpr.dataType) + + If(IsNull(path), nullExpr, createExpr) } - val createExpr = CreateNamedStruct(exprs) - val nullExpr = Literal.create(null, createExpr.dataType) + def fromCatalyst(path: Expression): Expression = { + + val newArgs = stage1.fromCatalystToCells(path) + + val newExpr = + NewInstance( + classTag.runtimeClass, + newArgs, + jvmRepr, + propagateNull = true + ) + + val nullExpr = Literal.create(null, jvmRepr) - If(IsNull(path), nullExpr, createExpr) + If(IsNull(path), nullExpr, newExpr) + } } - def fromCatalyst(path: Expression): Expression = { - val exprs = fields.value.value.map { field => - field.encoder.fromCatalyst( - GetStructField(path, field.ordinal, Some(field.name)) - ) + case class ForTypedRow[G <: HList, H <: HList]( + )(implicit + stage1: RecordEncoderStage1[G, H], + classTag: ClassTag[TypedRow[G]]) + extends RecordEncoder[TypedRow[G], G, H] { + + import stage1._ + + private final val _apply = "apply" + private final val _fromInternalRow = "fromInternalRow" + + def toCatalyst(path: Expression): Expression = { + + val valueExprs = fields.value.value.zipWithIndex.map { + case (field, i) => + val fieldPath = Invoke( + path, + _apply, + field.encoder.jvmRepr, + Seq(Literal.create(i, IntegerType)) + ) + field.encoder.toCatalyst(fieldPath) + } + + val createExpr = stage1.cellsToCatalyst(valueExprs) + + val nullExpr = Literal.create(null, createExpr.dataType) + + If(IsNull(path), nullExpr, createExpr) } - val newArgs = newInstanceExprs.value.from(exprs) - val newExpr = - NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) + def fromCatalyst(path: Expression): Expression = { - val nullExpr = Literal.create(null, jvmRepr) + val newArgs = stage1.fromCatalystToCells(path) + val aggregated = CreateStruct(newArgs) - If(IsNull(path), nullExpr, newExpr) + val partial = TypedRow.WithCatalystTypes(newArgs.map(_.dataType)) + + val newExpr = Invoke( + Literal.fromObject(partial), + _fromInternalRow, + TypedRow.catalystType, + Seq(aggregated) + ) + + val nullExpr = Literal.create(null, jvmRepr) + + If(IsNull(path), nullExpr, newExpr) + } } } diff --git a/dataset/src/main/scala/frameless/RecordEncoderStage1.scala b/dataset/src/main/scala/frameless/RecordEncoderStage1.scala new file mode 100644 index 00000000..b7cecf38 --- /dev/null +++ b/dataset/src/main/scala/frameless/RecordEncoderStage1.scala @@ -0,0 +1,49 @@ +package frameless + +import org.apache.spark.sql.catalyst.expressions.{ + CreateNamedStruct, + Expression, + GetStructField, + Literal +} +import shapeless.{ HList, Lazy } + +case class RecordEncoderStage1[G <: HList, H <: HList]( + )(implicit +// i1: DropUnitValues.Aux[G, H], +// i2: IsHCons[H], + val fields: Lazy[RecordEncoderFields[H]], + val newInstanceExprs: Lazy[NewInstanceExprs[G]]) { + + def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = { + val nameExprs = fields.value.value.map { field => Literal(field.name) } + + // the way exprs are encoded in CreateNamedStruct + val exprs = nameExprs.zip(valueExprs).flatMap { + case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil + } + + val createExpr = CreateNamedStruct(exprs) + createExpr + } + + def fromCatalystToCells(path: Expression): Seq[Expression] = { + val exprs = fields.value.value.map { field => + field.encoder.fromCatalyst( + GetStructField(path, field.ordinal, Some(field.name)) + ) + } + + val newArgs = newInstanceExprs.value.from(exprs) + newArgs + } +} + +object RecordEncoderStage1 { + + implicit def usingDerivation[G <: HList, H <: HList]( + implicit + i3: Lazy[RecordEncoderFields[H]], + i4: Lazy[NewInstanceExprs[G]] + ): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]() +} diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index b42b026e..2877dc7d 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -727,7 +727,7 @@ object TypedEncoder { } /** Encodes things as records if there is no Injection defined */ - implicit def usingDerivation[F, G <: HList, H <: HList]( + implicit def deriveForGeneric[F, G <: HList, H <: HList]( implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], @@ -735,7 +735,15 @@ object TypedEncoder { i3: Lazy[RecordEncoderFields[H]], i4: Lazy[NewInstanceExprs[G]], i5: ClassTag[F] - ): TypedEncoder[F] = new RecordEncoder[F, G, H] + ): TypedEncoder[F] = RecordEncoder.ForGeneric[F, G, H]() + + implicit def deriveForTypedRow[G <: HList, H <: HList]( + implicit + i1: DropUnitValues.Aux[G, H], + i2: IsHCons[H], + i3: Lazy[RecordEncoderFields[H]], + i4: Lazy[NewInstanceExprs[G]] + ): TypedEncoder[TypedRow[G]] = RecordEncoder.ForTypedRow[G, H]() /** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */ implicit def usingUserDefinedType[ diff --git a/dataset/src/main/scala/frameless/TypedRow.scala b/dataset/src/main/scala/frameless/TypedRow.scala new file mode 100644 index 00000000..c7abae77 --- /dev/null +++ b/dataset/src/main/scala/frameless/TypedRow.scala @@ -0,0 +1,45 @@ +package frameless + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{ DataType, ObjectType } +import shapeless.HList + +case class TypedRow[T <: HList](row: Row) { + + def apply(i: Int): Any = row.apply(i) +} + +object TypedRow { + + def apply(values: Any*): TypedRow[HList] = { + + val row = Row.fromSeq(values) + TypedRow(row) + } + + case class WithCatalystTypes(schema: Seq[DataType]) { + + def fromInternalRow(row: InternalRow): TypedRow[HList] = { + val data = row.toSeq(schema).toArray + + apply(data: _*) + } + + } + + object WithCatalystTypes {} + + def fromHList[T <: HList]( + hlist: T + ): TypedRow[T] = { + + val cells = hlist.runtimeList + + val row = Row.fromSeq(cells) + TypedRow(row) + } + + lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]]) + +} diff --git a/dataset/src/test/scala/frameless/InjectionTests.scala b/dataset/src/test/scala/frameless/InjectionTests.scala index 9ee25261..3203175f 100644 --- a/dataset/src/test/scala/frameless/InjectionTests.scala +++ b/dataset/src/test/scala/frameless/InjectionTests.scala @@ -202,7 +202,7 @@ class InjectionTests extends TypedDatasetSuite { } test("Resolve ambiguity by importing usingDerivation") { - import TypedEncoder.usingDerivation + import TypedEncoder.deriveForGeneric assert( implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]] ) diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 12178537..93b267ab 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -1,23 +1,12 @@ package frameless +import frameless.RecordEncoderTests.{ A, B, E } +import org.apache.spark.sql.types._ import org.apache.spark.sql.{ Row, functions => F } -import org.apache.spark.sql.types.{ - ArrayType, - BinaryType, - DecimalType, - IntegerType, - LongType, - MapType, - ObjectType, - StringType, - StructField, - StructType -} - -import shapeless.{ HList, LabelledGeneric } -import shapeless.test.illTyped - import org.scalatest.matchers.should.Matchers +import shapeless.record.Record +import shapeless.test.illTyped +import shapeless.{ HList, LabelledGeneric } final class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Unable to encode products made from units only") { @@ -101,6 +90,20 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { ds.collect.head shouldBe obj } + test("shapeless Record") { + + val r1: RecordEncoderTests.RR = Record(x = 1, y = "abc") + val r2: TypedRow[RecordEncoderTests.RR] = TypedRow.fromHList(r1) + + val rdd = sc.parallelize(Seq(r2)) + val ds = + session.createDataset(rdd)( + TypedExpressionEncoder[TypedRow[RecordEncoderTests.RR]] + ) + + ds.collect.head shouldBe r2 + } + test("Scalar value class") { import RecordEncoderTests._ @@ -632,6 +635,9 @@ object RecordEncoderTests { case class D(m: Map[String, Int]) case class E(b: Set[B]) + val RR = Record.`'x -> Int, 'y -> String` + type RR = RR.T + final class Subject(val name: String) extends AnyVal with Serializable final class Grade(val value: BigDecimal) extends AnyVal with Serializable diff --git a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala index c152abf3..51bc21c4 100644 --- a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala +++ b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala @@ -127,7 +127,7 @@ object RefinedTypesTests { import frameless.refined._ // implicit instances for refined - implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation + implicit val encoderA: TypedEncoder[A] = TypedEncoder.deriveForGeneric - implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation + implicit val encoderB: TypedEncoder[B] = TypedEncoder.deriveForGeneric }