From 92ee79963258a2187b253e0a3f5bec77bf66a878 Mon Sep 17 00:00:00 2001 From: mario-bucev Date: Wed, 1 Nov 2023 13:50:56 +0100 Subject: [PATCH] Remove ensuring clause in ghost elimination (#1454) --- .../frontends/dotc/FragmentChecker.scala | 2 +- .../frontends/dotc/GhostAccessRewriter.scala | 39 +++++++++++++--- .../scalac/GhostAccessRewriter.scala | 29 +++++++++++- .../src/sbt-test/sbt-plugin/ghost/build.sbt | 7 +++ .../sbt-plugin/ghost/tailrec/TailRec.scala | 44 +++++++++++++++++++ sbt-plugin/src/sbt-test/sbt-plugin/ghost/test | 3 +- 6 files changed, 116 insertions(+), 8 deletions(-) create mode 100644 sbt-plugin/src/sbt-test/sbt-plugin/ghost/tailrec/TailRec.scala diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/FragmentChecker.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/FragmentChecker.scala index abf0299322..149f11a764 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/FragmentChecker.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/FragmentChecker.scala @@ -272,7 +272,7 @@ class FragmentChecker(inoxCtx: inox.Context)(using override val dottyCtx: DottyC class Checker extends tpd.TreeTraverser { private val ScalaEnsuringMethod = requiredMethod("scala.Predef.Ensuring") - private val StainlessLangPackage = getPackageIfDefinedOrNone("stainless.lang") + private val StainlessLangPackage = getClassIfDefinedOrNone("stainless.lang.package$") private val ExternAnnotation = getClassIfDefinedOrNone("stainless.annotation.extern") private val IgnoreAnnotation = getClassIfDefinedOrNone("stainless.annotation.ignore") private val StainlessOld = StainlessLangPackage.map(_.info.decl(Names.termName("old")).symbol) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/GhostAccessRewriter.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/GhostAccessRewriter.scala index cd3430a80b..16ec6ca877 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/GhostAccessRewriter.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/GhostAccessRewriter.scala @@ -24,17 +24,19 @@ class GhostAccessRewriter(afterPhase: String) extends PluginPhase { self => // However, the MacroTransform class is better suited for our needs. // Because PluginPhase extends MiniPhase (which is a class) and that MacroTransform is a class, we can't extend // both. So we use composition instead of inheritance to achieve our goal. + private class GhostAccessMacroTransform(using override val dottyCtx: DottyContext) extends MacroTransform with ASTExtractors { + import StructuralExtractors._ + import AuxiliaryExtractors._ + + private val ghostAnnotation = Symbols.requiredClass("stainless.annotation.ghost") + private val ghostFun = Symbols.requiredMethod("stainless.lang.ghost") - private class GhostAccessMacroTransform extends MacroTransform { override val phaseName = self.phaseName override val runsAfter = self.runsAfter override protected def newTransformer(using DottyContext): Transformer = new GhostRewriteTransformer private class GhostRewriteTransformer(using DottyContext) extends Transformer { - private val StainlessLangPackage = Symbols.requiredPackage("stainless.lang") - private val ghostAnnotation = Symbols.requiredClass("stainless.annotation.ghost") - private val ghostFun = StainlessLangPackage.info.decl(Names.termName("ghost")).alternatives.toSet /** * Is this symbol @ghost, or enclosed inside a ghost definition? @@ -80,9 +82,23 @@ class GhostAccessRewriter(afterPhase: String) extends PluginPhase { self => case vd@ValDef(name, tpt, _) if effectivelyGhost(tree.symbol) => cpy.ValDef(tree)(name, tpt, mkZero(vd.rhs.tpe)) - case Apply(fun, args) if effectivelyGhost(fun.symbol) || ghostFun(fun.symbol) => + case Apply(fun, args) if effectivelyGhost(fun.symbol) || fun.symbol == ghostFun => mkZero(tree.tpe) + case ExRequiredExpression(_, true) => tpd.Literal(Constant(())) + case ExDecreasesExpression(_) => tpd.Literal(Constant(())) + case ExAssertExpression(_, _, true) => tpd.Literal(Constant(())) + case ExEnsuredExpression(body, _, true) => + transform(body) match { + case Apply(ExSymbol("stainless", "lang", "StaticChecks$", "Ensuring"), Seq(unwrapped)) => unwrapped + case body => body + } + + case ExWhile.WithInvariant(_, body) => transform(body) + case ExWhile.WithWeakInvariant(_, body) => transform(body) + case ExWhile.WithInline(body) => transform(body) + case ExWhile.WithOpaque(body) => transform(body) + case f@Apply(fun, args) => val fun1 = super.transform(fun) @@ -103,6 +119,19 @@ class GhostAccessRewriter(afterPhase: String) extends PluginPhase { self => case Assign(lhs, rhs) if effectivelyGhost(lhs.symbol) => cpy.Assign(tree)(lhs, mkZero(rhs.tpe)) + case Block(stats, last) => + val recStats = transform(stats).filter { + case tpd.Literal(_) => false + case _ => true + } + val recLast = transform(last) + // Transform `val v = e; v` into `e` to allow for tail recursion elimination + (recStats.lastOption, recLast) match { + case (Some(vd @ ValDef(_, _, _)), iden @ (Ident(_) | Typed(Ident(_), _))) if iden.symbol == vd.symbol => + cpy.Block(tree)(recStats.init, vd.rhs) + case _ => cpy.Block(tree)(recStats, recLast) + } + case _ => super.transform(tree) } } diff --git a/frontends/scalac/src/main/scala/stainless/frontends/scalac/GhostAccessRewriter.scala b/frontends/scalac/src/main/scala/stainless/frontends/scalac/GhostAccessRewriter.scala index 2991e7971b..b4e640f449 100644 --- a/frontends/scalac/src/main/scala/stainless/frontends/scalac/GhostAccessRewriter.scala +++ b/frontends/scalac/src/main/scala/stainless/frontends/scalac/GhostAccessRewriter.scala @@ -10,8 +10,11 @@ import scala.tools.nsc.transform.Transform import stainless.frontend.{CallBack, UnsupportedCodeException} /** Extract each compilation unit and forward them to the Compiler callback */ -trait GhostAccessRewriter extends Transform { +trait GhostAccessRewriter extends Transform with ASTExtractors { import global._ + import ExtractorHelpers._ + import StructuralExtractors._ + import ExpressionExtractors._ val pluginOptions: PluginOptions val phaseName = "ghost-removal" @@ -31,6 +34,7 @@ trait GhostAccessRewriter extends Transform { } private class GhostRewriteTransformer extends Transformer { + private val StainlessLangPackage = rootMirror.getPackageIfDefined("stainless.lang") private val ghostFun = StainlessLangPackage.info.decl(newTermName("ghost")).alternatives.toSet @@ -61,6 +65,16 @@ trait GhostAccessRewriter extends Transform { case ValDef(mods, name, tpt, rhs) if effectivelyGhost(tree.symbol) => treeCopy.ValDef(tree, mods, name, tpt, gen.mkZero(rhs.tpe)) + case ExRequiredExpression(_, true) => gen.mkZero(tree.tpe) + case ExDecreasesExpression(_) => gen.mkZero(tree.tpe) + case ExAssertExpression(_, _, true) => gen.mkZero(tree.tpe) + case ExEnsuredExpression(body, _, true) => transform(body) + + case ExWhile.WithInvariant(_, body) => transform(body) + case ExWhile.WithWeakInvariant(_, body) => transform(body) + case ExWhile.WithInline(body) => transform(body) + case ExWhile.WithOpaque(body) => transform(body) + // labels are generated by pattern matching but they are not real applications and should not // be touched. They are simple jumps and tampering with them may lead to runtime verification errors case Apply(fun, args) if tree.symbol.isLabel => @@ -101,6 +115,19 @@ trait GhostAccessRewriter extends Transform { else super.transform(tree) + case Block(stats, last) => + val recStats = transformTrees(stats).filter { + case Literal(_) => false + case _ => true + } + val recLast = transform(last) + // Transform `val v = e; v` into `e` to allow for tail recursion elimination + (recStats.lastOption, recLast) match { + case (Some(vd@ValDef(_, _, _, _)), iden@(Ident(_) | Typed(Ident(_), _))) if iden.symbol == vd.symbol => + treeCopy.Block(tree, recStats.init, vd.rhs) + case _ => treeCopy.Block(tree, recStats, recLast) + } + case _ => super.transform(tree) } } diff --git a/sbt-plugin/src/sbt-test/sbt-plugin/ghost/build.sbt b/sbt-plugin/src/sbt-test/sbt-plugin/ghost/build.sbt index 9592c23e84..14125c85d5 100644 --- a/sbt-plugin/src/sbt-test/sbt-plugin/ghost/build.sbt +++ b/sbt-plugin/src/sbt-test/sbt-plugin/ghost/build.sbt @@ -16,6 +16,13 @@ lazy val basic = (project in file("basic")) Compile / run / mainClass := Some("test.Main") ) +lazy val tailrec = (project in file("tailrec")) + .enablePlugins(StainlessPlugin) + .settings(commonSettings) + .settings( + Compile / run / mainClass := Some("test.Main") + ) + lazy val `actor-tests` = (project in file("actor-tests")) .enablePlugins(StainlessPlugin) .settings(commonSettings) diff --git a/sbt-plugin/src/sbt-test/sbt-plugin/ghost/tailrec/TailRec.scala b/sbt-plugin/src/sbt-test/sbt-plugin/ghost/tailrec/TailRec.scala new file mode 100644 index 0000000000..fdd8eac0b4 --- /dev/null +++ b/sbt-plugin/src/sbt-test/sbt-plugin/ghost/tailrec/TailRec.scala @@ -0,0 +1,44 @@ +package test + +import stainless._ +import stainless.annotation.{ghost => ghostAnnot, _} +import stainless.lang._ +import StaticChecks._ + +object Main { + import stainless.lang.WhileDecorations + + def loop1(count: Long, l1: Long, l2: Long, l3: Long): Long = { + require(count >= 0) + decreases(count) + if (count == 0) l1 + else loop1(count - 1, l1, l2, l3) + }.ensuring(_ == l1) + + def loop2(count: Long, l1: Long, l2: Long, l3: Long): Long = { + require(count >= 0) + decreases(count) + if (count == 0) l1 + else { + val myRes = loop2(count - 1, l1, l2, l3) + ghost { + assert(myRes == l1) + } + myRes + } + }.ensuring(_ == l1) + + def whileLoop(upto: Long): Unit = { + var i: Long = 0 + (while(i < upto) { + decreases(upto - i) + i += 1 + }).invariant(i >= 0) + } + + def main(args: Array[String]): Unit = { + loop1(100000, 1, 2, 3) + loop2(100000, 1, 2, 3) + whileLoop(10000) + } +} diff --git a/sbt-plugin/src/sbt-test/sbt-plugin/ghost/test b/sbt-plugin/src/sbt-test/sbt-plugin/ghost/test index 172e6266e5..e20dd6fdce 100644 --- a/sbt-plugin/src/sbt-test/sbt-plugin/ghost/test +++ b/sbt-plugin/src/sbt-test/sbt-plugin/ghost/test @@ -1,5 +1,6 @@ +> + tailrec/run > + basic/run $ exists basic/target/scala-2.13/classes/test/Main.class $ exists basic/target/scala-3.3.0/classes/test/Main.class $ absent basic/target/sneakyGhostCalled basic/target/insideGhostCalled -> + actor-tests/compile +> + actor-tests/compile \ No newline at end of file