diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 633d526..90f7eac 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -34,8 +34,6 @@ object QuicklensMacros { focus: Expr[S => A], focusesExpr: Expr[Seq[S => A]] )(using Quotes): Expr[PathModify[S, A]] = { - import quotes.reflect.* - val focuses = focusesExpr match { case Varargs(args) => focus +: args } @@ -60,6 +58,9 @@ object QuicklensMacros { def noSuchMember(tpeStr: String, name: String) = s"$tpeStr has no member named $name" + def multipleMatchingMethods(tpeStr: String, name: String) = + s"Multiple methods named $name found in $tpeStr" + def methodSupported(method: String) = Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method) @@ -148,35 +149,38 @@ object QuicklensMacros { case AndType(l, r) if r <:< l => r case _ => tpe } + def widenAll: TypeRepr = tpe.widen.dealias.poorMansLUB + def matchingTypeSymbol: Symbol = tpe.widenAll match { case AndType(l, r) => val lSym = l.matchingTypeSymbol if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) => tpe.typeSymbol - case tpe => + case tpe if isProductLike(tpe.typeSymbol) => + tpe.typeSymbol + case _ => Symbol.noSymbol } - def symbolMethodByNameUnsafe(sym: Symbol, name: String): Symbol = { - sym - .methodMember(name) - .headOption - .getOrElse(report.errorAndAbort(noSuchMember(sym.name, name))) + def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = { + val mem = sym.fieldMember(name) + if mem != Symbol.noSymbol then mem + else report.errorAndAbort(noSuchMember(sym.name, name)) } - def termMethodByNameUnsafe(term: Term, name: String): Symbol = { - symbolMethodByNameUnsafe(term.tpe.widenAll.typeSymbol, name) + def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = { + sym.methodMember(name) match + case List(m) => m + case Nil => symbolAccessorByNameOrError(sym, name) + case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name)) } - def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = { - val typeSymbol = term.tpe.widenAll.matchingTypeSymbol - val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name) - val idx = caseParamNames.indexOf(name) - typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name))) - -> (idx + 1) + def termMethodByNameUnsafe(term: Term, name: String): Symbol = { + val typeSymbol = term.tpe.widenAll.typeSymbol + symbolMethodByNameOrError(typeSymbol, name) } def isProduct(sym: Symbol): Boolean = { @@ -184,10 +188,14 @@ object QuicklensMacros { } def isSum(sym: Symbol): Boolean = { - sym.flags.is(Flags.Enum) || + (sym.flags.is(Flags.Enum) && !sym.flags.is(Flags.Case)) || (sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract))) } + def isProductLike(sym: Symbol): Boolean = { + sym.methodMember("copy").size == 1 + } + def caseClassCopy( owner: Symbol, mod: Expr[A => A], @@ -195,43 +203,7 @@ object QuicklensMacros { fields: Seq[(PathSymbol.Field, Seq[PathTree])] ): Term = { val objSymbol = obj.tpe.widenAll.matchingTypeSymbol - if isProduct(objSymbol) then { - val copy = symbolMethodByNameUnsafe(objSymbol, "copy") - val argsMap: Map[Int, Term] = fields.map { (field, trees) => - val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name) - val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => - mapToCopy(owner, mod, term, tree) - } - val namedArg = NamedArg(field.name, resTerm) - idx -> namedArg - }.toMap - - val typeParams = obj.tpe.widenAll match { - case AppliedType(_, typeParams) => Some(typeParams) - case _ => None - } - val constructorTree: DefDef = objSymbol.primaryConstructor.tree.asInstanceOf[DefDef] - val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length - val fieldsIdxs = 1.to(firstParamListLength) - val args = fieldsIdxs.map { i => - val defaultMethod = obj.select(symbolMethodByNameUnsafe(objSymbol, "copy$default$" + i.toString)) - argsMap.getOrElse( - i, - typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes) - ) - }.toList - - if constructorTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then - report.errorAndAbort( - s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." - ) - - typeParams match { - // if the object's type is parametrised, we need to call .copy with the same type parameters - case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) - case _ => Apply(Select(obj, copy), args) - } - } else if isSum(objSymbol) then { + if isSum(objSymbol) then { obj.tpe.widenAll match { case AndType(_, _) => report.errorAndAbort( @@ -260,6 +232,45 @@ object QuicklensMacros { ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => If(ifCond, ifThen, ifElse) } + } else if isProduct(objSymbol) || isProductLike(objSymbol) then { + val copy = symbolMethodByNameOrError(objSymbol, "copy") + val argsMap: Map[String, Term] = fields.map { (field, trees) => + val fieldMethod = symbolMethodByNameOrError(objSymbol, field.name) + val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => + mapToCopy(owner, mod, term, tree) + } + val namedArg = NamedArg(field.name, resTerm) + field.name -> namedArg + }.toMap + + val typeParams = obj.tpe.widenAll match { + case AppliedType(_, typeParams) => Some(typeParams) + case _ => None + } + val copyTree: DefDef = copy.tree.asInstanceOf[DefDef] + val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name) + + val args = copyParamNames.zipWithIndex.map { (n, _i) => + val i = _i + 1 + val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString)) + // for extension methods, might need sth more like this: (or probably some weird implicit conversion) + // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n)) + argsMap.getOrElse( + n, + defaultMethod + ) + }.toList + + if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then + report.errorAndAbort( + s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." + ) + + typeParams match { + // if the object's type is parametrised, we need to call .copy with the same type parameters + case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) + case _ => Apply(Select(obj, copy), args) + } } else report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol") } diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala index a5c1f9f..179801b 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala @@ -1,6 +1,6 @@ package com.softwaremill -import scala.collection.{Factory, SortedMap} +import scala.collection.SortedMap import scala.annotation.compileTimeOnly import com.softwaremill.quicklens.QuicklensMacros._ import scala.reflect.ClassTag diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala new file mode 100644 index 0000000..eb7e469 --- /dev/null +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala @@ -0,0 +1,25 @@ +package com.softwaremill.quicklens + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class ExplicitCopyTest extends AnyFlatSpec with Matchers { + it should "modify a class with an explicit copy method" in { + case class V(x: Double, y: Double) + class Vec(val v: V) { + def x: Double = v.x + def y: Double = v.y + def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y)) {} + def show: String = s"Vec(${v.x}, ${v.y})" + } + object Vec { + def apply(x: Double, y: Double): Vec = new Vec(V(x, y)) {} + } + + val vec = Vec(1, 2) + val modified = vec.modify(_.x).using(_ + 1) + val expected = Vec(2, 2) + modified.show shouldEqual expected.show + } + +}