Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance union/intersection type support #515

Merged
merged 1 commit into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions shared/src/main/scala-3/org/scalamock/clazz/Utils.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.scalamock.clazz

import scala.quoted.*
import org.scalamock.context.MockContext

import scala.annotation.{experimental, tailrec}
private[clazz] class Utils(using val quotes: Quotes):
Expand Down Expand Up @@ -53,6 +52,22 @@ private[clazz] class Utils(using val quotes: Quotes):
case _ =>
tpe

def resolveAndOrTypeParamRefs: TypeRepr =
tpe match {
case AndType(left: ParamRef, right: ParamRef) =>
TypeRepr.of[Any]
case AndType(left: ParamRef, right) =>
right.resolveAndOrTypeParamRefs
case AndType(left, right: ParamRef) =>
left.resolveAndOrTypeParamRefs
case OrType(_: ParamRef, _) =>
TypeRepr.of[Any]
case OrType(_, _: ParamRef) =>
TypeRepr.of[Any]
case other =>
other
}

@experimental
def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) =
tpe match
Expand Down Expand Up @@ -117,7 +132,7 @@ private[clazz] class Utils(using val quotes: Quotes):
.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true))
.map { typeRepr =>
val adjusted =
typeRepr.widen.mapParamRefWithWildcard match
typeRepr.widen.mapParamRefWithWildcard.resolveAndOrTypeParamRefs match
case TypeBounds(lower, upper) => upper
case AppliedType(TypeRef(_, "<repeated>"), elemTyps) =>
TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps)
Expand Down
167 changes: 167 additions & 0 deletions shared/src/test/scala-3/com/paulbutcher/test/Scala3Spec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,173 @@ class Scala3Spec extends AnyFunSpec with MockFactory with Matchers {
m.method(1, new A with B) shouldBe 0
}

it("mock intersection type with type parameter from trait") {

trait B

trait C

trait TraitWithGenericIntersection[A] {
def methodWithGenericIntersection(x: A & B): Unit
}

val m = mock[TraitWithGenericIntersection[C]]

val obj = new B with C {}

(m.methodWithGenericIntersection _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with left type parameter from method") {

trait B

trait C

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A](x: A & B): Unit

def methodWithGenericUnion[A](x: A | B): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new C with B {}

(m.methodWithGenericIntersection[C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with right type parameter from method") {

trait B

trait C

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A](x: B & A): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C {}

(m.methodWithGenericIntersection[C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with both type parameters from method") {

trait B

trait C

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A, B](x: A & B): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C {}

(m.methodWithGenericIntersection[B, C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}


it("mock intersection type with more then two types from method") {

trait B

trait C

trait D

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A, B, C](x: A & B & C): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C with D {}

(m.methodWithGenericIntersection[B, C, D] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock intersection type with more then two types from method, one of witch is stable") {

trait B

trait C

trait D

trait TraitWithGenericIntersection {
def methodWithGenericIntersection[A, B](x: A & D & B): Unit
}

val m = mock[TraitWithGenericIntersection]

val obj = new B with C with D {}

(m.methodWithGenericIntersection[B, C] _).expects(obj).returns(())

m.methodWithGenericIntersection(obj)
}

it("mock union type with left type parameter from method") {

trait B

trait C

trait TraitWithGenericUnion {

def methodWithGenericUnion[A](x: A | B): Unit
}

val m = mock[TraitWithGenericUnion]

val obj1 = new C {}
val obj2 = new B {}

(m.methodWithGenericUnion[C] _).expects(obj1).returns(())
(m.methodWithGenericUnion[C] _).expects(obj2).returns(())

m.methodWithGenericUnion(obj1)
m.methodWithGenericUnion(obj2)
}

it("mock union type with right type parameter from method") {

trait B

trait C

trait TraitWithGenericUnion {

def methodWithGenericUnion[A](x: B | A): Unit
}

val m = mock[TraitWithGenericUnion]

val obj1 = new C {}
val obj2 = new B {}

(m.methodWithGenericUnion[C] _).expects(obj1).returns(())
(m.methodWithGenericUnion[C] _).expects(obj2).returns(())

m.methodWithGenericUnion(obj1)
m.methodWithGenericUnion(obj2)
}

it("mock methods returning function") {
trait Test {
def method(x: Int): Int => String
Expand Down
Loading