Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,27 @@ lazy val schema = crossProject(JSPlatform, JVMPlatform, NativePlatform)
.nativeSettings(nativeSettings)
.settings(
compileOrder := CompileOrder.JavaThenScala,
scalacOptions ++= (if (scalaVersion.value == Scala3)
Seq(
"-explain"
)
else
Seq(
"-opt:l:method"
)),
scalacOptions ++=
(CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, _)) =>
Seq(
"-opt:l:method"
)
case _ =>
Seq(
"-explain"
)
}),
libraryDependencies ++= Seq(
"dev.zio" %%% "zio-test" % "2.1.16" % Test,
"dev.zio" %%% "zio-test-sbt" % "2.1.16" % Test
),
) ++ (CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, _)) =>
Seq(
"org.scala-lang" % "scala-reflect" % scalaVersion.value
)
case _ => Seq()
}),
testFrameworks += new TestFramework("zio.test.sbt.ZTestFramework")
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package zio.blocks.schema

trait SchemaVersionSpecific {
import scala.language.experimental.macros

def derived[A]: Schema[A] = macro SchemaVersionSpecific.derived[A]
}

object SchemaVersionSpecific {
import scala.reflect.macros.blackbox
import scala.reflect.NameTransformer

def derived[A: c.WeakTypeTag](c: blackbox.Context): c.Expr[Schema[A]] = {
import c.universe._
import c.internal._

def fail(msg: String): Nothing = c.abort(c.enclosingPosition, msg)

def isNonAbstractScalaClass(tpe: Type): Boolean =
tpe.typeSymbol.isClass && !tpe.typeSymbol.isAbstract && !tpe.typeSymbol.isJava

def companion(typeSymbol: Symbol): Symbol = {
val comp = typeSymbol.companion
if (comp.isModule) comp
else {
val ownerChainOf = (s: Symbol) =>
Iterator.iterate(s)(_.owner).takeWhile(x => x != NoSymbol).toVector.reverseIterator
val path = ownerChainOf(typeSymbol)
.zipAll(ownerChainOf(enclosingOwner), NoSymbol, NoSymbol)
.dropWhile { case (x, y) => x == y }
.takeWhile { case (x, _) => x != NoSymbol }
.map { case (x, _) => x.name.toTermName }
if (path.isEmpty) NoSymbol
else c.typecheck(path.foldLeft[Tree](Ident(path.next()))(Select(_, _)), silent = true).symbol
}
}

def toName(sym: Symbol): (List[String], List[String], String) = {
var values = List.empty[String]
var packages = List.empty[String]
var owner = companion(sym).owner
while (owner != NoSymbol) {
val name = NameTransformer.decode(owner.name.toString)
if (owner.isPackage || owner.isPackageClass) packages = name :: packages
else values = name :: values
owner = owner.owner
}
(packages.tail, values, NameTransformer.decode(sym.name.toString))
}

val tpe = weakTypeOf[A].dealias
if (isNonAbstractScalaClass(tpe)) {
case class FieldInfo(name: String, tpe: Type, getter: Symbol)

val tpeTypeSym = tpe.typeSymbol
val tpeName = toName(tpeTypeSym)
val tpeClassSym = tpeTypeSym.asClass
val primaryConstructor = tpe.decls.collectFirst {
case m: MethodSymbol if m.isPrimaryConstructor => m
}.getOrElse(fail(s"Cannot find a primary constructor for '$tpe'"))
val tpeTypeArgs = tpe.typeArgs
val tpeTypeParams = tpeClassSym.typeParams
val tpeParams = primaryConstructor.paramLists
val fieldInfos = tpeParams.map(_.map { param =>
val sym = param.asTerm
val name = sym.name
FieldInfo(
name = NameTransformer.decode(name.toString),
tpe = {
val originFieldTpe = sym.typeSignature.dealias
if (tpeTypeArgs.isEmpty) originFieldTpe
else originFieldTpe.substituteTypes(tpeTypeParams, tpeTypeArgs)
},
getter = {
tpe.members
.filter(_.name == name)
.collectFirst {
case m: MethodSymbol if m.isParamAccessor && m.isGetter =>
m
}
.getOrElse(fail(s"Cannot find '$name' parameter of '$tpe' in the primary constructor."))
}
)
})
// TODO: use `fieldInfos` to generate remaining `Reflect.Record.fields` and `Reflect.Record.recordBinding`
c.Expr[Schema[A]](
q"""new _root_.zio.blocks.schema.Schema[$tpe](
reflect = _root_.zio.blocks.schema.Reflect.Record[_root_.zio.blocks.schema.binding.Binding, $tpe](
fields = Nil,
typeName = TypeName(
namespace = Namespace(
packages = ${tpeName._1},
values = ${tpeName._2}
),
name = ${tpeName._3}
),
recordBinding = null,
doc = Doc.Empty,
modifiers = Nil
)
)"""
)
} else fail(s"Cannot derive '${typeOf[Schema[_]]}' for '$tpe'.")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package zio.blocks.schema

trait SchemaVersionSpecific {
inline def derived[A]: Schema[A] = ${ SchemaVersionSpecific.derived }
}

object SchemaVersionSpecific {
import scala.quoted._

def derived[A: Type](using Quotes): Expr[Schema[A]] = {
import quotes.reflect._
import zio.blocks.schema.binding.Binding

def fail(msg: String): Nothing = report.errorAndAbort(msg, Position.ofMacroExpansion)

def isNonAbstractScalaClass(tpe: TypeRepr): Boolean = tpe.classSymbol.fold(false) { sym =>
val flags = sym.flags
!flags.is(Flags.Abstract) && !flags.is(Flags.JavaDefined) && !flags.is(Flags.Trait)
}

def toName(sym: Symbol): (List[String], List[String], String) = {
var values = List.empty[String]
var packages = List.empty[String]
var owner = sym.owner
while (owner != quotes.reflect.defn.RootClass) {
val name = owner.name.toString
if (owner.flags.is(Flags.Package)) packages = name :: packages
else if (owner.flags.is(Flags.Module)) values = name.substring(0, name.length - 1) :: values
else values = name :: values
owner = owner.owner
}
(packages, values, sym.name.toString)
}

val tpe = TypeRepr.of[A].dealias
if (isNonAbstractScalaClass(tpe)) {
case class FieldInfo(name: String, tpe: TypeRepr, getter: Symbol)

val tpeTypeSym = tpe.typeSymbol
val tpeName = toName(tpeTypeSym)
val tpeClassSym = tpe.classSymbol.get
val primaryConstructor =
if (tpeClassSym.primaryConstructor.exists) tpeClassSym.primaryConstructor
else fail(s"Cannot find a primary constructor for '$tpe'")
val tpeTypeArgs = tpe match
case AppliedType(_, typeArgs) => typeArgs
case _ => Nil
val (tpeTypeParams, tpeParams) = primaryConstructor.paramSymss match {
case tps :: ps if tps.exists(_.isTypeParam) => (tps, ps)
case ps => (Nil, ps)
}
val fieldInfos = tpeParams.map(_.map { sym =>
val name = sym.name
FieldInfo(
name = name,
tpe = {
val originFieldType = tpe.memberType(sym).dealias
if (tpeTypeArgs.isEmpty) originFieldType
else originFieldType.substituteTypes(tpeTypeParams, tpeTypeArgs)
},
getter = {
val fieldMember = tpeClassSym.fieldMember(name)
if (fieldMember.exists) fieldMember
else {
tpeClassSym
.methodMember(name)
.find(_.flags.is(Flags.ParamAccessor | Flags.CaseAccessor))
.getOrElse(fail(s"Cannot find '$name' parameter of '${tpe.show}' in the primary constructor."))
}
}
)
})
// TODO: use `fieldInfos` to generate remaining `Reflect.Record.fields` and `Reflect.Record.recordBinding`
'{
new Schema[A](
reflect = new Reflect.Record[Binding, A](
fields = Nil,
typeName = TypeName(
namespace = Namespace(
packages = ${ Expr(tpeName._1) },
values = ${ Expr(tpeName._2) }
),
name = ${ Expr(tpeName._3) }
),
recordBinding = null,
doc = Doc.Empty,
modifiers = Nil
)
)
}.asExprOf[Schema[A]]
} else fail(s"Cannot derive '${TypeRepr.of[Schema[_]].show}' for '${tpe.show}'.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ final case class Schema[A](reflect: Reflect.Bound[A]) {
def toDynamicValue(value: A): DynamicValue = ??? // TODO
}

object Schema {
object Schema extends SchemaVersionSpecific {
def apply[A](implicit schema: Schema[A]): Schema[A] = schema

implicit val dynamic: Schema[DynamicValue] = Schema(Reflect.dynamic[Binding])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package zio.blocks.schema

import zio.Scope
import zio.blocks.schema.Reflect.Primitive
import zio.blocks.schema.binding.RegisterOffset.RegisterOffset
import zio.blocks.schema.binding._
import zio.test.Assertion._
import zio.test._

object SchemaVersionSpecificSpec extends ZIOSpecDefault {
def spec: Spec[TestEnvironment with Scope, Any] = suite("SchemaVersionSpecificSpec")(
suite("Reflect.Record")(
test("derives schema using 'derives' keyword") {
case class Record1(b: Byte, i: Int) derives Schema

assert(Schema[Record1])(
equalTo(
new Schema[Record1](
reflect = Reflect.Record[Binding, Record1](
fields = Nil,
typeName = TypeName(
namespace = Namespace(
packages = Seq("zio", "blocks", "schema"),
values = Seq("SchemaVersionSpecificSpec", "spec")
),
name = "Record1"
),
recordBinding = null,
doc = Doc.Empty,
modifiers = Nil
)
)
)
)
}
)
)
}
24 changes: 12 additions & 12 deletions schema/shared/src/test/scala/zio/blocks/schema/OpticSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ object OpticSpec extends ZIOSpecDefault {
Reflect.boolean[Binding].asTerm("b"),
Reflect.float[Binding].asTerm("f")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Record1"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Record1"),
recordBinding = Binding.Record(
constructor = new Constructor[Record1] {
def usedRegisters: RegisterOffset = RegisterOffset(booleans = 1, floats = 1)
Expand Down Expand Up @@ -691,7 +691,7 @@ object OpticSpec extends ZIOSpecDefault {
Reflect.vector(Reflect.int[Binding]).asTerm("vi"),
Record1.reflect.asTerm("r1")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Record2"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Record2"),
recordBinding = Binding.Record(
constructor = new Constructor[Record2] {
def usedRegisters: RegisterOffset = RegisterOffset(longs = 1, objects = 2)
Expand Down Expand Up @@ -734,7 +734,7 @@ object OpticSpec extends ZIOSpecDefault {
Record2.reflect.asTerm("r2"),
Reflect.Deferred(() => Variant1.reflect).asTerm("v1")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Record3"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Record3"),
recordBinding = Binding.Record(
constructor = new Constructor[Record3] {
def usedRegisters: RegisterOffset = RegisterOffset(objects = 3)
Expand Down Expand Up @@ -779,7 +779,7 @@ object OpticSpec extends ZIOSpecDefault {
Case2.reflect.asTerm("c2"),
Reflect.Deferred(() => Variant2.reflect).asTerm("v2")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Variant1"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Variant1"),
variantBinding = Binding.Variant(
discriminator = new Discriminator[Variant1] {
def discriminate(a: Variant1): Int = a match {
Expand Down Expand Up @@ -839,7 +839,7 @@ object OpticSpec extends ZIOSpecDefault {
fields = List(
Reflect.double[Binding].asTerm("d")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case1"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case1"),
recordBinding = Binding.Record(
constructor = new Constructor[Case1] {
def usedRegisters: RegisterOffset = RegisterOffset(doubles = 1)
Expand Down Expand Up @@ -867,7 +867,7 @@ object OpticSpec extends ZIOSpecDefault {
fields = List(
Record3.reflect.asTerm("r3")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case2"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case2"),
recordBinding = Binding.Record(
constructor = new Constructor[Case2] {
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
Expand Down Expand Up @@ -898,7 +898,7 @@ object OpticSpec extends ZIOSpecDefault {
Case4.reflect.asTerm("c4"),
Reflect.Deferred(() => Variant3.reflect).asTerm("v3")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Variant2"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Variant2"),
variantBinding = Binding.Variant(
discriminator = new Discriminator[Variant2] {
def discriminate(a: Variant2): Int = a match {
Expand Down Expand Up @@ -949,7 +949,7 @@ object OpticSpec extends ZIOSpecDefault {
fields = List(
Reflect.Deferred(() => Variant1.reflect).asTerm("v1")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case3"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case3"),
recordBinding = Binding.Record(
constructor = new Constructor[Case3] {
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
Expand Down Expand Up @@ -983,7 +983,7 @@ object OpticSpec extends ZIOSpecDefault {
fields = List(
Reflect.list(Record3.reflect).asTerm("lr3")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case4"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case4"),
recordBinding = Binding.Record(
constructor = new Constructor[Case4] {
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ object OpticSpec extends ZIOSpecDefault {
Case5.reflect.asTerm("c5"),
Case6.reflect.asTerm("c6")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Variant3"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Variant3"),
variantBinding = Binding.Variant(
discriminator = new Discriminator[Variant3] {
def discriminate(a: Variant3): Int = a match {
Expand Down Expand Up @@ -1053,7 +1053,7 @@ object OpticSpec extends ZIOSpecDefault {
Reflect.set(Reflect.int[Binding]).asTerm("si"),
Reflect.array(Reflect.string[Binding]).asTerm("as")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case5"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case5"),
recordBinding = Binding.Record(
constructor = new Constructor[Case5] {
def usedRegisters: RegisterOffset = RegisterOffset(objects = 2)
Expand Down Expand Up @@ -1088,7 +1088,7 @@ object OpticSpec extends ZIOSpecDefault {
fields = List(
Reflect.Deferred(() => Variant2.reflect).asTerm("v2")
),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case6"),
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case6"),
recordBinding = Binding.Record(
constructor = new Constructor[Case6] {
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
Expand Down
Loading
Loading