Skip to content

Commit

Permalink
Merge pull request #236 from softwaremill/i234
Browse files Browse the repository at this point in the history
Allow for using explicit copy methods in classes
  • Loading branch information
adamw authored Sep 2, 2024
2 parents d4c24ca + 7247a2c commit 8d0e135
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ object QuicklensMacros {
focus: Expr[S => A],
focusesExpr: Expr[Seq[S => A]]
)(using Quotes): Expr[PathModify[S, A]] = {
import quotes.reflect.*

val focuses = focusesExpr match {
case Varargs(args) => focus +: args
}
Expand All @@ -60,6 +58,9 @@ object QuicklensMacros {
def noSuchMember(tpeStr: String, name: String) =
s"$tpeStr has no member named $name"

def multipleMatchingMethods(tpeStr: String, name: String) =
s"Multiple methods named $name found in $tpeStr"

def methodSupported(method: String) =
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)

Expand Down Expand Up @@ -148,90 +149,61 @@ object QuicklensMacros {
case AndType(l, r) if r <:< l => r
case _ => tpe
}

def widenAll: TypeRepr =
tpe.widen.dealias.poorMansLUB

def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol
case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) =>
tpe.typeSymbol
case tpe =>
case tpe if isProductLike(tpe.typeSymbol) =>
tpe.typeSymbol
case _ =>
Symbol.noSymbol
}

def symbolMethodByNameUnsafe(sym: Symbol, name: String): Symbol = {
sym
.methodMember(name)
.headOption
.getOrElse(report.errorAndAbort(noSuchMember(sym.name, name)))
def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
val mem = sym.fieldMember(name)
if mem != Symbol.noSymbol then mem
else report.errorAndAbort(noSuchMember(sym.name, name))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
symbolMethodByNameUnsafe(term.tpe.widenAll.typeSymbol, name)
def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case Nil => symbolAccessorByNameOrError(sym, name)
case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name))
}

def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = {
val typeSymbol = term.tpe.widenAll.matchingTypeSymbol
val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name)
val idx = caseParamNames.indexOf(name)
typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name)))
-> (idx + 1)
def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
val typeSymbol = term.tpe.widenAll.typeSymbol
symbolMethodByNameOrError(typeSymbol, name)
}

def isProduct(sym: Symbol): Boolean = {
sym.flags.is(Flags.Case)
}

def isSum(sym: Symbol): Boolean = {
sym.flags.is(Flags.Enum) ||
(sym.flags.is(Flags.Enum) && !sym.flags.is(Flags.Case)) ||
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
}

def isProductLike(sym: Symbol): Boolean = {
sym.methodMember("copy").size == 1
}

def caseClassCopy(
owner: Symbol,
mod: Expr[A => A],
obj: Term,
fields: Seq[(PathSymbol.Field, Seq[PathTree])]
): Term = {
val objSymbol = obj.tpe.widenAll.matchingTypeSymbol
if isProduct(objSymbol) then {
val copy = symbolMethodByNameUnsafe(objSymbol, "copy")
val argsMap: Map[Int, Term] = fields.map { (field, trees) =>
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
idx -> namedArg
}.toMap

val typeParams = obj.tpe.widenAll match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val constructorTree: DefDef = objSymbol.primaryConstructor.tree.asInstanceOf[DefDef]
val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length
val fieldsIdxs = 1.to(firstParamListLength)
val args = fieldsIdxs.map { i =>
val defaultMethod = obj.select(symbolMethodByNameUnsafe(objSymbol, "copy$default$" + i.toString))
argsMap.getOrElse(
i,
typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes)
)
}.toList

if constructorTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
} else if isSum(objSymbol) then {
if isSum(objSymbol) then {
obj.tpe.widenAll match {
case AndType(_, _) =>
report.errorAndAbort(
Expand Down Expand Up @@ -260,6 +232,45 @@ object QuicklensMacros {
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val copy = symbolMethodByNameOrError(objSymbol, "copy")
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val fieldMethod = symbolMethodByNameOrError(objSymbol, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
field.name -> namedArg
}.toMap

val typeParams = obj.tpe.widenAll match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)

val args = copyParamNames.zipWithIndex.map { (n, _i) =>
val i = _i + 1
val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString))
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
argsMap.getOrElse(
n,
defaultMethod
)
}.toList

if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
} else
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.softwaremill

import scala.collection.{Factory, SortedMap}
import scala.collection.SortedMap
import scala.annotation.compileTimeOnly
import com.softwaremill.quicklens.QuicklensMacros._
import scala.reflect.ClassTag
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.softwaremill.quicklens

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class ExplicitCopyTest extends AnyFlatSpec with Matchers {
it should "modify a class with an explicit copy method" in {
case class V(x: Double, y: Double)
class Vec(val v: V) {
def x: Double = v.x
def y: Double = v.y
def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y)) {}
def show: String = s"Vec(${v.x}, ${v.y})"
}
object Vec {
def apply(x: Double, y: Double): Vec = new Vec(V(x, y)) {}
}

val vec = Vec(1, 2)
val modified = vec.modify(_.x).using(_ + 1)
val expected = Vec(2, 2)
modified.show shouldEqual expected.show
}

}

0 comments on commit 8d0e135

Please sign in to comment.