Skip to content

Commit

Permalink
typelevel#787 - [Un]WrapOption, Invoke, NewInstance, GetStructField, …
Browse files Browse the repository at this point in the history
…ifisnull, GetColumnByOrdinal, MapObjects and TypedExpressionEncoder shimmed
  • Loading branch information
chris-twiner committed Feb 28, 2024
1 parent cb259fa commit b8d4f05
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 55 deletions.
12 changes: 11 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@ val shimVersion = "0.0.1-SNAPSHOT"

val Scala212 = "2.12.18"
val Scala213 = "2.13.12"

/*
//resolvers in Global += Resolver.mavenLocal
resolvers in Global += MavenRepository(
"sonatype-s01-snapshots",
Resolver.SonatypeS01RepositoryRoot + "/snapshots"
)
import scala.concurrent.duration.DurationInt
import lmcoursier.definitions.CachePolicy
csrConfiguration := csrConfiguration.value
.withTtl(Some(1.minute))
.withCachePolicies(Vector(CachePolicy.LocalOnly))
*/
ThisBuild / tlBaseVersion := "0.16"

ThisBuild / crossScalaVersions := Seq(Scala213, Scala212)
Expand Down
19 changes: 7 additions & 12 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package frameless

import com.sparkutils.shim.expressions.{CreateNamedStruct1 => CreateNamedStruct, GetStructField3 => GetStructField, UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption}
import com.sparkutils.shim.{deriveUnitLiteral, ifIsNull}
import org.apache.spark.sql.FramelessInternals

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{
Invoke, NewInstance, UnwrapOption, WrapOption
}
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct => _, GetStructField => _, _}
import org.apache.spark.sql.shim.{Invoke5 => Invoke, NewInstance4 => NewInstance}
import org.apache.spark.sql.types._

import shapeless._
import shapeless.labelled.FieldType
import shapeless.ops.hlist.IsHCons
Expand Down Expand Up @@ -72,7 +70,7 @@ object NewInstanceExprs {
tail: NewInstanceExprs[T]
): NewInstanceExprs[FieldType[K, Unit] :: T] = new NewInstanceExprs[FieldType[K, Unit] :: T] {
def from(exprs: List[Expression]): Seq[Expression] =
Literal.fromObject(()) +: tail.from(exprs)
deriveUnitLiteral +: tail.from(exprs)
}

implicit def deriveNonUnit[K <: Symbol, V, T <: HList]
Expand Down Expand Up @@ -161,9 +159,8 @@ class RecordEncoder[F, G <: HList, H <: HList]
}

val createExpr = CreateNamedStruct(exprs)
val nullExpr = Literal.create(null, createExpr.dataType)

If(IsNull(path), nullExpr, createExpr)
ifIsNull(createExpr.dataType, path, createExpr)
}

def fromCatalyst(path: Expression): Expression = {
Expand All @@ -176,9 +173,7 @@ class RecordEncoder[F, G <: HList, H <: HList]
val newExpr = NewInstance(
classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)

val nullExpr = Literal.create(null, jvmRepr)

If(IsNull(path), nullExpr, newExpr)
ifIsNull(jvmRepr, path, newExpr)
}
}

Expand Down
11 changes: 3 additions & 8 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
package frameless

import java.math.BigInteger

import java.util.Date

import java.time.{ Duration, Instant, Period, LocalDate }

import java.time.{Duration, Instant, LocalDate, Period}
import java.sql.Timestamp

import scala.reflect.ClassTag

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.sql.{ reflection => ScalaReflection }
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{StaticInvoke => _, _}
import org.apache.spark.sql.catalyst.util.{
ArrayBasedMapData,
DateTimeUtils,
Expand All @@ -25,8 +20,8 @@ import org.apache.spark.unsafe.types.UTF8String

import shapeless._
import shapeless.ops.hlist.IsHCons

import org.apache.spark.sql.shim.{StaticInvoke4 => StaticInvoke}
import com.sparkutils.shim.expressions.{UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption, MapObjects5 => MapObjects, ExternalMapToCatalyst7 => ExternalMapToCatalyst}
import org.apache.spark.sql.shim.{StaticInvoke4 => StaticInvoke, NewInstance4 => NewInstance, Invoke5 => Invoke}

abstract class TypedEncoder[T](
implicit
Expand Down
35 changes: 4 additions & 31 deletions dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package frameless

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If}
import org.apache.spark.sql.types.StructType

object TypedExpressionEncoder {
Expand All @@ -14,36 +11,12 @@ object TypedExpressionEncoder {
* with a single field called "value" set in ExpressionEncoder.
*/
def targetStructType[A](encoder: TypedEncoder[A]): StructType =
encoder.catalystRepr match {
case x: StructType =>
if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true)))
else x

case dt => new StructType().add("value", dt, nullable = encoder.nullable)
}
org.apache.spark.sql.ShimUtils.targetStructType(encoder.catalystRepr, encoder.nullable)

def apply[T](implicit encoder: TypedEncoder[T]): Encoder[T] = {
val in = BoundReference(0, encoder.jvmRepr, encoder.nullable)

val (out, serializer) = encoder.toCatalyst(in) match {
case it @ If(_, _, _: CreateNamedStruct) => {
val out = GetColumnByOrdinal(0, encoder.catalystRepr)

out -> it
}

case other => {
val out = GetColumnByOrdinal(0, encoder.catalystRepr)

out -> other
}
}

new ExpressionEncoder[T](
objSerializer = serializer,
objDeserializer = encoder.fromCatalyst(out),
clsTag = encoder.classTag
)
import encoder._
org.apache.spark.sql.ShimUtils.expressionEncoder[T](jvmRepr, nullable, toCatalyst, catalystRepr, fromCatalyst)
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package frameless.refined
import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{
Invoke, NewInstance, UnwrapOption, WrapOption
}
import org.apache.spark.sql.types._

import com.sparkutils.shim.expressions.{UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption}
import org.apache.spark.sql.shim.{Invoke5 => Invoke, NewInstance4 => NewInstance}

import eu.timepit.refined.api.RefType

import frameless.{ TypedEncoder, RecordFieldEncoder }
Expand Down

0 comments on commit b8d4f05

Please sign in to comment.