Skip to content

Commit

Permalink
Merge pull request #2138 from typelevel/derived_instances_more_tests
Browse files Browse the repository at this point in the history
More tests to verify derived instances
  • Loading branch information
jatcwang authored Jan 4, 2025
2 parents 59bca3b + aa39ec3 commit a0a23e0
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 42 deletions.
3 changes: 2 additions & 1 deletion modules/core/src/main/scala/doobie/util/analysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ object analysis {
columnAlignment: List[(Get[?], NullabilityKnown) `Ior` ColumnMeta]
) {

def parameterMisalignments: List[ParameterMisalignment] =
def parameterMisalignments: List[ParameterMisalignment] = {
parameterAlignment.zipWithIndex.collect {
case (Ior.Left(_), n) => ParameterMisalignment(n + 1, None)
case (Ior.Right(p), n) => ParameterMisalignment(n + 1, Some(p))
}
}

private def hasParameterTypeErrors[A](put: Put[A], paramMeta: ParameterMeta): Boolean = {
!put.jdbcTargets.contains_(paramMeta.jdbcType) ||
Expand Down
164 changes: 146 additions & 18 deletions modules/core/src/test/scala/doobie/util/ReadSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ import doobie.util.TestTypes.*
import doobie.util.transactor.Transactor
import doobie.testutils.VoidExtensions
import doobie.syntax.all.*
import doobie.Query
import doobie.{ConnectionIO, Query}
import doobie.util.analysis.{Analysis, ColumnMisalignment, ColumnTypeError, ColumnTypeWarning, NullabilityMisalignment}
import doobie.util.fragment.Fragment
import munit.Location

import scala.annotation.nowarn

class ReadSuite extends munit.FunSuite with ReadSuitePlatform {
Expand Down Expand Up @@ -57,54 +60,54 @@ class ReadSuite extends munit.FunSuite with ReadSuitePlatform {
test("Semiauto derivation selects custom Read instances when available") {
implicit val i0: Read[HasCustomReadWrite0] = Read.derived[HasCustomReadWrite0]
assertEquals(i0.length, 2)
insertTupleAndCheckRead(("x", "y"), HasCustomReadWrite0(CustomReadWrite("x_R"), "y"))
insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite0(CustomReadWrite("x_R"), "y"))

implicit val i1: Read[HasCustomReadWrite1] = Read.derived[HasCustomReadWrite1]
assertEquals(i1.length, 2)
insertTupleAndCheckRead(("x", "y"), HasCustomReadWrite1("x", CustomReadWrite("y_R")))
insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite1("x", CustomReadWrite("y_R")))

implicit val iOpt0: Read[HasOptCustomReadWrite0] = Read.derived[HasOptCustomReadWrite0]
assertEquals(iOpt0.length, 2)
insertTupleAndCheckRead(("x", "y"), HasOptCustomReadWrite0(Some(CustomReadWrite("x_R")), "y"))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite0(Some(CustomReadWrite("x_R")), "y"))

implicit val iOpt1: Read[HasOptCustomReadWrite1] = Read.derived[HasOptCustomReadWrite1]
assertEquals(iOpt1.length, 2)
insertTupleAndCheckRead(("x", "y"), HasOptCustomReadWrite1("x", Some(CustomReadWrite("y_R"))))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite1("x", Some(CustomReadWrite("y_R"))))
}

test("Semiauto derivation selects custom Get instances to use for Read when available") {
implicit val i0: Read[HasCustomGetPut0] = Read.derived[HasCustomGetPut0]
assertEquals(i0.length, 2)
insertTupleAndCheckRead(("x", "y"), HasCustomGetPut0(CustomGetPut("x_G"), "y"))
insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut0(CustomGetPut("x_G"), "y"))

implicit val i1: Read[HasCustomGetPut1] = Read.derived[HasCustomGetPut1]
assertEquals(i1.length, 2)
insertTupleAndCheckRead(("x", "y"), HasCustomGetPut1("x", CustomGetPut("y_G")))
insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut1("x", CustomGetPut("y_G")))

implicit val iOpt0: Read[HasOptCustomGetPut0] = Read.derived[HasOptCustomGetPut0]
assertEquals(iOpt0.length, 2)
insertTupleAndCheckRead(("x", "y"), HasOptCustomGetPut0(Some(CustomGetPut("x_G")), "y"))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut0(Some(CustomGetPut("x_G")), "y"))

implicit val iOpt1: Read[HasOptCustomGetPut1] = Read.derived[HasOptCustomGetPut1]
assertEquals(iOpt1.length, 2)
insertTupleAndCheckRead(("x", "y"), HasOptCustomGetPut1("x", Some(CustomGetPut("y_G"))))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut1("x", Some(CustomGetPut("y_G"))))
}

