Skip to content

Commit 61c9dba

Browse files
committed
Add Schema.derived
1 parent d117969 commit 61c9dba

File tree

7 files changed

+325
-30
lines changed

7 files changed

+325
-30
lines changed

build.sbt

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,27 @@ lazy val schema = crossProject(JSPlatform, JVMPlatform, NativePlatform)
5757
.nativeSettings(nativeSettings)
5858
.settings(
5959
compileOrder := CompileOrder.JavaThenScala,
60-
scalacOptions ++= (if (scalaVersion.value == Scala3)
61-
Seq(
62-
"-explain"
63-
)
64-
else
65-
Seq(
66-
"-opt:l:method"
67-
)),
60+
scalacOptions ++=
61+
(CrossVersion.partialVersion(scalaVersion.value) match {
62+
case Some((2, _)) =>
63+
Seq(
64+
"-opt:l:method"
65+
)
66+
case _ =>
67+
Seq(
68+
"-explain"
69+
)
70+
}),
6871
libraryDependencies ++= Seq(
6972
"dev.zio" %%% "zio-test" % "2.1.16" % Test,
7073
"dev.zio" %%% "zio-test-sbt" % "2.1.16" % Test
71-
),
74+
) ++ (CrossVersion.partialVersion(scalaVersion.value) match {
75+
case Some((2, _)) =>
76+
Seq(
77+
"org.scala-lang" % "scala-reflect" % scalaVersion.value
78+
)
79+
case _ => Seq()
80+
}),
7281
testFrameworks += new TestFramework("zio.test.sbt.ZTestFramework")
7382
)
7483

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package zio.blocks.schema
2+
3+
trait SchemaVersionSpecific {
4+
import scala.language.experimental.macros
5+
6+
def derived[A]: Schema[A] = macro SchemaVersionSpecific.derived[A]
7+
}
8+
9+
object SchemaVersionSpecific {
10+
import scala.reflect.macros.blackbox
11+
import scala.reflect.NameTransformer
12+
13+
def derived[A: c.WeakTypeTag](c: blackbox.Context): c.Expr[Schema[A]] = {
14+
import c.universe._
15+
import c.internal._
16+
17+
def fail(msg: String): Nothing = c.abort(c.enclosingPosition, msg)
18+
19+
def isNonAbstractScalaClass(tpe: Type): Boolean =
20+
tpe.typeSymbol.isClass && !tpe.typeSymbol.isAbstract && !tpe.typeSymbol.isJava
21+
22+
def companion(typeSymbol: Symbol): Symbol = {
23+
val comp = typeSymbol.companion
24+
if (comp.isModule) comp
25+
else {
26+
val ownerChainOf = (s: Symbol) =>
27+
Iterator.iterate(s)(_.owner).takeWhile(x => x != NoSymbol).toVector.reverseIterator
28+
val path = ownerChainOf(typeSymbol)
29+
.zipAll(ownerChainOf(enclosingOwner), NoSymbol, NoSymbol)
30+
.dropWhile { case (x, y) => x == y }
31+
.takeWhile { case (x, _) => x != NoSymbol }
32+
.map { case (x, _) => x.name.toTermName }
33+
if (path.isEmpty) NoSymbol
34+
else c.typecheck(path.foldLeft[Tree](Ident(path.next()))(Select(_, _)), silent = true).symbol
35+
}
36+
}
37+
38+
def toName(sym: Symbol): (List[String], List[String], String) = {
39+
var values = List.empty[String]
40+
var packages = List.empty[String]
41+
var owner = companion(sym).owner
42+
while (owner != NoSymbol) {
43+
val name = NameTransformer.decode(owner.name.toString)
44+
if (owner.isPackage || owner.isPackageClass) packages = name :: packages
45+
else values = name :: values
46+
owner = owner.owner
47+
}
48+
(packages.tail, values, NameTransformer.decode(sym.name.toString))
49+
}
50+
51+
val tpe = weakTypeOf[A].dealias
52+
if (isNonAbstractScalaClass(tpe)) {
53+
case class FieldInfo(name: String, tpe: Type, getter: Symbol)
54+
55+
val tpeTypeSym = tpe.typeSymbol
56+
val tpeName = toName(tpeTypeSym)
57+
val tpeClassSym = tpeTypeSym.asClass
58+
val primaryConstructor = tpe.decls.collectFirst {
59+
case m: MethodSymbol if m.isPrimaryConstructor => m
60+
}.getOrElse(fail(s"Cannot find a primary constructor for '$tpe'"))
61+
val tpeTypeArgs = tpe.typeArgs
62+
val tpeTypeParams = tpeClassSym.typeParams
63+
val tpeParams = primaryConstructor.paramLists
64+
val fieldInfos = tpeParams.map(_.map { param =>
65+
val sym = param.asTerm
66+
val name = sym.name
67+
FieldInfo(
68+
name = NameTransformer.decode(name.toString),
69+
tpe = {
70+
val originFieldTpe = sym.typeSignature.dealias
71+
if (tpeTypeArgs.isEmpty) originFieldTpe
72+
else originFieldTpe.substituteTypes(tpeTypeParams, tpeTypeArgs)
73+
},
74+
getter = {
75+
tpe.members
76+
.filter(_.name == name)
77+
.collectFirst {
78+
case m: MethodSymbol if m.isParamAccessor && m.isGetter =>
79+
m
80+
}
81+
.getOrElse(fail(s"Cannot find '$name' parameter of '$tpe' in the primary constructor."))
82+
}
83+
)
84+
})
85+
// TODO: use `fieldInfos` to generate remaining `Reflect.Record.fields` and `Reflect.Record.recordBinding`
86+
c.Expr[Schema[A]](
87+
q"""new _root_.zio.blocks.schema.Schema[$tpe](
88+
reflect = _root_.zio.blocks.schema.Reflect.Record[_root_.zio.blocks.schema.binding.Binding, $tpe](
89+
fields = Nil,
90+
typeName = TypeName(
91+
namespace = Namespace(
92+
packages = ${tpeName._1},
93+
values = ${tpeName._2}
94+
),
95+
name = ${tpeName._3}
96+
),
97+
recordBinding = null,
98+
doc = Doc.Empty,
99+
modifiers = Nil
100+
)
101+
)"""
102+
)
103+
} else fail(s"Cannot derive '${typeOf[Schema[_]]}' for '$tpe'.")
104+
}
105+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package zio.blocks.schema
2+
3+
trait SchemaVersionSpecific {
4+
inline def derived[A]: Schema[A] = ${ SchemaVersionSpecific.derived }
5+
}
6+
7+
object SchemaVersionSpecific {
8+
import scala.quoted._
9+
10+
def derived[A: Type](using Quotes): Expr[Schema[A]] = {
11+
import quotes.reflect._
12+
import zio.blocks.schema.binding.Binding
13+
14+
def fail(msg: String): Nothing = report.errorAndAbort(msg, Position.ofMacroExpansion)
15+
16+
def isNonAbstractScalaClass(tpe: TypeRepr): Boolean = tpe.classSymbol.fold(false) { sym =>
17+
val flags = sym.flags
18+
!flags.is(Flags.Abstract) && !flags.is(Flags.JavaDefined) && !flags.is(Flags.Trait)
19+
}
20+
21+
def toName(sym: Symbol): (List[String], List[String], String) = {
22+
var values = List.empty[String]
23+
var packages = List.empty[String]
24+
var owner = sym.owner
25+
while (owner != quotes.reflect.defn.RootClass) {
26+
val name = owner.name.toString
27+
if (owner.flags.is(Flags.Package)) packages = name :: packages
28+
else if (owner.flags.is(Flags.Module)) values = name.substring(0, name.length - 1) :: values
29+
else values = name :: values
30+
owner = owner.owner
31+
}
32+
(packages, values, sym.name.toString)
33+
}
34+
35+
val tpe = TypeRepr.of[A].dealias
36+
if (isNonAbstractScalaClass(tpe)) {
37+
case class FieldInfo(name: String, tpe: TypeRepr, getter: Symbol)
38+
39+
val tpeTypeSym = tpe.typeSymbol
40+
val tpeName = toName(tpeTypeSym)
41+
val tpeClassSym = tpe.classSymbol.get
42+
val primaryConstructor =
43+
if (tpeClassSym.primaryConstructor.exists) tpeClassSym.primaryConstructor
44+
else fail(s"Cannot find a primary constructor for '$tpe'")
45+
val tpeTypeArgs = tpe match
46+
case AppliedType(_, typeArgs) => typeArgs
47+
case _ => Nil
48+
val (tpeTypeParams, tpeParams) = primaryConstructor.paramSymss match {
49+
case tps :: ps if tps.exists(_.isTypeParam) => (tps, ps)
50+
case ps => (Nil, ps)
51+
}
52+
val fieldInfos = tpeParams.map(_.map { sym =>
53+
val name = sym.name
54+
FieldInfo(
55+
name = name,
56+
tpe = {
57+
val originFieldType = tpe.memberType(sym).dealias
58+
if (tpeTypeArgs.isEmpty) originFieldType
59+
else originFieldType.substituteTypes(tpeTypeParams, tpeTypeArgs)
60+
},
61+
getter = {
62+
val fieldMember = tpeClassSym.fieldMember(name)
63+
if (fieldMember.exists) fieldMember
64+
else {
65+
tpeClassSym
66+
.methodMember(name)
67+
.find(_.flags.is(Flags.ParamAccessor | Flags.CaseAccessor))
68+
.getOrElse(fail(s"Cannot find '$name' parameter of '${tpe.show}' in the primary constructor."))
69+
}
70+
}
71+
)
72+
})
73+
// TODO: use `fieldInfos` to generate remaining `Reflect.Record.fields` and `Reflect.Record.recordBinding`
74+
'{
75+
new Schema[A](
76+
reflect = new Reflect.Record[Binding, A](
77+
fields = Nil,
78+
typeName = TypeName(
79+
namespace = Namespace(
80+
packages = ${ Expr(tpeName._1) },
81+
values = ${ Expr(tpeName._2) }
82+
),
83+
name = ${ Expr(tpeName._3) }
84+
),
85+
recordBinding = null,
86+
doc = Doc.Empty,
87+
modifiers = Nil
88+
)
89+
)
90+
}.asExprOf[Schema[A]]
91+
} else fail(s"Cannot derive '${TypeRepr.of[Schema[_]].show}' for '${tpe.show}'.")
92+
}
93+
}

