diff --git a/hercules-ci-agent/src/Hercules/Agent/Socket.hs b/hercules-ci-agent/src/Hercules/Agent/Socket.hs index cec5e717..366d10cd 100644 --- a/hercules-ci-agent/src/Hercules/Agent/Socket.hs +++ b/hercules-ci-agent/src/Hercules/Agent/Socket.hs @@ -33,7 +33,8 @@ import Network.URI (URI, uriAuthority, uriPath, uriPort, uriQuery, uriRegName, u import Network.WebSockets (Connection, runClientWith) import qualified Network.WebSockets as WS import Protolude hiding (atomically, handle, race, race_) -import UnliftIO.Async (AsyncCancelled (AsyncCancelled), race, race_) +import qualified UnliftIO +import UnliftIO.Async (race, race_) import UnliftIO.Exception (handle) import UnliftIO.STM (readTVarIO) import UnliftIO.Timeout (timeout) @@ -225,13 +226,13 @@ runReliableSocket socketConfig writeQueue serviceMessageChan highestAcked = kati pass else noAckCleanupThread' expectedN forever do - removeTimeout <- prepareTimeout handshakeTimeoutMicroseconds HandshakeTimeout handle logWarningPause $ - withConnection' socketConfig $ - \conn -> do - katipAddNamespace "Handshake" do - handshake conn removeTimeout - readThread conn `race_` writeThread conn `race_` noAckCleanupThread + withCancelableTimeout handshakeTimeoutMicroseconds HandshakeTimeout \removeTimeout -> do + withConnection' socketConfig $ + \conn -> do + katipAddNamespace "Handshake" do + handshake conn removeTimeout + readThread conn `race_` writeThread conn `race_` noAckCleanupThread handshakeTimeoutMicroseconds :: Int handshakeTimeoutMicroseconds = 30_000_000 @@ -239,17 +240,15 @@ handshakeTimeoutMicroseconds = 30_000_000 data HandshakeTimeout = HandshakeTimeout deriving (Show, Exception) -prepareTimeout :: (Exception e, MonadIO m) => Int -> e -> m (IO ()) -prepareTimeout delay exc = do +withCancelableTimeout :: (Exception e, MonadUnliftIO m) => Int -> e -> (IO () -> m a) -> m a +withCancelableTimeout delay exc cont = do requestingThread <- liftIO myThreadId - tid <- liftIO $ forkIO do - do - threadDelay delay - throwTo requestingThread exc - `catch` \(_ :: AsyncCancelled) -> - -- Removal of the timeout is normal, so do nothing - pass - pure $ throwTo tid AsyncCancelled + UnliftIO.withAsync + ( liftIO do + threadDelay delay + throwTo requestingThread exc + ) + (cont . cancel) msgN :: Frame o a -> Maybe Integer msgN Frame.Msg {n = n} = Just n