test("Automatic derivation selects custom Read instances when available") {
import doobie.implicits.*

insertTupleAndCheckRead(("x", "y"), HasCustomReadWrite0(CustomReadWrite("x_R"), "y"))
insertTupleAndCheckRead(("x", "y"), HasCustomReadWrite1("x", CustomReadWrite("y_R")))
insertTupleAndCheckRead(("x", "y"), HasOptCustomReadWrite0(Some(CustomReadWrite("x_R")), "y"))
insertTupleAndCheckRead(("x", "y"), HasOptCustomReadWrite1("x", Some(CustomReadWrite("y_R"))))
insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite0(CustomReadWrite("x_R"), "y"))
insertTuple2AndCheckRead(("x", "y"), HasCustomReadWrite1("x", CustomReadWrite("y_R")))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite0(Some(CustomReadWrite("x_R")), "y"))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomReadWrite1("x", Some(CustomReadWrite("y_R"))))
}

test("Automatic derivation selects custom Get instances to use for Read when available") {
import doobie.implicits.*
insertTupleAndCheckRead(("x", "y"), HasCustomGetPut0(CustomGetPut("x_G"), "y"))
insertTupleAndCheckRead(("x", "y"), HasCustomGetPut1("x", CustomGetPut("y_G")))
insertTupleAndCheckRead(("x", "y"), HasOptCustomGetPut0(Some(CustomGetPut("x_G")), "y"))
insertTupleAndCheckRead(("x", "y"), HasOptCustomGetPut1("x", Some(CustomGetPut("y_G"))))
insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut0(CustomGetPut("x_G"), "y"))
insertTuple2AndCheckRead(("x", "y"), HasCustomGetPut1("x", CustomGetPut("y_G")))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut0(Some(CustomGetPut("x_G")), "y"))
insertTuple2AndCheckRead(("x", "y"), HasOptCustomGetPut1("x", Some(CustomGetPut("y_G"))))
}

