Skip to content

Commit

Permalink
Fix tparams for preconditions
Browse files Browse the repository at this point in the history
  • Loading branch information
drganam committed Sep 1, 2021
1 parent 63a066a commit 08b1fee
Showing 1 changed file with 43 additions and 42 deletions.
85 changes: 43 additions & 42 deletions core/src/main/scala/stainless/extraction/trace/Trace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ trait Trace extends CachingPhase with IdentityFunctions with IdentitySorts { sel
val newParamTps = eqLemma.tparams.map{tparam => tparam.tp}
val newParamVars = eqLemma.params.map{param => param.toVariable}

val specsMap = (fd1.params zip newParamVars).toMap
val subst = (fd1.params.map(_.id) zip newParamVars).toMap
val tsubst = (fd1.tparams zip newParamTps).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap
val specializer = new Specializer(eqLemma, eqLemma.id, tsubst, subst)

val specs = BodyWithSpecs(fd1.fullBody).specs.filter(s => s.kind == LetKind || s.kind == PreconditionKind)
val pre = specs.map(spec => spec match {
case Precondition(cond) => Precondition(exprOps.replaceFromSymbols(specsMap, cond))
case LetInSpec(vd, expr) => LetInSpec(vd, exprOps.replaceFromSymbols(specsMap, expr))
case Precondition(cond) => Precondition(specializer.transform(cond))
case LetInSpec(vd, expr) => LetInSpec(vd, specializer.transform(expr))
})

val fun1 = s.FunctionInvocation(fd1.id, newParamTps, newParamVars)
Expand All @@ -88,7 +91,7 @@ trait Trace extends CachingPhase with IdentityFunctions with IdentitySorts { sel

eqLemma.copy(
fullBody = BodyWithSpecs(withPre).withSpec(post).reconstructed,
flags = Seq(s.Derived(fd1.id), s.Annotation("traceInduct",List(StringLiteral(fd1.id.name)))),
flags = Seq(s.Derived(Some(fd1.id)), s.Annotation("traceInduct",List(StringLiteral(fd1.id.name)))),
returnType = s.UnitType()
).copiedFrom(eqLemma)
}
Expand Down Expand Up @@ -180,65 +183,37 @@ trait Trace extends CachingPhase with IdentityFunctions with IdentitySorts { sel
import exprOps._

val indPattern = exprOps.freshenSignature(model).copy(id = FreshIdentifier(lemma.id+"$induct"))

val newParamTps = indPattern.tparams.map{tparam => tparam.tp}
val newParamVars = indPattern.params.map{param => param.toVariable}

val fi = FunctionInvocation(model.id, newParamTps, newParamVars)

class Specializer(
origFd: FunDef,
newId: Identifier,
tsubst: Map[Identifier, Type],
vsubst: Map[Identifier, Expr]
) extends s.SelfTreeTransformer {

override def transform(expr: s.Expr): t.Expr = expr match {
case v: Variable =>
vsubst.getOrElse(v.id, super.transform(v))

case fi: FunctionInvocation if fi.id == origFd.id =>
val fi1 = FunctionInvocation(newId, tps = fi.tps, args = fi.args)
super.transform(fi1.copiedFrom(fi))

case _ => super.transform(expr)
}

override def transform(tpe: s.Type): t.Type = tpe match {
case tp: TypeParameter =>
tsubst.getOrElse(tp.id, super.transform(tp))

case _ => super.transform(tpe)
}
}

val tpairs = model.tparams zip fi.tps
val tsubst = tpairs.map { case (tparam, targ) => tparam.tp.id -> targ } .toMap

val subst = (model.params.map(_.id) zip fi.args).toMap
val specializer = new Specializer(model, indPattern.id, tsubst, subst)

val fullBodySpecialized = specializer.transform(exprOps.withoutSpecs(model.fullBody).get)

val specsMap = (lemma.params zip newParamVars).toMap ++ (model.params zip newParamVars).toMap
val specs = BodyWithSpecs(model.fullBody).specs ++ BodyWithSpecs(lemma.fullBody).specs.filterNot(_.kind == MeasureKind)
val specsSubst = (lemma.params.map(_.id) zip newParamVars).toMap ++ (model.params.map(_.id) zip newParamVars).toMap
val specsTsubst = ((lemma.tparams zip fi.tps) ++ (model.tparams zip fi.tps)).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap
val specsSpecializer = new Specializer(indPattern, indPattern.id, specsTsubst, specsSubst)

val specs = BodyWithSpecs(model.fullBody).specs ++ BodyWithSpecs(lemma.fullBody).specs.filterNot(_.kind == MeasureKind)
val pre = specs.filterNot(_.kind == PostconditionKind).map(spec => spec match {
case Precondition(cond) => Precondition(exprOps.replaceFromSymbols(specsMap, cond)).setPos(spec)
case LetInSpec(vd, expr) => LetInSpec(vd, exprOps.replaceFromSymbols(specsMap, expr)).setPos(spec)
case Measure(measure) => Measure(exprOps.replaceFromSymbols(specsMap, measure)).setPos(spec)
case Precondition(cond) => Precondition(specsSpecializer.transform(cond)).setPos(spec)
case LetInSpec(vd, expr) => LetInSpec(vd, specsSpecializer.transform(expr)).setPos(spec)
case Measure(measure) => Measure(specsSpecializer.transform(measure)).setPos(spec)
case s => context.reporter.fatalError(s"Unsupported specs: $s")
})

val withPre = exprOps.reconstructSpecs(pre, Some(fullBodySpecialized), indPattern.returnType)

val specsSpecializer = new Specializer(indPattern, indPattern.id, (lemma.tparams zip fi.tps).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap, Map())

val speccedLemma = BodyWithSpecs(lemma.fullBody).addPost
val speccedOrig = BodyWithSpecs(model.fullBody).addPost
val postLemma = speccedLemma.getSpec(PostconditionKind).map(post =>
specsSpecializer.transform(exprOps.replaceFromSymbols(specsMap, post.expr)))
val postOrig = speccedOrig.getSpec(PostconditionKind).map(post => exprOps.replaceFromSymbols(specsMap, post.expr))
specsSpecializer.transform(post.expr))
val postOrig = speccedOrig.getSpec(PostconditionKind).map(post => specsSpecializer.transform(post.expr))

(postLemma, postOrig) match {
case (Some(Lambda(Seq(res1), cond1)), Some(Lambda(Seq(res2), cond2))) =>
Expand All @@ -251,12 +226,38 @@ trait Trace extends CachingPhase with IdentityFunctions with IdentitySorts { sel

indPattern.copy(
fullBody = BodyWithSpecs(withPre).withSpec(post).reconstructed,
flags = Seq(s.Derived(lemma.id), s.Derived(model.id))
flags = Seq(s.Derived(Some(lemma.id)), s.Derived(Some(model.id)))
).copiedFrom(indPattern)
}

}

class Specializer(
origFd: FunDef,
newId: Identifier,
tsubst: Map[Identifier, Type],
vsubst: Map[Identifier, Expr]
) extends s.SelfTreeTransformer {

override def transform(expr: s.Expr): t.Expr = expr match {
case v: Variable =>
vsubst.getOrElse(v.id, super.transform(v))

case fi: FunctionInvocation if fi.id == origFd.id =>
val fi1 = FunctionInvocation(newId, tps = fi.tps, args = fi.args)
super.transform(fi1.copiedFrom(fi))

case _ => super.transform(expr)
}

override def transform(tpe: s.Type): t.Type = tpe match {
case tp: TypeParameter =>
tsubst.getOrElse(tp.id, super.transform(tp))

case _ => super.transform(tpe)
}
}

type Path = Seq[String]

private lazy val pathsOpt: Option[Seq[Path]] = context.options.findOption(optCompareFuns) map { functions =>
Expand Down

0 comments on commit 08b1fee

Please sign in to comment.