Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#777 #778

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft

#777 #778

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 73 additions & 21 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,19 @@ object DropUnitValues {
}
}

class RecordEncoder[F, G <: HList, H <: HList]
abstract class RecordEncoder[F, G <: HList, H <: HList]
(implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
fields: Lazy[RecordEncoderFields[H]],
newInstanceExprs: Lazy[NewInstanceExprs[G]],
stage1: RecordEncoderStage1[G, H],
classTag: ClassTag[F]
) extends TypedEncoder[F] {

import stage1._

def nullable: Boolean = false

def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
lazy val jvmRepr: DataType = FramelessInternals.objectTypeFor[F]

def catalystRepr: DataType = {
lazy val catalystRepr: DataType = {
val structFields = fields.value.value.map { field =>
StructField(
name = field.name,
Expand All @@ -145,41 +144,94 @@ class RecordEncoder[F, G <: HList, H <: HList]
StructType(structFields)
}

}

object RecordEncoder {

case class ForGeneric[F, G <: HList, H <: HList](
)(implicit
stage1: RecordEncoderStage1[G, H],
classTag: ClassTag[F])
extends RecordEncoder[F, G, H] {

import stage1._

def toCatalyst(path: Expression): Expression = {
val nameExprs = fields.value.value.map { field =>
Literal(field.name)
}

val valueExprs = fields.value.value.map { field =>
val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
field.encoder.toCatalyst(fieldPath)
}

// the way exprs are encoded in CreateNamedStruct
val exprs = nameExprs.zip(valueExprs).flatMap {
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
}
val createExpr = stage1.cellsToCatalyst(valueExprs)

val createExpr = CreateNamedStruct(exprs)
val nullExpr = Literal.create(null, createExpr.dataType)

If(IsNull(path), nullExpr, createExpr)
}

def fromCatalyst(path: Expression): Expression = {
val exprs = fields.value.value.map { field =>
field.encoder.fromCatalyst(
GetStructField(path, field.ordinal, Some(field.name)))
}

val newArgs = newInstanceExprs.value.from(exprs)
val newArgs = stage1.fromCatalystToCells(path)
val newExpr = NewInstance(
classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)

val nullExpr = Literal.create(null, jvmRepr)

If(IsNull(path), nullExpr, newExpr)
}
}

case class ForTypedRow[G <: HList, H <: HList](
)(implicit
stage1: RecordEncoderStage1[G, H],
classTag: ClassTag[TypedRow[G]])
extends RecordEncoder[TypedRow[G], G, H] {

import stage1._

private final val _apply = "apply"
private final val _fromInternalRow = "fromInternalRow"

def toCatalyst(path: Expression): Expression = {

val valueExprs = fields.value.value.zipWithIndex.map {
case (field, i) =>
val fieldPath = Invoke(
path,
_apply,
field.encoder.jvmRepr,
Seq(Literal.create(i, IntegerType))
)
field.encoder.toCatalyst(fieldPath)
}

val createExpr = stage1.cellsToCatalyst(valueExprs)

val nullExpr = Literal.create(null, createExpr.dataType)

If(IsNull(path), nullExpr, createExpr)
}

def fromCatalyst(path: Expression): Expression = {

val newArgs = stage1.fromCatalystToCells(path)
val aggregated = CreateStruct(newArgs)

val partial = TypedRow.WithCatalystTypes(newArgs.map(_.dataType))

val newExpr = Invoke(
Literal.fromObject(partial),
_fromInternalRow,
TypedRow.catalystType,
Seq(aggregated)
)

val nullExpr = Literal.create(null, jvmRepr)

If(IsNull(path), nullExpr, newExpr)
}
}
}

final class RecordFieldEncoder[T](
Expand Down
49 changes: 49 additions & 0 deletions dataset/src/main/scala/frameless/RecordEncoderStage1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package frameless

import org.apache.spark.sql.catalyst.expressions.{
CreateNamedStruct,
Expression,
GetStructField,
Literal
}
import shapeless.{ HList, Lazy }

case class RecordEncoderStage1[G <: HList, H <: HList](
)(implicit
// i1: DropUnitValues.Aux[G, H],
// i2: IsHCons[H],
val fields: Lazy[RecordEncoderFields[H]],
val newInstanceExprs: Lazy[NewInstanceExprs[G]]) {

def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = {
val nameExprs = fields.value.value.map { field => Literal(field.name) }

// the way exprs are encoded in CreateNamedStruct
val exprs = nameExprs.zip(valueExprs).flatMap {
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
}

val createExpr = CreateNamedStruct(exprs)
createExpr
}

def fromCatalystToCells(path: Expression): Seq[Expression] = {
val exprs = fields.value.value.map { field =>
field.encoder.fromCatalyst(
GetStructField(path, field.ordinal, Some(field.name))
)
}

val newArgs = newInstanceExprs.value.from(exprs)
newArgs
}
}

object RecordEncoderStage1 {

implicit def usingDerivation[G <: HList, H <: HList](
implicit
i3: Lazy[RecordEncoderFields[H]],
i4: Lazy[NewInstanceExprs[G]]
): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]()
}
12 changes: 10 additions & 2 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -727,15 +727,23 @@ object TypedEncoder {
}

/** Encodes things as records if there is no Injection defined */
implicit def usingDerivation[F, G <: HList, H <: HList](
implicit def deriveForGeneric[F, G <: HList, H <: HList](
implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
i3: Lazy[RecordEncoderFields[H]],
i4: Lazy[NewInstanceExprs[G]],
i5: ClassTag[F]
): TypedEncoder[F] = new RecordEncoder[F, G, H]
): TypedEncoder[F] = RecordEncoder.ForGeneric[F, G, H]()

implicit def deriveForTypedRow[G <: HList, H <: HList](
implicit
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
i3: Lazy[RecordEncoderFields[H]],
i4: Lazy[NewInstanceExprs[G]]
): TypedEncoder[TypedRow[G]] = RecordEncoder.ForTypedRow[G, H]()

/** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */
implicit def usingUserDefinedType[
Expand Down
45 changes: 45 additions & 0 deletions dataset/src/main/scala/frameless/TypedRow.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package frameless

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{ DataType, ObjectType }
import shapeless.HList

case class TypedRow[T <: HList](row: Row) {

def apply(i: Int): Any = row.apply(i)
}

object TypedRow {

def apply(values: Any*): TypedRow[HList] = {

val row = Row.fromSeq(values)
TypedRow(row)
}

case class WithCatalystTypes(schema: Seq[DataType]) {

def fromInternalRow(row: InternalRow): TypedRow[HList] = {
val data = row.toSeq(schema).toArray

apply(data: _*)
}

}

object WithCatalystTypes {}

def fromHList[T <: HList](
hlist: T
): TypedRow[T] = {

val cells = hlist.runtimeList

val row = Row.fromSeq(cells)
TypedRow(row)
}

lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]])

}
2 changes: 1 addition & 1 deletion dataset/src/test/scala/frameless/InjectionTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class InjectionTests extends TypedDatasetSuite {
}

test("Resolve ambiguity by importing usingDerivation") {
import TypedEncoder.usingDerivation
import TypedEncoder.deriveForGeneric
assert(implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]])
check(forAll(prop[Person] _))
}
Expand Down
46 changes: 29 additions & 17 deletions dataset/src/test/scala/frameless/RecordEncoderTests.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
package frameless

import org.apache.spark.sql.{Row, functions => F}
import org.apache.spark.sql.types.{
ArrayType,
BinaryType,
DecimalType,
IntegerType,
LongType,
MapType,
ObjectType,
StringType,
StructField,
StructType
}

import shapeless.{HList, LabelledGeneric}
import shapeless.test.illTyped

import frameless.RecordEncoderTests.{ A, B, E }
import org.apache.spark.sql.types._
import org.apache.spark.sql.{ Row, functions => F }
import org.scalatest.matchers.should.Matchers
import shapeless.record.Record
import shapeless.test.illTyped
import shapeless.{ HList, LabelledGeneric }

final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
test("Unable to encode products made from units only") {
Expand Down Expand Up @@ -87,6 +76,26 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
ds.collect.head shouldBe obj
}

test("TypedRow") {

val r1: RecordEncoderTests.RR = Record(x = 1, y = "abc")
val r2: TypedRow[RecordEncoderTests.RR] = TypedRow.fromHList(r1)

val rdd = sc.parallelize(Seq(r2))
val ds =
session.createDataset(rdd)(
TypedExpressionEncoder[TypedRow[RecordEncoderTests.RR]]
)

ds.schema.treeString shouldBe
"""root
| |-- x: integer (nullable = true)
| |-- y: string (nullable = true)
|""".stripMargin

ds.collect.head shouldBe r2
}

test("Scalar value class") {
import RecordEncoderTests._

Expand Down Expand Up @@ -548,6 +557,9 @@ object RecordEncoderTests {
case class D(m: Map[String, Int])
case class E(b: Set[B])

val RR = Record.`'x -> Int, 'y -> String`
type RR = RR.T

final class Subject(val name: String) extends AnyVal with Serializable

final class Grade(val value: BigDecimal) extends AnyVal with Serializable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ object RefinedTypesTests {

import frameless.refined._ // implicit instances for refined

implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation
implicit val encoderA: TypedEncoder[A] = TypedEncoder.deriveForGeneric

implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation
implicit val encoderB: TypedEncoder[B] = TypedEncoder.deriveForGeneric
}
Loading