Skip to content

Commit

Permalink
Improves transaction rollback handling
Browse files Browse the repository at this point in the history
I added exception handling around the success event SQL execution
in `finishTransaction` in `withTransaction`. If a `SqlExecutionError`
occurs, it will run the rollback callback and rethrow the exception.

I also added libpq transaction status checks to `beginTransaction` and
`finishTransaction` in `withTransaction`. If the status is unexpected
for a given transaction event, `withTransaction` will throw an
`UnexpectedTransactionStatusError`.
  • Loading branch information
jlavelle committed Dec 19, 2024
1 parent 96d388c commit e08919c
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 15 deletions.
90 changes: 75 additions & 15 deletions orville-postgresql/src/Orville/PostgreSQL/Execution/Transaction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ module Orville.PostgreSQL.Execution.Transaction
( withTransaction
, inWithTransaction
, InWithTransaction (InOutermostTransaction, InSavepointTransaction)
, UnexpectedTransactionStatusError (..)
)
where

import Control.Exception (Exception, throwIO, try)
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Database.PostgreSQL.LibPQ as LibPQ
import Numeric.Natural (Natural)

import qualified Orville.PostgreSQL.Execution.Execute as Execute
Expand All @@ -25,6 +28,7 @@ import qualified Orville.PostgreSQL.Internal.Bracket as Bracket
import qualified Orville.PostgreSQL.Internal.MonadOrville as MonadOrville
import qualified Orville.PostgreSQL.Internal.OrvilleState as OrvilleState
import qualified Orville.PostgreSQL.Monad as Monad
import qualified Orville.PostgreSQL.Raw.Connection as Connection
import qualified Orville.PostgreSQL.Raw.RawSql as RawSql

{- | Performs an action in an Orville monad within a database transaction. The transaction
Expand Down Expand Up @@ -64,35 +68,91 @@ withTransaction action =
OrvilleState.orvilleTransactionCallback state

beginTransaction :: Monad.MonadOrville m => m ()
beginTransaction = do
beginTransaction =
liftIO $ do
status <- Connection.transactionStatusOrThrow conn
let
openEvent = OrvilleState.openTransactionEvent transaction
executeTransactionSql (transactionEventSql state openEvent)
callback openEvent
beginAction = do
executeTransactionSql (transactionEventSql state openEvent)
callback openEvent
transactionError = UnexpectedTransactionStatusError status openEvent
case status of
LibPQ.TransIdle -> case openEvent of
OrvilleState.BeginTransaction ->
beginAction
_ ->
throwIO transactionError
LibPQ.TransInTrans -> case openEvent of
OrvilleState.NewSavepoint _ ->
beginAction
_ ->
throwIO transactionError
LibPQ.TransActive ->
throwIO transactionError
LibPQ.TransInError ->
throwIO transactionError
LibPQ.TransUnknown ->
throwIO transactionError

doAction () =
Monad.localOrvilleState
(OrvilleState.connectState innerConnectedState)
action

finishTransaction :: MonadIO m => () -> Bracket.BracketResult -> m ()
finishTransaction () result =
liftIO $
case result of
finishTransaction () result = liftIO $ do
status <- Connection.transactionStatusOrThrow conn
let
successEvent = OrvilleState.transactionSuccessEvent transaction
rollbackEvent = OrvilleState.rollbackTransactionEvent transaction
rollback = do
executeTransactionSql (transactionEventSql state rollbackEvent)
callback rollbackEvent
transactionError = UnexpectedTransactionStatusError status $ case result of
Bracket.BracketSuccess -> successEvent
Bracket.BracketError -> rollbackEvent

case status of
LibPQ.TransInTrans -> case result of
Bracket.BracketSuccess -> do
let
successEvent = OrvilleState.transactionSuccessEvent transaction
executeTransactionSql (transactionEventSql state successEvent)
callback successEvent
Bracket.BracketError -> do
let
rollbackEvent = OrvilleState.rollbackTransactionEvent transaction
executeTransactionSql (transactionEventSql state rollbackEvent)
callback rollbackEvent
eSuccess <- try $ executeTransactionSql (transactionEventSql state successEvent)
case eSuccess of
Right () ->
callback successEvent
Left ex -> do
callback rollbackEvent
throwIO (ex :: Connection.SqlExecutionError)
Bracket.BracketError ->
rollback
LibPQ.TransInError ->
rollback
LibPQ.TransActive ->
throwIO transactionError
LibPQ.TransIdle ->
throwIO transactionError
LibPQ.TransUnknown ->
throwIO transactionError

Bracket.bracketWithResult beginTransaction finishTransaction doAction

{- |
'withTransaction' will throw this exception if libpq reports a transaction status on the underlying
connection that is incompatible with the current transaction event.
@since 1.1.0.0
-}
data UnexpectedTransactionStatusError = UnexpectedTransactionStatusError
{ unexpectedTransactionStatusErrorTransactionStatus :: LibPQ.TransactionStatus
, unexpectedTransactionStatusErrorTransactionEvent :: OrvilleState.TransactionEvent
}

instance Show UnexpectedTransactionStatusError where
show (UnexpectedTransactionStatusError status event) =
"Unexpected transaction status during event " <> show event <> ": " <> show status

instance Exception UnexpectedTransactionStatusError

transactionEventSql ::
OrvilleState.OrvilleState ->
OrvilleState.TransactionEvent ->
Expand Down
10 changes: 10 additions & 0 deletions orville-postgresql/src/Orville/PostgreSQL/Raw/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ module Orville.PostgreSQL.Raw.Connection
, ConnectionError
, SqlExecutionError (..)
, transactionStatus
, transactionStatusOrThrow
)
where

