Skip to content

Commit

Permalink
Merge pull request haskell#587 from kazu-yamamoto/non-empty
Browse files Browse the repository at this point in the history
Making getAddrInfo polymorphic
  • Loading branch information
kazu-yamamoto authored Sep 11, 2024
2 parents a521e19 + b7ba6ee commit a7b15f8
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 64 deletions.
6 changes: 4 additions & 2 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
-- > import qualified Control.Exception as E
-- > import Control.Monad (unless, forever, void)
-- > import qualified Data.ByteString as S
-- > import qualified Data.List.NonEmpty as NE
-- > import Network.Socket
-- > import Network.Socket.ByteString (recv, sendAll)
-- >
Expand All @@ -56,7 +57,7 @@
-- > addrFlags = [AI_PASSIVE]
-- > , addrSocketType = Stream
-- > }
-- > head <$> getAddrInfo (Just hints) mhost (Just port)
-- > NE.head <$> getAddrInfo (Just hints) mhost (Just port)
-- > open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
-- > setSocketOption sock ReuseAddr 1
-- > withFdSocket sock setCloseOnExecIfNeeded
Expand All @@ -77,6 +78,7 @@
-- >
-- > import qualified Control.Exception as E
-- > import qualified Data.ByteString.Char8 as C
-- > import qualified Data.List.NonEmpty as NE
-- > import Network.Socket
-- > import Network.Socket.ByteString (recv, sendAll)
-- >
Expand All @@ -95,7 +97,7 @@
-- > where
-- > resolve = do
-- > let hints = defaultHints { addrSocketType = Stream }
-- > head <$> getAddrInfo (Just hints) (Just host) (Just port)
-- > NE.head <$> getAddrInfo (Just hints) (Just host) (Just port)
-- > open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
-- > connect sock $ addrAddress addr
-- > return sock
Expand Down
146 changes: 91 additions & 55 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

module Network.Socket.Info where

import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import Foreign.Marshal.Alloc (alloca, allocaBytes)
import Foreign.Marshal.Utils (maybeWith, with)
import GHC.IO.Exception (IOErrorType(NoSuchThing))
Expand Down Expand Up @@ -200,53 +202,72 @@ defaultHints = AddrInfo {
, addrCanonName = Nothing
}

-----------------------------------------------------------------------------
-- | Resolve a host or service name to one or more addresses.
-- The 'AddrInfo' values that this function returns contain 'SockAddr'
-- values that you can pass directly to 'connect' or
-- 'bind'.
--
-- This function is protocol independent. It can return both IPv4 and
-- IPv6 address information.
--
-- The 'AddrInfo' argument specifies the preferred query behaviour,
-- socket options, or protocol. You can override these conveniently
-- using Haskell's record update syntax on 'defaultHints', for example
-- as follows:
--
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST], addrSocketType = Stream }
--
-- You must provide a 'Just' value for at least one of the 'HostName'
-- or 'ServiceName' arguments. 'HostName' can be either a numeric
-- network address (dotted quad for IPv4, colon-separated hex for
-- IPv6) or a hostname. In the latter case, its addresses will be
-- looked up unless 'AI_NUMERICHOST' is specified as a hint. If you
-- do not provide a 'HostName' value /and/ do not set 'AI_PASSIVE' as
-- a hint, network addresses in the result will contain the address of
-- the loopback interface.
--
-- If the query fails, this function throws an IO exception instead of
-- returning an empty list. Otherwise, it returns a non-empty list
-- of 'AddrInfo' values.
--
-- There are several reasons why a query might result in several
-- values. For example, the queried-for host could be multihomed, or
-- the service might be available via several protocols.
--
-- Note: the order of arguments is slightly different to that defined
-- for @getaddrinfo@ in RFC 2553. The 'AddrInfo' parameter comes first
-- to make partial application easier.
--
-- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "http")
-- >>> addrAddress addr
-- 127.0.0.1:80

