Skip to content

Commit

Permalink
(cherry picked from commit c2f3492)
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 21, 2024
1 parent b82d266 commit e36eac2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
6 changes: 5 additions & 1 deletion dataset/src/main/scala/frameless/functions/Udf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ trait Udf {
*
* Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]].
*/
// Possibly add UserDefinedExpression trait to stop the functions being registered and used as aggregates
case class FramelessUdf[T, R](
function: AnyRef,
encoders: Seq[TypedEncoder[_]],
Expand All @@ -172,6 +173,9 @@ case class FramelessUdf[T, R](
lazy val typedEnc =
TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]]

lazy val isSerializedAsStructForTopLevel =
typedEnc.isSerializedAsStructForTopLevel

def eval(input: InternalRow): Any = {
val jvmTypes = children.map(_.eval(input))

Expand All @@ -181,7 +185,7 @@ case class FramelessUdf[T, R](
val retval =
if (returnCatalyst == null)
null
else if (typedEnc.isSerializedAsStructForTopLevel)
else if (isSerializedAsStructForTopLevel)
returnCatalyst
else
returnCatalyst.get(0, dataType)
Expand Down
53 changes: 52 additions & 1 deletion dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import java.time.format.DateTimeFormatter
import java.time.{ LocalDateTime => JavaLocalDateTime }
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.internal.SQLConf
import org.scalacheck.{ Arbitrary, Cogen, Gen }

import org.scalacheck.{ Arbitrary, Gen }
import scala.collection.immutable.{ ListSet, TreeSet }

package object frameless {

Expand Down Expand Up @@ -39,6 +42,54 @@ package object frameless {

def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary

implicit def arbSeq[A](
implicit
A: Arbitrary[A]
): Arbitrary[scala.collection.Seq[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector.toSeq))

def seqGen[A: Arbitrary]: Gen[scala.collection.Seq[A]] = arbSeq[A].arbitrary

implicit def arbList[A](
implicit
A: Arbitrary[A]
): Arbitrary[List[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(_.toList))

def listGen[A: Arbitrary]: Gen[List[A]] = arbList[A].arbitrary

implicit def arbSet[A](
implicit
A: Arbitrary[A]
): Arbitrary[Set[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(Set.newBuilder.++=(_).result()))

def setGen[A: Arbitrary]: Gen[Set[A]] = arbSet[A].arbitrary

implicit def cogenListSet[A: Cogen: Ordering]: Cogen[ListSet[A]] =
Cogen.it(_.toVector.sorted.iterator)

implicit def arbListSet[A](
implicit
A: Arbitrary[A]
): Arbitrary[ListSet[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(ListSet.newBuilder.++=(_).result()))

def listSetGen[A: Arbitrary]: Gen[ListSet[A]] = arbListSet[A].arbitrary

implicit def cogenTreeSet[A: Cogen: Ordering]: Cogen[TreeSet[A]] =
Cogen.it(_.toVector.sorted.iterator)

implicit def arbTreeSet[A](
implicit
A: Arbitrary[A],
o: Ordering[A]
): Arbitrary[TreeSet[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(TreeSet.newBuilder.++=(_).result()))

def treeSetGen[A: Arbitrary: Ordering]: Gen[TreeSet[A]] =
arbTreeSet[A].arbitrary

implicit val arbUdtEncodedClass: Arbitrary[UdtEncodedClass] = Arbitrary {
for {
int <- Arbitrary.arbitrary[Int]
Expand Down

0 comments on commit e36eac2

Please sign in to comment.