Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for using explicit copy methods in classes #236

Merged
merged 2 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ object QuicklensMacros {
focus: Expr[S => A],
focusesExpr: Expr[Seq[S => A]]
)(using Quotes): Expr[PathModify[S, A]] = {
import quotes.reflect.*

val focuses = focusesExpr match {
case Varargs(args) => focus +: args
}
Expand All @@ -60,6 +58,9 @@ object QuicklensMacros {
def noSuchMember(tpeStr: String, name: String) =
s"$tpeStr has no member named $name"

def multipleMatchingMethods(tpeStr: String, name: String) =
s"Multiple methods named $name found in $tpeStr"

def methodSupported(method: String) =
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)

Expand Down Expand Up @@ -148,90 +149,61 @@ object QuicklensMacros {
case AndType(l, r) if r <:< l => r
case _ => tpe
}

def widenAll: TypeRepr =
tpe.widen.dealias.poorMansLUB

def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
tpe.typeSymbol
case tpe =>
case tpe if isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
Symbol.noSymbol
}

def symbolMethodByNameUnsafe(sym: Symbol, name: String): Symbol = {
sym
.methodMember(name)
.headOption
.getOrElse(report.errorAndAbort(noSuchMember(sym.name, name)))
def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
val mem = sym.fieldMember(name)
if mem != Symbol.noSymbol then mem
else report.errorAndAbort(noSuchMember(sym.name, name))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
symbolMethodByNameUnsafe(term.tpe.widenAll.typeSymbol, name)
def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case Nil => symbolAccessorByNameOrError(sym, name)
case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name))
}

def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = {
val typeSymbol = term.tpe.widenAll.matchingTypeSymbol
val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name)
val idx = caseParamNames.indexOf(name)
typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name)))
-> (idx + 1)
def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
val typeSymbol = term.tpe.widenAll.typeSymbol
symbolMethodByNameOrError(typeSymbol, name)
}

def isProduct(sym: Symbol): Boolean = {
sym.flags.is(Flags.Case)
}

def isSum(sym: Symbol): Boolean = {
sym.flags.is(Flags.Enum) ||
(sym.flags.is(Flags.Enum) && !sym.flags.is(Flags.Case)) ||
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").size == 1
}

def caseClassCopy(
owner: Symbol,
mod: Expr[A => A],
obj: Term,
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
): Term = {
val objSymbol = obj.tpe.widenAll.matchingTypeSymbol
if isProduct(objSymbol) then {
val copy = symbolMethodByNameUnsafe(objSymbol, "copy")
val argsMap: Map[Int, Term] = fields.map { (field, trees) =>
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
idx -> namedArg
}.toMap

val typeParams = obj.tpe.widenAll match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val constructorTree: DefDef = objSymbol.primaryConstructor.tree.asInstanceOf[DefDef]
val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length
val fieldsIdxs = 1.to(firstParamListLength)
val args = fieldsIdxs.map { i =>
val defaultMethod = obj.select(symbolMethodByNameUnsafe(objSymbol, "copy$default$" + i.toString))
argsMap.getOrElse(
i,
typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes)
)
}.toList

if constructorTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
} else if isSum(objSymbol) then {
if isSum(objSymbol) then {
obj.tpe.widenAll match {
case AndType(_, _) =>
report.errorAndAbort(
Expand Down Expand Up @@ -260,6 +232,45 @@ object QuicklensMacros {
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val copy = symbolMethodByNameOrError(objSymbol, "copy")
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val fieldMethod = symbolMethodByNameOrError(objSymbol, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap

val typeParams = obj.tpe.widenAll match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)

val args = copyParamNames.zipWithIndex.map { (n, _i) =>
val i = _i + 1
val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString))
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
argsMap.getOrElse(
n,
defaultMethod
)
}.toList

if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
} else
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.softwaremill

import scala.collection.{Factory, SortedMap}
import scala.collection.SortedMap
import scala.annotation.compileTimeOnly
import com.softwaremill.quicklens.QuicklensMacros._
import scala.reflect.ClassTag
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.softwaremill.quicklens

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class ExplicitCopyTest extends AnyFlatSpec with Matchers {
it should "modify a class with an explicit copy method" in {
case class V(x: Double, y: Double)
class Vec(val v: V) {
def x: Double = v.x
def y: Double = v.y
def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y)) {}
def show: String = s"Vec(${v.x}, ${v.y})"
}
object Vec {
def apply(x: Double, y: Double): Vec = new Vec(V(x, y)) {}
}

val vec = Vec(1, 2)
val modified = vec.modify(_.x).using(_ + 1)
val expected = Vec(2, 2)
modified.show shouldEqual expected.show
}

}
Loading