diff --git a/Network/Socket/Win32/Cmsg.hsc b/Network/Socket/Win32/Cmsg.hsc index 702be7c1..435470ab 100644 --- a/Network/Socket/Win32/Cmsg.hsc +++ b/Network/Socket/Win32/Cmsg.hsc @@ -1,7 +1,8 @@ - {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -70,8 +71,8 @@ pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO) -- | Control message ID for POSIX file-descriptor passing. -- -- Not supported on Windows; use WSADuplicateSocket instead -pattern CmsgIdFd :: CmsgId -pattern CmsgIdFd = CmsgId (-1) (-1) +pattern CmsgIdFds :: CmsgId +pattern CmsgIdFds = CmsgId (-1) (-1) ---------------------------------------------------------------- @@ -91,11 +92,13 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs ---------------------------------------------------------------- -- | A class to encode and decode control message. -class Storable a => ControlMessage a where +class ControlMessage a where controlMessageId :: CmsgId + encodeCmsg :: a -> Cmsg + decodeCmsg :: Cmsg -> Maybe a -encodeCmsg :: forall a. ControlMessage a => a -> Cmsg -encodeCmsg x = unsafeDupablePerformIO $ do +encodeStorableCmsg :: forall a. (ControlMessage a, Storable a) => a -> Cmsg +encodeStorableCmsg x = unsafeDupablePerformIO $ do bs <- create siz $ \p0 -> do let p = castPtr p0 poke p x @@ -104,8 +107,8 @@ encodeCmsg x = unsafeDupablePerformIO $ do where siz = sizeOf x -decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a -decodeCmsg (Cmsg cmsid (PS fptr off len)) +decodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a +decodeStorableCmsg (Cmsg cmsid (PS fptr off len)) | cid /= cmsid = Nothing | len < siz = Nothing | otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do @@ -122,6 +125,8 @@ newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable) instance ControlMessage IPv4TTL where controlMessageId = CmsgIdIPv4TTL + decodeCmsg = decodeStorableCmsg + encodeCmsg = encodeStorableCmsg ---------------------------------------------------------------- @@ -130,6 +135,8 @@ newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable) instance ControlMessage IPv6HopLimit where controlMessageId = CmsgIdIPv6HopLimit + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -138,6 +145,8 @@ newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable) instance ControlMessage IPv4TOS where controlMessageId = CmsgIdIPv4TOS + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -146,6 +155,8 @@ newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable) instance ControlMessage IPv6TClass where controlMessageId = CmsgIdIPv6TClass + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg ---------------------------------------------------------------- @@ -158,6 +169,8 @@ instance Show IPv4PktInfo where instance ControlMessage IPv4PktInfo where controlMessageId = CmsgIdIPv4PktInfo + encodeCmsg = encodeStorableCmsg + decodeCmsg = decodeStorableCmsg instance Storable IPv4PktInfo where sizeOf ~_ = #{size IN_PKTINFO} @@ -180,6 +193,8 @@ instance Show IPv6PktInfo where instance ControlMessage IPv6PktInfo where controlMessageId = CmsgIdIPv6PktInfo + decodeCmsg = decodeStorableCmsg + encodeCmsg = encodeStorableCmsg instance Storable IPv6PktInfo where sizeOf ~_ = #{size IN6_PKTINFO} @@ -192,8 +207,14 @@ instance Storable IPv6PktInfo where n :: ULONG <- (#peek IN6_PKTINFO, ipi6_ifindex) p return $ IPv6PktInfo (fromIntegral n) ha6 -instance ControlMessage Fd where - controlMessageId = CmsgIdFd +---------------------------------------------------------------- + +instance ControlMessage [Fd] where + controlMessageId = CmsgIdFds + encodeCmsg = \_ -> Cmsg CmsgIdFds "" + decodeCmsg = \_ -> Just [] + +---------------------------------------------------------------- cmsgIdBijection :: Bijection CmsgId String cmsgIdBijection = @@ -204,7 +225,7 @@ cmsgIdBijection = , (CmsgIdIPv6TClass, "CmsgIdIPv6TClass") , (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo") , (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo") - , (CmsgIdFd, "CmsgIdFd") + , (CmsgIdFds, "CmsgIdFds") ] instance Show CmsgId where diff --git a/network.cabal b/network.cabal index 7bab60f6..b39d6b9b 100644 --- a/network.cabal +++ b/network.cabal @@ -206,3 +206,6 @@ test-suite spec if impl(ghc >=8) default-extensions: Strict StrictData + + if os(windows) + cpp-options: -D_WIN32 diff --git a/tests/Network/SocketSpec.hs b/tests/Network/SocketSpec.hs index 30294050..f2ac337f 100644 --- a/tests/Network/SocketSpec.hs +++ b/tests/Network/SocketSpec.hs @@ -379,9 +379,11 @@ spec = do let msgid = CmsgId (-300) (-300) in show msgid `shouldBe` "CmsgId (-300) (-300)" +#if !defined(_WIN32) describe "bijective encodeCmsg-decodeCmsg roundtrip equality" $ do it "holds for [Fd]" $ forAll genFds $ \x -> (decodeCmsg . encodeCmsg $ x) == Just (x :: [Fd]) +#endif describe "bijective read-show roundtrip equality" $ do it "holds for Family" $ forAll familyGen $