diff --git a/Network/Socket/Shutdown.hs b/Network/Socket/Shutdown.hs index 9648533d..b3452029 100644 --- a/Network/Socket/Shutdown.hs +++ b/Network/Socket/Shutdown.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE ScopedTypeVariables #-} #include "HsNetDef.h" @@ -59,7 +60,8 @@ gracefulClose s tmout0 = sendRecvFIN `E.finally` close s -- Sending TCP FIN. ex <- E.try $ shutdown s ShutdownSend case ex of - Left (E.SomeException _) -> return () + -- Don't catch asynchronous exceptions + Left (_ :: E.IOException) -> return () Right () -> do -- Giving CPU time to other threads hoping that -- FIN arrives meanwhile. @@ -93,29 +95,26 @@ recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO () recvEOFevent s tmout0 buf = do tmmgr <- Ev.getSystemTimerManager tvar <- newTVarIO False - E.bracket (setup tmmgr tvar) teardown $ \(wait, _) -> do - waitRes <- wait - case waitRes of - TimeoutTripped -> return () - -- We don't check the (positive) length. - -- In normal case, it's 0. That is, only FIN is received. - -- In error cases, data is available. But there is no - -- application which can read it. So, let's stop receiving - -- to prevent attacks. - MoreData -> void $ recvBufNoWait s buf bufSize + E.bracket (setupTimeout tmmgr tvar) (cancelTimeout tmmgr) $ \_ -> do + E.bracket (setupRead s) cancelRead $ \(rxWait,_) -> do + let toWait = readTVar tvar >>= check + wait = atomically ((toWait >> return TimeoutTripped) + <|> (rxWait >> return MoreData)) + waitRes <- wait + case waitRes of + TimeoutTripped -> return () + -- We don't check the (positive) length. + -- In normal case, it's 0. That is, only FIN is received. + -- In error cases, data is available. But there is no + -- application which can read it. So, let's stop receiving + -- to prevent attacks. + MoreData -> void $ recvBufNoWait s buf bufSize where - setup tmmgr tvar = do - -- millisecond to microsecond - key <- Ev.registerTimeout tmmgr (tmout0 * 1000) $ - atomically $ writeTVar tvar True - (evWait, evCancel) <- waitAndCancelReadSocketSTM s - let toWait = do - tmout <- readTVar tvar - check tmout - toCancel = Ev.unregisterTimeout tmmgr key - wait = atomically ((toWait >> return TimeoutTripped) - <|> (evWait >> return MoreData)) - cancel = evCancel >> toCancel - return (wait, cancel) - teardown (_, cancel) = cancel + -- millisecond to microsecond + tmout = tmout0 * 1000 + setupTimeout tmmgr tvar = + Ev.registerTimeout tmmgr tmout $ atomically $ writeTVar tvar True + cancelTimeout = Ev.unregisterTimeout + setupRead = waitAndCancelReadSocketSTM + cancelRead (_,cancel) = cancel #endif