Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Socket endpoints, in particular reading them from a String #464

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ module Network.Socket
, Socket
, socket
, openSocket
, socketFromEndpoint
, withFdSocket
, unsafeFdSocket
, touchSocket
Expand Down Expand Up @@ -183,8 +184,14 @@ module Network.Socket
-- ** Protocol number
, ProtocolNumber
, defaultProtocol
-- * Basic socket endpoint type
, SockEndpoint(..)
, readSockEndpoint
, showSockEndpoint
, resolveEndpoint
-- * Basic socket address type
, SockAddr(..)
, sockAddrFamily
, isSupportedSockAddr
, getPeerName
, getSocketName
Expand Down
80 changes: 69 additions & 11 deletions Network/Socket/Info.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,19 @@

module Network.Socket.Info where

import Control.Exception (try, IOException)
import Foreign.Marshal.Alloc (alloca, allocaBytes)
import Foreign.Marshal.Utils (maybeWith, with)
import GHC.IO.Exception (IOErrorType(NoSuchThing))
import System.IO.Error (ioeSetErrorString, mkIOError)
import System.IO.Unsafe (unsafePerformIO)
import Text.Read (readEither)

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Syscall
import Network.Socket.Syscall (socket)
import Network.Socket.Types

-----------------------------------------------------------------------------

-- | Either a host name e.g., @\"haskell.org\"@ or a numeric host
-- address string consisting of a dotted decimal IPv4 address or an
-- IPv6 address e.g., @\"192.168.0.1\"@.
type HostName = String
-- | Either a service name e.g., @\"http\"@ or a numeric port number.
type ServiceName = String

-----------------------------------------------------------------------------
-- Address and service lookups

Expand Down Expand Up @@ -467,10 +461,74 @@ showHostAddress6 ha6@(a1, a2, a3, a4)
scanl (\c i -> if i == 0 then c - 1 else 0) 0 fields `zip` [0..]

-----------------------------------------------------------------------------

-- | A utility function to open a socket with `AddrInfo`.
-- This is a just wrapper for the following code:
--
-- > \addr -> socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)
openSocket :: AddrInfo -> IO Socket
openSocket addr = socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr)

-----------------------------------------------------------------------------
-- SockEndpoint

-- | Read a string representing a socket endpoint.
readSockEndpoint :: PortNumber -> String -> Either String SockEndpoint
readSockEndpoint defPort hostport = case hostport of
'/':_ -> Right $ EndpointByAddr $ SockAddrUnix hostport
'[':tl -> case span ((/=) ']') tl of
(_, []) -> Left $ "unterminated IPv6 address: " <> hostport
(ipv6, _:port) -> case readAddr ipv6 of
Nothing -> Left $ "invalid IPv6 address: " <> ipv6
Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port
_ -> case span ((/=) ':') hostport of
(host, port) -> case readAddr host of
Nothing -> EndpointByName host <$> readPort port
Just addr -> EndpointByAddr . sockAddrPort addr <$> readPort port
where
readPort "" = Right defPort
readPort ":" = Right defPort
readPort (':':port) = case readEither port of
Right p -> Right p
Left _ -> Left $ "bad port: " <> port
readPort x = Left $ "bad port: " <> x
hints = Just $ defaultHints { addrFlags = [AI_NUMERICHOST] }
readAddr host = case unsafePerformIO (try (getAddrInfo hints (Just host) Nothing)) of
Left e -> Nothing where _ = e :: IOException
Right r -> Just (addrAddress (head r))
sockAddrPort h p = case h of
SockAddrInet _ a -> SockAddrInet p a
SockAddrInet6 _ f a s -> SockAddrInet6 p f a s
x -> x

showSockEndpoint :: SockEndpoint -> String
showSockEndpoint n = case n of
EndpointByName h p -> h <> ":" <> show p
EndpointByAddr a -> show a

-- | Resolve a socket endpoint into a list of socket addresses.
-- The result is always non-empty; Haskell throws an exception if name
-- resolution fails.
resolveEndpoint :: SockEndpoint -> IO [SockAddr]
resolveEndpoint name = case name of
EndpointByAddr a -> return [a]
EndpointByName host port -> fmap addrAddress <$> getAddrInfo hints (Just host) (Just (show port))
where
hints = Just $ defaultHints { addrSocketType = Stream }
-- prevents duplicates, otherwise getAddrInfo returns all socket types

