diff --git a/jdbc/src/main/scala/zio/sql/SqlDriverLiveModule.scala b/jdbc/src/main/scala/zio/sql/SqlDriverLiveModule.scala index af6563645..690d479c7 100644 --- a/jdbc/src/main/scala/zio/sql/SqlDriverLiveModule.scala +++ b/jdbc/src/main/scala/zio/sql/SqlDriverLiveModule.scala @@ -14,9 +14,9 @@ import zio.sql.delete._ trait SqlDriverLiveModule { self: Jdbc => private[sql] trait SqlDriverCore { - def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, List[Int]] + def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, Int] - def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, List[Int]] + def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, Int] def deleteOn(delete: Delete[_], conn: Connection): IO[Exception, Int] @@ -31,7 +31,7 @@ trait SqlDriverLiveModule { self: Jdbc => def delete(delete: Delete[_]): IO[Exception, Int] = ZIO.scoped(pool.connection.flatMap(deleteOn(delete, _))) - def delete(delete: List[Delete[_]]): IO[Exception, List[Int]] = + def delete(delete: List[Delete[_]]): IO[Exception, Int] = ZIO.scoped(pool.connection.flatMap(deleteOnBatch(delete, _))) def deleteOn(delete: Delete[_], conn: Connection): IO[Exception, Int] = @@ -41,11 +41,11 @@ trait SqlDriverLiveModule { self: Jdbc => statement.executeUpdate(query) }.refineToOrDie[Exception] - def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, List[Int]] = + def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, Int] = ZIO.attemptBlocking { val statement = conn.createStatement() delete.map(delete_ => statement.addBatch(renderDelete(delete_))) - statement.executeBatch().toList + statement.executeBatch().sum }.refineToOrDie[Exception] def update(update: Update[_]): IO[Exception, Int] = @@ -58,14 +58,14 @@ trait SqlDriverLiveModule { self: Jdbc => statement.executeUpdate(query) }.refineToOrDie[Exception] - def update(update: List[Update[_]]): IO[Exception, List[Int]] = + def update(update: List[Update[_]]): IO[Exception, Int] = ZIO.scoped(pool.connection.flatMap(updateOnBatch(update, _))) - def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, List[Int]] = + def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, Int] = ZIO.attemptBlocking { val statement = conn.createStatement() update.map(update_ => statement.addBatch(renderUpdate(update_))) - statement.executeBatch().toList + statement.executeBatch().sum }.refineToOrDie[Exception] def read[A](read: Read[A]): Stream[Exception, A] = @@ -136,9 +136,15 @@ trait SqlDriverLiveModule { self: Jdbc => def delete(delete: Delete[_]): IO[Exception, Int] = deleteOn(delete, connection) + def delete(delete: List[Delete[_]]): IO[Exception, Int] = + deleteOnBatch(delete, connection) + def update(update: Update[_]): IO[Exception, Int] = updateOn(update, connection) + def update(update: List[Update[_]]): IO[Exception, Int] = + updateOnBatch(update, connection) + def read[A](read: Read[A]): Stream[Exception, A] = readOn(read, connection) diff --git a/jdbc/src/main/scala/zio/sql/TransactionSyntaxModule.scala b/jdbc/src/main/scala/zio/sql/TransactionSyntaxModule.scala index 3548c6846..e0e79db33 100644 --- a/jdbc/src/main/scala/zio/sql/TransactionSyntaxModule.scala +++ b/jdbc/src/main/scala/zio/sql/TransactionSyntaxModule.scala @@ -19,6 +19,11 @@ trait TransactionSyntaxModule { self: Jdbc => ZIO.serviceWithZIO(_.delete(self)) } + implicit final class BatchDeleteSyntax[A: Schema](self: List[Delete[_]]) { + def run: ZIO[SqlTransaction, Exception, Int] = + ZIO.serviceWithZIO(_.delete(self)) + } + implicit final class InsertSyntax[A: Schema](self: Insert[_, A]) { def run: ZIO[SqlTransaction, Exception, Int] = ZIO.serviceWithZIO(_.insert(self)) @@ -28,4 +33,9 @@ trait TransactionSyntaxModule { self: Jdbc => def run: ZIO[SqlTransaction, Exception, Int] = ZIO.serviceWithZIO(_.update(self)) } + + implicit final class BatchUpdatedSyntax(self: List[Update[_]]) { + def run: ZIO[SqlTransaction, Exception, Int] = + ZIO.serviceWithZIO(_.update(self)) + } } diff --git a/jdbc/src/main/scala/zio/sql/jdbc.scala b/jdbc/src/main/scala/zio/sql/jdbc.scala index eb1674a93..d620fd615 100644 --- a/jdbc/src/main/scala/zio/sql/jdbc.scala +++ b/jdbc/src/main/scala/zio/sql/jdbc.scala @@ -13,11 +13,11 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra trait SqlDriver { def delete(delete: Delete[_]): IO[Exception, Int] - def delete(delete: List[Delete[_]]): IO[Exception, List[Int]] + def delete(delete: List[Delete[_]]): IO[Exception, Int] def update(update: Update[_]): IO[Exception, Int] - def update(update: List[Update[_]]): IO[Exception, List[Int]] + def update(update: List[Update[_]]): IO[Exception, Int] def read[A](read: Read[A]): Stream[Exception, A] @@ -35,8 +35,12 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra def delete(delete: Delete[_]): IO[Exception, Int] + def delete(delete: List[Delete[_]]): IO[Exception, Int] + def update(update: Update[_]): IO[Exception, Int] + def update(update: List[Update[_]]): IO[Exception, Int] + def read[A](read: Read[A]): Stream[Exception, A] def insert[A: Schema](insert: Insert[_, A]): IO[Exception, Int] @@ -65,7 +69,7 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra def execute(delete: Delete[_]): ZIO[SqlDriver, Exception, Int] = ZIO.serviceWithZIO(_.delete(delete)) - def executeBatchDelete(delete: List[Delete[_]]): ZIO[SqlDriver, Exception, List[Int]] = + def executeBatchDelete(delete: List[Delete[_]]): ZIO[SqlDriver, Exception, Int] = ZIO.serviceWithZIO(_.delete(delete)) def execute[A: Schema](insert: Insert[_, A]): ZIO[SqlDriver, Exception, Int] = @@ -74,7 +78,7 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra def execute(update: Update[_]): ZIO[SqlDriver, Exception, Int] = ZIO.serviceWithZIO(_.update(update)) - def executeBatchUpdate(update: List[Update[_]]): ZIO[SqlDriver, Exception, List[Int]] = + def executeBatchUpdate(update: List[Update[_]]): ZIO[SqlDriver, Exception, Int] = ZIO.serviceWithZIO(_.update(update)) val transact: ZLayer[SqlDriver, Exception, SqlTransaction] = diff --git a/postgres/src/test/scala/zio/sql/postgresql/DeleteBatchSpec.scala b/postgres/src/test/scala/zio/sql/postgresql/DeleteBatchSpec.scala index 8e7efb1a5..6257c05b6 100644 --- a/postgres/src/test/scala/zio/sql/postgresql/DeleteBatchSpec.scala +++ b/postgres/src/test/scala/zio/sql/postgresql/DeleteBatchSpec.scala @@ -23,7 +23,7 @@ object DeleteBatchSpec extends PostgresRunnableSpec with DbSchema { val assertion = for { r <- result - } yield assert(r)(equalTo(List(1))) + } yield assert(r)(equalTo(1)) assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) }, @@ -83,7 +83,7 @@ object DeleteBatchSpec extends PostgresRunnableSpec with DbSchema { _ <- insertResult customers <- selectResult deletes = customers.toList.map(delete_) - result <- executeBatchDelete(deletes).map(l => l.fold(0)((a, b) => a + b)) + result <- executeBatchDelete(deletes) } yield assert(result)(equalTo(expected)) assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) diff --git a/postgres/src/test/scala/zio/sql/postgresql/TransactionSpec.scala b/postgres/src/test/scala/zio/sql/postgresql/TransactionSpec.scala index d34dda408..ffdadc64b 100644 --- a/postgres/src/test/scala/zio/sql/postgresql/TransactionSpec.scala +++ b/postgres/src/test/scala/zio/sql/postgresql/TransactionSpec.scala @@ -1,10 +1,14 @@ package zio.sql.postgresql import zio._ +import zio.sql.update.Update import zio.test.Assertion._ import zio.test._ import zio.test.TestAspect.sequential +import java.time.{ LocalDate, ZonedDateTime } +import java.util.UUID + object TransactionSpec extends PostgresRunnableSpec with DbSchema { override val autoCommit = false @@ -38,18 +42,88 @@ object TransactionSpec extends PostgresRunnableSpec with DbSchema { assertZIO(result)(equalTo((5L, 5L))).mapErrorCause(cause => Cause.stackless(cause.untraced)) }, test("Transaction succeeded and deleted rows") { - val query = select(customerId) from customers val deleteQuery = deleteFrom(customers).where(verified === false) + val id1 = UUID.randomUUID() + val id2 = UUID.randomUUID() + + val c1 = Customer( + id1, + LocalDate.now(), + "fnameCustomer1", + "lnameCustomer1", + true, + LocalDate.now().toString, + ZonedDateTime.now() + ) + val c2 = Customer( + id2, + LocalDate.now(), + "fnameCustomer2", + "lnameCustomer2", + true, + LocalDate.now().toString, + ZonedDateTime.now() + ) + val allCustomer = List(c1, c2) + val data = allCustomer.map(Customer.unapply(_).get) + val insertStmt = insertInto(customers)(ALL).values(data) + val updateStmt = allCustomer.map(update_) - val tx = transact(deleteQuery.run) + val batchResult = for { + deleted <- deleteQuery.run + inserted <- insertStmt.run + updated <- updateStmt.run + } yield deleted + inserted + updated + + val result = for { + tx <- transact(batchResult) + } yield tx + assertZIO(result)(equalTo(5)).mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("Transaction failed and no row was inserted updated or deleted") { + val deleteQuery = deleteFrom(customers).where(verified === false) + val id1 = UUID.randomUUID() + + val c1 = Customer( + id1, + LocalDate.now(), + "fnameCustomer1", + "lnameCustomer1", + true, + LocalDate.now().toString, + ZonedDateTime.now() + ) + val c2 = Customer( + id1, + LocalDate.now(), + "fnameCustomer2", + "lnameCustomer2", + true, + LocalDate.now().toString, + ZonedDateTime.now() + ) + val allCustomer = List(c1, c2) + val data = allCustomer.map(Customer.unapply(_).get) + val insertStmt = insertInto(customers)(ALL).values(data) + val updateStmt = allCustomer.map(update_) + + val batchResult = for { + deleted <- deleteQuery.run + _ <- ZIO.fail(insertStmt.run).exit + updated <- updateStmt.run + + } yield deleted + updated val result = (for { - allCustomersCount <- execute(query).runCount - _ <- tx - remainingCustomersCount <- execute(query).runCount - } yield (allCustomersCount, remainingCustomersCount)) + tx <- transact(batchResult) + } yield tx).flip.exit + assertZIO(result)(fails((anything))) - assertZIO(result)(equalTo((5L, 4L))).mapErrorCause(cause => Cause.stackless(cause.untraced)) } ) @@ sequential + + private def update_(c: Customer): Update[customers.TableType] = + update(customers) + .set(verified, !c.verified) + .where(customerId === c.id) } diff --git a/postgres/src/test/scala/zio/sql/postgresql/UpdateBatchSpec.scala b/postgres/src/test/scala/zio/sql/postgresql/UpdateBatchSpec.scala index cd1b1cae2..04dad2230 100644 --- a/postgres/src/test/scala/zio/sql/postgresql/UpdateBatchSpec.scala +++ b/postgres/src/test/scala/zio/sql/postgresql/UpdateBatchSpec.scala @@ -76,7 +76,7 @@ object UpdateBatchSpec extends PostgresRunnableSpec with DbSchema { val assertion_ = for { x <- result_ updated = x.toList.map(update_) - result <- executeBatchUpdate(updated).map(l => l.reduce((a, b) => a + b)) + result <- executeBatchUpdate(updated) } yield assert(result)(equalTo(5)) assertion_.mapErrorCause(cause => Cause.stackless(cause.untraced)) }