From 7247a2c8e31e08c8ed47d55e0cbc9441c099c494 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 12 Aug 2024 13:25:35 +0200 Subject: [PATCH] Use copy for modifying case classes as well --- .../quicklens/QuicklensMacros.scala | 61 ++++--------------- 1 file changed, 13 insertions(+), 48 deletions(-) diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 3773dbf..90f7eac 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -149,8 +149,10 @@ object QuicklensMacros { case AndType(l, r) if r <:< l => r case _ => tpe } + def widenAll: TypeRepr = tpe.widen.dealias.poorMansLUB + def matchingTypeSymbol: Symbol = tpe.widenAll match { case AndType(l, r) => val lSym = l.matchingTypeSymbol @@ -163,23 +165,22 @@ object QuicklensMacros { Symbol.noSymbol } + def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = { + val mem = sym.fieldMember(name) + if mem != Symbol.noSymbol then mem + else report.errorAndAbort(noSuchMember(sym.name, name)) + } + def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = { sym.methodMember(name) match case List(m) => m - case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) + case Nil => symbolAccessorByNameOrError(sym, name) case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name)) } def termMethodByNameUnsafe(term: Term, name: String): Symbol = { - symbolMethodByNameOrError(term.tpe.widenAll.typeSymbol, name) - } - - def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = { - val typeSymbol = term.tpe.widenAll.matchingTypeSymbol - val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name) - val idx = caseParamNames.indexOf(name) - typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name))) - -> (idx + 1) + val typeSymbol = term.tpe.widenAll.typeSymbol + symbolMethodByNameOrError(typeSymbol, name) } def isProduct(sym: Symbol): Boolean = { @@ -231,46 +232,10 @@ object QuicklensMacros { ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => If(ifCond, ifThen, ifElse) } - } else if isProduct(objSymbol) then { - val copy = symbolMethodByNameOrError(objSymbol, "copy") - val argsMap: Map[Int, Term] = fields.map { (field, trees) => - val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name) - val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => - mapToCopy(owner, mod, term, tree) - } - val namedArg = NamedArg(field.name, resTerm) - idx -> namedArg - }.toMap - - val typeParams = obj.tpe.widenAll match { - case AppliedType(_, typeParams) => Some(typeParams) - case _ => None - } - val constructorTree: DefDef = objSymbol.primaryConstructor.tree.asInstanceOf[DefDef] - val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length - val fieldsIdxs = 1.to(firstParamListLength) - val args = fieldsIdxs.map { i => - val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString)) - argsMap.getOrElse( - i, - typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes) - ) - }.toList - - if constructorTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then - report.errorAndAbort( - s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." - ) - - typeParams match { - // if the object's type is parametrised, we need to call .copy with the same type parameters - case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) - case _ => Apply(Select(obj, copy), args) - } - } else if isProductLike(objSymbol) then { + } else if isProduct(objSymbol) || isProductLike(objSymbol) then { val copy = symbolMethodByNameOrError(objSymbol, "copy") val argsMap: Map[String, Term] = fields.map { (field, trees) => - val fieldMethod = termMethodByNameUnsafe(obj, field.name) + val fieldMethod = symbolMethodByNameOrError(objSymbol, field.name) val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => mapToCopy(owner, mod, term, tree) }