-- | Shortcut for creating a socket from a socket endpoint.
--
-- >>> import Network.Socket
-- >>> let Right sn = readSockEndpoint 0 "0.0.0.0:0"
-- >>> (s, a) <- socketFromEndpoint sn head Stream defaultProtocol
-- >>> bind s a
socketFromEndpoint
:: SockEndpoint
-> ([SockAddr] -> SockAddr)
-> SocketType
-> ProtocolNumber
-> IO (Socket, SockAddr)
socketFromEndpoint end select stype protocol = do
a <- select <$> resolveEndpoint end
s <- socket (sockAddrFamily a) stype protocol
return (s, a)
34 changes: 34 additions & 0 deletions Network/Socket/Types.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ module Network.Socket.Types (
, withNewSocketAddress

-- * Socket address type
, SockEndpoint(..)
, SockAddr(..)
, sockAddrFamily
, isSupportedSockAddr
, HostAddress
, hostAddressToTuple
Expand All @@ -72,6 +74,8 @@ module Network.Socket.Types (
, defaultProtocol
, PortNumber
, defaultPort
, HostName
, ServiceName

-- * Low-level helpers
, zeroMemory
Expand Down Expand Up @@ -289,6 +293,13 @@ type ProtocolNumber = CInt
defaultProtocol :: ProtocolNumber
defaultProtocol = 0

-- | Either a host name e.g., @\"haskell.org\"@ or a numeric host
-- address string consisting of a dotted decimal IPv4 address or an
-- IPv6 address e.g., @\"192.168.0.1\"@.
type HostName = String
-- | Either a service name e.g., @\"http\"@ or a numeric port number.
type ServiceName = String

-----------------------------------------------------------------------------
-- Socket types

Expand Down Expand Up @@ -1047,6 +1058,23 @@ type FlowInfo = Word32
-- | Scope identifier.
type ScopeID = Word32

-- | Socket endpoints.
--
-- A wrapper around socket addresses that also accommodates the
-- popular usage of specifying them by name, e.g. "example.com:80".
-- We don't support service names here (string aliases for port
-- numbers) because they also imply a particular socket type, which
-- is outside of the scope of this data type.
--
-- This roughly corresponds to the "authority" part of a URI, as
-- defined here: https://tools.ietf.org/html/rfc3986#section-3.2
--
-- See also 'Network.Socket.socketFromEndpoint'.
data SockEndpoint
= EndpointByName !HostName !PortNumber
| EndpointByAddr !SockAddr
deriving (Eq, Ord)

-- | Socket addresses.
-- The existence of a constructor does not necessarily imply that
-- that socket address type is supported on your system: see
Expand All @@ -1070,6 +1098,12 @@ instance NFData SockAddr where
rnf (SockAddrInet6 _ _ _ _) = ()
rnf (SockAddrUnix str) = rnf str

sockAddrFamily :: SockAddr -> Family
sockAddrFamily addr = case addr of
SockAddrInet _ _ -> AF_INET
SockAddrInet6 _ _ _ _ -> AF_INET6
SockAddrUnix _ -> AF_UNIX

-- | Is the socket address type supported on this system?
isSupportedSockAddr :: SockAddr -> Bool
isSupportedSockAddr addr = case addr of
Expand Down
21 changes: 21 additions & 0 deletions tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,27 @@ spec = do
-- check if an exception is not thrown.
isSupportedSockAddr addr `shouldBe` True

it "endpoints API, IPv4" $ do
let Right end = readSockEndpoint 0 "127.0.0.1:6001"
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
bind sock addr
listen sock 1
close sock

it "endpoints API, IPv6" $ do
let Right end = readSockEndpoint 0 "[::1]:6001"
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
bind sock addr
listen sock 1
close sock

it "endpoints API, DNS" $ do
let Right end = readSockEndpoint 0 "localhost:6001"
(sock, addr) <- socketFromEndpoint end head Stream defaultProtocol
bind sock addr
listen sock 1
close sock

#if !defined(mingw32_HOST_OS)
when isUnixDomainSocketAvailable $ do
context "unix sockets" $ do
Expand Down