From 65df9e6da1cb306d7ef7495132a69ed090583749 Mon Sep 17 00:00:00 2001 From: Gosha Kovalyov <134854076+goshacodes@users.noreply.github.com> Date: Sat, 6 Apr 2024 16:01:55 +0500 Subject: [PATCH] add support for intersection/union types with type parameters from method definition (#515) --- .../scala-3/org/scalamock/clazz/Utils.scala | 19 +- .../com/paulbutcher/test/Scala3Spec.scala | 167 ++++++++++++++++++ 2 files changed, 184 insertions(+), 2 deletions(-) diff --git a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala index 60bfb4dd..4333635c 100644 --- a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala +++ b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala @@ -1,7 +1,6 @@ package org.scalamock.clazz import scala.quoted.* -import org.scalamock.context.MockContext import scala.annotation.{experimental, tailrec} private[clazz] class Utils(using val quotes: Quotes): @@ -53,6 +52,22 @@ private[clazz] class Utils(using val quotes: Quotes): case _ => tpe + def resolveAndOrTypeParamRefs: TypeRepr = + tpe match { + case AndType(left: ParamRef, right: ParamRef) => + TypeRepr.of[Any] + case AndType(left: ParamRef, right) => + right.resolveAndOrTypeParamRefs + case AndType(left, right: ParamRef) => + left.resolveAndOrTypeParamRefs + case OrType(_: ParamRef, _) => + TypeRepr.of[Any] + case OrType(_, _: ParamRef) => + TypeRepr.of[Any] + case other => + other + } + @experimental def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) = tpe match @@ -117,7 +132,7 @@ private[clazz] class Utils(using val quotes: Quotes): .map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true)) .map { typeRepr => val adjusted = - typeRepr.widen.mapParamRefWithWildcard match + typeRepr.widen.mapParamRefWithWildcard.resolveAndOrTypeParamRefs match case TypeBounds(lower, upper) => upper case AppliedType(TypeRef(_, ""), elemTyps) => TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps) diff --git a/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala b/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala index b7bdb78c..67412e04 100644 --- a/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala +++ b/shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala @@ -28,6 +28,173 @@ class Scala3Spec extends AnyFunSpec with MockFactory with Matchers { m.method(1, new A with B) shouldBe 0 } + it("mock intersection type with type parameter from trait") { + + trait B + + trait C + + trait TraitWithGenericIntersection[A] { + def methodWithGenericIntersection(x: A & B): Unit + } + + val m = mock[TraitWithGenericIntersection[C]] + + val obj = new B with C {} + + (m.methodWithGenericIntersection _).expects(obj).returns(()) + + m.methodWithGenericIntersection(obj) + } + + it("mock intersection type with left type parameter from method") { + + trait B + + trait C + + trait TraitWithGenericIntersection { + def methodWithGenericIntersection[A](x: A & B): Unit + + def methodWithGenericUnion[A](x: A | B): Unit + } + + val m = mock[TraitWithGenericIntersection] + + val obj = new C with B {} + + (m.methodWithGenericIntersection[C] _).expects(obj).returns(()) + + m.methodWithGenericIntersection(obj) + } + + it("mock intersection type with right type parameter from method") { + + trait B + + trait C + + trait TraitWithGenericIntersection { + def methodWithGenericIntersection[A](x: B & A): Unit + } + + val m = mock[TraitWithGenericIntersection] + + val obj = new B with C {} + + (m.methodWithGenericIntersection[C] _).expects(obj).returns(()) + + m.methodWithGenericIntersection(obj) + } + + it("mock intersection type with both type parameters from method") { + + trait B + + trait C + + trait TraitWithGenericIntersection { + def methodWithGenericIntersection[A, B](x: A & B): Unit + } + + val m = mock[TraitWithGenericIntersection] + + val obj = new B with C {} + + (m.methodWithGenericIntersection[B, C] _).expects(obj).returns(()) + + m.methodWithGenericIntersection(obj) + } + + + it("mock intersection type with more then two types from method") { + + trait B + + trait C + + trait D + + trait TraitWithGenericIntersection { + def methodWithGenericIntersection[A, B, C](x: A & B & C): Unit + } + + val m = mock[TraitWithGenericIntersection] + + val obj = new B with C with D {} + + (m.methodWithGenericIntersection[B, C, D] _).expects(obj).returns(()) + + m.methodWithGenericIntersection(obj) + } + + it("mock intersection type with more then two types from method, one of witch is stable") { + + trait B + + trait C + + trait D + + trait TraitWithGenericIntersection { + def methodWithGenericIntersection[A, B](x: A & D & B): Unit + } + + val m = mock[TraitWithGenericIntersection] + + val obj = new B with C with D {} + + (m.methodWithGenericIntersection[B, C] _).expects(obj).returns(()) + + m.methodWithGenericIntersection(obj) + } + + it("mock union type with left type parameter from method") { + + trait B + + trait C + + trait TraitWithGenericUnion { + + def methodWithGenericUnion[A](x: A | B): Unit + } + + val m = mock[TraitWithGenericUnion] + + val obj1 = new C {} + val obj2 = new B {} + + (m.methodWithGenericUnion[C] _).expects(obj1).returns(()) + (m.methodWithGenericUnion[C] _).expects(obj2).returns(()) + + m.methodWithGenericUnion(obj1) + m.methodWithGenericUnion(obj2) + } + + it("mock union type with right type parameter from method") { + + trait B + + trait C + + trait TraitWithGenericUnion { + + def methodWithGenericUnion[A](x: B | A): Unit + } + + val m = mock[TraitWithGenericUnion] + + val obj1 = new C {} + val obj2 = new B {} + + (m.methodWithGenericUnion[C] _).expects(obj1).returns(()) + (m.methodWithGenericUnion[C] _).expects(obj2).returns(()) + + m.methodWithGenericUnion(obj1) + m.methodWithGenericUnion(obj2) + } + it("mock methods returning function") { trait Test { def method(x: Int): Int => String