diff --git a/core/src/main/scala/stainless/ast/Deconstructors.scala b/core/src/main/scala/stainless/ast/Deconstructors.scala index e3712e4361..a597b3e3b9 100644 --- a/core/src/main/scala/stainless/ast/Deconstructors.scala +++ b/core/src/main/scala/stainless/ast/Deconstructors.scala @@ -199,7 +199,8 @@ trait TreeDeconstructor extends inox.ast.TreeDeconstructor { case s.RecursiveType(id, tps, e) => (Seq(id), Seq(), Seq(e), tps, Seq(), (ids, _, es, ntps, _) => t.RecursiveType(ids(0), ntps, es(0))) case s.ValueType(tpe) => (Seq(), Seq(), Seq(), Seq(tpe), Seq(), (_, _, _, tps, _) => t.ValueType(tps(0))) case s.AnnotatedType(tpe, flags) => - (Seq(), Seq(), Seq(), Seq(tpe), flags, (_, _, _, tps, flags) => t.AnnotatedType(tps(0), flags)) + (Seq(), Seq(), Seq(), Seq(tpe), flags, + (_, _, _, tps, flags) => if (flags.nonEmpty) t.AnnotatedType(tps(0), flags) else tps(0)) case _ => super.deconstruct(tpe) } diff --git a/core/src/main/scala/stainless/ast/SymbolOps.scala b/core/src/main/scala/stainless/ast/SymbolOps.scala index 3bc6d96b74..807e42e94b 100644 --- a/core/src/main/scala/stainless/ast/SymbolOps.scala +++ b/core/src/main/scala/stainless/ast/SymbolOps.scala @@ -144,9 +144,10 @@ trait SymbolOps extends inox.ast.SymbolOps { self: TypeOps => def rewritePM(e: Expr): Option[Expr] = e match { case m @ MatchExpr(scrut, cases) => + val scrutVd = ValDef.fresh("scrut", scrut.getType).setPos(m) val condsAndRhs = for (cse <- cases) yield { - val map = mapForPattern(scrut, cse.pattern) - val patCond = conditionForPattern[Path](scrut, cse.pattern, includeBinders = false) + val map = mapForPattern(scrutVd.toVariable, cse.pattern) + val patCond = conditionForPattern[Path](scrutVd.toVariable, cse.pattern, includeBinders = false) val realCond = cse.optGuard match { case Some(g) => patCond withCond replaceFromSymbols(map, g) case None => patCond @@ -170,7 +171,7 @@ trait SymbolOps extends inox.ast.SymbolOps { self: TypeOps => } }) - Some(bigIte) + Some(Let(scrutVd, scrut, bigIte).setPos(m)) case _ => None } diff --git a/core/src/main/scala/stainless/solvers/InoxEncoder.scala b/core/src/main/scala/stainless/solvers/InoxEncoder.scala index 4735f1c0c6..fd82fe2e76 100644 --- a/core/src/main/scala/stainless/solvers/InoxEncoder.scala +++ b/core/src/main/scala/stainless/solvers/InoxEncoder.scala @@ -73,6 +73,50 @@ trait InoxEncoder extends ProgramEncoder { import sourceProgram._ import sourceProgram.symbols._ + // Lowering to Inox erases various trees, which provides many opportunities for let simplifications. + override def transform(fd: s.FunDef): t.FunDef = { + import t._ + import exprOps._ + + def singleOccurences(expr: Expr): Set[Identifier] = { + import collection.mutable.HashMap + val counts = HashMap.empty[Identifier, Int] + preTraversal { + case v: Variable => counts.update(v.id, counts.getOrElse(v.id, 0) + 1) + case _ => + } (expr) + counts.iterator.collect { case (id, 1) => id } .toSet + } + + // Copied from Inox's SymbolOps. Doesn't really depend on symbols! + def simplifyLetsExt(expr: Expr, forced: Set[Identifier]): Expr = preMap({ + case l1 @ Let(v1, Let(v2, e2, b2), b1) => Some(Let(v2, e2, Let(v1, b2, b1).copiedFrom(l1)).copiedFrom(l1)) + + case Let(v, e, v2) if v.toVariable == v2 => Some(e) + + case Let(v, ts @ ( + (_: Variable) | + TupleSelect(_: Variable, _) | + ADTSelector(_: Variable, _) | + FiniteMap(Seq(), _, _, _) | + FiniteBag(Seq(), _) | + FiniteSet(Seq(), _) | + IsConstructor(_: Variable, _) + ), b) => Some(replaceFromSymbols(Map(v -> ts), b)) + + case Let(v, e, b) if forced.contains(v.id) => Some(replaceFromSymbols(Map(v -> e), b)) + + case _ => None + }, applyRec = true)(expr) + + val newFd = super.transform(fd) + + timers.verification.simplify.run { + val simplBody = simplifyLetsExt(newFd.fullBody, singleOccurences(newFd.fullBody)) + newFd.copy(fullBody = simplBody) + } + } + override def transform(e: s.Expr): t.Expr = e match { case m: s.MatchExpr => transform(matchToIfThenElse(m)) diff --git a/core/src/main/scala/stainless/verification/AssertionInjector.scala b/core/src/main/scala/stainless/verification/AssertionInjector.scala index 57d0c11d68..73070bcf1f 100644 --- a/core/src/main/scala/stainless/verification/AssertionInjector.scala +++ b/core/src/main/scala/stainless/verification/AssertionInjector.scala @@ -10,10 +10,12 @@ package verification */ trait AssertionInjector extends transformers.TreeTransformer { val s: ast.Trees - val t: ast.Trees + val t: s.type implicit val symbols: s.Symbols + import t.dsl._ + val strictArithmetic: Boolean private[this] var inWrappingMode: Boolean = false @@ -53,11 +55,33 @@ trait AssertionInjector extends transformers.TreeTransformer { ).copiedFrom(i), transform(v)).copiedFrom(e) case sel @ s.ADTSelector(expr, _) => - t.Assert( - t.IsConstructor(transform(expr), sel.constructor.id).copiedFrom(e), - Some("Cast error"), - super.transform(e) - ).copiedFrom(e) + def isIndexedType(tpe: s.Type) = + tpe match { + case _: s.RecursiveType => true + case _ => false + } + val hasIndexedType = expr match { + case v: s.Variable => isIndexedType(v.tpe) + case fi: s.FunctionInvocation => isIndexedType(fi.tfd.returnType) + case _ => false + } + + // NOTE(gsps): Keeping the old, code-duplicating behavior here, since index annotations on + // types are not propagated through `expr.getType`, breaking the Streams example. + if (hasIndexedType) + t.Assert( + t.IsConstructor(transform(expr), sel.constructor.id).copiedFrom(e), + Some("Cast error"), + super.transform(e) + ).copiedFrom(e) + else + let("recv" :: expr.getType, transform(expr)) { recv => + t.Assert( + t.IsConstructor(recv, sel.constructor.id).copiedFrom(e), + Some("Cast error"), + super.transform(sel.copy(adt = recv).copiedFrom(e)) + ).copiedFrom(e) + }.copiedFrom(e) case BVTyped(true, size, e0 @ s.Plus(lhs0, rhs0)) if checkOverflow => val lhs = transform(lhs0)