test("Read should not be derivable for case objects") {
Expand Down Expand Up @@ -138,6 +141,16 @@ class ReadSuite extends munit.FunSuite with ReadSuitePlatform {
assertEquals(p.gets, (readInt.gets ++ readString.gets))
}

test(".map should correctly transform the value") {
import doobie.implicits.*
implicit val r: Read[WrappedSimpleCaseClass] = Read[SimpleCaseClass].map(s =>
WrappedSimpleCaseClass(
s.copy(s = "custom")
))

insertTuple3AndCheckRead((1, "s1", "s2"), WrappedSimpleCaseClass(SimpleCaseClass(Some(1), "custom", Some("s2"))))
}

/*
case class with nested Option case class field
*/
Expand Down Expand Up @@ -195,10 +208,125 @@ class ReadSuite extends munit.FunSuite with ReadSuitePlatform {
assertEquals(o, List((1, (2, 3))))
}

private def insertTupleAndCheckRead[Tup: Write, A: Read](in: Tup, expectedOut: A)(implicit loc: Location): Unit = {
test("Read typechecking should work for Tuples") {
val frag = sql"SELECT 1, 's', 3.0 :: DOUBLE"

assertSuccessTypecheckRead[(Int, String, Double)](frag)
assertSuccessTypecheckRead[(Int, (String, Double))](frag)
assertSuccessTypecheckRead[((Int, String), Double)](frag)

assertSuccessTypecheckRead[(Int, Option[String], Double)](frag)
assertSuccessTypecheckRead[(Option[Int], Option[(String, Double)])](frag)
assertSuccessTypecheckRead[Option[((Int, String), Double)]](frag)

assertWarnedTypecheckRead[(Boolean, String, Double)](frag)

assertMisalignedTypecheckRead[(Int, String)](frag)
assertMisalignedTypecheckRead[(Int, String, Double, Int)](frag)

}

test("Read typechecking should work for case classes") {
implicit val rscc: Read[SimpleCaseClass] = Read.derived[SimpleCaseClass]
implicit val rccc: Read[ComplexCaseClass] = Read.derived[ComplexCaseClass]
implicit val rwscc: Read[WrappedSimpleCaseClass] =
rscc.map(WrappedSimpleCaseClass.apply) // Test map doesn't break typechecking

assertSuccessTypecheckRead(
sql"create table tab(c1 int, c2 varchar not null, c3 varchar)".update.run.flatMap(_ =>
sql"SELECT c1,c2,c3 from tab".query[SimpleCaseClass].analysis)
)
assertSuccessTypecheckRead(
sql"create table tab(c1 int, c2 varchar not null, c3 varchar)".update.run.flatMap(_ =>
sql"SELECT c1,c2,c3 from tab".query[WrappedSimpleCaseClass].analysis)
)

assertSuccessTypecheckRead(
sql"create table tab(c1 int, c2 varchar, c3 varchar)".update.run.flatMap(_ =>
sql"SELECT c1,c2,c3 from tab".query[Option[SimpleCaseClass]].analysis)
)
assertSuccessTypecheckRead(
sql"create table tab(c1 int, c2 varchar, c3 varchar)".update.run.flatMap(_ =>
sql"SELECT c1,c2,c3 from tab".query[Option[WrappedSimpleCaseClass]].analysis)
)

assertTypeErrorTypecheckRead(
sql"create table tab(c1 binary, c2 varchar not null, c3 varchar)".update.run.flatMap(_ =>
sql"SELECT c1,c2,c3 from tab".query[SimpleCaseClass].analysis)
)

assertMisalignedNullabilityTypecheckRead(
sql"create table tab(c1 int, c2 varchar, c3 varchar)".update.run.flatMap(_ =>
sql"SELECT c1,c2,c3 from tab".query[SimpleCaseClass].analysis)
)

assertSuccessTypecheckRead(
sql"create table tab(c1 int, c2 varchar not null, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)"
.update.run.flatMap(_ =>
sql"SELECT c1,c2,c3,c4,c5,c6,c7,c8 from tab".query[ComplexCaseClass].analysis)
)

assertTypeErrorTypecheckRead(
sql"create table tab(c1 binary, c2 varchar not null, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)"
.update.run.flatMap(_ =>
sql"SELECT c1,c2,c3,c4,c5,c6,c7,c8 from tab".query[ComplexCaseClass].analysis)
)

assertMisalignedNullabilityTypecheckRead(
sql"create table tab(c1 int, c2 varchar, c3 varchar, c4 int, c5 varchar, c6 varchar, c7 int, c8 varchar not null)"
.update.run.flatMap(_ =>
sql"SELECT c1,c2,c3,c4,c5,c6,c7,c8 from tab".query[ComplexCaseClass].analysis)
)

}

private def insertTuple3AndCheckRead[Tup <: (?, ?, ?): Write, A: Read](in: Tup, expectedOut: A)(implicit
loc: Location
): Unit = {
val res = Query[Tup, A]("SELECT ?, ?, ?").unique(in).transact(xa)
.unsafeRunSync()
assertEquals(res, expectedOut)
}

private def insertTuple2AndCheckRead[Tup <: (?, ?): Write, A: Read](in: Tup, expectedOut: A)(implicit
loc: Location
): Unit = {
val res = Query[Tup, A]("SELECT ?, ?").unique(in).transact(xa)
.unsafeRunSync()
assertEquals(res, expectedOut)
}

private def assertSuccessTypecheckRead(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = {
val analysisResult = connio.transact(xa).unsafeRunSync()
assertEquals(analysisResult.columnAlignmentErrors, Nil)
}

private def assertSuccessTypecheckRead[A: Read](frag: Fragment)(implicit loc: Location): Unit = {
assertSuccessTypecheckRead(frag.query[A].analysis)
}

private def assertWarnedTypecheckRead[A: Read](frag: Fragment)(implicit loc: Location): Unit = {
val analysisResult = frag.query[A].analysis.transact(xa).unsafeRunSync()
val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass)
assertEquals(errorClasses, List(classOf[ColumnTypeWarning]))
}

private def assertTypeErrorTypecheckRead(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = {
val analysisResult = connio.transact(xa).unsafeRunSync()
val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass)
assertEquals(errorClasses, List(classOf[ColumnTypeError]))
}

private def assertMisalignedNullabilityTypecheckRead(connio: ConnectionIO[Analysis])(implicit loc: Location): Unit = {
val analysisResult = connio.transact(xa).unsafeRunSync()
val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass)
assertEquals(errorClasses, List(classOf[NullabilityMisalignment]))
}

private def assertMisalignedTypecheckRead[A: Read](frag: Fragment)(implicit loc: Location): Unit = {
val analysisResult = frag.query[A].analysis.transact(xa).unsafeRunSync()
val errorClasses = analysisResult.columnAlignmentErrors.map(_.getClass)
assertEquals(errorClasses, List(classOf[ColumnMisalignment]))
}

}
1 change: 1 addition & 0 deletions modules/core/src/test/scala/doobie/util/TestTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ object TestTypes {
case class TrivialCaseClass(i: Int)
case class SimpleCaseClass(i: Option[Int], s: String, os: Option[String])
case class ComplexCaseClass(sc: SimpleCaseClass, osc: Option[SimpleCaseClass], i: Option[Int], s: String)
case class WrappedSimpleCaseClass(sc: SimpleCaseClass)

case class HasCustomReadWrite0(c: CustomReadWrite, s: String)
case class HasCustomReadWrite1(s: String, c: CustomReadWrite)
Expand Down
Loading

0 comments on commit a0a23e0

Please sign in to comment.