getAddrInfo
class GetAddrInfo t where
-----------------------------------------------------------------------------
-- | Resolve a host or service name to one or more addresses.
-- The 'AddrInfo' values that this function returns contain 'SockAddr'
-- values that you can pass directly to 'connect' or
-- 'bind'.
--
-- This function calls @getaddrinfo(3)@, which never successfully returns
-- with an empty list. If the query fails, 'getAddrInfo' throws an IO
-- exception.
--
-- For backwards-compatibility reasons, a hidden 'GetAddrInfo' class is used
-- to make the result polymorphic. It only has instances for @[]@ (lists)
-- and 'NonEmpty'. Use of 'NonEmpty' is recommended.
--
-- This function is protocol independent. It can return both IPv4 and
-- IPv6 address information.
--
-- The 'AddrInfo' argument specifies the preferred query behaviour,
-- socket options, or protocol. You can override these conveniently
-- using Haskell's record update syntax on 'defaultHints', for example
-- as follows:
--
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST], addrSocketType = Stream }
--
-- You must provide a 'Just' value for at least one of the 'HostName'
-- or 'ServiceName' arguments. 'HostName' can be either a numeric
-- network address (dotted quad for IPv4, colon-separated hex for
-- IPv6) or a hostname. In the latter case, its addresses will be
-- looked up unless 'AI_NUMERICHOST' is specified as a hint. If you
-- do not provide a 'HostName' value /and/ do not set 'AI_PASSIVE' as
-- a hint, network addresses in the result will contain the address of
-- the loopback interface.
--
-- There are several reasons why a query might result in several
-- values. For example, the queried-for host could be multihomed, or
-- the service might be available via several protocols.
--
-- Note: the order of arguments is slightly different to that defined
-- for @getaddrinfo@ in RFC 2553. The 'AddrInfo' parameter comes first
-- to make partial application easier.
--
-- >>> import qualified Data.List.NonEmpty as NE
-- >>> addr <- NE.head <$> getAddrInfo (Just hints) (Just "127.0.0.1") (Just "http")
-- >>> addrAddress addr
-- 127.0.0.1:80
--
-- Polymorphic version: @since 3.2.3.0
getAddrInfo
:: Maybe AddrInfo -- ^ preferred socket type or protocol
-> Maybe HostName -- ^ host name to look up
-> Maybe ServiceName -- ^ service name to look up
-> IO (t AddrInfo) -- ^ resolved addresses, with "best" first

instance GetAddrInfo [] where
getAddrInfo = getAddrInfoList

instance GetAddrInfo NE.NonEmpty where
getAddrInfo = getAddrInfoNE

getAddrInfoNE
:: Maybe AddrInfo -- ^ preferred socket type or protocol
-> Maybe HostName -- ^ host name to look up
-> Maybe ServiceName -- ^ service name to look up
-> IO [AddrInfo] -- ^ resolved addresses, with "best" first
getAddrInfo hints node service = alloc getaddrinfo
-> IO (NonEmpty AddrInfo) -- ^ resolved addresses, with "best" first
getAddrInfoNE hints node service = alloc getaddrinfo
where
alloc body = withSocketsDo $ maybeWith withCString node $ \c_node ->
maybeWith withCString service $ \c_service ->
Expand All @@ -258,12 +279,7 @@ getAddrInfo hints node service = alloc getaddrinfo
if ret == 0 then do
ptr_addrs <- peek ptr_ptr_addrs
ais <- followAddrInfo ptr_addrs
c_freeaddrinfo ptr_addrs
-- POSIX requires that getaddrinfo(3) returns at least one addrinfo.
-- See: http://pubs.opengroup.org/onlinepubs/9699919799/functions/getaddrinfo.html
case ais of
[] -> ioError $ mkIOError NoSuchThing message Nothing Nothing
_ -> return ais
return ais
else do
err <- gai_strerror ret
ioError $ ioeSetErrorString
Expand All @@ -290,13 +306,33 @@ getAddrInfo hints node service = alloc getaddrinfo
filteredHints = hints
#endif

followAddrInfo :: Ptr AddrInfo -> IO [AddrInfo]
getAddrInfoList
:: Maybe AddrInfo
-> Maybe HostName
-> Maybe ServiceName
-> IO [AddrInfo]
getAddrInfoList hints node service =
-- getAddrInfo never returns an empty list.
NE.toList <$> getAddrInfoNE hints node service

