From 868cc4c0da20769d1f537153c20ea878f8c1deee Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Sat, 19 Oct 2024 11:56:06 +0900 Subject: [PATCH 1/2] revisiting gracefulClose with STM racing --- Network/Socket/Shutdown.hs | 88 ++++++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 17 deletions(-) diff --git a/Network/Socket/Shutdown.hs b/Network/Socket/Shutdown.hs index e6f9048a..2fd249bf 100644 --- a/Network/Socket/Shutdown.hs +++ b/Network/Socket/Shutdown.hs @@ -8,14 +8,19 @@ module Network.Socket.Shutdown ( , gracefulClose ) where +import Control.Concurrent (threadDelay, yield) import qualified Control.Exception as E import Foreign.Marshal.Alloc (mallocBytes, free) -import Control.Concurrent (threadDelay, yield) +#if !defined(mingw32_HOST_OS) +import Control.Concurrent.STM +import qualified GHC.Event as Ev +#endif import Network.Socket.Buffer import Network.Socket.Imports import Network.Socket.Internal +import Network.Socket.STM import Network.Socket.Types data ShutdownCmd = ShutdownReceive @@ -59,19 +64,68 @@ gracefulClose s tmout0 = sendRecvFIN `E.finally` close s -- FIN arrives meanwhile. yield -- Waiting TCP FIN. - E.bracket (mallocBytes bufSize) free recvEOFloop - recvEOFloop buf = loop 1 0 - where - loop delay tmout = do - -- We don't check the (positive) length. - -- In normal case, it's 0. That is, only FIN is received. - -- In error cases, data is available. But there is no - -- application which can read it. So, let's stop receiving - -- to prevent attacks. - r <- recvBufNoWait s buf bufSize - when (r == -1 && tmout < tmout0) $ do - threadDelay (delay * 1000) - loop (delay * 2) (tmout + delay) - -- Don't use 4092 here. The GHC runtime takes the global lock - -- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit. - bufSize = 1024 + E.bracket (mallocBytes bufSize) free (recvEOF s tmout0) + +recvEOF :: Socket -> Int -> Ptr Word8 -> IO () +#if !defined(mingw32_HOST_OS) +recvEOF s tmout0 buf = do + mevmgr <- Ev.getSystemEventManager + case mevmgr of + Nothing -> recvEOFloop s tmout0 buf + Just _ -> recvEOFevent s tmout0 buf +#else +recvEOF = recvEOFloop +#endif + +-- Don't use 4092 here. The GHC runtime takes the global lock +-- if the length is over 3276 bytes in 32bit or 3272 bytes in 64bit. +bufSize :: Int +bufSize = 1024 + +recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO () +recvEOFloop s tmout0 buf = loop 1 0 + where + loop delay tmout = do + -- We don't check the (positive) length. + -- In normal case, it's 0. That is, only FIN is received. + -- In error cases, data is available. But there is no + -- application which can read it. So, let's stop receiving + -- to prevent attacks. + r <- recvBufNoWait s buf bufSize + when (r == -1 && tmout < tmout0) $ do + threadDelay (delay * 1000) + loop (delay * 2) (tmout + delay) + +#if !defined(mingw32_HOST_OS) +data Wait = MoreData | TimeoutTripped + +recvEOFevent :: Socket -> Int -> Ptr Word8 -> IO () +recvEOFevent s tmout0 buf = do + tmmgr <- Ev.getSystemTimerManager + tvar <- newTVarIO False + E.bracket (setup tmmgr tvar) teardown $ \(wait, _) -> do + waitRes <- wait + case waitRes of + TimeoutTripped -> return () + -- We don't check the (positive) length. + -- In normal case, it's 0. That is, only FIN is received. + -- In error cases, data is available. But there is no + -- application which can read it. So, let's stop receiving + -- to prevent attacks. + MoreData -> void $ recvBufNoWait s buf bufSize + where + setup tmmgr tvar = do + -- millisecond to microsecond + key <- Ev.registerTimeout tmmgr (tmout0 * 1000) $ + atomically $ writeTVar tvar True + (evWait, evCancel) <- waitAndCancelReadSocketSTM s + let toWait = do + tmout <- readTVar tvar + check tmout + toCancel = Ev.unregisterTimeout tmmgr key + wait = atomically ((toWait >> return TimeoutTripped) + <|> (evWait >> return MoreData)) + cancel = evCancel >> toCancel + return (wait, cancel) + teardown (_, cancel) = cancel +#endif From 10ab2cb25a3cac19ca22eeacb54de528f3d12171 Mon Sep 17 00:00:00 2001 From: Kazu Yamamoto Date: Wed, 23 Oct 2024 11:46:09 +0900 Subject: [PATCH 2/2] using timeout for recvEOFloop of gracefulClose --- Network/Socket/Shutdown.hs | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/Network/Socket/Shutdown.hs b/Network/Socket/Shutdown.hs index 2fd249bf..9648533d 100644 --- a/Network/Socket/Shutdown.hs +++ b/Network/Socket/Shutdown.hs @@ -8,9 +8,10 @@ module Network.Socket.Shutdown ( , gracefulClose ) where -import Control.Concurrent (threadDelay, yield) +import Control.Concurrent (yield) import qualified Control.Exception as E import Foreign.Marshal.Alloc (mallocBytes, free) +import System.Timeout #if !defined(mingw32_HOST_OS) import Control.Concurrent.STM @@ -83,18 +84,7 @@ bufSize :: Int bufSize = 1024 recvEOFloop :: Socket -> Int -> Ptr Word8 -> IO () -recvEOFloop s tmout0 buf = loop 1 0 - where - loop delay tmout = do - -- We don't check the (positive) length. - -- In normal case, it's 0. That is, only FIN is received. - -- In error cases, data is available. But there is no - -- application which can read it. So, let's stop receiving - -- to prevent attacks. - r <- recvBufNoWait s buf bufSize - when (r == -1 && tmout < tmout0) $ do - threadDelay (delay * 1000) - loop (delay * 2) (tmout + delay) +recvEOFloop s tmout0 buf = void $ timeout tmout0 $ recvBuf s buf bufSize #if !defined(mingw32_HOST_OS) data Wait = MoreData | TimeoutTripped