Skip to content

Commit

Permalink
Remove ensuring clause in ghost elimination (#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev authored Nov 1, 2023
1 parent d941894 commit 92ee799
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class FragmentChecker(inoxCtx: inox.Context)(using override val dottyCtx: DottyC
class Checker extends tpd.TreeTraverser {
private val ScalaEnsuringMethod = requiredMethod("scala.Predef.Ensuring")

private val StainlessLangPackage = getPackageIfDefinedOrNone("stainless.lang")
private val StainlessLangPackage = getClassIfDefinedOrNone("stainless.lang.package$")
private val ExternAnnotation = getClassIfDefinedOrNone("stainless.annotation.extern")
private val IgnoreAnnotation = getClassIfDefinedOrNone("stainless.annotation.ignore")
private val StainlessOld = StainlessLangPackage.map(_.info.decl(Names.termName("old")).symbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@ class GhostAccessRewriter(afterPhase: String) extends PluginPhase { self =>
// However, the MacroTransform class is better suited for our needs.
// Because PluginPhase extends MiniPhase (which is a class) and that MacroTransform is a class, we can't extend
// both. So we use composition instead of inheritance to achieve our goal.
private class GhostAccessMacroTransform(using override val dottyCtx: DottyContext) extends MacroTransform with ASTExtractors {
import StructuralExtractors._
import AuxiliaryExtractors._

private val ghostAnnotation = Symbols.requiredClass("stainless.annotation.ghost")
private val ghostFun = Symbols.requiredMethod("stainless.lang.ghost")

private class GhostAccessMacroTransform extends MacroTransform {
override val phaseName = self.phaseName
override val runsAfter = self.runsAfter

override protected def newTransformer(using DottyContext): Transformer = new GhostRewriteTransformer

private class GhostRewriteTransformer(using DottyContext) extends Transformer {
private val StainlessLangPackage = Symbols.requiredPackage("stainless.lang")
private val ghostAnnotation = Symbols.requiredClass("stainless.annotation.ghost")
private val ghostFun = StainlessLangPackage.info.decl(Names.termName("ghost")).alternatives.toSet

/**
* Is this symbol @ghost, or enclosed inside a ghost definition?
Expand Down Expand Up @@ -80,9 +82,23 @@ class GhostAccessRewriter(afterPhase: String) extends PluginPhase { self =>
case vd@ValDef(name, tpt, _) if effectivelyGhost(tree.symbol) =>
cpy.ValDef(tree)(name, tpt, mkZero(vd.rhs.tpe))

case Apply(fun, args) if effectivelyGhost(fun.symbol) || ghostFun(fun.symbol) =>
case Apply(fun, args) if effectivelyGhost(fun.symbol) || fun.symbol == ghostFun =>
mkZero(tree.tpe)

case ExRequiredExpression(_, true) => tpd.Literal(Constant(()))
case ExDecreasesExpression(_) => tpd.Literal(Constant(()))
case ExAssertExpression(_, _, true) => tpd.Literal(Constant(()))
case ExEnsuredExpression(body, _, true) =>
transform(body) match {
case Apply(ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring"), Seq(unwrapped)) => unwrapped
case body => body
}

case ExWhile.WithInvariant(_, body) => transform(body)
case ExWhile.WithWeakInvariant(_, body) => transform(body)
case ExWhile.WithInline(body) => transform(body)
case ExWhile.WithOpaque(body) => transform(body)

case f@Apply(fun, args) =>
val fun1 = super.transform(fun)

Expand All @@ -103,6 +119,19 @@ class GhostAccessRewriter(afterPhase: String) extends PluginPhase { self =>
case Assign(lhs, rhs) if effectivelyGhost(lhs.symbol) =>
cpy.Assign(tree)(lhs, mkZero(rhs.tpe))

case Block(stats, last) =>
val recStats = transform(stats).filter {
case tpd.Literal(_) => false
case _ => true
}
val recLast = transform(last)
// Transform `val v = e; v` into `e` to allow for tail recursion elimination
(recStats.lastOption, recLast) match {
case (Some(vd @ ValDef(_, _, _)), iden @ (Ident(_) | Typed(Ident(_), _))) if iden.symbol == vd.symbol =>
cpy.Block(tree)(recStats.init, vd.rhs)
case _ => cpy.Block(tree)(recStats, recLast)
}

case _ => super.transform(tree)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ import scala.tools.nsc.transform.Transform
import stainless.frontend.{CallBack, UnsupportedCodeException}

/** Extract each compilation unit and forward them to the Compiler callback */
trait GhostAccessRewriter extends Transform {
trait GhostAccessRewriter extends Transform with ASTExtractors {
import global._
import ExtractorHelpers._
import StructuralExtractors._
import ExpressionExtractors._

val pluginOptions: PluginOptions
val phaseName = "ghost-removal"
Expand All @@ -31,6 +34,7 @@ trait GhostAccessRewriter extends Transform {
}

private class GhostRewriteTransformer extends Transformer {

private val StainlessLangPackage = rootMirror.getPackageIfDefined("stainless.lang")
private val ghostFun = StainlessLangPackage.info.decl(newTermName("ghost")).alternatives.toSet

Expand Down Expand Up @@ -61,6 +65,16 @@ trait GhostAccessRewriter extends Transform {
case ValDef(mods, name, tpt, rhs) if effectivelyGhost(tree.symbol) =>
treeCopy.ValDef(tree, mods, name, tpt, gen.mkZero(rhs.tpe))

case ExRequiredExpression(_, true) => gen.mkZero(tree.tpe)
case ExDecreasesExpression(_) => gen.mkZero(tree.tpe)
case ExAssertExpression(_, _, true) => gen.mkZero(tree.tpe)
case ExEnsuredExpression(body, _, true) => transform(body)

case ExWhile.WithInvariant(_, body) => transform(body)
case ExWhile.WithWeakInvariant(_, body) => transform(body)
case ExWhile.WithInline(body) => transform(body)
case ExWhile.WithOpaque(body) => transform(body)

// labels are generated by pattern matching but they are not real applications and should not
// be touched. They are simple jumps and tampering with them may lead to runtime verification errors
case Apply(fun, args) if tree.symbol.isLabel =>
Expand Down Expand Up @@ -101,6 +115,19 @@ trait GhostAccessRewriter extends Transform {
else
super.transform(tree)

case Block(stats, last) =>
val recStats = transformTrees(stats).filter {
case Literal(_) => false
case _ => true
}
val recLast = transform(last)
// Transform `val v = e; v` into `e` to allow for tail recursion elimination
(recStats.lastOption, recLast) match {
case (Some(vd@ValDef(_, _, _, _)), iden@(Ident(_) | Typed(Ident(_), _))) if iden.symbol == vd.symbol =>
treeCopy.Block(tree, recStats.init, vd.rhs)
case _ => treeCopy.Block(tree, recStats, recLast)
}

case _ => super.transform(tree)
}
}
Expand Down
7 changes: 7 additions & 0 deletions sbt-plugin/src/sbt-test/sbt-plugin/ghost/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ lazy val basic = (project in file("basic"))
Compile / run / mainClass := Some("test.Main")
)

lazy val tailrec = (project in file("tailrec"))
.enablePlugins(StainlessPlugin)
.settings(commonSettings)
.settings(
Compile / run / mainClass := Some("test.Main")
)

lazy val `actor-tests` = (project in file("actor-tests"))
.enablePlugins(StainlessPlugin)
.settings(commonSettings)
Expand Down
44 changes: 44 additions & 0 deletions sbt-plugin/src/sbt-test/sbt-plugin/ghost/tailrec/TailRec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package test

import stainless._
import stainless.annotation.{ghost => ghostAnnot, _}
import stainless.lang._
import StaticChecks._

object Main {
import stainless.lang.WhileDecorations

def loop1(count: Long, l1: Long, l2: Long, l3: Long): Long = {
require(count >= 0)
decreases(count)
if (count == 0) l1
else loop1(count - 1, l1, l2, l3)
}.ensuring(_ == l1)

def loop2(count: Long, l1: Long, l2: Long, l3: Long): Long = {
require(count >= 0)
decreases(count)
if (count == 0) l1
else {
val myRes = loop2(count - 1, l1, l2, l3)
ghost {
assert(myRes == l1)
}
myRes
}
}.ensuring(_ == l1)

def whileLoop(upto: Long): Unit = {
var i: Long = 0
(while(i < upto) {
decreases(upto - i)
i += 1
}).invariant(i >= 0)
}

def main(args: Array[String]): Unit = {
loop1(100000, 1, 2, 3)
loop2(100000, 1, 2, 3)
whileLoop(10000)
}
}
3 changes: 2 additions & 1 deletion sbt-plugin/src/sbt-test/sbt-plugin/ghost/test
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
> + tailrec/run
> + basic/run
$ exists basic/target/scala-2.13/classes/test/Main.class
$ exists basic/target/scala-3.3.0/classes/test/Main.class
$ absent basic/target/sneakyGhostCalled basic/target/insideGhostCalled
> + actor-tests/compile
> + actor-tests/compile

0 comments on commit 92ee799

Please sign in to comment.