diff --git a/core/src/main/scala/stainless/Component.scala b/core/src/main/scala/stainless/Component.scala index 3c3b445fbb..74c5db27f0 100644 --- a/core/src/main/scala/stainless/Component.scala +++ b/core/src/main/scala/stainless/Component.scala @@ -31,6 +31,20 @@ object optFunctions extends inox.OptionDef[Seq[String]] { val usageRhs = "f1,f2,..." } +object optCompareFuns extends inox.OptionDef[Seq[String]] { + val name = "comparefuns" + val default = Seq[String]() + val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser) + val usageRhs = "f1,f2,..." +} + +object optModels extends inox.OptionDef[Seq[String]] { + val name = "models" + val default = Seq[String]() + val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser) + val usageRhs = "f1,f2,..." +} + trait ComponentRun { self => val component: Component val trees: ast.Trees diff --git a/core/src/main/scala/stainless/MainHelpers.scala b/core/src/main/scala/stainless/MainHelpers.scala index b4b559c3e3..88bb4c1688 100644 --- a/core/src/main/scala/stainless/MainHelpers.scala +++ b/core/src/main/scala/stainless/MainHelpers.scala @@ -25,6 +25,8 @@ trait MainHelpers extends inox.MainHelpers { self => optVersion -> Description(General, "Display the version number"), optConfigFile -> Description(General, "Path to configuration file, set to false to disable (default: stainless.conf or .stainless.conf)"), optFunctions -> Description(General, "Only consider functions f1,f2,..."), + optCompareFuns -> Description(General, "Only consider functions f1,f2,... for equivalence checking"), + optModels -> Description(General, "Consider functions f1, f2, ... as model functions for equivalence checking"), extraction.utils.optDebugObjects -> Description(General, "Only print debug output for functions/adts named o1,o2,..."), extraction.utils.optDebugPhases -> Description(General, { "Only print debug output for phases p1,p2,...\nAvailable: " + @@ -166,6 +168,11 @@ trait MainHelpers extends inox.MainHelpers { self => import ctx.{ reporter, timers } + if (extraction.trace.Trace.optionsError) { + reporter.error(s"Equivalence checking for --comparefuns and --models only works in batched mode.") + System.exit(1) + } + if (!useParallelism) { reporter.warning(s"Parallelism is disabled.") } diff --git a/core/src/main/scala/stainless/Report.scala b/core/src/main/scala/stainless/Report.scala index 1f3747619d..99cdea9549 100644 --- a/core/src/main/scala/stainless/Report.scala +++ b/core/src/main/scala/stainless/Report.scala @@ -100,6 +100,18 @@ trait AbstractReport[SelfType <: AbstractReport[SelfType]] { self: SelfType => case Level.Error => Console.RED } + def hasError(identifier: Identifier)(implicit ctx: inox.Context): Boolean = { + annotatedRows.exists(elem => elem match { + case RecordRow(id, pos, level, extra, time) => level == Level.Error && id == identifier + }) + } + + def hasUnknown(identifier: Identifier)(implicit ctx: inox.Context): Boolean = { + annotatedRows.exists(elem => elem match { + case RecordRow(id, pos, level, extra, time) => level == Level.Warning && id == identifier + }) + } + // Emit the report table, with all VCs when full is true, otherwise only with unknown/invalid VCs. private def emitTable(full: Boolean)(implicit ctx: inox.Context): Table = { val rows = processRows(full) diff --git a/core/src/main/scala/stainless/extraction/trace/Trace.scala b/core/src/main/scala/stainless/extraction/trace/Trace.scala index 45f12f58fb..dadf988a17 100644 --- a/core/src/main/scala/stainless/extraction/trace/Trace.scala +++ b/core/src/main/scala/stainless/extraction/trace/Trace.scala @@ -4,6 +4,8 @@ package stainless package extraction package trace +import stainless.utils.CheckFilter + trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self => val s: Trees val t: termination.Trees @@ -24,6 +26,60 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self override val t: self.t.type = self.t } + override protected def extractSymbols(context: TransformerContext, symbols: s.Symbols): t.Symbols = { + import symbols._ + import trees._ + + if (Trace.getModels.isEmpty) { + val models = symbols.functions.values.toList.filter(elem => isModel(elem.id)).map(elem => elem.id) + Trace.setModels(models) + Trace.nextModel + } + + if (Trace.getFunctions.isEmpty) { + val functions = symbols.functions.values.toList.filter(elem => shouldBeChecked(elem.id)).map(elem => elem.id) + Trace.setFunctions(functions) + Trace.nextFunction + } + + def checkPair(fd1: s.FunDef, fd2: s.FunDef): s.FunDef = { + val name = CheckFilter.fixedFullName(fd1.id)+"$"+CheckFilter.fixedFullName(fd2.id) + + val newParams = fd1.params.map{param => param.freshen} + val newParamVars = newParams.map{param => param.toVariable} + val newParamTypes = fd1.tparams.map{tparam => tparam.freshen} + val newParamTps = newParamTypes.map{tparam => tparam.tp} + + val vd = s.ValDef.fresh("holds", s.BooleanType()) + val post = s.Lambda(Seq(vd), vd.toVariable) + + val body = s.Ensuring(s.Equals(s.FunctionInvocation(fd1.id, newParamTps, newParamVars), s.FunctionInvocation(fd2.id, newParamTps, newParamVars)), post) + val flags: Seq[s.Flag] = Seq(s.Derived(fd1.id), s.Annotation("traceInduct",List(StringLiteral(fd1.id.name)))) + + new s.FunDef(FreshIdentifier(name), newParamTypes, newParams, s.BooleanType(), body, flags) + } + + def newFuns: List[s.FunDef] = (Trace.getModel, Trace.getFunction) match { + case (Some(model), Some(function)) => { + val m = symbols.functions(model) + val f = symbols.functions(function) + if (m != f && m.params.size == f.params.size) { + val newFun = checkPair(m, f) + Trace.setTrace(newFun.id) + List(newFun) + } + else { + Trace.reportWrong + Nil + } + } + case _ => Nil + } + + val extracted = super.extractSymbols(context, symbols) + registerFunctions(extracted, newFuns.map(f => extractFunction(symbols, f))) + } + override protected def extractFunction(symbols: Symbols, fd: FunDef): t.FunDef = { import symbols._ var funInv: Option[FunctionInvocation] = None @@ -33,13 +89,12 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self case Annotation("traceInduct", fun) => { exprOps.preTraversal { case _ if funInv.isDefined => // do nothing - case fi @ FunctionInvocation(tfd, _, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral(""))) - => { + case fi @ FunctionInvocation(tfd, _, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral(""))) => { val paramVars = fd.params.map(_.toVariable) val argCheck = args.forall(paramVars.contains) && args.toSet.size == args.size if (argCheck) funInv = Some(fi) - } + } case _ => }(fd.fullBody) } @@ -105,8 +160,8 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self val argsMap = callee.params.map(_.toVariable).zip(finv.args).toMap val tparamMap = callee.typeArgs.zip(finv.tfd.tps).toMap val inlinedBody = typeOps.instantiateType(exprOps.replaceFromSymbols(argsMap, callee.body.get), tparamMap) - val inductScheme = inductPattern(inlinedBody) + val inductScheme = inductPattern(inlinedBody) val prevBody = function.fullBody match { case Ensuring(body, pred) => body case _ => function.fullBody @@ -115,19 +170,86 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self // body, pre and post for the tactFun val body = andJoin(Seq(inductScheme, prevBody)) - val precondition = function.precondition - val postcondition = function.postcondition - + val precondition = exprOps.preconditionOf(function.fullBody) //function.precondition + val postcondition = exprOps.postconditionOf(function.fullBody) //function.postcondition val bodyPre = exprOps.withPrecondition(body, precondition) val bodyPost = exprOps.withPostcondition(bodyPre,postcondition) + function.copy(function.id, function.tparams, function.params, function.returnType, bodyPost, function.flags) + } - function.copy(function.id, function.tparams, function.params, function.returnType, bodyPost, function.flags) + type Path = Seq[String] + + private lazy val pathsOpt: Option[Seq[Path]] = context.options.findOption(optCompareFuns) map { functions => + functions map CheckFilter.fullNameToPath } -} + private def shouldBeChecked(fid: Identifier): Boolean = pathsOpt match { + case None => false + + case Some(paths) => + // Support wildcard `_` as specified in the documentation. + // A leading wildcard is always assumes. + val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid)) + paths exists { p => + if (p endsWith Seq("_")) path containsSlice p.init + else path endsWith p + } + } + + private lazy val pathsOptModels: Option[Seq[Path]] = context.options.findOption(optModels) map { functions => + functions map CheckFilter.fullNameToPath + } + private def isModel(fid: Identifier): Boolean = pathsOptModels match { + case None => false + + case Some(paths) => + // Support wildcard `_` as specified in the documentation. + // A leading wildcard is always assumes. + val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid)) + paths exists { p => + if (p endsWith Seq("_")) path containsSlice p.init + else path endsWith p + } + } + +} object Trace { + var clusters: Map[Identifier, List[Identifier]] = Map() + var errors: List[Identifier] = List() + var unknowns: List[Identifier] = List() + var wrong: List[Identifier] = List() + + def optionsError(implicit ctx: inox.Context): Boolean = + !ctx.options.findOptionOrDefault(frontend.optBatchedProgram) && + (!ctx.options.findOptionOrDefault(optModels).isEmpty || !ctx.options.findOptionOrDefault(optCompareFuns).isEmpty) + + def printEverything(implicit ctx: inox.Context) = { + import ctx.{ reporter, timers } + if(!clusters.isEmpty || !errors.isEmpty || !unknowns.isEmpty) { + reporter.info(s"Printing equivalence checking results:") + allModels.foreach(model => { + val l = clusters(model).mkString(", ") + reporter.info(s"List of functions that are equivalent to model $model: $l") + }) + val errorneous = errors.mkString(", ") + reporter.info(s"List of erroneous functions: $errorneous") + val timeouts = unknowns.mkString(", ") + reporter.info(s"List of timed-out functions: $timeouts") + } + } + + var allModels: List[Identifier] = List() + var tmpModels: List[Identifier] = List() + + var allFunctions: List[Identifier] = List() + var tmpFunctions: List[Identifier] = List() + + var model: Option[Identifier] = None + var function: Option[Identifier] = None + var trace: Option[Identifier] = None + def apply(ts: Trees, tt: termination.Trees)(implicit ctx: inox.Context): ExtractionPipeline { val s: ts.type val t: tt.type @@ -136,4 +258,107 @@ object Trace { override val t: tt.type = tt override val context = ctx } + + def setModels(m: List[Identifier]) = { + allModels = m + tmpModels = m + clusters = (m zip m.map(_ => Nil)).toMap + } + + def setFunctions(f: List[Identifier]) = { + allFunctions = f + tmpFunctions = f + } + + def getModels = allModels + + def getFunctions = allFunctions + + //model for the current iteration + def getModel = model + + //function to check in the current iteration + def getFunction = function + + def setTrace(t: Identifier) = trace = Some(t) + def getTrace = trace + + //iterate model for the current function + def nextModel = (tmpModels, allModels) match { + case (x::xs, _) => { // check the next model for the current function + tmpModels = xs + model = Some(x) + } + case (Nil, x::xs) => { + tmpModels = allModels + model = Some(x) + tmpModels = xs + function = tmpFunctions match { + case x::xs => { + tmpFunctions = xs + Some(x) + } + case Nil => None + } + } + case _ => model = None + } + + //iterate function to check; reset model + def nextFunction = tmpFunctions match { + case x::xs => { + tmpFunctions = xs + function = Some(x) + tmpModels = allModels + tmpModels match { + case Nil => model = None + case x::xs => { + model = Some(x) + tmpModels = xs + } + } + function + } + case Nil => { + function = None + } + } + + def nextIteration[T <: AbstractReport[T]](report: AbstractReport[T])(implicit context: inox.Context): Boolean = trace match { + case Some(t) => { + if (report.hasError(t)) reportError + else if (report.hasUnknown(t)) reportUnknown + else reportValid + !isDone + } + case None => { + nextFunction + !isDone + } + } + + private def isDone = function == None + + private def reportError = { + errors = function.get::errors + nextFunction + } + + private def reportUnknown = { + nextModel + if (model == None) { + unknowns = function.get::unknowns + nextFunction + } + } + + private def reportValid = { + clusters = clusters + (model.get -> (function.get::clusters(model.get))) + nextFunction + } + + private def reportWrong = { + trace = None + wrong = function.get::wrong + } } \ No newline at end of file diff --git a/core/src/main/scala/stainless/frontend/BatchedCallBack.scala b/core/src/main/scala/stainless/frontend/BatchedCallBack.scala index 9b4110f14b..d9ba2068dc 100644 --- a/core/src/main/scala/stainless/frontend/BatchedCallBack.scala +++ b/core/src/main/scala/stainless/frontend/BatchedCallBack.scala @@ -5,6 +5,7 @@ package frontend import stainless.extraction.xlang.{trees => xt, TreeSanitizer} import stainless.utils.LibraryFilter +import stainless.extraction.trace.Trace import scala.util.{Try, Success, Failure} import scala.concurrent.Await @@ -102,15 +103,21 @@ class BatchedCallBack(components: Seq[Component])(implicit val context: inox.Con reportError(defn.getPos, e.getMessage, symbols) } - val reports = runs map { run => - val ids = symbols.functions.keys.toSeq - val analysis = Await.result(run(ids, symbols, filterSymbols = true), Duration.Inf) - RunReport(run)(analysis.toReport) + var rerunPipeline = true + while (rerunPipeline) { + val reports = runs map { run => + val ids = symbols.functions.keys.toSeq + val analysis = Await.result(run(ids, symbols, filterSymbols = true), Duration.Inf) + RunReport(run)(analysis.toReport) + } + report = Report(reports) + rerunPipeline = Trace.nextIteration(report) + if (rerunPipeline) report.emit(context) + else Trace.printEverything } - - report = Report(reports) + } - + def stop(): Unit = { currentClasses = Seq() currentFunctions = Seq() diff --git a/core/src/main/scala/stainless/utils/CheckFilter.scala b/core/src/main/scala/stainless/utils/CheckFilter.scala index 79df04a800..5692d95619 100644 --- a/core/src/main/scala/stainless/utils/CheckFilter.scala +++ b/core/src/main/scala/stainless/utils/CheckFilter.scala @@ -10,25 +10,9 @@ trait CheckFilter { import trees._ type Path = Seq[String] - private def fullNameToPath(fullName: String): Path = (fullName split '.').toSeq - - // TODO this is probably done somewhere else in a cleaner fasion... - private def fixedFullName(id: Identifier): String = id.fullName - .replaceAllLiterally("$bar", "|") - .replaceAllLiterally("$up", "^") - .replaceAllLiterally("$eq", "=") - .replaceAllLiterally("$plus", "+") - .replaceAllLiterally("$minus", "-") - .replaceAllLiterally("$times", "*") - .replaceAllLiterally("$div", "/") - .replaceAllLiterally("$less", "<") - .replaceAllLiterally("$geater", ">") - .replaceAllLiterally("$colon", ":") - .replaceAllLiterally("$amp", "&") - .replaceAllLiterally("$tilde", "~") private lazy val pathsOpt: Option[Seq[Path]] = context.options.findOption(optFunctions) map { functions => - functions map fullNameToPath + functions map CheckFilter.fullNameToPath } private def shouldBeChecked(fid: Identifier, flags: Seq[trees.Flag]): Boolean = pathsOpt match { @@ -40,7 +24,7 @@ trait CheckFilter { case Some(paths) => // Support wildcard `_` as specified in the documentation. // A leading wildcard is always assumes. - val path: Path = fullNameToPath(fixedFullName(fid)) + val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid)) paths exists { p => if (p endsWith Seq("_")) path containsSlice p.init else path endsWith p @@ -86,5 +70,24 @@ object CheckFilter { override val context = ctx override val trees: t.type = t } + + type Path = Seq[String] + + def fullNameToPath(fullName: String): Path = (fullName split '.').toSeq + + // TODO this is probably done somewhere else in a cleaner fasion... + def fixedFullName(id: Identifier): String = id.fullName + .replaceAllLiterally("$bar", "|") + .replaceAllLiterally("$up", "^") + .replaceAllLiterally("$eq", "=") + .replaceAllLiterally("$plus", "+") + .replaceAllLiterally("$minus", "-") + .replaceAllLiterally("$times", "*") + .replaceAllLiterally("$div", "/") + .replaceAllLiterally("$less", "<") + .replaceAllLiterally("$geater", ">") + .replaceAllLiterally("$colon", ":") + .replaceAllLiterally("$amp", "&") + .replaceAllLiterally("$tilde", "~") } diff --git a/core/src/sphinx/index.rst b/core/src/sphinx/index.rst index e322207388..8d0aa2f116 100644 --- a/core/src/sphinx/index.rst +++ b/core/src/sphinx/index.rst @@ -18,6 +18,7 @@ Contents: options verification laws + trace imperative ghost wrap diff --git a/core/src/sphinx/trace.rst b/core/src/sphinx/trace.rst new file mode 100644 index 0000000000..06e010749e --- /dev/null +++ b/core/src/sphinx/trace.rst @@ -0,0 +1,84 @@ +.. _trace: + +Induction and equivalence checking +================================== + +Induction and @traceInduct annotation +------------------------------------- + +We introduce the @traceInduct annotation for automating proofs using induction. Stainless will transform the annotated lemma by generating the inductive proof, based on the structure of one of the functions that appear in the lemma. This approach is useful for functions that have simple inductive form and are easily mapped into inductive proofs. + +Here is one simple example of an equivalence lemma: + +.. code-block:: scala + + def zero1(arg: BigInt): BigInt = { + if (arg > 0) zero1(arg - 1) + else BigInt(0) + } + + def zero2(arg: BigInt): BigInt = { + BigInt(0) + } + + @traceInduct("") + def zero_check(arg: BigInt): Boolean = { + zero1(arg) == zero2(arg) + }.holds + +Without the annotation, Stainless times out when trying to prove equivalence. To help with the proof, the user would have to write more details: + +.. code-block:: scala + + def zero_check(arg: BigInt): Boolean = { + zero1(arg) == zero2(arg) + }.holds because { + if (arg > 0) zero_check(arg - 1) + else true + } + +With @traceInduct annotation, Stainless automatically comes up with a similar proof. + + +It is possible to specify which function should serve as reference implementation for the inductive proof. This can be done by specifying the reference function name as @traceInduct parameter: + +.. code-block:: scala + + def content(l: List[BigInt]): Set[BigInt] = l match { + case Nil() => Set.empty[BigInt] + case x::xs => Set(x) ++ content(xs) + } + + def reverse(l1: List[BigInt], l2: List[BigInt]): List[BigInt] = l1 match { + case Nil() => l2 + case x::xs => reverse(xs, x::l2) + } + + @traceInduct("reverse") + def revPreservesContent(l1: List[BigInt], l2: List[BigInt]): Boolean = { + content(l1) ++ content(l2) == content(reverse(l1, l2)) + }.holds + +Stainless constructs the proof based on the definition of the function reverse, by writing its name as annotation parameter. By induction on l1 and following the structure of the reverse function, Stainless manages to prove this lemma. + +Equivalence checking +-------------------- + +The first example of the previous section shows how @traceInduct annotation can be used to automate equivalence checking for pairs of functions. This way it is possible to verify given implementation by proving equivalence to some reference implementation. + +In batched mode, Stainless also supports checking equivalence of larger sets of functions. To avoid writing @traceInduct annotated lemmas for each pair, it is possible to specify the list of functions that we want to check for equivalence. Command line option --comparefuns is used for specifying the list of functions. Command line option --models is used for specifying the list of reference model functions. Those model functions are considered correct and serve as reference implementation for the inductive proof. + +Stainless can automatically generate all the equivalence lemmas and report resulting equivalence classes. This is done by checking for eqivalence of each function from the --comparefuns list and each model function from the --models list, until the proof of equvalence or a counter example is found for one of the models. + +For example, when running Stainless with the following options (assuming that the file zero.scala contains functions f1, f2, f3, m1 and m2): + +.. code-block:: bash + + $ stainless file.scala --batched=true --comparefuns=f1,f2,f3 --models=m1,m2 --timeout=5 + +Stainless will try to prove equivalence for the following pairs of functions, assuming that f1 and f3 are equivalent to m1, and f2 is equivalent to m2 (but not m1): + +- f1 == m1 (verifies, no need to check for f1 == m2) +- f2 == m1 +- f2 == m2 +- f3 == m1 (verifies, no need to check for f3 == m2)