followAddrInfo :: Ptr AddrInfo -> IO (NonEmpty AddrInfo)
followAddrInfo ptr_ai
| ptr_ai == nullPtr = return []
-- POSIX requires that getaddrinfo(3) returns at least one addrinfo.
-- See: http://pubs.opengroup.org/onlinepubs/9699919799/functions/getaddrinfo.html
| ptr_ai == nullPtr = ioError $ mkIOError NoSuchThing "getaddrinfo must return at least one addrinfo" Nothing Nothing
| otherwise = do
a <- peek ptr_ai
as <- (# peek struct addrinfo, ai_next) ptr_ai >>= followAddrInfo
return (a : as)
a <- peek ptr_ai
ptr <- (# peek struct addrinfo, ai_next) ptr_ai
(a :|) <$> go ptr
where
go :: Ptr AddrInfo -> IO [AddrInfo]
go ptr
| ptr == nullPtr = return []
| otherwise = do
a' <- peek ptr
ptr' <- (# peek struct addrinfo, ai_next) ptr
as' <- go ptr'
return (a':as')

foreign import ccall safe "hsnet_getaddrinfo"
c_getaddrinfo :: CString -> CString -> Ptr AddrInfo -> Ptr (Ptr AddrInfo)
Expand Down
3 changes: 2 additions & 1 deletion Network/Socket/Syscall.hs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ import Network.Socket.Types
-- can be handled with one socket.
--
-- >>> import Network.Socket
-- >>> import qualified Data.List.NonEmpty as NE
-- >>> let hints = defaultHints { addrFlags = [AI_NUMERICHOST, AI_NUMERICSERV], addrSocketType = Stream }
-- >>> addr:_ <- getAddrInfo (Just hints) (Just "127.0.0.1") (Just "5000")
-- >>> addr <- NE.head <$> getAddrInfo (Just hints) (Just "127.0.0.1") (Just "5000")
-- >>> sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
-- >>> Network.Socket.bind sock (addrAddress addr)
-- >>> getSocketName sock
Expand Down
3 changes: 2 additions & 1 deletion examples/EchoClient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Main (main) where

import qualified Control.Exception as E
import qualified Data.ByteString.Char8 as C
import qualified Data.List.NonEmpty as NE
import Network.Socket
import Network.Socket.ByteString (recv, sendAll)

Expand All @@ -23,7 +24,7 @@ runTCPClient host port client = withSocketsDo $ do
where
resolve = do
let hints = defaultHints{addrSocketType = Stream}
head <$> getAddrInfo (Just hints) (Just host) (Just port)
NE.head <$> getAddrInfo (Just hints) (Just host) (Just port)
open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
connect sock $ addrAddress addr
return sock
3 changes: 2 additions & 1 deletion examples/EchoServer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Control.Concurrent (forkFinally)
import qualified Control.Exception as E
import Control.Monad (forever, unless, void)
import qualified Data.ByteString as S
import qualified Data.List.NonEmpty as NE
import Network.Socket
import Network.Socket.ByteString (recv, sendAll)

Expand All @@ -29,7 +30,7 @@ runTCPServer mhost port server = withSocketsDo $ do
{ addrFlags = [AI_PASSIVE]
, addrSocketType = Stream
}
head <$> getAddrInfo (Just hints) mhost (Just port)
NE.head <$> getAddrInfo (Just hints) mhost (Just port)
open addr = E.bracketOnError (openSocket addr) close $ \sock -> do
setSocketOption sock ReuseAddr 1
withFdSocket sock setCloseOnExecIfNeeded
Expand Down
2 changes: 1 addition & 1 deletion network.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 1.18
name: network
version: 3.2.2.0
version: 3.2.3.0
license: BSD3
license-file: LICENSE
maintainer: Kazu Yamamoto, Tamar Christina
Expand Down
2 changes: 1 addition & 1 deletion tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ spec = do

it "does not cause segfault on macOS 10.8.2 due to AI_NUMERICSERV" $ do
let hints = defaultHints { addrFlags = [AI_NUMERICSERV] }
void $ getAddrInfo (Just hints) (Just "localhost") Nothing
void (getAddrInfo (Just hints) (Just "localhost") Nothing :: IO [AddrInfo])

#if defined(mingw32_HOST_OS)
let lpdevname = "loopback_0"
Expand Down
5 changes: 3 additions & 2 deletions tests/Network/Test/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import qualified Control.Exception as E
import Control.Monad
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as L
import qualified Data.List.NonEmpty as NE
import Network.Socket
import System.Directory
import System.Timeout (timeout)
Expand Down Expand Up @@ -244,7 +245,7 @@ bracketWithReraise tid setup teardown thing =

resolveClient :: SocketType -> HostName -> PortNumber -> IO AddrInfo
resolveClient socketType host port =
head <$> getAddrInfo (Just hints) (Just host) (Just $ show port)
NE.head <$> getAddrInfo (Just hints) (Just host) (Just $ show port)
where
hints = defaultHints {
addrSocketType = socketType
Expand All @@ -253,7 +254,7 @@ resolveClient socketType host port =

resolveServer :: SocketType -> HostName -> IO AddrInfo
resolveServer socketType host =
head <$> getAddrInfo (Just hints) (Just host) Nothing
NE.head <$> getAddrInfo (Just hints) (Just host) Nothing
where
hints = defaultHints {
addrSocketType = socketType
Expand Down

0 comments on commit a7b15f8

Please sign in to comment.