From bb3d607cfdd4ff31c83b83b500ebfdb6ee7e235c Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Wed, 10 Jul 2024 18:30:20 +0200 Subject: [PATCH 1/2] Allow for using explicit copy methods in classes fixes part of softwaremill#234 --- .../quicklens/QuicklensMacros.scala | 120 ++++++++++++------ .../com/softwaremill/quicklens/package.scala | 2 +- .../quicklens/test/ExplicitCopyTest.scala | 25 ++++ 3 files changed, 109 insertions(+), 38 deletions(-) create mode 100644 quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 633d526..3773dbf 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -34,8 +34,6 @@ object QuicklensMacros { focus: Expr[S => A], focusesExpr: Expr[Seq[S => A]] )(using Quotes): Expr[PathModify[S, A]] = { - import quotes.reflect.* - val focuses = focusesExpr match { case Varargs(args) => focus +: args } @@ -60,6 +58,9 @@ object QuicklensMacros { def noSuchMember(tpeStr: String, name: String) = s"$tpeStr has no member named $name" + def multipleMatchingMethods(tpeStr: String, name: String) = + s"Multiple methods named $name found in $tpeStr" + def methodSupported(method: String) = Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method) @@ -156,19 +157,21 @@ object QuicklensMacros { if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) => tpe.typeSymbol - case tpe => + case tpe if isProductLike(tpe.typeSymbol) => + tpe.typeSymbol + case _ => Symbol.noSymbol } - def symbolMethodByNameUnsafe(sym: Symbol, name: String): Symbol = { - sym - .methodMember(name) - .headOption - .getOrElse(report.errorAndAbort(noSuchMember(sym.name, name))) + def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = { + sym.methodMember(name) match + case List(m) => m + case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) + case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name)) } def termMethodByNameUnsafe(term: Term, name: String): Symbol = { - symbolMethodByNameUnsafe(term.tpe.widenAll.typeSymbol, name) + symbolMethodByNameOrError(term.tpe.widenAll.typeSymbol, name) } def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = { @@ -184,10 +187,14 @@ object QuicklensMacros { } def isSum(sym: Symbol): Boolean = { - sym.flags.is(Flags.Enum) || + (sym.flags.is(Flags.Enum) && !sym.flags.is(Flags.Case)) || (sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract))) } + def isProductLike(sym: Symbol): Boolean = { + sym.methodMember("copy").size == 1 + } + def caseClassCopy( owner: Symbol, mod: Expr[A => A], @@ -195,8 +202,37 @@ object QuicklensMacros { fields: Seq[(PathSymbol.Field, Seq[PathTree])] ): Term = { val objSymbol = obj.tpe.widenAll.matchingTypeSymbol - if isProduct(objSymbol) then { - val copy = symbolMethodByNameUnsafe(objSymbol, "copy") + if isSum(objSymbol) then { + obj.tpe.widenAll match { + case AndType(_, _) => + report.errorAndAbort( + s"Implementation limitation: Cannot modify sealed hierarchies mixed with & types. Try providing a more specific type." + ) + case _ => + } + /* + if (obj.isInstanceOf[Child1]) caseClassCopy(obj.asInstanceOf[Child1]) else + if (obj.isInstanceOf[Child2]) caseClassCopy(obj.asInstanceOf[Child2]) else + ... + else throw new IllegalStateException() + */ + val ifThens = objSymbol.children.map { child => + val ifCond = TypeApply(Select.unique(obj, "isInstanceOf"), List(TypeIdent(child))) + + val ifThen = ValDef.let(owner, TypeApply(Select.unique(obj, "asInstanceOf"), List(TypeIdent(child)))) { + castToChildVal => + caseClassCopy(owner, mod, castToChildVal, fields) + } + + ifCond -> ifThen + } + + val elseThrow = '{ throw new IllegalStateException() }.asTerm + ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => + If(ifCond, ifThen, ifElse) + } + } else if isProduct(objSymbol) then { + val copy = symbolMethodByNameOrError(objSymbol, "copy") val argsMap: Map[Int, Term] = fields.map { (field, trees) => val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name) val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => @@ -214,7 +250,7 @@ object QuicklensMacros { val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length val fieldsIdxs = 1.to(firstParamListLength) val args = fieldsIdxs.map { i => - val defaultMethod = obj.select(symbolMethodByNameUnsafe(objSymbol, "copy$default$" + i.toString)) + val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString)) argsMap.getOrElse( i, typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes) @@ -231,34 +267,44 @@ object QuicklensMacros { case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) case _ => Apply(Select(obj, copy), args) } - } else if isSum(objSymbol) then { - obj.tpe.widenAll match { - case AndType(_, _) => - report.errorAndAbort( - s"Implementation limitation: Cannot modify sealed hierarchies mixed with & types. Try providing a more specific type." - ) - case _ => - } - /* - if (obj.isInstanceOf[Child1]) caseClassCopy(obj.asInstanceOf[Child1]) else - if (obj.isInstanceOf[Child2]) caseClassCopy(obj.asInstanceOf[Child2]) else - ... - else throw new IllegalStateException() - */ - val ifThens = objSymbol.children.map { child => - val ifCond = TypeApply(Select.unique(obj, "isInstanceOf"), List(TypeIdent(child))) - - val ifThen = ValDef.let(owner, TypeApply(Select.unique(obj, "asInstanceOf"), List(TypeIdent(child)))) { - castToChildVal => - caseClassCopy(owner, mod, castToChildVal, fields) + } else if isProductLike(objSymbol) then { + val copy = symbolMethodByNameOrError(objSymbol, "copy") + val argsMap: Map[String, Term] = fields.map { (field, trees) => + val fieldMethod = termMethodByNameUnsafe(obj, field.name) + val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => + mapToCopy(owner, mod, term, tree) } + val namedArg = NamedArg(field.name, resTerm) + field.name -> namedArg + }.toMap - ifCond -> ifThen + val typeParams = obj.tpe.widenAll match { + case AppliedType(_, typeParams) => Some(typeParams) + case _ => None } + val copyTree: DefDef = copy.tree.asInstanceOf[DefDef] + val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name) + + val args = copyParamNames.zipWithIndex.map { (n, _i) => + val i = _i + 1 + val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString)) + // for extension methods, might need sth more like this: (or probably some weird implicit conversion) + // val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n)) + argsMap.getOrElse( + n, + defaultMethod + ) + }.toList - val elseThrow = '{ throw new IllegalStateException() }.asTerm - ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => - If(ifCond, ifThen, ifElse) + if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then + report.errorAndAbort( + s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." + ) + + typeParams match { + // if the object's type is parametrised, we need to call .copy with the same type parameters + case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) + case _ => Apply(Select(obj, copy), args) } } else report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol") diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala index a5c1f9f..179801b 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala @@ -1,6 +1,6 @@ package com.softwaremill -import scala.collection.{Factory, SortedMap} +import scala.collection.SortedMap import scala.annotation.compileTimeOnly import com.softwaremill.quicklens.QuicklensMacros._ import scala.reflect.ClassTag diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala new file mode 100644 index 0000000..eb7e469 --- /dev/null +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExplicitCopyTest.scala @@ -0,0 +1,25 @@ +package com.softwaremill.quicklens + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class ExplicitCopyTest extends AnyFlatSpec with Matchers { + it should "modify a class with an explicit copy method" in { + case class V(x: Double, y: Double) + class Vec(val v: V) { + def x: Double = v.x + def y: Double = v.y + def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y)) {} + def show: String = s"Vec(${v.x}, ${v.y})" + } + object Vec { + def apply(x: Double, y: Double): Vec = new Vec(V(x, y)) {} + } + + val vec = Vec(1, 2) + val modified = vec.modify(_.x).using(_ + 1) + val expected = Vec(2, 2) + modified.show shouldEqual expected.show + } + +} From 7247a2c8e31e08c8ed47d55e0cbc9441c099c494 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 12 Aug 2024 13:25:35 +0200 Subject: [PATCH 2/2] Use copy for modifying case classes as well --- .../quicklens/QuicklensMacros.scala | 61 ++++--------------- 1 file changed, 13 insertions(+), 48 deletions(-) diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index 3773dbf..90f7eac 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -149,8 +149,10 @@ object QuicklensMacros { case AndType(l, r) if r <:< l => r case _ => tpe } + def widenAll: TypeRepr = tpe.widen.dealias.poorMansLUB + def matchingTypeSymbol: Symbol = tpe.widenAll match { case AndType(l, r) => val lSym = l.matchingTypeSymbol @@ -163,23 +165,22 @@ object QuicklensMacros { Symbol.noSymbol } + def symbolAccessorByNameOrError(sym: Symbol, name: String): Symbol = { + val mem = sym.fieldMember(name) + if mem != Symbol.noSymbol then mem + else report.errorAndAbort(noSuchMember(sym.name, name)) + } + def symbolMethodByNameOrError(sym: Symbol, name: String): Symbol = { sym.methodMember(name) match case List(m) => m - case Nil => report.errorAndAbort(noSuchMember(sym.name, name)) + case Nil => symbolAccessorByNameOrError(sym, name) case _ => report.errorAndAbort(multipleMatchingMethods(sym.name, name)) } def termMethodByNameUnsafe(term: Term, name: String): Symbol = { - symbolMethodByNameOrError(term.tpe.widenAll.typeSymbol, name) - } - - def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = { - val typeSymbol = term.tpe.widenAll.matchingTypeSymbol - val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name) - val idx = caseParamNames.indexOf(name) - typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name))) - -> (idx + 1) + val typeSymbol = term.tpe.widenAll.typeSymbol + symbolMethodByNameOrError(typeSymbol, name) } def isProduct(sym: Symbol): Boolean = { @@ -231,46 +232,10 @@ object QuicklensMacros { ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) => If(ifCond, ifThen, ifElse) } - } else if isProduct(objSymbol) then { - val copy = symbolMethodByNameOrError(objSymbol, "copy") - val argsMap: Map[Int, Term] = fields.map { (field, trees) => - val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name) - val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => - mapToCopy(owner, mod, term, tree) - } - val namedArg = NamedArg(field.name, resTerm) - idx -> namedArg - }.toMap - - val typeParams = obj.tpe.widenAll match { - case AppliedType(_, typeParams) => Some(typeParams) - case _ => None - } - val constructorTree: DefDef = objSymbol.primaryConstructor.tree.asInstanceOf[DefDef] - val firstParamListLength: Int = constructorTree.termParamss.headOption.map(_.params).toList.flatten.length - val fieldsIdxs = 1.to(firstParamListLength) - val args = fieldsIdxs.map { i => - val defaultMethod = obj.select(symbolMethodByNameOrError(objSymbol, "copy$default$" + i.toString)) - argsMap.getOrElse( - i, - typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes) - ) - }.toList - - if constructorTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then - report.errorAndAbort( - s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit." - ) - - typeParams match { - // if the object's type is parametrised, we need to call .copy with the same type parameters - case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) - case _ => Apply(Select(obj, copy), args) - } - } else if isProductLike(objSymbol) then { + } else if isProduct(objSymbol) || isProductLike(objSymbol) then { val copy = symbolMethodByNameOrError(objSymbol, "copy") val argsMap: Map[String, Term] = fields.map { (field, trees) => - val fieldMethod = termMethodByNameUnsafe(obj, field.name) + val fieldMethod = symbolMethodByNameOrError(objSymbol, field.name) val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => mapToCopy(owner, mod, term, tree) }