Expand Down Expand Up @@ -441,6 +442,15 @@ transactionStatus (Connection handle) = do
OpenConnection conn ->
fmap Just (LibPQ.transactionStatus conn)

{- |
Similar to 'transactionStatus', but throws a 'ConnectionUsedAfterCloseError' if the connection is closed.
@since 1.1.0.0
-}
transactionStatusOrThrow :: Connection -> IO LibPQ.TransactionStatus
transactionStatusOrThrow conn =
withLibPQConnectionOrFailIfClosed conn LibPQ.transactionStatus

throwConnectionError :: String -> LibPQ.Connection -> IO a
throwConnectionError message conn = do
mbLibPQError <- LibPQ.errorMessage conn
Expand Down
18 changes: 18 additions & 0 deletions orville-postgresql/test/Test/Transaction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Test.Transaction
)
where

import Control.Exception (SomeException (..), catch)
import qualified Control.Monad as Monad
import qualified Data.ByteString as BS
import qualified Data.IORef as IORef
Expand All @@ -12,6 +13,7 @@ import qualified Hedgehog as HH
import qualified Hedgehog.Gen as Gen

import qualified Orville.PostgreSQL as Orville
import qualified Orville.PostgreSQL.Execution as Execution
import qualified Orville.PostgreSQL.Expr as Expr
import qualified Orville.PostgreSQL.OrvilleState as OrvilleState
import qualified Orville.PostgreSQL.Raw.Connection as Conn
Expand All @@ -31,6 +33,7 @@ transactionTests pool =
, prop_callbacksMadeForTransactionRollback pool
, prop_usesCustomBeginTransactionSql pool
, prop_inWithTransaction pool
, prop_rollbackCallbackInInvalidTransaction pool
]

prop_transactionsWithoutExceptionsCommit :: Property.NamedDBProperty
Expand Down Expand Up @@ -177,6 +180,21 @@ prop_inWithTransaction =
outsideBefore === Nothing
outsideAfter === Nothing

prop_rollbackCallbackInInvalidTransaction :: Property.NamedDBProperty
prop_rollbackCallbackInInvalidTransaction =
Property.singletonNamedDBProperty "withTransaction triggers the rollback callback if the LibPQ transaction status is TransInError" $ \pool -> do
let
badQuery = RawSql.fromString "bad"

allEvents <- captureTransactionCallbackEvents pool $
Orville.withTransaction $ do
Orville.liftCatch
catch
(Execution.executeVoid Execution.OtherQuery badQuery)
(\(SomeException _) -> pure ())

allEvents === [Orville.BeginTransaction, Orville.RollbackTransaction]

captureTransactionCallbackEvents ::
Orville.ConnectionPool ->
Orville.Orville () ->
Expand Down

0 comments on commit e08919c

Please sign in to comment.