Skip to content

Commit

Permalink
Merge pull request #2136 from typelevel/fix_auto_derivation
Browse files Browse the repository at this point in the history
Rework Read/Write & Fix derivation to use custom instances
  • Loading branch information
jatcwang authored Jan 4, 2025
2 parents 010a978 + 8a83120 commit 59bca3b
Show file tree
Hide file tree
Showing 36 changed files with 771 additions and 1,130 deletions.
20 changes: 14 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ lazy val weaverVersion = "0.8.4"
ThisBuild / tlBaseVersion := "1.0"
ThisBuild / tlCiReleaseBranches := Seq("main") // publish snapshots on `main`
ThisBuild / tlCiScalafmtCheck := true
//ThisBuild / scalaVersion := scala212Version
ThisBuild / scalaVersion := scala213Version
//ThisBuild / scalaVersion := scala3Version
ThisBuild / crossScalaVersions := Seq(scala212Version, scala213Version, scala3Version)
Expand Down Expand Up @@ -98,9 +99,12 @@ lazy val compilerFlags = Seq(
Compile / doc / scalacOptions --= Seq(
"-Xfatal-warnings"
),
// Test / scalacOptions --= Seq(
// "-Xfatal-warnings"
// ),
// Disable warning when @nowarn annotation isn't suppressing a warning
// to simplify cross-building
// because 2.12 @nowarn doesn't actually do anything.. https://github.com/scala/bug/issues/12313
scalacOptions ++= Seq(
"-Wconf:cat=unused-nowarn:s"
),
scalacOptions ++= (if (tlIsScala3.value)
// Handle irrefutable patterns in for comprehensions
Seq("-source:future", "-language:adhocExtensions")
Expand Down Expand Up @@ -249,8 +253,7 @@ lazy val core = project
).filterNot(_ => tlIsScala3.value) ++ Seq(
"org.tpolecat" %% "typename" % "1.1.0",
"com.h2database" % "h2" % h2Version % "test",
"org.postgresql" % "postgresql" % postgresVersion % "test",
"org.mockito" % "mockito-core" % "5.12.0" % Test
"org.postgresql" % "postgresql" % postgresVersion % "test"
),
Compile / unmanagedSourceDirectories += {
val sourceDir = (Compile / sourceDirectory).value
Expand Down Expand Up @@ -493,7 +496,12 @@ lazy val bench = project
.enablePlugins(NoPublishPlugin)
.enablePlugins(AutomateHeaderPlugin)
.enablePlugins(JmhPlugin)
.dependsOn(core, postgres)
.settings(
libraryDependencies ++= (if (scalaVersion.value == scala212Version)
Seq("org.scala-lang.modules" %% "scala-collection-compat" % "2.12.0")
else Seq.empty)
)
.dependsOn(core, postgres, hikari)
.settings(doobieSettings)

lazy val docs = project
Expand Down
6 changes: 4 additions & 2 deletions modules/core/src/main/scala-2/doobie/util/GetPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ trait GetPlatform {
import doobie.util.compat.=:=

/** @group Instances */
@deprecated("Use Get.derived instead to derive instances explicitly", "1.0.0-RC6")
def unaryProductGet[A, L <: HList, H, T <: HList](
implicit
G: Generic.Aux[A, L],
C: IsHCons.Aux[L, H, T],
H: Lazy[Get[H]],
E: (H :: HNil) =:= L
): MkGet[A] = MkGet.unaryProductGet
): Get[A] = {
void(C) // C drives inference but is not used directly
H.value.tmap[A](h => G.from(h :: HNil))
}

}
26 changes: 0 additions & 26 deletions modules/core/src/main/scala-2/doobie/util/MkGetPlatform.scala

This file was deleted.

26 changes: 0 additions & 26 deletions modules/core/src/main/scala-2/doobie/util/MkPutPlatform.scala

This file was deleted.

165 changes: 57 additions & 108 deletions modules/core/src/main/scala-2/doobie/util/MkReadPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,142 +4,91 @@

package doobie.util

import shapeless.{HList, HNil, ::, Generic, Lazy, <:!<, OrElse}
import shapeless.labelled.{field, FieldType}
import shapeless.{HList, HNil, ::, Generic, Lazy, OrElse}
import shapeless.labelled.FieldType

trait MkReadPlatform extends LowerPriorityRead {
trait MkReadPlatform extends LowerPriorityMkRead {

// Derivation base case for product types (1-element)
implicit def productBase[H](
implicit H: Read[H] OrElse MkRead[H]
): MkRead[H :: HNil] = {
val head = H.unify

new MkRead[H :: HNil](
head.gets,
(rs, n) => head.unsafeGet(rs, n) :: HNil
implicit H: Read[H] OrElse Derived[MkRead[H]]
): Derived[MkRead[H :: HNil]] = {
val headInstance = H.fold(identity, _.instance)

new Derived(
new MkRead(
headInstance.map(_ :: HNil)
)
)
}

// Derivation base case for shapeless record (1-element)
implicit def recordBase[K <: Symbol, H](
implicit H: Read[H] OrElse MkRead[H]
): MkRead[FieldType[K, H] :: HNil] = {
val head = H.unify

new MkRead[FieldType[K, H] :: HNil](
head.gets,
(rs, n) => field[K](head.unsafeGet(rs, n)) :: HNil
implicit H: Read[H] OrElse Derived[MkRead[H]]
): Derived[MkRead[FieldType[K, H] :: HNil]] = {
val headInstance = H.fold(identity, _.instance)

new Derived(
new MkRead(
new Read.Transform[FieldType[K, H] :: HNil, H](
headInstance,
h => shapeless.labelled.field[K].apply(h) :: HNil
)
)
)
}
}

trait LowerPriorityRead extends EvenLowerPriorityRead {
trait LowerPriorityMkRead {

// Derivation inductive case for product types
implicit def product[H, T <: HList](
implicit
H: Read[H] OrElse MkRead[H],
T: MkRead[T]
): MkRead[H :: T] = {
val head = H.unify

new MkRead[H :: T](
head.gets ++ T.gets,
(rs, n) => head.unsafeGet(rs, n) :: T.unsafeGet(rs, n + head.length)
H: Read[H] OrElse Derived[MkRead[H]],
T: Read[T] OrElse Derived[MkRead[T]]
): Derived[MkRead[H :: T]] = {
val headInstance = H.fold(identity, _.instance)
val tailInstance = T.fold(identity, _.instance)

new Derived(
new MkRead(
new Read.Composite[H :: T, H, T](
headInstance,
tailInstance,
(h, t) => h :: t
)
)
)
}

// Derivation inductive case for shapeless records
implicit def record[K <: Symbol, H, T <: HList](
implicit
H: Read[H] OrElse MkRead[H],
T: MkRead[T]
): MkRead[FieldType[K, H] :: T] = {
val head = H.unify

new MkRead[FieldType[K, H] :: T](
head.gets ++ T.gets,
(rs, n) => field[K](head.unsafeGet(rs, n)) :: T.unsafeGet(rs, n + head.length)
H: Read[H] OrElse Derived[MkRead[H]],
T: Read[T] OrElse Derived[MkRead[T]]
): Derived[MkRead[FieldType[K, H] :: T]] = {
val headInstance = H.fold(identity, _.instance)
val tailInstance = T.fold(identity, _.instance)

new Derived(
new MkRead(
new Read.Composite[FieldType[K, H] :: T, H, T](
headInstance,
tailInstance,
(h, t) => shapeless.labelled.field[K].apply(h) :: t
)
)
)
}

// Derivation for product types (i.e. case class)
implicit def generic[T, Repr](implicit gen: Generic.Aux[T, Repr], G: Lazy[MkRead[Repr]]): MkRead[T] =
new MkRead[T](G.value.gets, (rs, n) => gen.from(G.value.unsafeGet(rs, n)))

// Derivation base case for Option of product types (1-element)
implicit def optProductBase[H](
implicit
H: Read[Option[H]] OrElse MkRead[Option[H]],
N: H <:!< Option[α] forSome { type α }
): MkRead[Option[H :: HNil]] = {
void(N)
val head = H.unify

new MkRead[Option[H :: HNil]](
head.gets,
(rs, n) =>
head.unsafeGet(rs, n).map(_ :: HNil)
)
}

// Derivation base case for Option of product types (where the head element is Option)
implicit def optProductOptBase[H](
implicit H: Read[Option[H]] OrElse MkRead[Option[H]]
): MkRead[Option[Option[H] :: HNil]] = {
val head = H.unify

new MkRead[Option[Option[H] :: HNil]](
head.gets,
(rs, n) => head.unsafeGet(rs, n).map(h => Some(h) :: HNil)
)
}

}

trait EvenLowerPriorityRead {

// Read[Option[H]], Read[Option[T]] implies Read[Option[H *: T]]
implicit def optProduct[H, T <: HList](
implicit def genericRead[T, Repr](
implicit
H: Read[Option[H]] OrElse MkRead[Option[H]],
T: MkRead[Option[T]],
N: H <:!< Option[α] forSome { type α }
): MkRead[Option[H :: T]] = {
void(N)
val head = H.unify

new MkRead[Option[H :: T]](
head.gets ++ T.gets,
(rs, n) =>
for {
h <- head.unsafeGet(rs, n)
t <- T.unsafeGet(rs, n + head.length)
} yield h :: t
)
gen: Generic.Aux[T, Repr],
hlistRead: Lazy[Read[Repr] OrElse Derived[MkRead[Repr]]]
): Derived[MkRead[T]] = {
val hlistInstance: Read[Repr] = hlistRead.value.fold(identity, _.instance)
new Derived(new MkRead(hlistInstance.map(gen.from)))
}

// Read[Option[H]], Read[Option[T]] implies Read[Option[Option[H] *: T]]
implicit def optProductOpt[H, T <: HList](
implicit
H: Read[Option[H]] OrElse MkRead[Option[H]],
T: MkRead[Option[T]]
): MkRead[Option[Option[H] :: T]] = {
val head = H.unify

new MkRead[Option[Option[H] :: T]](
head.gets ++ T.gets,
(rs, n) => T.unsafeGet(rs, n + head.length).map(head.unsafeGet(rs, n) :: _)
)
}

// Derivation for optional of product types (i.e. case class)
implicit def ogeneric[A, Repr <: HList](
implicit
G: Generic.Aux[A, Repr],
B: Lazy[MkRead[Option[Repr]]]
): MkRead[Option[A]] =
new MkRead[Option[A]](B.value.gets, B.value.unsafeGet(_, _).map(G.from))

}
Loading

0 comments on commit 59bca3b

Please sign in to comment.