Skip to content

Commit 917a875

Browse files
committed
[type checking] Solve meta right before failing subsumption check
1 parent e3e6921 commit 917a875

File tree

8 files changed

+120
-34
lines changed

8 files changed

+120
-34
lines changed

.idea/modules/archon.archon-build.iml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules/archon.iml

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

build.sbt

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ lazy val root = (project in file("."))
2121
"com.lihaoyi" %% "pprint" % "0.9.0",
2222
"com.lihaoyi" %% "fastparse" % "3.1.0",
2323
"com.lihaoyi" %% "os-lib" % "0.10.2",
24+
"org.scala-lang.modules" %% "scala-collection-contrib" % "0.3.0"
2425
),
2526
Compile / unmanagedJars += {
2627
baseDirectory.value / "unmanaged" / "scalaz3_3-4.8.14-macos-aarch64.jar"

src/main/scala/com/github/tgeng/archon/core/ir/pprint.scala

+23-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import com.github.tgeng.archon.core.ir.Pattern.*
1212
import com.github.tgeng.archon.core.ir.VTerm.*
1313
import com.github.tgeng.archon.core.ir.Variance.*
1414

15+
import scala.collection.decorators.*
1516
import scala.collection.mutable
1617
import scala.collection.mutable.ArrayBuffer
1718

@@ -219,6 +220,27 @@ object PrettyPrinter extends Visitor[PPrintContext, Block]:
219220
: Block =
220221
Block(usageLiteral.usage.toString)
221222

223+
override def visitUsageProd
224+
(usageProd: VTerm.UsageProd)
225+
(using ctx: PPrintContext)
226+
(using Σ: Signature)
227+
: Block =
228+
Block(usageProd.operands.toSeq.map(visitVTerm).intersperse("*"), Whitespace, Aligned, Wrap)
229+
230+
override def visitUsageSum
231+
(usageSum: VTerm.UsageSum)
232+
(using ctx: PPrintContext)
233+
(using Σ: Signature)
234+
: Block =
235+
Block(usageSum.operands.multiToSeq.map(visitVTerm).intersperse("+"), Whitespace, Aligned, Wrap)
236+
237+
override def visitUsageJoin
238+
(usageJoin: VTerm.UsageJoin)
239+
(using ctx: PPrintContext)
240+
(using Σ: Signature)
241+
: Block =
242+
Block(usageJoin.operands.toSeq.map(visitVTerm).intersperse("|"), Whitespace, Aligned, Wrap)
243+
222244
override def visitTContext
223245
(tTelescope: List[(Binding[VTerm], Variance)])
224246
(using PPrintContext)
@@ -434,7 +456,7 @@ object PrettyPrinter extends Visitor[PPrintContext, Block]:
434456
(using
435457
Σ: Signature,
436458
)
437-
: Block = app(".return", r.v)
459+
: Block = app(".return", "[", r.usage, "]", r.v)
438460

439461
override def visitLet
440462
(let: Let)

src/main/scala/com/github/tgeng/archon/core/ir/reduction.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ extension (v: VTerm)
544544

545545
def sumToTerm(sum: USum[VTerm]): VTerm = UsageSum(sum.map(prodToTerm)*)
546546

547-
def prodToTerm(prod: UProd[VTerm]): VTerm = UsageProd(prod.map(varOrUsageToTerm).toSeq*)
547+
def prodToTerm(prod: UProd[VTerm]): VTerm = UsageProd(prod.map(varOrUsageToTerm)*)
548548

549549
def varOrUsageToTerm(t: VTerm | Usage): VTerm = t match
550550
case v: VTerm => v

src/main/scala/com/github/tgeng/archon/core/ir/subtyping.scala

+41-10
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,11 @@ private def checkHandlerTypeSubsumption
368368
Set(Constraint.HandlerTypeSubsumption(Γ, handlerType1, handlerType2))
369369
case (_: ResolvedMetaVariable, _: ResolvedMetaVariable) =>
370370
Set(Constraint.HandlerTypeSubsumption(Γ, handlerType1, handlerType2))
371-
case _ => throw NotHandlerTypeSubsumption(handlerType1, handlerType2)
371+
case (sub: VTerm, sup: VTerm) =>
372+
val solvedSub = ctx.solveTerm(sub)
373+
val solvedSup = ctx.solveTerm(sup)
374+
if solvedSub == sub && solvedSup == sup then throw NotHandlerTypeSubsumption(sub, sup)
375+
else checkHandlerTypeSubsumption(solvedSub, solvedSup)
372376

373377
@throws(classOf[IrError])
374378
private def checkEqDecidabilitySubsumption
@@ -390,7 +394,11 @@ private def checkEqDecidabilitySubsumption
390394
Set(Constraint.EqDecidabilitySubsumption(Γ, eqD1, eqD2))
391395
case (_: ResolvedMetaVariable, _: ResolvedMetaVariable) =>
392396
Set(Constraint.EqDecidabilitySubsumption(Γ, eqD1, eqD2))
393-
case _ => throw NotEqDecidabilitySubsumption(eqD1, eqD2)
397+
case (sub: VTerm, sup: VTerm) =>
398+
val solvedSub = ctx.solveTerm(sub)
399+
val solvedSup = ctx.solveTerm(sup)
400+
if solvedSub == sub && solvedSup == sup then throw NotEqDecidabilitySubsumption(sub, sup)
401+
else checkEqDecidabilitySubsumption(solvedSub, solvedSup)
394402

395403
/** @param invert
396404
* useful when checking patterns where the consumed usages are actually provided usages because
@@ -442,7 +450,11 @@ def checkUsageSubsumption
442450
// hence we can't decide subsumption yet.
443451
else if spuriousOperands.forall(isMeta) || operands1.exists(isMeta) then
444452
Set(Constraint.UsageSubsumption(Γ, sub, sup))
445-
else throw NotUsageSubsumption(sub, sup)
453+
else
454+
val solvedSub = ctx.solveTerm(sub)
455+
val solvedSup = ctx.solveTerm(sup)
456+
if solvedSub == sub && solvedSup == sup then throw NotUsageSubsumption(sub, sup)
457+
else checkUsageSubsumption(solvedSub, solvedSup)
446458
// Handle the special case that the right hand side simply contains the left hand side as an operand.
447459
case (UsageJoin(operands), RUnsolved(_, _, _, tm, _)) if operands.contains(Collapse(tm)) =>
448460
Set.empty
@@ -512,11 +524,15 @@ def checkUsageSubsumption
512524
Set(Constraint.UsageSubsumption(Γ, rawSub, sup))
513525
case (_: ResolvedMetaVariable, _: ResolvedMetaVariable) =>
514526
Set(Constraint.UsageSubsumption(Γ, rawSub, rawSup))
515-
case _ =>
527+
case (sub: VTerm, sup: VTerm) =>
516528
if isMeta(rawSub) || isMeta(rawSup) then
517529
// We can't decide if the terms are unsolved.
518530
Set(Constraint.UsageSubsumption(Γ, rawSub, rawSup))
519-
else throw NotUsageSubsumption(rawSub, rawSup)
531+
else
532+
val solvedSub = ctx.solveTerm(sub)
533+
val solvedSup = ctx.solveTerm(sup)
534+
if solvedSub == sub && solvedSup == sup then throw NotUsageSubsumption(sub, sup)
535+
else checkUsageSubsumption(solvedSub, solvedSup)
520536

521537
@throws(classOf[IrError])
522538
private def checkEffSubsumption
@@ -600,15 +616,23 @@ private def checkEffSubsumption
600616
).normalized
601617
ctx.updateConstraint(u, UmcEffSubsumption(newLowerBound))
602618
Set.empty
603-
case _ => throw NotEffectSubsumption(sub, sup)
619+
case _ =>
620+
val solvedSub = ctx.solveTerm(sub)
621+
val solvedSup = ctx.solveTerm(sup)
622+
if solvedSub == sub && solvedSup == sup then throw NotEffectSubsumption(sub, sup)
623+
else checkEffSubsumption(solvedSub, solvedSup)
604624
// If spurious operands are all stuck computation, it's possible for sub to be if all of these stuck computation
605625
// ends up being assigned values that are part of sup
606626
// Also, if sup contains stuck computation, it's possible for sup to end up including arbitrary effects and hence
607627
// we can't decide subsumption yet.
608628
else if spuriousOperands.keys.forall(isMeta) || unionOperands2.keys.exists(isMeta) then
609629
Set(Constraint.EffSubsumption(Γ, sub, sup))
610630
else throw NotEffectSubsumption(sub, sup)
611-
case _ => throw NotEffectSubsumption(rawSub, rawSup)
631+
case (sub: VTerm, sup: VTerm) =>
632+
val solvedSub = ctx.solveTerm(sub)
633+
val solvedSup = ctx.solveTerm(sup)
634+
if solvedSub == sub && solvedSup == sup then throw NotEffectSubsumption(sub, sup)
635+
else checkEffSubsumption(solvedSub, solvedSup)
612636

613637
/** Checks if l1 is smaller than l2.
614638
*/
@@ -647,7 +671,11 @@ private def checkLevelSubsumption
647671
// if sup contains unsolved meta variables, it's possible for sup to end up including
648672
// arbitrary large level and hence we can't decide subsumption yet.
649673
Set(Constraint.LevelSubsumption(Γ, sub, sup))
650-
else throw NotLevelSubsumption(sub, sup)
674+
else
675+
val solvedSub = ctx.solveTerm(sub)
676+
val solvedSup = ctx.solveTerm(sup)
677+
if solvedSub == sub && solvedSup == sup then throw NotLevelSubsumption(sub, sup)
678+
else checkLevelSubsumption(solvedSub, solvedSup)
651679
// Handle the special case that the right hand side simply contains the left hand side as an operand.
652680
case (RUnsolved(_, _, _, tm, _), Level(_, operands)) if operands.contains(Collapse(tm)) =>
653681
Set.empty
@@ -676,8 +704,11 @@ private def checkLevelSubsumption
676704
Set(Constraint.LevelSubsumption(Γ, rawSub, sup))
677705
case (_: ResolvedMetaVariable, _: ResolvedMetaVariable) =>
678706
Set(Constraint.LevelSubsumption(Γ, rawSub, rawSup))
679-
case (sub, sup) =>
680-
throw NotLevelSubsumption(rawSub, rawSup)
707+
case (sub: VTerm, sup: VTerm) =>
708+
val solvedSub = ctx.solveTerm(sub)
709+
val solvedSup = ctx.solveTerm(sup)
710+
if solvedSub == sub && solvedSup == sup then throw NotLevelSubsumption(sub, sup)
711+
else checkLevelSubsumption(solvedSub, solvedSup)
681712

682713
private def getSpurious[T, E: PartialOrdering](sub: SeqMap[T, E], sup: SeqMap[T, E]): SeqMap[T, E] =
683714
sub.filter { case (operand1, e1) =>

src/main/scala/com/github/tgeng/archon/core/ir/usages.scala

+27-10
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ def collectUsages
3636
(using Σ: Signature)
3737
(using ctx: TypingContext)
3838
: Usages =
39-
ctx.trace("collectUsages", PrettyPrinter.pprint(tm)):
39+
ctx.trace[Usages](
40+
"collectUsages",
41+
ty match
42+
case Some(ty) => Block(PrettyPrinter.pprint(tm), ":", PrettyPrinter.pprint(ty))
43+
case None => PrettyPrinter.pprint(tm),
44+
successMsg = PrettyPrinter.pprint,
45+
):
4046
tm match
4147
case Type(upperBound) => collectUsages(upperBound, None)
4248
case Top(level, eqDecidability) =>
@@ -98,7 +104,13 @@ def collectUsages
98104
(using Σ: Signature)
99105
(using ctx: TypingContext)
100106
: Usages =
101-
ctx.trace("collectUsages", PrettyPrinter.pprint(tm)):
107+
ctx.trace[Usages](
108+
"collectUsages",
109+
ty match
110+
case Some(ty) => Block(PrettyPrinter.pprint(tm), ":", PrettyPrinter.pprint(ty))
111+
case None => PrettyPrinter.pprint(tm),
112+
successMsg = PrettyPrinter.pprint,
113+
):
102114
tm match
103115
case Hole => throw IllegalStateException("Hole should not appear here.")
104116
case CapturedContinuationTip(_) =>
@@ -127,17 +139,22 @@ def collectUsages
127139
case F(ty, _, _) => ty
128140
case ty => throw IllegalStateException(s"bad type, expect F but got $ty")
129141
},
130-
) * usage
142+
) * usage.normalized
131143
case Let(t, tBinding, eff, body) =>
132144
val tUsages = collectUsages(t, Some(F(tBinding.ty, eff, tBinding.usage)))
133-
val bodyUsages = collectUsages(body, ty.map(_.weakened))(using Γ :+ tBinding)
134-
val actualTUsage = bodyUsages.last
135-
ctx.checkSolved(
136-
checkUsageSubsumption(actualTUsage, tBinding.usage),
137-
NotUsageSubsumption(actualTUsage, tBinding.usage),
138-
)
145+
val bodyUsages = {
146+
given Context = Γ :+ tBinding
147+
148+
val bodyUsages = collectUsages(body, ty.map(_.weakened))
149+
val actualTUsage = bodyUsages.last
150+
ctx.checkSolved(
151+
checkUsageSubsumption(actualTUsage, tBinding.usage),
152+
NotUsageSubsumption(actualTUsage, tBinding.usage),
153+
)
154+
bodyUsages.dropRight(1)
155+
}
139156
val continuationUsage = getEffectsContinuationUsage(eff)
140-
tUsages + bodyUsages.dropRight(1).map { t =>
157+
tUsages + bodyUsages.map { t =>
141158
// A variable's usage may reference the variable bound to the value returned from `t`. In
142159
// this case, strength would fail and the referenced usage can take any value. Hence, we
143160
// just approximate it with `uAny`.

src/test/scala/com/github/tgeng/archon/core/ir/typing/BasicTypeCheckSpec.scala

+25-11
Original file line numberDiff line numberDiff line change
@@ -128,26 +128,40 @@ class BasicTypeCheckSpec extends AnyFreeSpec:
128128
"with nat" in:
129129
decls"""
130130
data Nat: Type 0L
131-
Zero: Nat
132-
Succ: Nat -> Nat
133-
131+
Z: Nat
132+
S: Nat -> Nat
133+
134134
def prec: Nat -> <> Nat
135-
Zero = Zero
136-
(Succ m) = m
135+
Z = Z
136+
(S m) = m
137+
138+
def plus: Nat -> Nat -> <> Nat
139+
Z n = n
140+
(S m) n = S (plus m n)
137141

138142
data Vec (l: Level) +(t: Type l): Nat -> Type l
139-
Nil: Vec l t Zero
140-
Succ: n: Nat -> t -> Vec l t n -> Vec l t (Succ n)
143+
Nil: Vec l t Z
144+
Suc: n: Nat -> t -> Vec l t n -> Vec l t (S n)
145+
141146
""".inUse:
142147
assertVType(vt"Nat", Type(Top(LevelLiteral(0))))
143-
assertVType(vt"Zero", vt"Nat")
148+
assertVType(vt"Z", vt"Nat")
149+
assertCType(
150+
ct"prec (S Z)",
151+
ct"<> Nat",
152+
)
153+
assertCConvertible(
154+
ct"prec (S Z)",
155+
ct"Z",
156+
Some(ct"<> Nat"),
157+
)
144158
assertCType(
145-
ct"prec (Succ Zero)",
159+
ct"plus (S Z) (S (S Z))",
146160
ct"<> Nat",
147161
)
148162
assertCConvertible(
149-
ct"prec (Succ Zero)",
150-
ct"Zero",
163+
ct"plus (S Z) (S (S Z))",
164+
ct"S (S Z)",
151165
Some(ct"<> Nat"),
152166
)
153167
}

0 commit comments

Comments
 (0)