Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clustering (--comparefuns for equivalence checking) #931

Open
wants to merge 7 commits into
base: scala-2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/scala/stainless/Component.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ object optFunctions extends inox.OptionDef[Seq[String]] {
val usageRhs = "f1,f2,..."
}

object optCompareFuns extends inox.OptionDef[Seq[String]] {
val name = "comparefuns"
val default = Seq[String]()
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
val usageRhs = "f1,f2,..."
}

object optModels extends inox.OptionDef[Seq[String]] {
val name = "models"
val default = Seq[String]()
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
val usageRhs = "f1,f2,..."
}

trait ComponentRun { self =>
val component: Component
val trees: ast.Trees
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/stainless/MainHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ trait MainHelpers extends inox.MainHelpers { self =>
optVersion -> Description(General, "Display the version number"),
optConfigFile -> Description(General, "Path to configuration file, set to false to disable (default: stainless.conf or .stainless.conf)"),
optFunctions -> Description(General, "Only consider functions f1,f2,..."),
optCompareFuns -> Description(General, "Only consider functions f1,f2,... for equivalence checking"),
optModels -> Description(General, "Consider functions f1, f2, ... as model functions for equivalence checking"),
extraction.utils.optDebugObjects -> Description(General, "Only print debug output for functions/adts named o1,o2,..."),
extraction.utils.optDebugPhases -> Description(General, {
"Only print debug output for phases p1,p2,...\nAvailable: " +
Expand Down Expand Up @@ -166,6 +168,11 @@ trait MainHelpers extends inox.MainHelpers { self =>

import ctx.{ reporter, timers }

if (extraction.trace.Trace.optionsError) {
reporter.error(s"Equivalence checking for --comparefuns and --models only works in batched mode.")
System.exit(1)
}

if (!useParallelism) {
reporter.warning(s"Parallelism is disabled.")
}
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/stainless/Report.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,18 @@ trait AbstractReport[SelfType <: AbstractReport[SelfType]] { self: SelfType =>
case Level.Error => Console.RED
}

def hasError(identifier: Identifier)(implicit ctx: inox.Context): Boolean = {
annotatedRows.exists(elem => elem match {
case RecordRow(id, pos, level, extra, time) => level == Level.Error && id == identifier
})
}

def hasUnknown(identifier: Identifier)(implicit ctx: inox.Context): Boolean = {
annotatedRows.exists(elem => elem match {
case RecordRow(id, pos, level, extra, time) => level == Level.Warning && id == identifier
})
}

// Emit the report table, with all VCs when full is true, otherwise only with unknown/invalid VCs.
private def emitTable(full: Boolean)(implicit ctx: inox.Context): Table = {
val rows = processRows(full)
Expand Down
243 changes: 234 additions & 9 deletions core/src/main/scala/stainless/extraction/trace/Trace.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ package stainless
package extraction
package trace

import stainless.utils.CheckFilter

trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self =>
val s: Trees
val t: termination.Trees
Expand All @@ -24,6 +26,60 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
override val t: self.t.type = self.t
}

override protected def extractSymbols(context: TransformerContext, symbols: s.Symbols): t.Symbols = {
import symbols._
import trees._

if (Trace.getModels.isEmpty) {
val models = symbols.functions.values.toList.filter(elem => isModel(elem.id)).map(elem => elem.id)
Trace.setModels(models)
Trace.nextModel
}

if (Trace.getFunctions.isEmpty) {
val functions = symbols.functions.values.toList.filter(elem => shouldBeChecked(elem.id)).map(elem => elem.id)
Trace.setFunctions(functions)
Trace.nextFunction
}

def checkPair(fd1: s.FunDef, fd2: s.FunDef): s.FunDef = {
val name = CheckFilter.fixedFullName(fd1.id)+"$"+CheckFilter.fixedFullName(fd2.id)

val newParams = fd1.params.map{param => param.freshen}
val newParamVars = newParams.map{param => param.toVariable}
val newParamTypes = fd1.tparams.map{tparam => tparam.freshen}
val newParamTps = newParamTypes.map{tparam => tparam.tp}

val vd = s.ValDef.fresh("holds", s.BooleanType())
val post = s.Lambda(Seq(vd), vd.toVariable)

val body = s.Ensuring(s.Equals(s.FunctionInvocation(fd1.id, newParamTps, newParamVars), s.FunctionInvocation(fd2.id, newParamTps, newParamVars)), post)
val flags: Seq[s.Flag] = Seq(s.Derived(fd1.id), s.Annotation("traceInduct",List(StringLiteral(fd1.id.name))))

new s.FunDef(FreshIdentifier(name), newParamTypes, newParams, s.BooleanType(), body, flags)
}

def newFuns: List[s.FunDef] = (Trace.getModel, Trace.getFunction) match {
case (Some(model), Some(function)) => {
val m = symbols.functions(model)
val f = symbols.functions(function)
if (m != f && m.params.size == f.params.size) {
val newFun = checkPair(m, f)
Trace.setTrace(newFun.id)
List(newFun)
}
else {
Trace.reportWrong
Nil
}
}
case _ => Nil
}

val extracted = super.extractSymbols(context, symbols)
registerFunctions(extracted, newFuns.map(f => extractFunction(symbols, f)))
}

override protected def extractFunction(symbols: Symbols, fd: FunDef): t.FunDef = {
import symbols._
var funInv: Option[FunctionInvocation] = None
Expand All @@ -33,13 +89,12 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
case Annotation("traceInduct", fun) => {
exprOps.preTraversal {
case _ if funInv.isDefined => // do nothing
case fi @ FunctionInvocation(tfd, _, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral("")))
=> {
case fi @ FunctionInvocation(tfd, _, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral(""))) => {
val paramVars = fd.params.map(_.toVariable)
val argCheck = args.forall(paramVars.contains) && args.toSet.size == args.size
if (argCheck)
funInv = Some(fi)
}
}
case _ =>
}(fd.fullBody)
}
Expand Down Expand Up @@ -105,8 +160,8 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
val argsMap = callee.params.map(_.toVariable).zip(finv.args).toMap
val tparamMap = callee.typeArgs.zip(finv.tfd.tps).toMap
val inlinedBody = typeOps.instantiateType(exprOps.replaceFromSymbols(argsMap, callee.body.get), tparamMap)
val inductScheme = inductPattern(inlinedBody)

val inductScheme = inductPattern(inlinedBody)
val prevBody = function.fullBody match {
case Ensuring(body, pred) => body
case _ => function.fullBody
Expand All @@ -115,19 +170,86 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
// body, pre and post for the tactFun

val body = andJoin(Seq(inductScheme, prevBody))
val precondition = function.precondition
val postcondition = function.postcondition

val precondition = exprOps.preconditionOf(function.fullBody) //function.precondition
val postcondition = exprOps.postconditionOf(function.fullBody) //function.postcondition
val bodyPre = exprOps.withPrecondition(body, precondition)
val bodyPost = exprOps.withPostcondition(bodyPre,postcondition)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you try to manipulate specifications (getting pre/postconditions, and changing them), using the new BodyWithSpecs API?

function.copy(function.id, function.tparams, function.params, function.returnType, bodyPost, function.flags)
}

function.copy(function.id, function.tparams, function.params, function.returnType, bodyPost, function.flags)
type Path = Seq[String]

private lazy val pathsOpt: Option[Seq[Path]] = context.options.findOption(optCompareFuns) map { functions =>
functions map CheckFilter.fullNameToPath
}

}
private def shouldBeChecked(fid: Identifier): Boolean = pathsOpt match {
case None => false

case Some(paths) =>
// Support wildcard `_` as specified in the documentation.
// A leading wildcard is always assumes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo: assumed

val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid))
paths exists { p =>
if (p endsWith Seq("_")) path containsSlice p.init
else path endsWith p
}
}

private lazy val pathsOptModels: Option[Seq[Path]] = context.options.findOption(optModels) map { functions =>
functions map CheckFilter.fullNameToPath
}

private def isModel(fid: Identifier): Boolean = pathsOptModels match {
case None => false

case Some(paths) =>
// Support wildcard `_` as specified in the documentation.
// A leading wildcard is always assumes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here for "assumes", perhaps you could factor this code into a common function

val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid))
paths exists { p =>
if (p endsWith Seq("_")) path containsSlice p.init
else path endsWith p
}
}

}

object Trace {
var clusters: Map[Identifier, List[Identifier]] = Map()
var errors: List[Identifier] = List()
var unknowns: List[Identifier] = List()
var wrong: List[Identifier] = List()

def optionsError(implicit ctx: inox.Context): Boolean =
!ctx.options.findOptionOrDefault(frontend.optBatchedProgram) &&
(!ctx.options.findOptionOrDefault(optModels).isEmpty || !ctx.options.findOptionOrDefault(optCompareFuns).isEmpty)

def printEverything(implicit ctx: inox.Context) = {
import ctx.{ reporter, timers }
if(!clusters.isEmpty || !errors.isEmpty || !unknowns.isEmpty) {
reporter.info(s"Printing equivalence checking results:")
allModels.foreach(model => {
val l = clusters(model).mkString(", ")
reporter.info(s"List of functions that are equivalent to model $model: $l")
})
val errorneous = errors.mkString(", ")
reporter.info(s"List of erroneous functions: $errorneous")
val timeouts = unknowns.mkString(", ")
reporter.info(s"List of timed-out functions: $timeouts")
}
}

var allModels: List[Identifier] = List()
var tmpModels: List[Identifier] = List()

var allFunctions: List[Identifier] = List()
var tmpFunctions: List[Identifier] = List()

var model: Option[Identifier] = None
var function: Option[Identifier] = None
var trace: Option[Identifier] = None

def apply(ts: Trees, tt: termination.Trees)(implicit ctx: inox.Context): ExtractionPipeline {
val s: ts.type
val t: tt.type
Expand All @@ -136,4 +258,107 @@ object Trace {
override val t: tt.type = tt
override val context = ctx
}

def setModels(m: List[Identifier]) = {
allModels = m
tmpModels = m
clusters = (m zip m.map(_ => Nil)).toMap
}

def setFunctions(f: List[Identifier]) = {
allFunctions = f
tmpFunctions = f
}

def getModels = allModels

def getFunctions = allFunctions

//model for the current iteration
def getModel = model

//function to check in the current iteration
def getFunction = function

def setTrace(t: Identifier) = trace = Some(t)
def getTrace = trace

//iterate model for the current function
def nextModel = (tmpModels, allModels) match {
case (x::xs, _) => { // check the next model for the current function
tmpModels = xs
model = Some(x)
}
case (Nil, x::xs) => {
tmpModels = allModels
model = Some(x)
tmpModels = xs
function = tmpFunctions match {
case x::xs => {
tmpFunctions = xs
Some(x)
}
case Nil => None
}
}
case _ => model = None
}

//iterate function to check; reset model
def nextFunction = tmpFunctions match {
case x::xs => {
tmpFunctions = xs
function = Some(x)
tmpModels = allModels
tmpModels match {
case Nil => model = None
case x::xs => {
model = Some(x)
tmpModels = xs
}
}
function
}
case Nil => {
function = None
}
}

def nextIteration[T <: AbstractReport[T]](report: AbstractReport[T])(implicit context: inox.Context): Boolean = trace match {
case Some(t) => {
if (report.hasError(t)) reportError
else if (report.hasUnknown(t)) reportUnknown
else reportValid
!isDone
}
case None => {
nextFunction
!isDone
}
}

private def isDone = function == None

private def reportError = {
errors = function.get::errors
nextFunction
}

private def reportUnknown = {
nextModel
if (model == None) {
unknowns = function.get::unknowns
nextFunction
}
}

private def reportValid = {
clusters = clusters + (model.get -> (function.get::clusters(model.get)))
nextFunction
}

private def reportWrong = {
trace = None
wrong = function.get::wrong
}
}
21 changes: 14 additions & 7 deletions core/src/main/scala/stainless/frontend/BatchedCallBack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package frontend

import stainless.extraction.xlang.{trees => xt, TreeSanitizer}
import stainless.utils.LibraryFilter
import stainless.extraction.trace.Trace

import scala.util.{Try, Success, Failure}
import scala.concurrent.Await
Expand Down Expand Up @@ -102,15 +103,21 @@ class BatchedCallBack(components: Seq[Component])(implicit val context: inox.Con
reportError(defn.getPos, e.getMessage, symbols)
}

val reports = runs map { run =>
val ids = symbols.functions.keys.toSeq
val analysis = Await.result(run(ids, symbols, filterSymbols = true), Duration.Inf)
RunReport(run)(analysis.toReport)
var rerunPipeline = true
while (rerunPipeline) {
val reports = runs map { run =>
val ids = symbols.functions.keys.toSeq
val analysis = Await.result(run(ids, symbols, filterSymbols = true), Duration.Inf)
RunReport(run)(analysis.toReport)
}
report = Report(reports)
rerunPipeline = Trace.nextIteration(report)
if (rerunPipeline) report.emit(context)
else Trace.printEverything
}

report = Report(reports)

}

def stop(): Unit = {
currentClasses = Seq()
currentFunctions = Seq()
Expand Down
Loading