Skip to content

Commit

Permalink
Merge pull request #1600 from digitallyinduced/ws-issue
Browse files Browse the repository at this point in the history
Fixed GC Issue in AutoRefresh
  • Loading branch information
mpscholten authored Feb 5, 2023
2 parents aca1a71 + 06ff20c commit c0488e7
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 69 deletions.
47 changes: 21 additions & 26 deletions IHP/AutoRefresh.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import Data.String.Interpolate.IsString
initAutoRefresh :: (?context :: ControllerContext, ?applicationContext :: ApplicationContext) => IO ()
initAutoRefresh = do
putContext AutoRefreshDisabled
putContext (?applicationContext |> get #autoRefreshServer)
putContext ?applicationContext.autoRefreshServer

autoRefresh :: (
?theAction :: action
Expand Down Expand Up @@ -89,7 +89,7 @@ autoRefresh runAction = do

event <- MVar.newEmptyMVar
let session = AutoRefreshSession { id, renderView, event, tables, lastResponse, lastPing }
modifyIORef' autoRefreshServer (\s -> s { sessions = session:(get #sessions s) } )
modifyIORef' autoRefreshServer (\s -> s { sessions = session:s.sessions } )
async (gcSessions autoRefreshServer)

registerNotificationTrigger ?touchedTables autoRefreshServer
Expand All @@ -111,9 +111,7 @@ instance WSApp AutoRefreshWSApp where
sessionId <- receiveData @UUID
setState AutoRefreshActive { sessionId }

availableSessions <- ?applicationContext
|> get #autoRefreshServer
|> getAvailableSessions
availableSessions <- getAvailableSessions ?applicationContext.autoRefreshServer

when (sessionId `elem` availableSessions) do
AutoRefreshSession { renderView, event, lastResponse } <- getSessionById sessionId
Expand All @@ -129,7 +127,7 @@ instance WSApp AutoRefreshWSApp where

async $ forever do
MVar.takeMVar event
let requestContext = get #requestContext ?context
let requestContext = ?context.requestContext
(renderView requestContext) `catch` handleResponseException
pure ()

Expand All @@ -146,20 +144,20 @@ instance WSApp AutoRefreshWSApp where
onClose = do
getState >>= \case
AutoRefreshActive { sessionId } -> do
let autoRefreshServer = ?applicationContext |> get #autoRefreshServer
modifyIORef' autoRefreshServer (\server -> server { sessions = filter (\AutoRefreshSession { id } -> id /= sessionId) (get #sessions server) })
let autoRefreshServer = ?applicationContext.autoRefreshServer
modifyIORef' autoRefreshServer (\server -> server { sessions = filter (\AutoRefreshSession { id } -> id /= sessionId) server.sessions })
AwaitingSessionID -> pure ()


registerNotificationTrigger :: (?modelContext :: ModelContext) => IORef (Set ByteString) -> IORef AutoRefreshServer -> IO ()
registerNotificationTrigger touchedTablesVar autoRefreshServer = do
touchedTables <- Set.toList <$> readIORef touchedTablesVar
subscribedTables <- (get #subscribedTables) <$> (autoRefreshServer |> readIORef)
subscribedTables <- (.subscribedTables) <$> (autoRefreshServer |> readIORef)

let subscriptionRequired = touchedTables |> filter (\table -> subscribedTables |> Set.notMember table)
modifyIORef' autoRefreshServer (\server -> server { subscribedTables = get #subscribedTables server <> Set.fromList subscriptionRequired })
modifyIORef' autoRefreshServer (\server -> server { subscribedTables = server.subscribedTables <> Set.fromList subscriptionRequired })

pgListener <- get #pgListener <$> readIORef autoRefreshServer
pgListener <- (.pgListener) <$> readIORef autoRefreshServer
subscriptions <- subscriptionRequired |> mapM (\table -> do
let createTriggerSql = notificationTrigger table

Expand All @@ -169,22 +167,22 @@ registerNotificationTrigger touchedTablesVar autoRefreshServer = do
sqlExec createTriggerSql ()

pgListener |> PGListener.subscribe (channelName table) \notification -> do
sessions <- (get #sessions) <$> readIORef autoRefreshServer
sessions <- (.sessions) <$> readIORef autoRefreshServer
sessions
|> filter (\session -> table `Set.member` (get #tables session))
|> map (\session -> get #event session)
|> filter (\session -> table `Set.member` session.tables)
|> map (\session -> session.event)
|> mapM (\event -> MVar.tryPutMVar event ())
pure ())
modifyIORef' autoRefreshServer (\s -> s { subscriptions = get #subscriptions s <> subscriptions })
modifyIORef' autoRefreshServer (\s -> s { subscriptions = s.subscriptions <> subscriptions })
pure ()

-- | Returns the ids of all sessions available to the client based on what sessions are found in the session cookie
getAvailableSessions :: (?context :: ControllerContext) => IORef AutoRefreshServer -> IO [UUID]
getAvailableSessions autoRefreshServer = do
allSessions <- (get #sessions) <$> readIORef autoRefreshServer
allSessions <- (.sessions) <$> readIORef autoRefreshServer
text <- fromMaybe "" <$> getSession "autoRefreshSessions"
let uuidCharCount = Text.length (UUID.toText UUID.nil)
let allSessionIds = map (get #id) allSessions
let allSessionIds = map (.id) allSessions
text
|> Text.chunksOf uuidCharCount
|> mapMaybe UUID.fromText
Expand All @@ -194,21 +192,18 @@ getAvailableSessions autoRefreshServer = do
-- | Returns a session for a given session id. Errors in case the session does not exist.
getSessionById :: (?applicationContext :: ApplicationContext) => UUID -> IO AutoRefreshSession
getSessionById sessionId = do
autoRefreshServer <- ?applicationContext
|> get #autoRefreshServer
|> readIORef
autoRefreshServer
|> get #sessions
autoRefreshServer <- readIORef ?applicationContext.autoRefreshServer
autoRefreshServer.sessions
|> find (\AutoRefreshSession { id } -> id == sessionId)
|> Maybe.fromMaybe (error "getSessionById: Could not find the session")
|> pure

-- | Applies a update function to a session specified by its session id
updateSession :: (?applicationContext :: ApplicationContext) => UUID -> (AutoRefreshSession -> AutoRefreshSession) -> IO ()
updateSession sessionId updateFunction = do
let server = ?applicationContext |> get #autoRefreshServer
let updateSession' session = if get #id session == sessionId then updateFunction session else session
modifyIORef' server (\server -> server { sessions = map updateSession' (get #sessions server) })
let server = ?applicationContext.autoRefreshServer
let updateSession' session = if session.id == sessionId then updateFunction session else session
modifyIORef' server (\server -> server { sessions = map updateSession' server.sessions })
pure ()

-- | Removes all expired sessions
Expand All @@ -219,7 +214,7 @@ updateSession sessionId updateFunction = do
gcSessions :: IORef AutoRefreshServer -> IO ()
gcSessions autoRefreshServer = do
now <- getCurrentTime
modifyIORef' autoRefreshServer (\autoRefreshServer -> autoRefreshServer { sessions = filter (not . isSessionExpired now) (get #sessions autoRefreshServer) })
modifyIORef' autoRefreshServer (\autoRefreshServer -> autoRefreshServer { sessions = filter (not . isSessionExpired now) autoRefreshServer.sessions })

-- | A session is expired if it was not pinged in the last 60 seconds
isSessionExpired :: UTCTime -> AutoRefreshSession -> Bool
Expand Down
66 changes: 23 additions & 43 deletions IHP/WebSocket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,20 @@ class WSApp state where
onClose = pure ()

startWSApp :: forall state. (WSApp state, ?applicationContext :: ApplicationContext, ?requestContext :: RequestContext, ?context :: ControllerContext, ?modelContext :: ModelContext) => Websocket.Connection -> IO ()
startWSApp connection = do
startWSApp connection' = do
state <- newIORef (initialState @state)
lastPongAt <- getCurrentTime >>= newIORef


let connection = installPongHandler lastPongAt connection'
let ?state = state
let ?connection = connection
let pingHandler = do
seconds <- secondsSinceLastPong lastPongAt
when (seconds > pingWaitTime * 2) (throwIO PongTimeout)
onPing @state

let runWithPongChan pongChan = do
let connectionOnPong = writeChan pongChan ()
let ?connection = connection
{ WebSocket.connectionOptions = (get #connectionOptions connection) { WebSocket.connectionOnPong }
}
in
run @state

result <- Exception.try ((withPinger connection runWithPongChan) `Exception.finally` onClose @state)
result <- Exception.try ((WebSocket.withPingThread connection pingWaitTime pingHandler (run @state)) `Exception.finally` onClose @state)
case result of
Left (e@Exception.SomeException{}) ->
case Exception.fromException e of
Expand Down Expand Up @@ -114,37 +114,17 @@ instance Exception PongTimeout
pingWaitTime :: Int
pingWaitTime = 30

installPongHandler :: IORef UTCTime -> WebSocket.Connection -> WebSocket.Connection
installPongHandler lastPongAt connection =
connection { WebSocket.connectionOptions = connection.connectionOptions { WebSocket.connectionOnPong = connectionOnPong lastPongAt } }

-- | Pings the client every 30 seconds and expects a pong response within 10 secons. If no pong response
-- is received within 10 seconds, it will kill the connection.
--
-- We cannot use the withPingThread of the websockets package as this doesn't deal with pong messages. So
-- open connection will stay around forever.
--
-- This implementation is based on https://github.com/jaspervdj/websockets/issues/159#issuecomment-552776502
withPinger conn action = do
pongChan <- newChan
mainAsync <- async $ action pongChan
pingerAsync <- async $ runPinger conn pongChan

waitEitherCatch mainAsync pingerAsync >>= \case
-- If the application async died for any reason, kill the pinger async
Left result -> do
cancel pingerAsync
case result of
Left exception -> throw exception
Right result -> pure ()
-- The pinger thread should never throw an exception. If it does, kill the app thread
Right (Left exception) -> do
cancel mainAsync
throw exception
-- The pinger thread exited due to a pong timeout. Tell the app thread about it.
Right (Right ()) -> cancelWith mainAsync PongTimeout

runPinger conn pongChan = fix $ \loop -> do
Websocket.sendPing conn (mempty :: ByteString)
threadDelay pingWaitTime
-- See if we got a pong in that time
timeout 1000000 (readChan pongChan) >>= \case
Just () -> loop
Nothing -> return ()
connectionOnPong :: IORef UTCTime -> IO ()
connectionOnPong lastPongAt = do
now <- getCurrentTime
writeIORef lastPongAt now

secondsSinceLastPong :: IORef UTCTime -> IO Int
secondsSinceLastPong lastPongAt = do
now <- getCurrentTime
last <- readIORef lastPongAt
pure $ ceiling $ nominalDiffTimeToSeconds $ diffUTCTime now last

0 comments on commit c0488e7

Please sign in to comment.