schema/shared/src/main/scala/zio/blocks/schema/Schema.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ final case class Schema[A](reflect: Reflect.Bound[A]) {
5252
def toDynamicValue(value: A): DynamicValue = ??? // TODO
5353
}
5454

55-
object Schema {
55+
object Schema extends SchemaVersionSpecific {
5656
def apply[A](implicit schema: Schema[A]): Schema[A] = schema
5757

5858
implicit val dynamic: Schema[DynamicValue] = Schema(Reflect.dynamic[Binding])
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package zio.blocks.schema
2+
3+
import zio.Scope
4+
import zio.blocks.schema.Reflect.Primitive
5+
import zio.blocks.schema.binding.RegisterOffset.RegisterOffset
6+
import zio.blocks.schema.binding._
7+
import zio.test.Assertion._
8+
import zio.test._
9+
10+
object SchemaVersionSpecificSpec extends ZIOSpecDefault {
11+
def spec: Spec[TestEnvironment with Scope, Any] = suite("SchemaVersionSpecificSpec")(
12+
suite("Reflect.Record")(
13+
test("derives schema using 'derives' keyword") {
14+
case class Record1(b: Byte, i: Int) derives Schema
15+
16+
assert(Schema[Record1])(
17+
equalTo(
18+
new Schema[Record1](
19+
reflect = Reflect.Record[Binding, Record1](
20+
fields = Nil,
21+
typeName = TypeName(
22+
namespace = Namespace(
23+
packages = Seq("zio", "blocks", "schema"),
24+
values = Seq("SchemaVersionSpecificSpec", "spec")
25+
),
26+
name = "Record1"
27+
),
28+
recordBinding = null,
29+
doc = Doc.Empty,
30+
modifiers = Nil
31+
)
32+
)
33+
)
34+
)
35+
}
36+
)
37+
)
38+
}

schema/shared/src/test/scala/zio/blocks/schema/OpticSpec.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ object OpticSpec extends ZIOSpecDefault {
657657
Reflect.boolean[Binding].asTerm("b"),
658658
Reflect.float[Binding].asTerm("f")
659659
),
660-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Record1"),
660+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Record1"),
661661
recordBinding = Binding.Record(
662662
constructor = new Constructor[Record1] {
663663
def usedRegisters: RegisterOffset = RegisterOffset(booleans = 1, floats = 1)
@@ -691,7 +691,7 @@ object OpticSpec extends ZIOSpecDefault {
691691
Reflect.vector(Reflect.int[Binding]).asTerm("vi"),
692692
Record1.reflect.asTerm("r1")
693693
),
694-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Record2"),
694+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Record2"),
695695
recordBinding = Binding.Record(
696696
constructor = new Constructor[Record2] {
697697
def usedRegisters: RegisterOffset = RegisterOffset(longs = 1, objects = 2)
@@ -734,7 +734,7 @@ object OpticSpec extends ZIOSpecDefault {
734734
Record2.reflect.asTerm("r2"),
735735
Reflect.Deferred(() => Variant1.reflect).asTerm("v1")
736736
),
737-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Record3"),
737+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Record3"),
738738
recordBinding = Binding.Record(
739739
constructor = new Constructor[Record3] {
740740
def usedRegisters: RegisterOffset = RegisterOffset(objects = 3)
@@ -779,7 +779,7 @@ object OpticSpec extends ZIOSpecDefault {
779779
Case2.reflect.asTerm("c2"),
780780
Reflect.Deferred(() => Variant2.reflect).asTerm("v2")
781781
),
782-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Variant1"),
782+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Variant1"),
783783
variantBinding = Binding.Variant(
784784
discriminator = new Discriminator[Variant1] {
785785
def discriminate(a: Variant1): Int = a match {
@@ -839,7 +839,7 @@ object OpticSpec extends ZIOSpecDefault {
839839
fields = List(
840840
Reflect.double[Binding].asTerm("d")
841841
),
842-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case1"),
842+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case1"),
843843
recordBinding = Binding.Record(
844844
constructor = new Constructor[Case1] {
845845
def usedRegisters: RegisterOffset = RegisterOffset(doubles = 1)
@@ -867,7 +867,7 @@ object OpticSpec extends ZIOSpecDefault {
867867
fields = List(
868868
Record3.reflect.asTerm("r3")
869869
),
870-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case2"),
870+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case2"),
871871
recordBinding = Binding.Record(
872872
constructor = new Constructor[Case2] {
873873
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
@@ -898,7 +898,7 @@ object OpticSpec extends ZIOSpecDefault {
898898
Case4.reflect.asTerm("c4"),
899899
Reflect.Deferred(() => Variant3.reflect).asTerm("v3")
900900
),
901-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Variant2"),
901+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Variant2"),
902902
variantBinding = Binding.Variant(
903903
discriminator = new Discriminator[Variant2] {
904904
def discriminate(a: Variant2): Int = a match {
@@ -949,7 +949,7 @@ object OpticSpec extends ZIOSpecDefault {
949949
fields = List(
950950
Reflect.Deferred(() => Variant1.reflect).asTerm("v1")
951951
),
952-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case3"),
952+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case3"),
953953
recordBinding = Binding.Record(
954954
constructor = new Constructor[Case3] {
955955
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
@@ -983,7 +983,7 @@ object OpticSpec extends ZIOSpecDefault {
983983
fields = List(
984984
Reflect.list(Record3.reflect).asTerm("lr3")
985985
),
986-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case4"),
986+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case4"),
987987
recordBinding = Binding.Record(
988988
constructor = new Constructor[Case4] {
989989
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)
@@ -1015,7 +1015,7 @@ object OpticSpec extends ZIOSpecDefault {
10151015
Case5.reflect.asTerm("c5"),
10161016
Case6.reflect.asTerm("c6")
10171017
),
1018-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Variant3"),
1018+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Variant3"),
10191019
variantBinding = Binding.Variant(
10201020
discriminator = new Discriminator[Variant3] {
10211021
def discriminate(a: Variant3): Int = a match {
@@ -1053,7 +1053,7 @@ object OpticSpec extends ZIOSpecDefault {
10531053
Reflect.set(Reflect.int[Binding]).asTerm("si"),
10541054
Reflect.array(Reflect.string[Binding]).asTerm("as")
10551055
),
1056-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case5"),
1056+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case5"),
10571057
recordBinding = Binding.Record(
10581058
constructor = new Constructor[Case5] {
10591059
def usedRegisters: RegisterOffset = RegisterOffset(objects = 2)
@@ -1088,7 +1088,7 @@ object OpticSpec extends ZIOSpecDefault {
10881088
fields = List(
10891089
Reflect.Deferred(() => Variant2.reflect).asTerm("v2")
10901090
),
1091-
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Nil), "Case6"),
1091+
typeName = TypeName(Namespace(List("zio", "blocks", "schema"), Seq("OpticSpec")), "Case6"),
10921092
recordBinding = Binding.Record(
10931093
constructor = new Constructor[Case6] {
10941094
def usedRegisters: RegisterOffset = RegisterOffset(objects = 1)

0 commit comments

Comments
 (0)