Skip to content

Commit

Permalink
Use copy for modifying case classes as well
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Aug 12, 2024
1 parent bb3d607 commit 7247a2c
Showing 1 changed file with 13 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ object QuicklensMacros {
case AndType(l, r) if r <:< l => r
case _ => tpe
}

def widenAll: TypeRepr =
tpe.widen.dealias.poorMansLUB

def matchingTypeSymbol: Symbol = tpe.widenAll match {
case AndType(l, r) =>
val lSym = l.matchingTypeSymbol
Expand All @@ -163,23 +165,22 @@ object QuicklensMacros {
Symbol.noSymbol
}

def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = {
val mem = sym.fieldMember(name)
if mem != Symbol.noSymbol then mem
else report.errorAndAbort(noSuchMember(sym.name, name))
}

def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = {
sym.methodMember(name) match
case List(m) => m
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
case Nil => symbolAccessorByNameOrError(sym, name)
case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name))
}

def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
symbolMethodByNameOrError(term.tpe.widenAll.typeSymbol, name)
}

def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = {
val typeSymbol = term.tpe.widenAll.matchingTypeSymbol
val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name)
val idx = caseParamNames.indexOf(name)
typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name)))
-> (idx + 1)
val typeSymbol = term.tpe.widenAll.typeSymbol
symbolMethodByNameOrError(typeSymbol, name)
}

def isProduct(sym: Symbol): Boolean = {
Expand Down Expand Up @@ -231,46 +232,10 @@ object QuicklensMacros {
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
If(ifCond, ifThen, ifElse)
}
} else if isProduct(objSymbol) then {
val copy = symbolMethodByNameOrError(objSymbol, "copy")
val argsMap: Map[Int, Term] = fields.map { (field, trees) =>
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
val namedArg = NamedArg(field.name, resTerm)
idx -> namedArg
}.toMap

val typeParams = obj.tpe.widenAll match {
case AppliedType(_, typeParams) => Some(typeParams)
case _ => None
}
val constructorTree: DefDef = objSymbol.primaryConstructor.tree.asInstanceOf[DefDef]
val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length
val fieldsIdxs = 1.to(firstParamListLength)
val args = fieldsIdxs.map { i =>
val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString))
argsMap.getOrElse(
i,
typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes)
)
}.toList

if constructorTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
report.errorAndAbort(
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
)

typeParams match {
// if the object's type is parametrised, we need to call .copy with the same type parameters
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
case _ => Apply(Select(obj, copy), args)
}
} else if isProductLike(objSymbol) then {
} else if isProduct(objSymbol) || isProductLike(objSymbol) then {
val copy = symbolMethodByNameOrError(objSymbol, "copy")
val argsMap: Map[String, Term] = fields.map { (field, trees) =>
val fieldMethod = termMethodByNameUnsafe(obj, field.name)
val fieldMethod = symbolMethodByNameOrError(objSymbol, field.name)
val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) =>
mapToCopy(owner, mod, term, tree)
}
Expand Down

0 comments on commit 7247a2c

Please sign in to comment.