From 98626b47eed069b27d3788b35ca988435b412dc9 Mon Sep 17 00:00:00 2001 From: mario-bucev Date: Mon, 27 Mar 2023 15:31:07 +0200 Subject: [PATCH] Fix #1349 (#1393) --- .../extraction/methods/MethodLifting.scala | 13 +++++++-- .../verification/valid/i1349a.scala | 27 +++++++++++++++++++ .../verification/valid/i1349b.scala | 21 +++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 frontends/benchmarks/verification/valid/i1349a.scala create mode 100644 frontends/benchmarks/verification/valid/i1349b.scala diff --git a/core/src/main/scala/stainless/extraction/methods/MethodLifting.scala b/core/src/main/scala/stainless/extraction/methods/MethodLifting.scala index c4df022704..1f479de2ae 100644 --- a/core/src/main/scala/stainless/extraction/methods/MethodLifting.scala +++ b/core/src/main/scala/stainless/extraction/methods/MethodLifting.scala @@ -71,9 +71,18 @@ class MethodLifting(override val s: Trees, override val t: oo.Trees) override def transform(e: s.Expr): t.Expr = e match { case s.MethodInvocation(rec, id, tps, args) => - val ct = rec.getType(using symbols).asInstanceOf[s.ClassType] + given s.Symbols = symbols + val ct = rec.getType match { + case ct: s.ClassType => ct + case ta: s.TypeApply if ta.lookupTypeDef.isDefined && !ta.isAbstract => + ta.resolve match { + case ct: s.ClassType => ct + case other => context.reporter.fatalError(rec.getPos, s"Unexpected type for method invocation receiver: got $other") + } + case other => context.reporter.fatalError(rec.getPos, s"Unexpected type for method invocation receiver: got $other") + } val cid = symbols.getFunction(id).flags.collectFirst { case s.IsMethodOf(cid) => cid }.get - val tcd = (ct.tcd(using symbols) +: ct.tcd(using symbols).ancestors).find(_.id == cid).get + val tcd = (ct.tcd +: ct.tcd.ancestors).find(_.id == cid).get t.FunctionInvocation(id, (tcd.tps ++ tps) map transform, (rec +: args) map transform).copiedFrom(e) case _ => super.transform(e) diff --git a/frontends/benchmarks/verification/valid/i1349a.scala b/frontends/benchmarks/verification/valid/i1349a.scala new file mode 100644 index 0000000000..1d3c8c16bc --- /dev/null +++ b/frontends/benchmarks/verification/valid/i1349a.scala @@ -0,0 +1,27 @@ +import stainless.lang._ +import stainless.collection._ + +object i1349a { + type Index = BigInt + + type LIndex = List[Index] + + case class IndexedKey(index: BigInt, key: LIndex) { + require(0 <= index && index < key.length) + } + + def mkIndexedKey1(index: BigInt, key: LIndex): IndexedKey = { + require(0 <= index && index < key.length) + IndexedKey(index, key) + } + + def mkIndexedKey2(index: BigInt, key: List[Index]): IndexedKey = { + require(0 <= index && index < key.length) + IndexedKey(index, key) + } + + def mkIndexedKey3(index: Index, key: LIndex): IndexedKey = { + require(0 <= index && index < key.length) + IndexedKey(index, key) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/verification/valid/i1349b.scala b/frontends/benchmarks/verification/valid/i1349b.scala new file mode 100644 index 0000000000..de4b8f60bf --- /dev/null +++ b/frontends/benchmarks/verification/valid/i1349b.scala @@ -0,0 +1,21 @@ +object i1349b { + final case class Wrap[A](a: A) { + def get: A = a + } + + type WInt = Wrap[Int] + + case class IndexedKey(key: WInt) { + require(key.get < 100) + } + + def mkIndexedKey1(key: WInt): IndexedKey = { + require(key.get < 100) + IndexedKey(key) + } + + def mkIndexedKey2(key: Wrap[Int]): IndexedKey = { + require(key.get < 100) + IndexedKey(key) + } +} \ No newline at end of file