Skip to content

Commit

Permalink
Merge pull request zio#900 from balanka/feature645
Browse files Browse the repository at this point in the history
Implement Issue#899: Include batch update and batch delete into the SqlTransaction
  • Loading branch information
sviezypan authored Jul 28, 2023
2 parents 72fa738 + 8b437bf commit b63708a
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 22 deletions.
22 changes: 14 additions & 8 deletions jdbc/src/main/scala/zio/sql/SqlDriverLiveModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import zio.sql.delete._
trait SqlDriverLiveModule { self: Jdbc =>
private[sql] trait SqlDriverCore {

def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, List[Int]]
def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, Int]

def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, List[Int]]
def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, Int]

def deleteOn(delete: Delete[_], conn: Connection): IO[Exception, Int]

Expand All @@ -31,7 +31,7 @@ trait SqlDriverLiveModule { self: Jdbc =>
def delete(delete: Delete[_]): IO[Exception, Int] =
ZIO.scoped(pool.connection.flatMap(deleteOn(delete, _)))

def delete(delete: List[Delete[_]]): IO[Exception, List[Int]] =
def delete(delete: List[Delete[_]]): IO[Exception, Int] =
ZIO.scoped(pool.connection.flatMap(deleteOnBatch(delete, _)))

def deleteOn(delete: Delete[_], conn: Connection): IO[Exception, Int] =
Expand All @@ -41,11 +41,11 @@ trait SqlDriverLiveModule { self: Jdbc =>
statement.executeUpdate(query)
}.refineToOrDie[Exception]

def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, List[Int]] =
def deleteOnBatch(delete: List[Delete[_]], conn: Connection): IO[Exception, Int] =
ZIO.attemptBlocking {
val statement = conn.createStatement()
delete.map(delete_ => statement.addBatch(renderDelete(delete_)))
statement.executeBatch().toList
statement.executeBatch().sum
}.refineToOrDie[Exception]

def update(update: Update[_]): IO[Exception, Int] =
Expand All @@ -58,14 +58,14 @@ trait SqlDriverLiveModule { self: Jdbc =>
statement.executeUpdate(query)
}.refineToOrDie[Exception]

def update(update: List[Update[_]]): IO[Exception, List[Int]] =
def update(update: List[Update[_]]): IO[Exception, Int] =
ZIO.scoped(pool.connection.flatMap(updateOnBatch(update, _)))

def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, List[Int]] =
def updateOnBatch(update: List[Update[_]], conn: Connection): IO[Exception, Int] =
ZIO.attemptBlocking {
val statement = conn.createStatement()
update.map(update_ => statement.addBatch(renderUpdate(update_)))
statement.executeBatch().toList
statement.executeBatch().sum
}.refineToOrDie[Exception]

