diff --git a/Network/Socket/SockAddr.hs b/Network/Socket/SockAddr.hs index 4005184d..31e43c02 100644 --- a/Network/Socket/SockAddr.hs +++ b/Network/Socket/SockAddr.hs @@ -20,6 +20,7 @@ import qualified Network.Socket.Name as G import qualified Network.Socket.Syscall as G import Network.Socket.Flag import Network.Socket.Imports +import Network.Socket.Info () #if !defined(mingw32_HOST_OS) import Network.Socket.Posix.Cmsg #else diff --git a/Network/Socket/Syscall.hs b/Network/Socket/Syscall.hs index c45bda10..673332f0 100644 --- a/Network/Socket/Syscall.hs +++ b/Network/Socket/Syscall.hs @@ -135,14 +135,14 @@ bind s sa = withSocketAddress sa $ \p_sa siz -> void $ withFdSocket s $ \fd -> d -- Connecting a socket -- | Connect to a remote socket at address. -connect :: SocketAddress sa => Socket -> sa -> IO () +connect :: (Show sa, SocketAddress sa) => Socket -> sa -> IO () connect s sa = withSocketsDo $ withSocketAddress sa $ \p_sa sz -> - connectLoop s p_sa (fromIntegral sz) + connectLoop (show sa) s p_sa (fromIntegral sz) -connectLoop :: SocketAddress sa => Socket -> Ptr sa -> CInt -> IO () -connectLoop s p_sa sz = withFdSocket s $ \fd -> loop fd +connectLoop :: SocketAddress sa => String -> Socket -> Ptr sa -> CInt -> IO () +connectLoop show_sa s p_sa sz = withFdSocket s $ \fd -> loop fd where - errLoc = "Network.Socket.connect: " ++ show s + errLoc = "Network.Socket.connect: " ++ show s ++ " " ++ show_sa loop fd = do r <- c_connect fd p_sa sz when (r == -1) $ do diff --git a/tests/Network/SocketSpec.hs b/tests/Network/SocketSpec.hs index 743db7a8..ff9008dd 100644 --- a/tests/Network/SocketSpec.hs +++ b/tests/Network/SocketSpec.hs @@ -1,13 +1,15 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} module Network.SocketSpec (main, spec) where import Control.Concurrent (threadDelay, forkIO) import Control.Concurrent.MVar (readMVar) +import Control.Exception (SomeException) import Control.Monad import Data.Maybe (fromJust) -import Data.List (nub) +import Data.List (isInfixOf, nub) import Network.Socket import Network.Socket.ByteString import Network.Test.Common @@ -32,9 +34,16 @@ spec = do sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) connect sock (addrAddress addr) return sock + assumedFreePort = 8003 :: Int it "fails to connect and throws an IOException" $ do - connect' (8003 :: Int) `shouldThrow` anyIOException + connect' assumedFreePort `shouldThrow` anyIOException + + it "fails to connect exception contains target port" $ do + connect' assumedFreePort `shouldThrow` \(e :: SomeException) -> show assumedFreePort `isInfixOf` show e + + it "fails to connect exception contains target host" $ do + connect' assumedFreePort `shouldThrow` \(e :: SomeException) -> serverAddr `isInfixOf` show e it "successfully connects to a socket with no exception" $ do withPort $ \portVar -> test (tcp serverAddr return portVar)