diff --git a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala index ec6c4d4..feac364 100644 --- a/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala +++ b/quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala @@ -54,8 +54,8 @@ object QuicklensMacros { def unsupportedShapeInfo(tree: Tree) = s"Unsupported path element. Path must have shape: _.field1.field2.each.field3.(...), got: ${tree.show}" - def noSuchMember(term: Term, name: String) = - s"${term.tpe} has no member named $name" + def noSuchMember(tpeStr: String, name: String) = + s"$tpeStr has no member named $name" def methodSupported(method: String) = Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method) @@ -136,33 +136,65 @@ object QuicklensMacros { case i: Ident if i.name.startsWith("_") => Seq.empty case _ => - report.throwError(unsupportedShapeInfo(focus.asTerm)) + report.errorAndAbort(unsupportedShapeInfo(focus.asTerm)) } } - def termMethodByNameUnsafe(term: Term, name: String): Symbol = { - term.tpe.widen.dealias.typeSymbol - .memberMethod(name) + extension (tpe: TypeRepr) + def poorMansLUB: TypeRepr = tpe match { + case AndType(l, r) if l <:< r => l + case AndType(l, r) if r <:< l => r + case _ => tpe + } + def widenAll: TypeRepr = + tpe.widen.dealias.poorMansLUB + def matchingTypeSymbol: Symbol = tpe.widenAll match { + case AndType(l, r) => + val lSym = l.matchingTypeSymbol + if l.matchingTypeSymbol != Symbol.noSymbol then lSym else r.matchingTypeSymbol + case tpe if isProduct(tpe.typeSymbol) || isSum(tpe.typeSymbol) => + tpe.typeSymbol + case tpe => + Symbol.noSymbol + } + + def symbolMethodByNameUnsafe(sym: Symbol, name: String): Symbol = { + sym + .methodMember(name) .headOption - .getOrElse(report.errorAndAbort(noSuchMember(term, name))) + .getOrElse(report.errorAndAbort(noSuchMember(sym.name, name))) + } + + def termMethodByNameUnsafe(term: Term, name: String): Symbol = { + symbolMethodByNameUnsafe(term.tpe.widenAll.typeSymbol, name) } def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = { - val caseParamNames = term.tpe.widen.dealias.typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name) + val typeSymbol = term.tpe.widenAll.matchingTypeSymbol + val caseParamNames = typeSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).map(_.name) val idx = caseParamNames.indexOf(name) - term.tpe.widen.dealias.typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term, name))) + typeSymbol.caseFields.find(_.name == name).getOrElse(report.errorAndAbort(noSuchMember(term.tpe.show, name))) -> (idx + 1) } + def isProduct(sym: Symbol): Boolean = { + sym.flags.is(Flags.Case) + } + + def isSum(sym: Symbol): Boolean = { + sym.flags.is(Flags.Enum) || + (sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract))) + } + def caseClassCopy( owner: Symbol, mod: Expr[A => A], obj: Term, fields: Seq[(PathSymbol.Field, Seq[PathTree])] ): Term = { - val objSymbol = obj.tpe.widen.dealias.typeSymbol - if objSymbol.flags.is(Flags.Case) then { - val copy = termMethodByNameUnsafe(obj, "copy") + val objSymbol = obj.tpe.widenAll.matchingTypeSymbol + if isProduct(objSymbol) then { + val copy = symbolMethodByNameUnsafe(objSymbol, "copy") val argsMap: Map[Int, Term] = fields.map { (field, trees) => val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name) val resTerm: Term = trees.foldLeft[Term](Select(obj, fieldMethod)) { (term, tree) => @@ -172,36 +204,38 @@ object QuicklensMacros { idx -> namedArg }.toMap + val typeParams = obj.tpe.widenAll match { + case AppliedType(_, typeParams) => Some(typeParams) + case _ => None + } + val fieldsIdxs = 1.to(objSymbol.primaryConstructor.paramSymss.flatten.filter(_.isTerm).length) val args = fieldsIdxs.map { i => + val defaultMethod = obj.select(symbolMethodByNameUnsafe(objSymbol, "copy$default$" + i.toString)) argsMap.getOrElse( i, - Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString)) + typeParams.fold(defaultMethod)(defaultMethod.appliedToTypes) ) }.toList - obj.tpe.widen match { + typeParams match { // if the object's type is parametrised, we need to call .copy with the same type parameters - case AppliedType(_, typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) + case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args) case _ => Apply(Select(obj, copy), args) } - } else if objSymbol.flags.is(Flags.Enum) || - (objSymbol.flags.is(Flags.Sealed) && (objSymbol.flags.is(Flags.Trait) || objSymbol.flags.is(Flags.Abstract))) - then { - // if the source is a sealed trait / sealed abstract class / enum, generating a if-then-else with a .copy for each child (implementing case class) - val cases = obj.tpe.widen.dealias.typeSymbol.children.map { child => - val subtype = TypeIdent(child) - val bind = Symbol.newBind(owner, "c", Flags.EmptyFlags, subtype.tpe) - CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), fields)) + } else if isSum(objSymbol) then { + obj.tpe.widenAll match { + case AndType(_, _) => + report.errorAndAbort(s"Implementation limitation: Cannot modify sealed hierarchies mixed with & types. Try providing a more specific type.") + case _ => } - /* if (obj.isInstanceOf[Child1]) caseClassCopy(obj.asInstanceOf[Child1]) else if (obj.isInstanceOf[Child2]) caseClassCopy(obj.asInstanceOf[Child2]) else ... else throw new IllegalStateException() */ - val ifThens = obj.tpe.widen.dealias.typeSymbol.children.map { child => + val ifThens = objSymbol.children.map { child => val ifCond = TypeApply(Select.unique(obj, "isInstanceOf"), List(TypeIdent(child))) val ifThen = ValDef.let(owner, TypeApply(Select.unique(obj, "asInstanceOf"), List(TypeIdent(child)))) { @@ -217,7 +251,7 @@ object QuicklensMacros { If(ifCond, ifThen, ifElse) } } else - report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol") + report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol") } def applyFunctionDelegate( @@ -289,14 +323,14 @@ object QuicklensMacros { case Block(List(DefDef(_, _, _, Some(p))), _) => toPath(p, focus) case _ => - report.throwError(unsupportedShapeInfo(tree)) + report.errorAndAbort(unsupportedShapeInfo(tree)) } val pathTree: PathTree = paths.foldLeft(PathTree.empty) { (tree, path) => tree <> path } val res: (Expr[A => A] => Expr[S]) = (mod: Expr[A => A]) => - mapToCopy(Symbol.spliceOwner, mod, obj.asTerm, pathTree).asExpr.asInstanceOf[Expr[S]] + Typed(mapToCopy(Symbol.spliceOwner, mod, obj.asTerm, pathTree), TypeTree.of[S]).asExpr.asInstanceOf[Expr[S]] to(res) } } diff --git a/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ModifyAndTypeTest.scala b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ModifyAndTypeTest.scala new file mode 100644 index 0000000..ff229a3 --- /dev/null +++ b/quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ModifyAndTypeTest.scala @@ -0,0 +1,88 @@ +package com.softwaremill.quicklens + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import ModifyAndTypeTest._ + +object ModifyAndTypeTest { + case class A(a: Int) extends B + trait B { + def a: Int + } + + case class A1(a: Int) + + sealed trait T + case class C(a: Int) extends T with B + + sealed trait T1 + case class C1(a: Int) extends T1 +} + +class ModifyAndTypeTest extends AnyFlatSpec with Matchers { + it should "modify an & type object" in { + val ab: A & B = A(0) + + val modified = ab.modify(_.a).setTo(1) + + modified.a shouldBe 1 + } + + it should "modify an & type object 1" in { + val ab: B & A = A(0) + + val modified = ab.modify(_.a).setTo(1) + + modified.a shouldBe 1 + } + + it should "modify an & type object 2" in { + val ab: B & A1 = new A1(0) with B + + val modified = ab.modify(_.a).setTo(1) + + modified.a shouldBe 1 + } + + it should "modify an & type object 3" in { + val ab: A1 & B = new A1(0) with B + + val modified = ab.modify(_.a).setTo(1) + + modified.a shouldBe 1 + } + + // TODO this is an implemenation limitation for now, since anonymous classes crash on runtime + // it should "modify an & type object with a sealed trait" in { + // val tb: T & B = C(0) + + // val modified = tb.modify(_.a).setTo(1) + + // modified.a shouldBe 1 + // } + + // it should "modify an & type object with a sealed trait 1" in { + // val tb: B & T = C(0) + + // val modified = tb.modify(_.a).setTo(1) + + // modified.a shouldBe 1 + // } + + // it should "modify an & type object with a sealed trait 2" in { + // val tb: B & T1 = new C1(0) with B + + // val modified = tb.modify(_.a).setTo(1) + + // modified.a shouldBe 1 + // } + + // it should "modify an & type object with a sealed trait 3" in { + // val tb: T1 & B = new C1(0) with B + + // val modified = tb.modify(_.a).setTo(1) + + // modified.a shouldBe 1 + // } +} diff --git a/quicklens/src/test/scala/com/softwaremill/quicklens/ModifySelfThisTest.scala b/quicklens/src/test/scala/com/softwaremill/quicklens/ModifySelfThisTest.scala new file mode 100644 index 0000000..d175946 --- /dev/null +++ b/quicklens/src/test/scala/com/softwaremill/quicklens/ModifySelfThisTest.scala @@ -0,0 +1,41 @@ +package com.softwaremill.quicklens + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import ModifySelfThisTest._ + +object ModifySelfThisTest { + + case class State(x: Int) { self => + + def mod: State = this.modify(_.x).setTo(1) + } + + trait A { + def a: Unit + } + + case class State1(x: Int) extends A { self: A => + + def mod: State1 = this.modify(_.x).setTo(1) + + def a: Unit = () + } +} + +class ModifySelfThisTest extends AnyFlatSpec with Matchers { + it should "modify an object even in presence of self alias" in { + val s = State(0) + val modified = s.mod + + modified.x shouldBe 1 + } + + it should "modify an object even in presence of self type" in { + val s = State(0) + val modified = s.mod + + modified.x shouldBe 1 + } +}