def read[A](read: Read[A]): Stream[Exception, A] =
Expand Down Expand Up @@ -136,9 +136,15 @@ trait SqlDriverLiveModule { self: Jdbc =>
def delete(delete: Delete[_]): IO[Exception, Int] =
deleteOn(delete, connection)

def delete(delete: List[Delete[_]]): IO[Exception, Int] =
deleteOnBatch(delete, connection)

def update(update: Update[_]): IO[Exception, Int] =
updateOn(update, connection)

def update(update: List[Update[_]]): IO[Exception, Int] =
updateOnBatch(update, connection)

def read[A](read: Read[A]): Stream[Exception, A] =
readOn(read, connection)

Expand Down
10 changes: 10 additions & 0 deletions jdbc/src/main/scala/zio/sql/TransactionSyntaxModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ trait TransactionSyntaxModule { self: Jdbc =>
ZIO.serviceWithZIO(_.delete(self))
}

implicit final class BatchDeleteSyntax[A: Schema](self: List[Delete[_]]) {
def run: ZIO[SqlTransaction, Exception, Int] =
ZIO.serviceWithZIO(_.delete(self))
}

implicit final class InsertSyntax[A: Schema](self: Insert[_, A]) {
def run: ZIO[SqlTransaction, Exception, Int] =
ZIO.serviceWithZIO(_.insert(self))
Expand All @@ -28,4 +33,9 @@ trait TransactionSyntaxModule { self: Jdbc =>
def run: ZIO[SqlTransaction, Exception, Int] =
ZIO.serviceWithZIO(_.update(self))
}

implicit final class BatchUpdatedSyntax(self: List[Update[_]]) {
def run: ZIO[SqlTransaction, Exception, Int] =
ZIO.serviceWithZIO(_.update(self))
}
}
12 changes: 8 additions & 4 deletions jdbc/src/main/scala/zio/sql/jdbc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra
trait SqlDriver {
def delete(delete: Delete[_]): IO[Exception, Int]

def delete(delete: List[Delete[_]]): IO[Exception, List[Int]]
def delete(delete: List[Delete[_]]): IO[Exception, Int]

def update(update: Update[_]): IO[Exception, Int]

def update(update: List[Update[_]]): IO[Exception, List[Int]]
def update(update: List[Update[_]]): IO[Exception, Int]

def read[A](read: Read[A]): Stream[Exception, A]

Expand All @@ -35,8 +35,12 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra

def delete(delete: Delete[_]): IO[Exception, Int]

def delete(delete: List[Delete[_]]): IO[Exception, Int]

def update(update: Update[_]): IO[Exception, Int]

def update(update: List[Update[_]]): IO[Exception, Int]

def read[A](read: Read[A]): Stream[Exception, A]

def insert[A: Schema](insert: Insert[_, A]): IO[Exception, Int]
Expand Down Expand Up @@ -65,7 +69,7 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra
def execute(delete: Delete[_]): ZIO[SqlDriver, Exception, Int] =
ZIO.serviceWithZIO(_.delete(delete))

def executeBatchDelete(delete: List[Delete[_]]): ZIO[SqlDriver, Exception, List[Int]] =
def executeBatchDelete(delete: List[Delete[_]]): ZIO[SqlDriver, Exception, Int] =
ZIO.serviceWithZIO(_.delete(delete))

def execute[A: Schema](insert: Insert[_, A]): ZIO[SqlDriver, Exception, Int] =
Expand All @@ -74,7 +78,7 @@ trait Jdbc extends Sql with JdbcInternalModule with SqlDriverLiveModule with Tra
def execute(update: Update[_]): ZIO[SqlDriver, Exception, Int] =
ZIO.serviceWithZIO(_.update(update))

def executeBatchUpdate(update: List[Update[_]]): ZIO[SqlDriver, Exception, List[Int]] =
def executeBatchUpdate(update: List[Update[_]]): ZIO[SqlDriver, Exception, Int] =
ZIO.serviceWithZIO(_.update(update))

val transact: ZLayer[SqlDriver, Exception, SqlTransaction] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ object DeleteBatchSpec extends PostgresRunnableSpec with DbSchema {

val assertion = for {
r <- result
} yield assert(r)(equalTo(List(1)))
} yield assert(r)(equalTo(1))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
Expand Down Expand Up @@ -83,7 +83,7 @@ object DeleteBatchSpec extends PostgresRunnableSpec with DbSchema {
_ <- insertResult
customers <- selectResult
deletes = customers.toList.map(delete_)
result <- executeBatchDelete(deletes).map(l => l.fold(0)((a, b) => a + b))
result <- executeBatchDelete(deletes)
} yield assert(result)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
Expand Down
88 changes: 81 additions & 7 deletions postgres/src/test/scala/zio/sql/postgresql/TransactionSpec.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package zio.sql.postgresql

import zio._
import zio.sql.update.Update
import zio.test.Assertion._
import zio.test._
import zio.test.TestAspect.sequential

import java.time.{ LocalDate, ZonedDateTime }
import java.util.UUID

object TransactionSpec extends PostgresRunnableSpec with DbSchema {

override val autoCommit = false
Expand Down Expand Up @@ -38,18 +42,88 @@ object TransactionSpec extends PostgresRunnableSpec with DbSchema {
assertZIO(result)(equalTo((5L, 5L))).mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("Transaction succeeded and deleted rows") {
val query = select(customerId) from customers
val deleteQuery = deleteFrom(customers).where(verified === false)
val id1 = UUID.randomUUID()
val id2 = UUID.randomUUID()

val c1 = Customer(
id1,
LocalDate.now(),
"fnameCustomer1",
"lnameCustomer1",
true,
LocalDate.now().toString,
ZonedDateTime.now()
)
val c2 = Customer(
id2,
LocalDate.now(),
"fnameCustomer2",
"lnameCustomer2",
true,
LocalDate.now().toString,
ZonedDateTime.now()
)
val allCustomer = List(c1, c2)
val data = allCustomer.map(Customer.unapply(_).get)
val insertStmt = insertInto(customers)(ALL).values(data)
val updateStmt = allCustomer.map(update_)

val tx = transact(deleteQuery.run)
val batchResult = for {
deleted <- deleteQuery.run
inserted <- insertStmt.run
updated <- updateStmt.run
} yield deleted + inserted + updated

val result = for {
tx <- transact(batchResult)
} yield tx
assertZIO(result)(equalTo(5)).mapErrorCause(cause => Cause.stackless(cause.untraced))
},
test("Transaction failed and no row was inserted updated or deleted") {
val deleteQuery = deleteFrom(customers).where(verified === false)
val id1 = UUID.randomUUID()

val c1 = Customer(
id1,
LocalDate.now(),
"fnameCustomer1",
"lnameCustomer1",
true,
LocalDate.now().toString,
ZonedDateTime.now()
)
val c2 = Customer(
id1,
LocalDate.now(),
"fnameCustomer2",
"lnameCustomer2",
true,
LocalDate.now().toString,
ZonedDateTime.now()
)
val allCustomer = List(c1, c2)
val data = allCustomer.map(Customer.unapply(_).get)
val insertStmt = insertInto(customers)(ALL).values(data)
val updateStmt = allCustomer.map(update_)

val batchResult = for {
deleted <- deleteQuery.run
_ <- ZIO.fail(insertStmt.run).exit
updated <- updateStmt.run

} yield deleted + updated

val result = (for {
allCustomersCount <- execute(query).runCount
_ <- tx
remainingCustomersCount <- execute(query).runCount
} yield (allCustomersCount, remainingCustomersCount))
tx <- transact(batchResult)
} yield tx).flip.exit
assertZIO(result)(fails((anything)))

assertZIO(result)(equalTo((5L, 4L))).mapErrorCause(cause => Cause.stackless(cause.untraced))
}
) @@ sequential

private def update_(c: Customer): Update[customers.TableType] =
update(customers)
.set(verified, !c.verified)
.where(customerId === c.id)
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object UpdateBatchSpec extends PostgresRunnableSpec with DbSchema {
val assertion_ = for {
x <- result_
updated = x.toList.map(update_)
result <- executeBatchUpdate(updated).map(l => l.reduce((a, b) => a + b))
result <- executeBatchUpdate(updated)
} yield assert(result)(equalTo(5))
assertion_.mapErrorCause(cause => Cause.stackless(cause.untraced))
}
Expand Down

0 comments on commit b63708a

Please sign in to comment.