Skip to content

Commit

Permalink
Use PartiallyConstructedTypedColumn for functions.col
Browse files Browse the repository at this point in the history
  • Loading branch information
Itamar Ravid committed Sep 22, 2017
1 parent 897e499 commit 0be79f0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package frameless
package functions

import org.apache.spark.sql.functions.{ col => sparkCol }
import shapeless.Witness

case class PartiallyConstructedTypedColumn[K <: Symbol](column: Witness.Aux[K])
object PartiallyConstructedTypedColumn {
implicit def toTypedColumn[T, K <: Symbol, A](p: PartiallyConstructedTypedColumn[K])(
implicit
exists: TypedColumn.Exists[T, K, A],
encoder: TypedEncoder[A]): TypedColumn[T, A] = {
val untypedExpr = sparkCol(p.column.value.name).as[A](TypedExpressionEncoder[A])
new TypedColumn[T, A](untypedExpr)
}
}
9 changes: 1 addition & 8 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package frameless

import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.functions.{ col => sparkCol }
import shapeless.Witness

package object functions extends Udf with UnaryFunctions {
Expand All @@ -20,11 +19,5 @@ package object functions extends Udf with UnaryFunctions {
}
}

def col[T, A](column: Witness.Lt[Symbol])(
implicit
exists: TypedColumn.Exists[T, column.T, A],
encoder: TypedEncoder[A]): TypedColumn[T, A] = {
val untypedExpr = sparkCol(column.value.name).as[A](TypedExpressionEncoder[A])
new TypedColumn[T, A](untypedExpr)
}
def col[K <: Symbol](column: Witness.Aux[K]): PartiallyConstructedTypedColumn[K] = PartiallyConstructedTypedColumn(column)
}
5 changes: 3 additions & 2 deletions dataset/src/test/scala/frameless/SelectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val col = functions.col('a)

val dataset2 = dataset.select(A).collect().run().toVector
val symDataset2 = dataset.select(functions.col('a)).collect().run().toVector
val symDataset2 = dataset.select(col).collect().run().toVector
val data2 = data.map { case X4(a, _, _, _) => a }

(dataset2 ?= data2) && (symDataset2 ?= data2)
Expand Down Expand Up @@ -398,4 +399,4 @@ class SelectTests extends TypedDatasetSuite {
val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
illTyped("""e.select(frameless.functions.aggregate.sum(e('_1)))""")
}
}
}

0 comments on commit 0be79f0

Please sign in to comment.