-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a test case for the new TypedRow encoder
implemented the proposal
- Loading branch information
Showing
7 changed files
with
203 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
dataset/src/main/scala/frameless/RecordEncoderStage1.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package frameless | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{ | ||
CreateNamedStruct, | ||
Expression, | ||
GetStructField, | ||
Literal | ||
} | ||
import shapeless.{ HList, Lazy } | ||
|
||
case class RecordEncoderStage1[G <: HList, H <: HList]( | ||
)(implicit | ||
// i1: DropUnitValues.Aux[G, H], | ||
// i2: IsHCons[H], | ||
val fields: Lazy[RecordEncoderFields[H]], | ||
val newInstanceExprs: Lazy[NewInstanceExprs[G]]) { | ||
|
||
def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = { | ||
val nameExprs = fields.value.value.map { field => Literal(field.name) } | ||
|
||
// the way exprs are encoded in CreateNamedStruct | ||
val exprs = nameExprs.zip(valueExprs).flatMap { | ||
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil | ||
} | ||
|
||
val createExpr = CreateNamedStruct(exprs) | ||
createExpr | ||
} | ||
|
||
def fromCatalystToCells(path: Expression): Seq[Expression] = { | ||
val exprs = fields.value.value.map { field => | ||
field.encoder.fromCatalyst( | ||
GetStructField(path, field.ordinal, Some(field.name)) | ||
) | ||
} | ||
|
||
val newArgs = newInstanceExprs.value.from(exprs) | ||
newArgs | ||
} | ||
} | ||
|
||
object RecordEncoderStage1 { | ||
|
||
implicit def usingDerivation[G <: HList, H <: HList]( | ||
implicit | ||
i3: Lazy[RecordEncoderFields[H]], | ||
i4: Lazy[NewInstanceExprs[G]] | ||
): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package frameless | ||
|
||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.types.{ DataType, ObjectType } | ||
import shapeless.HList | ||
|
||
case class TypedRow[T <: HList](row: Row) { | ||
|
||
def apply(i: Int): Any = row.apply(i) | ||
} | ||
|
||
object TypedRow { | ||
|
||
def apply(values: Any*): TypedRow[HList] = { | ||
|
||
val row = Row.fromSeq(values) | ||
TypedRow(row) | ||
} | ||
|
||
case class WithCatalystTypes(schema: Seq[DataType]) { | ||
|
||
def fromInternalRow(row: InternalRow): TypedRow[HList] = { | ||
val data = row.toSeq(schema).toArray | ||
|
||
apply(data: _*) | ||
} | ||
|
||
} | ||
|
||
object WithCatalystTypes {} | ||
|
||
def fromHList[T <: HList]( | ||
hlist: T | ||
): TypedRow[T] = { | ||
|
||
val cells = hlist.runtimeList | ||
|
||
val row = Row.fromSeq(cells) | ||
TypedRow(row) | ||
} | ||
|
||
lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]]) | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters