Skip to content

Commit

Permalink
Fix #1377 (#1392)
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev authored Mar 24, 2023
1 parent f95184b commit 9e62648
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .larabot.conf
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
commands = [
"sbt -batch -Dtest-parallelism=5 test"
"sbt -batch -Dtest-parallelism=5 \"it:testOnly stainless.GhostRewriteSuite stainless.GenCSuite stainless.ScalacExtractionSuite stainless.LibrarySuite stainless.verification.SMTZ3VerificationSuite stainless.verification.SMTZ3UncheckedSuite stainless.verification.TerminationVerificationSuite stainless.verification.ImperativeSuite stainless.verification.FullImperativeSuite stainless.verification.StrictArithmeticSuite stainless.verification.CodeGenVerificationSuite stainless.verification.SMTCVC4VerificationSuite stainless.verificatoin.SMTCVC4UncheckedSuite stainless.termination.TerminationSuite\""
"sbt -batch -Dtest-parallelism=5 \"it:testOnly stainless.GhostRewriteSuite stainless.GenCSuite stainless.ScalacExtractionSuite stainless.LibrarySuite stainless.verification.SMTZ3VerificationSuite stainless.verification.SMTZ3UncheckedSuite stainless.verification.TerminationVerificationSuite stainless.verification.ImperativeSuite stainless.verification.FullImperativeSuite stainless.verification.StrictArithmeticSuite stainless.verification.CodeGenVerificationSuite stainless.verification.SMTCVC4VerificationSuite stainless.verificatoin.SMTCVC4UncheckedSuite stainless.termination.TerminationSuite stainless.evaluators.EvaluatorComponentTest\""
]

nightly {
Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ val scriptSettings: Seq[Setting[_]] = Seq(
def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${version}"))

// lazy val inox = RootProject(file("../inox"))
lazy val inox = ghProject("https://github.com/epfl-lara/inox.git", "6efba92979cc420fd73dda8bfafa05102f0f4047")
lazy val inox = ghProject("https://github.com/epfl-lara/inox.git", "41ffe806b04769c0d6757ebfeb17a96c7d5efd8a")
lazy val cafebabe = ghProject("https://github.com/epfl-lara/cafebabe.git", "616e639b34379e12b8ac202849de3ebbbd0848bc")

// Allow integration test to use facilities from regular tests
Expand Down
88 changes: 60 additions & 28 deletions core/src/main/scala/stainless/codegen/CodeGeneration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ trait CodeGeneration { self: CompilationUnit =>
import program.symbols.{given, _}
import program.trees.exprOps._

lazy val ignoreContracts = options.findOptionOrDefault(inox.evaluators.optIgnoreContracts)
lazy val globallyIgnoreContracts = options.findOptionOrDefault(inox.evaluators.optIgnoreContracts)
lazy val doInstrument = options.findOptionOrDefault(optInstrumentFields)
lazy val smallArrays = options.findOptionOrDefault(optSmallArrays)
lazy val recordInvocations = maxSteps >= 0
Expand Down Expand Up @@ -74,6 +74,8 @@ trait CodeGeneration { self: CompilationUnit =>

object NoLocals extends Locals(Map.empty, Map.empty, Map.empty, Seq.empty, Seq.empty)

case class ContractsCtx(locallyIgnored: Boolean)

lazy val monitorID = FreshIdentifier("__$monitor")
lazy val tpsID = FreshIdentifier("__$tps")

Expand Down Expand Up @@ -306,7 +308,7 @@ trait CodeGeneration { self: CompilationUnit =>
Seq(monitor) ++ tpsOpt ++ params
}.toMap

val body = if (!ignoreContracts) {
val body = if (!globallyIgnoreContracts) {
funDef.fullBody
} else {
funDef.body.getOrElse(
Expand All @@ -324,11 +326,11 @@ trait CodeGeneration { self: CompilationUnit =>
.withTypeParameters(funDef.tparams.map(_.tp))

if (recordInvocations) {
loadImpl(monitorID, ch)(using locals)
loadImpl(monitorID, ch)(using locals, ContractsCtx(locallyIgnored = false))
ch << InvokeVirtual(MonitorClass, "onInvocation", "()V")
}

mkExpr(body, ch)(using locals)
mkExpr(body, ch)(using locals, ContractsCtx(locallyIgnored = false))

funDef.getType match {
case JvmIType() =>
Expand All @@ -348,7 +350,7 @@ trait CodeGeneration { self: CompilationUnit =>

private val typeParams: ListBuffer[TypeParameter] = new ListBuffer[TypeParameter]

protected def compileLambda(l: Lambda, params: Seq[ValDef]):
protected def compileLambda(l: Lambda, params: Seq[ValDef])(using contractsCtx: ContractsCtx):
(String, Seq[(Identifier, String)], Seq[TypeParameter], String) = synchronized {
assert(normalizeStructure(l)._1 == l)

Expand Down Expand Up @@ -547,7 +549,7 @@ trait CodeGeneration { self: CompilationUnit =>
}

// also makes tuples with 0/1 args
private def mkTuple(es: Seq[Expr], ch: CodeHandler)(using locals: Locals) : Unit = {
private def mkTuple(es: Seq[Expr], ch: CodeHandler)(using locals: Locals, contractsCtx: ContractsCtx) : Unit = {
ch << New(TupleClass) << DUP
ch << Ldc(es.size)
ch << NewArray(s"$ObjectClass")
Expand All @@ -560,7 +562,7 @@ trait CodeGeneration { self: CompilationUnit =>
ch << InvokeSpecial(TupleClass, constructorName, s"([L$ObjectClass;)V")
}

private def loadTypes(tps: Seq[Type], ch: CodeHandler)(using locals: Locals): Unit = {
private def loadTypes(tps: Seq[Type], ch: CodeHandler)(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
if (tps.nonEmpty) {
ch << Ldc(tps.size)
ch << NewArray.primitive("T_INT")
Expand Down Expand Up @@ -605,21 +607,27 @@ trait CodeGeneration { self: CompilationUnit =>
}

private[codegen] def mkExpr(e: Expr, ch: CodeHandler, canDelegateToMkBranch: Boolean = true)
(using locals: Locals): Unit = e match {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = e match {
case v: Variable =>
load(v, ch)

case Assert(cond, oerr, body) =>
mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch)
if (!ignoreContractsOn(e))
mkExpr(IfExpr(Not(cond), Error(body.getType, oerr.getOrElse("Assertion failed @"+e.getPos)), body), ch)
else mkExpr(body, ch)

case Assume(cond, body) =>
mkExpr(IfExpr(Not(cond), Error(body.getType, "Assumption failed @"+e.getPos), body), ch)
if (!ignoreContractsOn(e))
mkExpr(IfExpr(Not(cond), Error(body.getType, "Assumption failed @"+e.getPos), body), ch)
else mkExpr(body, ch)

case en @ Ensuring(_, _) =>
mkExpr(en.toAssert, ch)

case Require(pre, body) =>
mkExpr(IfExpr(pre, body, Error(body.getType, "Precondition failed")), ch)
if (!ignoreContractsOn(e))
mkExpr(IfExpr(pre, body, Error(body.getType, "Precondition failed")), ch)
else mkExpr(body, ch)

case Decreases(measure, body) =>
mkExpr(body, ch)
Expand Down Expand Up @@ -707,7 +715,7 @@ trait CodeGeneration { self: CompilationUnit =>
ch << InvokeSpecial(adtName, constructorName, adtApplySig)

// check invariant (if it exists)
if (!ignoreContracts && cons.getSort.hasInvariant) {
if (!ignoreContractsOn(adt) && cons.getSort.hasInvariant) {
ch << DUP

val tfd = tcons.sort.invariant.get
Expand Down Expand Up @@ -1236,13 +1244,17 @@ trait CodeGeneration { self: CompilationUnit =>
case m: Max =>
mkExpr(maxToIfThenElse(m), ch)

case Annotated(body, _) =>
mkExpr(body, ch)
case Annotated(body, flags) =>
val nctx = {
if (flags.contains(DropVCs) || flags.contains(DropConjunct)) contractsCtx.copy(locallyIgnored = true)
else contractsCtx
}
mkExpr(body, ch)(using locals, nctx)

case _ => throw CompilationException("Unsupported expr " + e + " : " + e.getClass)
}

private[codegen] def mkLambda(lambda: Lambda, ch: CodeHandler)(using locals: Locals): Unit = {
private[codegen] def mkLambda(lambda: Lambda, ch: CodeHandler)(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
val vars = variablesOf(lambda).toSeq
val freshVars = vars.map(_.freshen)

Expand Down Expand Up @@ -1308,7 +1320,7 @@ trait CodeGeneration { self: CompilationUnit =>

// Leaves on the stack a value equal to `e`, always of a type compatible with java.lang.Object.
private[codegen] def mkBoxedExpr(e: Expr, ch: CodeHandler)
(using locals: Locals): Unit = e.getType match {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = e.getType match {
case Int8Type() =>
ch << New(BoxedByteClass) << DUP
mkExpr(e, ch)
Expand Down Expand Up @@ -1449,7 +1461,7 @@ trait CodeGeneration { self: CompilationUnit =>
}

private[codegen] def mkBranch(cond: Expr, thenn: String, elze: String, ch: CodeHandler, canDelegateToMkExpr: Boolean = true)
(using locals: Locals): Unit = cond match {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = cond match {
case BooleanLiteral(true) =>
ch << Goto(thenn)

Expand Down Expand Up @@ -1517,8 +1529,12 @@ trait CodeGeneration { self: CompilationUnit =>
mkExpr(other, ch, canDelegateToMkBranch = false)
ch << IfEq(elze) << Goto(thenn)

case Annotated(condition, _) =>
mkBranch(condition, thenn, elze, ch)
case Annotated(condition, flags) =>
val nctx = {
if (flags.contains(DropVCs) || flags.contains(DropConjunct)) contractsCtx.copy(locallyIgnored = true)
else contractsCtx
}
mkBranch(condition, thenn, elze, ch)(using locals, nctx)

case other => throw CompilationException("Unsupported branching expr. : " + other)
}
Expand Down Expand Up @@ -1579,7 +1595,7 @@ trait CodeGeneration { self: CompilationUnit =>

private def mkBVShift(l: Expr, r: Expr, ch: CodeHandler,
iop: ByteCode, lop: ByteCode, op: String)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
// NOTE for shift operations on Byte/Short/Int/Long:
// the lhs operand can be either Int or Long,
// the rhs operand must be an Int.
Expand All @@ -1600,7 +1616,7 @@ trait CodeGeneration { self: CompilationUnit =>

private def mkCmpJump(cond: Expr, thenn: String, elze: String, l: Expr, r: Expr, ch: CodeHandler,
iop: String => ControlOperator, lop: String => ControlOperator, op: String)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
mkExpr(l, ch)
mkExpr(r, ch)
l.getType match {
Expand All @@ -1625,13 +1641,13 @@ trait CodeGeneration { self: CompilationUnit =>

private def mkArithmeticBinary(l: Expr, r: Expr, ch: CodeHandler,
iop: ByteCode, lop: ByteCode, op: String, bvOnly: Boolean)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
mkArithmeticBinaryImpl(l, r, ch, { ch => ch << iop }, { ch => ch << lop }, op, bvOnly)
}

private def mkArithmeticBinaryImpl(l: Expr, r: Expr, ch: CodeHandler,
iopGen: AbstractByteCodeGenerator, lopGen: AbstractByteCodeGenerator, op: String, bvOnly: Boolean)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
ch << Comment(s"mkArithmeticBinary($op)")
mkExpr(l, ch)
mkExpr(r, ch)
Expand Down Expand Up @@ -1688,13 +1704,13 @@ trait CodeGeneration { self: CompilationUnit =>

private def mkArithmeticUnary(e: Expr, ch: CodeHandler,
iop: ByteCode, lop: ByteCode, op: String, bvOnly: Boolean)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
mkArithmeticUnaryImpl(e, ch, { ch => ch << iop }, { ch => ch << lop }, op, bvOnly)
}

private def mkArithmeticUnaryImpl(e: Expr, ch: CodeHandler,
iopGen: AbstractByteCodeGenerator, lopGen: AbstractByteCodeGenerator, op: String, bvOnly: Boolean)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
ch << Comment(s"mkArithmeticUnary($op)")
mkExpr(e, ch)

Expand Down Expand Up @@ -1742,9 +1758,9 @@ trait CodeGeneration { self: CompilationUnit =>
}


private def load(v: Variable, ch: CodeHandler)(using locals: Locals): Unit = loadImpl(v.id, ch, Some(v.getType))
private def load(v: Variable, ch: CodeHandler)(using locals: Locals, contractsCtx: ContractsCtx): Unit = loadImpl(v.id, ch, Some(v.getType))

private def loadImpl(id: Identifier, ch: CodeHandler, tpe: Option[Type] = None)(using locals: Locals): Unit = {
private def loadImpl(id: Identifier, ch: CodeHandler, tpe: Option[Type] = None)(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
locals.varToArg(id) match {
case Some(slot) =>
ch << ALoad(1) << Ldc(slot) << AALOAD
Expand Down Expand Up @@ -1816,7 +1832,7 @@ trait CodeGeneration { self: CompilationUnit =>
val instrumentedField = "__read"

def instrumentedGetField(ch: CodeHandler, cons: TypedADTConstructor, id: Identifier)
(using locals: Locals): Unit = {
(using locals: Locals, contractsCtx: ContractsCtx): Unit = {
cons.definition.fields.zipWithIndex.find(_._1.id == id) match {
case Some((f, i)) =>
val expType = cons.fields(i).getType
Expand Down Expand Up @@ -1844,6 +1860,7 @@ trait CodeGeneration { self: CompilationUnit =>
}

def compileADTConstructor(cons: ADTConstructor): Unit = {
given ContractsCtx = ContractsCtx(locallyIgnored = false)
val cName = defConsToJVMName(cons)
val pName = defSortToJVMName(cons.getSort)
val tcons = cons.typed
Expand Down Expand Up @@ -2041,6 +2058,21 @@ trait CodeGeneration { self: CompilationUnit =>

}

private def ignoreContractsOn(expr: Expr)(using contractsCtx: ContractsCtx): Boolean = {
def isAnnotatedDrop(e: Expr) = e match {
case Annotated(_, flags) => flags.contains(DropVCs) || flags.contains(DropConjunct)
case _ => false
}

globallyIgnoreContracts || contractsCtx.locallyIgnored || (expr match {
case Assume(pred, _) => isAnnotatedDrop(pred)
case Assert(pred, _, _) => isAnnotatedDrop(pred)
case Require(pred, _) => isAnnotatedDrop(pred)
case Ensuring(_, pred) => isAnnotatedDrop(pred) || isAnnotatedDrop(pred.body)
case _ => isAnnotatedDrop(expr)
})
}

private def internalErrorWithByteOrShort(e: Expr) =
throw CompilationException(s"Unexpected expression involving Byte or Short: $e")
}
4 changes: 2 additions & 2 deletions core/src/main/scala/stainless/codegen/CompilationUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class CompilationUnit(val program: Program, val context: inox.Context)(using val
case lambda: Lambda =>
val (l: Lambda, deps) = normalizeStructure(matchToIfThenElse(lambda, assumeExhaustive = false)): @unchecked
if (deps.forall { case (_, e, conds) => isValue(e) && conds.isEmpty }) {
val (afName, closures, tparams, consSig) = compileLambda(l, Seq.empty)
val (afName, closures, tparams, consSig) = compileLambda(l, Seq.empty)(using ContractsCtx(locallyIgnored = false))
val depsMap = deps.map { case (v, dep, _) => v.id -> valueToJVM(dep) }.toMap

val args = closures.map { case (id, _) =>
Expand Down Expand Up @@ -444,7 +444,7 @@ class CompilationUnit(val program: Program, val context: inox.Context)(using val
case (v, i) => v.id -> (i + 1)
}.toMap

mkExpr(e, ch)(using NoLocals.withVars(newMapping))
mkExpr(e, ch)(using NoLocals.withVars(newMapping), ContractsCtx(locallyIgnored = false))

e.getType match {
case JvmIType() =>
Expand Down
27 changes: 23 additions & 4 deletions core/src/main/scala/stainless/evaluators/RecursiveEvaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ abstract class RecursiveEvaluator(override val program: Program,

override def e(expr: Expr)(using rctx: RC, gctx: GC): Expr = expr match {
case Require(pred, body) =>
if (!ignoreContracts && e(pred) != BooleanLiteral(true))
if (!ignoreContractsOn(expr) && e(pred) != BooleanLiteral(true))
throw RuntimeError("Requirement did not hold @" + expr.getPos)
e(body)

Expand All @@ -25,7 +25,7 @@ abstract class RecursiveEvaluator(override val program: Program,
e(body)

case Assert(pred, err, body) =>
if (!ignoreContracts && e(pred) != BooleanLiteral(true))
if (!ignoreContractsOn(expr) && e(pred) != BooleanLiteral(true))
throw RuntimeError(err.getOrElse("Assertion failed @" + expr.getPos))
e(body)

Expand Down Expand Up @@ -77,12 +77,31 @@ abstract class RecursiveEvaluator(override val program: Program,
case NoTree(tpe) =>
throw RuntimeError("Reached empty tree in evaluation @" + expr.getPos)

case Annotated(body, _) =>
e(body)
case Annotated(body, flags) =>
val nrctx = {
if (flags.contains(DropVCs) || flags.contains(DropConjunct)) rctx.withLocallyIgnoredContracts
else rctx
}
e(body)(using nrctx)

case _ => super.e(expr)
}

override def ignoreContractsOn(expr: Expr)(using rctx: RC, gctx: GC): Boolean = {
def isAnnotatedDrop(e: Expr) = e match {
case Annotated(_, flags) => flags.contains(DropVCs) || flags.contains(DropConjunct)
case _ => false
}

super.ignoreContractsOn(expr) || (expr match {
case Assume(pred, _) => isAnnotatedDrop(pred)
case Assert(pred, _, _) => isAnnotatedDrop(pred)
case Require(pred, _) => isAnnotatedDrop(pred)
case Ensuring(_, pred) => isAnnotatedDrop(pred) || isAnnotatedDrop(pred.body)
case _ => isAnnotatedDrop(expr)
})
}

protected def matchesCase(scrut: Expr, caze: MatchCase)
(using rctx: RC, gctx: GC): Option[Map[ValDef, Expr]] = {

Expand Down
8 changes: 8 additions & 0 deletions frontends/benchmarks/evaluators/adtFailure/ADTFailure.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
object ADTFailure {

case class MyADT(x: BigInt) {
require(0 <= x && x < 100)
}

def ohno: BigInt = MyADT(42).x + MyADT(101).x
}
9 changes: 9 additions & 0 deletions frontends/benchmarks/evaluators/adtFailure/test_conf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"recursive": true,
"codegen": true,
"failure": [
{
"function": "ohno"
}
]
}
Loading

0 comments on commit 9e62648

Please sign in to comment.