diff --git a/IHP/AutoRefresh.hs b/IHP/AutoRefresh.hs index 1b6d446d4..a03475dc5 100644 --- a/IHP/AutoRefresh.hs +++ b/IHP/AutoRefresh.hs @@ -29,7 +29,7 @@ import Data.String.Interpolate.IsString initAutoRefresh :: (?context :: ControllerContext, ?applicationContext :: ApplicationContext) => IO () initAutoRefresh = do putContext AutoRefreshDisabled - putContext (?applicationContext |> get #autoRefreshServer) + putContext ?applicationContext.autoRefreshServer autoRefresh :: ( ?theAction :: action @@ -89,7 +89,7 @@ autoRefresh runAction = do event <- MVar.newEmptyMVar let session = AutoRefreshSession { id, renderView, event, tables, lastResponse, lastPing } - modifyIORef' autoRefreshServer (\s -> s { sessions = session:(get #sessions s) } ) + modifyIORef' autoRefreshServer (\s -> s { sessions = session:s.sessions } ) async (gcSessions autoRefreshServer) registerNotificationTrigger ?touchedTables autoRefreshServer @@ -111,9 +111,7 @@ instance WSApp AutoRefreshWSApp where sessionId <- receiveData @UUID setState AutoRefreshActive { sessionId } - availableSessions <- ?applicationContext - |> get #autoRefreshServer - |> getAvailableSessions + availableSessions <- getAvailableSessions ?applicationContext.autoRefreshServer when (sessionId `elem` availableSessions) do AutoRefreshSession { renderView, event, lastResponse } <- getSessionById sessionId @@ -129,7 +127,7 @@ instance WSApp AutoRefreshWSApp where async $ forever do MVar.takeMVar event - let requestContext = get #requestContext ?context + let requestContext = ?context.requestContext (renderView requestContext) `catch` handleResponseException pure () @@ -146,20 +144,20 @@ instance WSApp AutoRefreshWSApp where onClose = do getState >>= \case AutoRefreshActive { sessionId } -> do - let autoRefreshServer = ?applicationContext |> get #autoRefreshServer - modifyIORef' autoRefreshServer (\server -> server { sessions = filter (\AutoRefreshSession { id } -> id /= sessionId) (get #sessions server) }) + let autoRefreshServer = ?applicationContext.autoRefreshServer + modifyIORef' autoRefreshServer (\server -> server { sessions = filter (\AutoRefreshSession { id } -> id /= sessionId) server.sessions }) AwaitingSessionID -> pure () registerNotificationTrigger :: (?modelContext :: ModelContext) => IORef (Set ByteString) -> IORef AutoRefreshServer -> IO () registerNotificationTrigger touchedTablesVar autoRefreshServer = do touchedTables <- Set.toList <$> readIORef touchedTablesVar - subscribedTables <- (get #subscribedTables) <$> (autoRefreshServer |> readIORef) + subscribedTables <- (.subscribedTables) <$> (autoRefreshServer |> readIORef) let subscriptionRequired = touchedTables |> filter (\table -> subscribedTables |> Set.notMember table) - modifyIORef' autoRefreshServer (\server -> server { subscribedTables = get #subscribedTables server <> Set.fromList subscriptionRequired }) + modifyIORef' autoRefreshServer (\server -> server { subscribedTables = server.subscribedTables <> Set.fromList subscriptionRequired }) - pgListener <- get #pgListener <$> readIORef autoRefreshServer + pgListener <- (.pgListener) <$> readIORef autoRefreshServer subscriptions <- subscriptionRequired |> mapM (\table -> do let createTriggerSql = notificationTrigger table @@ -169,22 +167,22 @@ registerNotificationTrigger touchedTablesVar autoRefreshServer = do sqlExec createTriggerSql () pgListener |> PGListener.subscribe (channelName table) \notification -> do - sessions <- (get #sessions) <$> readIORef autoRefreshServer + sessions <- (.sessions) <$> readIORef autoRefreshServer sessions - |> filter (\session -> table `Set.member` (get #tables session)) - |> map (\session -> get #event session) + |> filter (\session -> table `Set.member` session.tables) + |> map (\session -> session.event) |> mapM (\event -> MVar.tryPutMVar event ()) pure ()) - modifyIORef' autoRefreshServer (\s -> s { subscriptions = get #subscriptions s <> subscriptions }) + modifyIORef' autoRefreshServer (\s -> s { subscriptions = s.subscriptions <> subscriptions }) pure () -- | Returns the ids of all sessions available to the client based on what sessions are found in the session cookie getAvailableSessions :: (?context :: ControllerContext) => IORef AutoRefreshServer -> IO [UUID] getAvailableSessions autoRefreshServer = do - allSessions <- (get #sessions) <$> readIORef autoRefreshServer + allSessions <- (.sessions) <$> readIORef autoRefreshServer text <- fromMaybe "" <$> getSession "autoRefreshSessions" let uuidCharCount = Text.length (UUID.toText UUID.nil) - let allSessionIds = map (get #id) allSessions + let allSessionIds = map (.id) allSessions text |> Text.chunksOf uuidCharCount |> mapMaybe UUID.fromText @@ -194,11 +192,8 @@ getAvailableSessions autoRefreshServer = do -- | Returns a session for a given session id. Errors in case the session does not exist. getSessionById :: (?applicationContext :: ApplicationContext) => UUID -> IO AutoRefreshSession getSessionById sessionId = do - autoRefreshServer <- ?applicationContext - |> get #autoRefreshServer - |> readIORef - autoRefreshServer - |> get #sessions + autoRefreshServer <- readIORef ?applicationContext.autoRefreshServer + autoRefreshServer.sessions |> find (\AutoRefreshSession { id } -> id == sessionId) |> Maybe.fromMaybe (error "getSessionById: Could not find the session") |> pure @@ -206,9 +201,9 @@ getSessionById sessionId = do -- | Applies a update function to a session specified by its session id updateSession :: (?applicationContext :: ApplicationContext) => UUID -> (AutoRefreshSession -> AutoRefreshSession) -> IO () updateSession sessionId updateFunction = do - let server = ?applicationContext |> get #autoRefreshServer - let updateSession' session = if get #id session == sessionId then updateFunction session else session - modifyIORef' server (\server -> server { sessions = map updateSession' (get #sessions server) }) + let server = ?applicationContext.autoRefreshServer + let updateSession' session = if session.id == sessionId then updateFunction session else session + modifyIORef' server (\server -> server { sessions = map updateSession' server.sessions }) pure () -- | Removes all expired sessions @@ -219,7 +214,7 @@ updateSession sessionId updateFunction = do gcSessions :: IORef AutoRefreshServer -> IO () gcSessions autoRefreshServer = do now <- getCurrentTime - modifyIORef' autoRefreshServer (\autoRefreshServer -> autoRefreshServer { sessions = filter (not . isSessionExpired now) (get #sessions autoRefreshServer) }) + modifyIORef' autoRefreshServer (\autoRefreshServer -> autoRefreshServer { sessions = filter (not . isSessionExpired now) autoRefreshServer.sessions }) -- | A session is expired if it was not pinged in the last 60 seconds isSessionExpired :: UTCTime -> AutoRefreshSession -> Bool diff --git a/IHP/WebSocket.hs b/IHP/WebSocket.hs index 0ec050cb7..55f3e7acc 100644 --- a/IHP/WebSocket.hs +++ b/IHP/WebSocket.hs @@ -48,20 +48,20 @@ class WSApp state where onClose = pure () startWSApp :: forall state. (WSApp state, ?applicationContext :: ApplicationContext, ?requestContext :: RequestContext, ?context :: ControllerContext, ?modelContext :: ModelContext) => Websocket.Connection -> IO () -startWSApp connection = do +startWSApp connection' = do state <- newIORef (initialState @state) + lastPongAt <- getCurrentTime >>= newIORef + + + let connection = installPongHandler lastPongAt connection' let ?state = state let ?connection = connection + let pingHandler = do + seconds <- secondsSinceLastPong lastPongAt + when (seconds > pingWaitTime * 2) (throwIO PongTimeout) + onPing @state - let runWithPongChan pongChan = do - let connectionOnPong = writeChan pongChan () - let ?connection = connection - { WebSocket.connectionOptions = (get #connectionOptions connection) { WebSocket.connectionOnPong } - } - in - run @state - - result <- Exception.try ((withPinger connection runWithPongChan) `Exception.finally` onClose @state) + result <- Exception.try ((WebSocket.withPingThread connection pingWaitTime pingHandler (run @state)) `Exception.finally` onClose @state) case result of Left (e@Exception.SomeException{}) -> case Exception.fromException e of @@ -114,37 +114,17 @@ instance Exception PongTimeout pingWaitTime :: Int pingWaitTime = 30 +installPongHandler :: IORef UTCTime -> WebSocket.Connection -> WebSocket.Connection +installPongHandler lastPongAt connection = + connection { WebSocket.connectionOptions = connection.connectionOptions { WebSocket.connectionOnPong = connectionOnPong lastPongAt } } --- | Pings the client every 30 seconds and expects a pong response within 10 secons. If no pong response --- is received within 10 seconds, it will kill the connection. --- --- We cannot use the withPingThread of the websockets package as this doesn't deal with pong messages. So --- open connection will stay around forever. --- --- This implementation is based on https://github.com/jaspervdj/websockets/issues/159#issuecomment-552776502 -withPinger conn action = do - pongChan <- newChan - mainAsync <- async $ action pongChan - pingerAsync <- async $ runPinger conn pongChan - - waitEitherCatch mainAsync pingerAsync >>= \case - -- If the application async died for any reason, kill the pinger async - Left result -> do - cancel pingerAsync - case result of - Left exception -> throw exception - Right result -> pure () - -- The pinger thread should never throw an exception. If it does, kill the app thread - Right (Left exception) -> do - cancel mainAsync - throw exception - -- The pinger thread exited due to a pong timeout. Tell the app thread about it. - Right (Right ()) -> cancelWith mainAsync PongTimeout - -runPinger conn pongChan = fix $ \loop -> do - Websocket.sendPing conn (mempty :: ByteString) - threadDelay pingWaitTime - -- See if we got a pong in that time - timeout 1000000 (readChan pongChan) >>= \case - Just () -> loop - Nothing -> return () \ No newline at end of file +connectionOnPong :: IORef UTCTime -> IO () +connectionOnPong lastPongAt = do + now <- getCurrentTime + writeIORef lastPongAt now + +secondsSinceLastPong :: IORef UTCTime -> IO Int +secondsSinceLastPong lastPongAt = do + now <- getCurrentTime + last <- readIORef lastPongAt + pure $ ceiling $ nominalDiffTimeToSeconds $ diffUTCTime now last \ No newline at end of file