From 5773f6c0c6fc970a539e4735e3fb018c18d54752 Mon Sep 17 00:00:00 2001 From: Thijs Broersen <4889512+ThijsBroersen@users.noreply.github.com> Date: Fri, 14 Jun 2024 10:30:37 +0200 Subject: [PATCH] Feat: scala3 enumeration support (#1068) * addscala 3 enumeration decoding/encoding --- docs/decoding.md | 17 +++ docs/encoding.md | 17 +++ .../scala/zio/json/golden/GoldenSpec.scala | 6 +- .../src/main/scala-3/zio/json/macros.scala | 112 ++++++++++++++---- .../scala-3/zio/json/DerivedDecoderSpec.scala | 51 +++++--- .../scala-3/zio/json/DerivedEncoderSpec.scala | 51 +++++--- 6 files changed, 193 insertions(+), 61 deletions(-) diff --git a/docs/decoding.md b/docs/decoding.md index 7d43c1ff9..cb8d7abb1 100644 --- a/docs/decoding.md +++ b/docs/decoding.md @@ -101,6 +101,23 @@ Decoding fail because 'Pear' is not a valid value Almost all of the standard library data types are supported as fields on the case class, and it is easy to add support if one is missing. +### Sealed families and enums for Scala 3 +Sealed families where all members are only objects, or a Scala 3 enum with all cases parameterless are interpreted as enumerations and will encode 1:1 with their value-names. +```scala +enum Foo derives JsonDecoder: + case Bar + case Baz + case Qux +``` +or +```scala +sealed trait Foo derives JsonDecoder +object Foo: + case object Bar extends Foo + case object Baz extends Foo + case object Qux extends Foo +``` + ## Manual instances Sometimes it is easier to reuse an existing `JsonDecoder` rather than generate a new one. This can be accomplished using convenience methods on the `JsonDecoder` typeclass to *derive* new decoders diff --git a/docs/encoding.md b/docs/encoding.md index b43123d80..6c9f92d4a 100644 --- a/docs/encoding.md +++ b/docs/encoding.md @@ -55,6 +55,23 @@ apple.toJson Almost all of the standard library data types are supported as fields on the case class, and it is easy to add support if one is missing. +### Sealed families and enums for Scala 3 +Sealed families where all members are only objects, or a Scala 3 enum with all cases parameterless are interpreted as enumerations and will encode 1:1 with their value-names. +```scala +enum Foo derives JsonEncoder: + case Bar + case Baz + case Qux +``` +or +```scala +sealed trait Foo derives JsonEncoder +object Foo: + case object Bar extends Foo + case object Baz extends Foo + case object Qux extends Foo +``` + ## Manual instances Sometimes it is easier to reuse an existing `JsonEncoder` rather than generate a new one. This can be accomplished using convenience methods on the `JsonEncoder` typeclass to *derive* new decoders: diff --git a/zio-json-golden/src/test/scala/zio/json/golden/GoldenSpec.scala b/zio-json-golden/src/test/scala/zio/json/golden/GoldenSpec.scala index 42547fc16..01b539033 100644 --- a/zio-json-golden/src/test/scala/zio/json/golden/GoldenSpec.scala +++ b/zio-json-golden/src/test/scala/zio/json/golden/GoldenSpec.scala @@ -11,9 +11,9 @@ object GoldenSpec extends ZIOSpecDefault { sealed trait SumType object SumType { - case object Case1 extends SumType - case object Case2 extends SumType - case object Case3 extends SumType + case object Case1 extends SumType + case object Case2 extends SumType + case class Case3() extends SumType implicit val jsonCodec: JsonCodec[SumType] = DeriveJsonCodec.gen } diff --git a/zio-json/shared/src/main/scala-3/zio/json/macros.scala b/zio-json/shared/src/main/scala-3/zio/json/macros.scala index 301a02428..3392d476a 100644 --- a/zio-json/shared/src/main/scala-3/zio/json/macros.scala +++ b/zio-json/shared/src/main/scala-3/zio/json/macros.scala @@ -207,21 +207,7 @@ final class jsonNoExtraFields extends Annotation */ final class jsonExclude extends Annotation -// TODO: implement same configuration as for Scala 2 once this issue is resolved: https://github.com/softwaremill/magnolia/issues/296 -object DeriveJsonDecoder extends Derivation[JsonDecoder] { self => - def join[A](ctx: CaseClass[Typeclass, A]): JsonDecoder[A] = { - val (transformNames, nameTransform): (Boolean, String => String) = - ctx.annotations.collectFirst { case jsonMemberNames(format) => format } - .map(true -> _) - .getOrElse(false -> identity) - - val no_extra = ctx - .annotations - .collectFirst { case _: jsonNoExtraFields => () } - .isDefined - - if (ctx.params.isEmpty) { - new JsonDecoder[A] { +private class CaseObjectDecoder[Typeclass[*], A](val ctx: CaseClass[Typeclass, A], no_extra: Boolean) extends JsonDecoder[A] { def unsafeDecode(trace: List[JsonError], in: RetractReader): A = { if (no_extra) { Lexer.char(trace, in, '{') @@ -239,6 +225,22 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self => case _ => throw UnsafeJson(JsonError.Message("Not an object") :: trace) } } + +// TODO: implement same configuration as for Scala 2 once this issue is resolved: https://github.com/softwaremill/magnolia/issues/296 +object DeriveJsonDecoder extends Derivation[JsonDecoder] { self => + def join[A](ctx: CaseClass[Typeclass, A]): JsonDecoder[A] = { + val (transformNames, nameTransform): (Boolean, String => String) = + ctx.annotations.collectFirst { case jsonMemberNames(format) => format } + .map(true -> _) + .getOrElse(false -> identity) + + val no_extra = ctx + .annotations + .collectFirst { case _: jsonNoExtraFields => () } + .isDefined + + if (ctx.params.isEmpty) { + new CaseObjectDecoder(ctx, no_extra) } else { new JsonDecoder[A] { val (names, aliases): (Array[String], Array[(String, Int)]) = { @@ -400,9 +402,35 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self => lazy val namesMap: Map[String, Int] = names.zipWithIndex.toMap + def isEnumeration = + (ctx.isEnum && ctx.subtypes.forall(_.typeclass.isInstanceOf[CaseObjectDecoder[?, ?]])) || ( + !ctx.isEnum && ctx.subtypes.forall(_.isObject) + ) + def discrim = ctx.annotations.collectFirst { case jsonDiscriminator(n) => n } - if (discrim.isEmpty) { + if (isEnumeration) { + new JsonDecoder[A] { + def unsafeDecode(trace: List[JsonError], in: RetractReader): A = { + val typeName = Lexer.string(trace, in).toString() + namesMap.find(_._1 == typeName) match { + case Some((_, idx)) => tcs(idx).asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil) + case None => throw UnsafeJson(JsonError.Message(s"Invalid enumeration value $typeName") :: trace) + } + } + + override final def unsafeFromJsonAST(trace: List[JsonError], json: Json): A = { + json match { + case Json.Str(typeName) => + ctx.subtypes.find(_.typeInfo.short == typeName) match { + case Some(sub) => sub.typeclass.asInstanceOf[CaseObjectDecoder[JsonDecoder, A]].ctx.rawConstruct(Nil) + case None => throw UnsafeJson(JsonError.Message(s"Invalid enumeration value $typeName") :: trace) + } + case _ => throw UnsafeJson(JsonError.Message("Not a string") :: trace) + } + } + } + } else if (discrim.isEmpty) { // We're not allowing extra fields in this encoding new JsonDecoder[A] { val spans: Array[JsonError] = names.map(JsonError.ObjectAccess(_)) @@ -506,16 +534,18 @@ object DeriveJsonDecoder extends Derivation[JsonDecoder] { self => } } +private lazy val caseObjectEncoder = new JsonEncoder[Any] { + def unsafeEncode(a: Any, indent: Option[Int], out: Write): Unit = + out.write("{}") + + override final def toJsonAST(a: Any): Either[String, Json] = + Right(Json.Obj(Chunk.empty)) +} + object DeriveJsonEncoder extends Derivation[JsonEncoder] { self => def join[A](ctx: CaseClass[Typeclass, A]): JsonEncoder[A] = if (ctx.params.isEmpty) { - new JsonEncoder[A] { - def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = - out.write("{}") - - override final def toJsonAST(a: A): Either[String, Json] = - Right(Json.Obj(Chunk.empty)) - } + caseObjectEncoder.narrow[A] } else { new JsonEncoder[A] { val (transformNames, nameTransform): (Boolean, String => String) = @@ -612,15 +642,49 @@ object DeriveJsonEncoder extends Derivation[JsonEncoder] { self => } def split[A](ctx: SealedTrait[JsonEncoder, A]): JsonEncoder[A] = { + val isEnumeration = + (ctx.isEnum && ctx.subtypes.forall(_.typeclass == caseObjectEncoder)) || ( + !ctx.isEnum && ctx.subtypes.forall(_.isObject) + ) + val jsonHintFormat: JsonMemberFormat = ctx.annotations.collectFirst { case jsonHintNames(format) => format }.getOrElse(IdentityFormat) + val discrim = ctx .annotations .collectFirst { case jsonDiscriminator(n) => n } - if (discrim.isEmpty) { + if (isEnumeration) { + new JsonEncoder[A] { + def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = { + val typeName = ctx.choose(a) { sub => + sub + .annotations + .collectFirst { + case jsonHint(name) => name + }.getOrElse(sub.typeInfo.short) + } + + JsonEncoder.string.unsafeEncode(typeName, indent, out) + } + + override final def toJsonAST(a: A): Either[String, Json] = { + ctx.choose(a) { sub => + Right( + Json.Str( + sub + .annotations + .collectFirst { + case jsonHint(name) => name + }.getOrElse(sub.typeInfo.short) + ) + ) + } + } + } + } else if (discrim.isEmpty) { new JsonEncoder[A] { def unsafeEncode(a: A, indent: Option[Int], out: Write): Unit = { ctx.choose(a) { sub => diff --git a/zio-json/shared/src/test/scala-3/zio/json/DerivedDecoderSpec.scala b/zio-json/shared/src/test/scala-3/zio/json/DerivedDecoderSpec.scala index a1520753c..b7b6379f0 100644 --- a/zio-json/shared/src/test/scala-3/zio/json/DerivedDecoderSpec.scala +++ b/zio-json/shared/src/test/scala-3/zio/json/DerivedDecoderSpec.scala @@ -9,25 +9,42 @@ object DerivedDecoderSpec extends ZIOSpecDefault { val spec = suite("DerivedDecoderSpec")( test("Derives for a product type") { - assertZIO(typeCheck { - """ - case class Foo(bar: String) derives JsonDecoder + case class Foo(bar: String) derives JsonDecoder - "{\"bar\": \"hello\"}".fromJson[Foo] - """ - })(isRight(anything)) + val result = "{\"bar\": \"hello\"}".fromJson[Foo] + + assertTrue(result == Right(Foo("hello"))) + }, + test("Derives for a sum enum Enumeration type") { + enum Foo derives JsonDecoder: + case Bar + case Baz + case Qux + + val result = "\"Qux\"".fromJson[Foo] + + assertTrue(result == Right(Foo.Qux)) }, - test("Derives for a sum type") { - assertZIO(typeCheck { - """ - enum Foo derives JsonDecoder: - case Bar - case Baz(baz: String) - case Qux(foo: Foo) - - "{\"Qux\":{\"foo\":{\"Bar\":{}}}}".fromJson[Foo] - """ - })(isRight(anything)) + test("Derives for a sum sealed trait Enumeration type") { + sealed trait Foo derives JsonDecoder + object Foo: + case object Bar extends Foo + case object Baz extends Foo + case object Qux extends Foo + + val result = "\"Qux\"".fromJson[Foo] + + assertTrue(result == Right(Foo.Qux)) + }, + test("Derives for a sum ADT type") { + enum Foo derives JsonDecoder: + case Bar + case Baz(baz: String) + case Qux(foo: Foo) + + val result = "{\"Qux\":{\"foo\":{\"Bar\":{}}}}".fromJson[Foo] + + assertTrue(result == Right(Foo.Qux(Foo.Bar))) }, test("Derives and decodes for a union of string-based literals") { case class Foo(aOrB: "A" | "B", optA: Option["A"]) derives JsonDecoder diff --git a/zio-json/shared/src/test/scala-3/zio/json/DerivedEncoderSpec.scala b/zio-json/shared/src/test/scala-3/zio/json/DerivedEncoderSpec.scala index 9b7fc862d..2eb329c40 100644 --- a/zio-json/shared/src/test/scala-3/zio/json/DerivedEncoderSpec.scala +++ b/zio-json/shared/src/test/scala-3/zio/json/DerivedEncoderSpec.scala @@ -8,25 +8,42 @@ import zio.test._ object DerivedEncoderSpec extends ZIOSpecDefault { val spec = suite("DerivedEncoderSpec")( test("Derives for a product type") { - assertZIO(typeCheck { - """ - case class Foo(bar: String) derives JsonEncoder + case class Foo(bar: String) derives JsonEncoder - Foo("bar").toJson - """ - })(isRight(anything)) + val json = Foo("bar").toJson + + assertTrue(json == """{"bar":"bar"}""") }, - test("Derives for a sum type") { - assertZIO(typeCheck { - """ - enum Foo derives JsonEncoder: - case Bar - case Baz(baz: String) - case Qux(foo: Foo) - - (Foo.Qux(Foo.Bar): Foo).toJson - """ - })(isRight(anything)) + test("Derives for a sum enum Enumeration type") { + enum Foo derives JsonEncoder: + case Bar + case Baz + case Qux + + val json = (Foo.Qux: Foo).toJson + + assertTrue(json == """"Qux"""") + }, + test("Derives for a sum sealed trait Enumeration type") { + sealed trait Foo derives JsonEncoder + object Foo: + case object Bar extends Foo + case object Baz extends Foo + case object Qux extends Foo + + val json = (Foo.Qux: Foo).toJson + + assertTrue(json == """"Qux"""") + }, + test("Derives for a sum ADT type") { + enum Foo derives JsonEncoder: + case Bar + case Baz(baz: String) + case Qux(foo: Foo) + + val json = (Foo.Qux(Foo.Bar): Foo).toJson + + assertTrue(json == """{"Qux":{"foo":{"Bar":{}}}}""") }, test("Derives and encodes for a union of string-based literals") { case class Foo(aOrB: "A" | "B", optA: Option["A"]) derives JsonEncoder