Skip to content

Commit

Permalink
enhance union/intersection types support
Browse files Browse the repository at this point in the history
  • Loading branch information
goshacodes committed Apr 11, 2024
1 parent 1f9ee41 commit 678a03a
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 5 deletions.
16 changes: 11 additions & 5 deletions shared/src/main/scala-3/org/scalamock/clazz/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ private[clazz] class Utils(using val quotes: Quotes):

def resolveAndOrTypeParamRefs: TypeRepr =
tpe match {
case AndType(left: ParamRef, right: ParamRef) =>
case AndType(left @ (_: ParamRef | _: AppliedType), right @ (_: ParamRef | _: AppliedType)) =>
TypeRepr.of[Any]
case AndType(left: ParamRef, right) =>
case AndType(left @ (_: ParamRef | _: AppliedType), right) =>
right.resolveAndOrTypeParamRefs
case AndType(left, right: ParamRef) =>
case AndType(left, right @ (_: ParamRef | _: AppliedType)) =>
left.resolveAndOrTypeParamRefs
case OrType(_: ParamRef, _) =>
case OrType(_: ParamRef | _: AppliedType, _) =>
TypeRepr.of[Any]
case OrType(_, _: ParamRef) =>
case OrType(_, _: ParamRef | _: AppliedType) =>
TypeRepr.of[Any]
case other =>
other
Expand All @@ -77,6 +77,12 @@ private[clazz] class Utils(using val quotes: Quotes):
case pr@ParamRef(bindings, idx) if bindings == baseBindings =>
methodArgs.head(idx).asInstanceOf[TypeTree].tpe

case AndType(left, right) =>
AndType(loop(left), loop(right))

case OrType(left, right) =>
OrType(loop(left), loop(right))

case AppliedType(tycon, args) =>
AppliedType(loop(tycon), args.map(arg => loop(arg)))

Expand Down
69 changes: 69 additions & 0 deletions shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,75 @@ class Scala3Spec extends AnyFunSpec with MockFactory with Matchers {
m.methodWithGenericUnion(obj2)
}

it("mock union return type") {

trait A

trait B

trait TraitWithUnionReturnType {

def methodWithUnionReturnType[T](): T | A
}

val m = mock[TraitWithUnionReturnType]

val obj = new B {}

(() => m.methodWithUnionReturnType[B]()).expects().returns(obj)

m.methodWithUnionReturnType[B]() shouldBe obj
}

it("mock intersection return type") {

trait A

trait B

trait TraitWithIntersectionReturnType {

def methodWithIntersectionReturnType[T](): A & T
}

val m = mock[TraitWithIntersectionReturnType]

val obj = new A with B {}

(() => m.methodWithIntersectionReturnType[B]()).expects().returns(obj)

m.methodWithIntersectionReturnType[B]() shouldBe obj
}

it("mock intersection|union types with type constructors") {

trait A[T]

trait B

trait C

trait ComplexUnionIntersectionCases {

def complexMethod1[T](x: A[T] & T): A[T] & T
def complexMethod2[T](x: A[A[T]] | T): A[T] | T
def complexMethod3[F[_], T](x: F[A[T] & F[T]] | T & A[F[T]]): F[T] & T
def complexMethod4[T](x: A[B & C] ): A[B & C]
def complexMethod5[T](x: A[B | A[C]]): A[B | C]
}

val m = mock[ComplexUnionIntersectionCases]

val obj = new A[B] with B {}
val obj2 = new A[A[B]] with B {}

(m.complexMethod1[B] _).expects(obj).returns(obj)
(m.complexMethod2[B] _).expects(obj2).returns(new A[B] {})

m.complexMethod1[B](obj)
m.complexMethod2[B](obj2)
}

it("mock methods returning function") {
trait Test {
def method(x: Int): Int => String
Expand Down

0 comments on commit 678a03a

Please sign in to comment.