diff --git a/avro4s-core/src/main/scala/com/sksamuel/avro4s/decoders/collections.scala b/avro4s-core/src/main/scala/com/sksamuel/avro4s/decoders/collections.scala index 50a42000..8ca2f74a 100644 --- a/avro4s-core/src/main/scala/com/sksamuel/avro4s/decoders/collections.scala +++ b/avro4s-core/src/main/scala/com/sksamuel/avro4s/decoders/collections.scala @@ -29,7 +29,7 @@ trait CollectionDecoders: given[T](using decoder: Decoder[T]): Decoder[Seq[T]] = iterableDecoder(decoder, _.toSeq) given[T](using decoder: Decoder[T]): Decoder[Set[T]] = iterableDecoder(decoder, _.toSet) given[T](using decoder: Decoder[T]): Decoder[Vector[T]] = iterableDecoder(decoder, _.toVector) - given[T](using decoder: Decoder[T]): Decoder[Map[String, T]] = new MapDecoder[T](decoder) + given mapDecoder[T](using decoder: Decoder[T]): Decoder[Map[String, T]] = new MapDecoder[T](decoder) def iterableDecoder[T, C[X] <: Iterable[X]](decoder: Decoder[T], build: Iterable[T] => C[T]): Decoder[C[T]] = diff --git a/avro4s-core/src/main/scala/com/sksamuel/avro4s/schemas/collections.scala b/avro4s-core/src/main/scala/com/sksamuel/avro4s/schemas/collections.scala index 6a3d947c..cf581f83 100644 --- a/avro4s-core/src/main/scala/com/sksamuel/avro4s/schemas/collections.scala +++ b/avro4s-core/src/main/scala/com/sksamuel/avro4s/schemas/collections.scala @@ -19,6 +19,5 @@ trait CollectionSchemas: given[T](using schemaFor: SchemaFor[T]): SchemaFor[List[T]] = buildIterableSchemaFor[List, T] - given[V](using schemaFor: SchemaFor[V]): SchemaFor[Map[String, V]] = + given mapSchemaFor[V](using schemaFor: SchemaFor[V]): SchemaFor[Map[String, V]] = schemaFor.map(SchemaBuilder.map().values(_)) - diff --git a/avro4s-refined/src/main/scala/com/sksamuel/avro4s/refined/package.scala b/avro4s-refined/src/main/scala/com/sksamuel/avro4s/refined/package.scala new file mode 100644 index 00000000..81a6f606 --- /dev/null +++ b/avro4s-refined/src/main/scala/com/sksamuel/avro4s/refined/package.scala @@ -0,0 +1,29 @@ +package com.sksamuel.avro4s + +import eu.timepit.refined.api.{RefType, Validate} + +package object refined: + + given[T, P, F[_, _]](using schemaFor: SchemaFor[T]): SchemaFor[F[T, P]] = schemaFor.forType + + given[T: Encoder, P, F[_, _] : RefType]: Encoder[F[T, P]] = Encoder[T].contramap(RefType[F].unwrap) + + given[T: Decoder, P, F[_, _] : RefType](using validate: Validate[T, P]): Decoder[F[T, P]] = Decoder[T].map(RefType[F].refine[P].unsafeFrom[T]) + + given[A, P, F[_, _]: RefType, B](using schemaForA: SchemaFor[A], schemaForB: SchemaFor[B], isString: A <:< String): SchemaFor[Map[F[A, P], B]] = + SchemaFor.mapSchemaFor[B].forType + + given[A: Encoder, B: Encoder, P, F[_, _]: RefType](using isString: A <:< String): Encoder[Map[F[A, P], B]] = + Encoder.mapEncoder[B].contramap[Map[F[A, P], B]]: theMap => + theMap.map: + case (k, v) => RefType[F].unwrap(k).asInstanceOf[String] -> v + + given[A: Decoder, B: Decoder, P, F[_, _]: RefType](using validate: Validate[A, P], isString: A <:< String): Decoder[Map[F[A, P], B]] = + Decoder.mapDecoder[B].map: theMap => + theMap.map: + case (str, b) => (RefType[F].refine[P].unsafeFrom[A](str.asInstanceOf[A]), b) + +// implicit def refinedTypeGuardedDecoding[T: WeakTypeTag, P, F[_, _]: RefType]: TypeGuardedDecoding[F[T, P]] = new TypeGuardedDecoding[F[T, P]] { +// override final def guard(decoderT: Decoder[F[T, P]]): PartialFunction[Any, F[T, P]] = +// TypeGuardedDecoding[T].guard(decoderT.map(RefType[F].unwrap)).andThen(RefType[F].unsafeWrap(_)) +// } diff --git a/avro4s-refined/src/test/scala/com/sksamuel/avro4s/refined/RefinedRoundtripTest.scala b/avro4s-refined/src/test/scala/com/sksamuel/avro4s/refined/RefinedRoundtripTest.scala new file mode 100644 index 00000000..bcf5880c --- /dev/null +++ b/avro4s-refined/src/test/scala/com/sksamuel/avro4s/refined/RefinedRoundtripTest.scala @@ -0,0 +1,53 @@ +package com.sksamuel.avro4s.refined + +import com.sksamuel.avro4s.streams.input.InputStreamTest +import eu.timepit.refined.types.numeric.PosInt +import eu.timepit.refined.types.string.NonEmptyString +import shapeless.* + +import scala.util.Failure + +class RefinedRoundtripTest extends InputStreamTest: + + type C1 = NonEmptyString :+: CNil + case class Container1(c1: C1) +// type C2 = Int :+: NonEmptyString :+: CNil +// case class Container2(c2: C2) +// type C3 = PosInt :+: NonEmptyString :+: CNil +// case class Container3(c3: C3) +// case class Container4(map: Map[String, NonEmptyString], c3: C3, list: List[(Int, PosInt)]) +// type C1b = String :+: CNil +// case class Container1b(c1: C1b) +// case class Container5(c5: Either[NonEmptyString, Int]) +// case class Container6(c6: Map[NonEmptyString, PosInt]) + +// test("a union of one refined type inside a record should rountrip"): +// writeRead(Container1(Coproduct[C1](NonEmptyString.unsafeFrom("a")))) + +// test("a union of one refined type and more standard types inside a record should rountrip") { +// writeRead(Container2(Coproduct[C2](NonEmptyString("a")))) +// } + +// test("a union of more than one refined type inside a record should rountrip") { +// writeRead(Container3(Coproduct[C3](PosInt(42)))) +// } + +// test("a more complex record should rountrip") { +// writeRead(Container4(Map("bla" -> NonEmptyString("a")), Coproduct[C3](NonEmptyString("b")), List(23 -> PosInt(42), 42 -> PosInt(23)))) +// } + +// test("a broken encoder will not decode") { +// val out = writeData(Container1b(Coproduct[C1b](""))) +// val result = tryReadData[Container1](out.toByteArray).next() +// result should matchPattern { case Failure(iae: IllegalArgumentException) if iae.getMessage == "Predicate isEmpty() did not fail." => } +// } + +// test("an either of one refined type inside a record should roundtrip") { +// writeRead(Container5(Left(NonEmptyString("a")))) +// } + +// test("a map with refined types on both key and value should roundtrip") { +// val key: NonEmptyString = NonEmptyString("foo") +// val value: PosInt = PosInt(1) +// writeRead(Container6(Map(key -> value))) +// } diff --git a/avro4s-refined/src/test/scala/com/sksamuel/avro4s/refined/RefinedTest.scala b/avro4s-refined/src/test/scala/com/sksamuel/avro4s/refined/RefinedTest.scala new file mode 100644 index 00000000..e5cd817e --- /dev/null +++ b/avro4s-refined/src/test/scala/com/sksamuel/avro4s/refined/RefinedTest.scala @@ -0,0 +1,99 @@ +package com.sksamuel.avro4s.refined + +import com.sksamuel.avro4s.* +import eu.timepit.refined.api.Refined +import eu.timepit.refined.collection.NonEmpty +import org.apache.avro.Schema +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec +import eu.timepit.refined.auto.* +import eu.timepit.refined.types.string.NonEmptyString +import eu.timepit.refined.types.numeric.NonNegInt + +case class Foo(nonEmptyStr: String Refined NonEmpty) +case class FooMap(nonEmptyStrKeyMap: Map[NonEmptyString, NonNegInt]) + +class RefinedTest extends AnyWordSpec with Matchers: + val fooSchema: Schema = AvroSchema[Foo] + val fooMapSchema: Schema = AvroSchema[FooMap] + + "refinedSchemaFor" should : + "use the schema for the underlying type" in: + AvroSchema[Foo] shouldBe new Schema.Parser().parse( + """ + |{ + | "type": "record", + | "name": "Foo", + | "namespace": "com.sksamuel.avro4s.refined", + | "fields": [{ + | "name": "nonEmptyStr", + | "type": "string" + | }] + |} + """.stripMargin) + + "generate correct schemas for a Map when refined instances are in scope" in: + case class Test(map: Map[String, Int], nonEmptyStr: String Refined NonEmpty) + val schema = AvroSchema[Test] + + println(s"schema: $schema") + + schema.getField("map").schema().getType shouldBe Schema.Type.MAP + schema.getField("nonEmptyStr").schema().getType shouldBe Schema.Type.STRING + + "refinedStringMapKeySchemaFor" should: + "use the schema for the underlying type" in: + AvroSchema[FooMap] shouldBe new Schema.Parser().parse( + """ + |{ + | "type": "record", + | "name": "FooMap", + | "namespace": "com.sksamuel.avro4s.refined", + | "fields": [{ + | "name": "nonEmptyStrKeyMap", + | "type": { + | "type": "map", + | "values": "int" + | } + | }] + |} + """.stripMargin + ) + + "refinedEncoder" should: + "use the encoder for the underlying type" in: + val expected: String Refined NonEmpty = NonEmptyString.unsafeFrom("foo") + val record = ToRecord[Foo](fooSchema).to(Foo(expected)) + record.get("nonEmptyStr").toString shouldBe expected.value + + "refinedStringMapKeyEncoder" should: + "use the encoder for the underlying type" in: + val key: NonEmptyString = NonEmptyString.unsafeFrom("foo") + val value: NonNegInt = NonNegInt.unsafeFrom(1) + val expected: Map[NonEmptyString, NonNegInt] = Map(key -> value) + val record = ToRecord[FooMap](fooMapSchema).to(FooMap(expected)) + val encodedMap = record.get("nonEmptyStrKeyMap").asInstanceOf[java.util.Map[String, Int]] + encodedMap.get(key.value) shouldBe value.value + + "refinedDecoder" should: + "use the decoder for the underlying type" in: + val expected: String Refined NonEmpty = NonEmptyString.unsafeFrom("foo") + val record = ImmutableRecord(AvroSchema[Foo], Vector(expected.value)) + FromRecord[Foo](fooSchema).from(record) shouldBe Foo(expected) + + "throw when the value does not conform to the refined predicate" in: + val record = ImmutableRecord(AvroSchema[Foo], Vector("")) + assertThrows[IllegalArgumentException](FromRecord[Foo](fooSchema).from(record)) + + "refinedStringMapKeyDecoder" should: + "use the decoder for the underlying type" in: + val key: NonEmptyString = NonEmptyString.unsafeFrom("foo") + val value: NonNegInt = NonNegInt.unsafeFrom(1) + + val jMap = new java.util.HashMap[String, Int]() + jMap.put(key.value, value.value) + + val expected = Map(key -> value) + val record = ImmutableRecord(AvroSchema[FooMap], Vector(jMap)) + + FromRecord[FooMap](fooMapSchema).from(record) shouldBe FooMap(expected) diff --git a/build.sbt b/build.sbt index 035012e8..e29884b2 100644 --- a/build.sbt +++ b/build.sbt @@ -12,8 +12,9 @@ lazy val root = Project("avro4s", file(".")) ) .aggregate( `avro4s-core`, - `avro4s-cats` - // `avro4s-kafka` + `avro4s-cats`, +// `avro4s-kafka` + `avro4s-refined` ) val `avro4s-core` = project.in(file("avro4s-core")) @@ -44,6 +45,14 @@ val `avro4s-cats` = project.in(file("avro4s-cats")) // ) // ) +val `avro4s-refined` = project.in(file("avro4s-refined")) + .dependsOn(`avro4s-core` % "compile->compile;test->test") + .settings( + libraryDependencies ++= Seq( + "eu.timepit" %% "refined" % RefinedVersion + ) + ) + val benchmarks = project .in(file("benchmarks")) .dependsOn(`avro4s-core`)