Skip to content

Commit

Permalink
Merge pull request #575 from kazu-yamamoto/fix-win-fds
Browse files Browse the repository at this point in the history
CmsgIdFd -> CmsgIdFds
  • Loading branch information
kazu-yamamoto authored Mar 25, 2024
2 parents 04b1943 + aeab895 commit 2ccb4aa
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
43 changes: 32 additions & 11 deletions Network/Socket/Win32/Cmsg.hsc
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
Expand Down Expand Up @@ -70,8 +71,8 @@ pattern CmsgIdIPv6PktInfo = CmsgId (#const IPPROTO_IPV6) (#const IPV6_PKTINFO)
-- | Control message ID for POSIX file-descriptor passing.
--
-- Not supported on Windows; use WSADuplicateSocket instead
pattern CmsgIdFd :: CmsgId
pattern CmsgIdFd = CmsgId (-1) (-1)
pattern CmsgIdFds :: CmsgId
pattern CmsgIdFds = CmsgId (-1) (-1)

----------------------------------------------------------------

Expand All @@ -91,11 +92,13 @@ filterCmsg cid cmsgs = filter (\cmsg -> cmsgId cmsg == cid) cmsgs
----------------------------------------------------------------

-- | A class to encode and decode control message.
class Storable a => ControlMessage a where
class ControlMessage a where
controlMessageId :: CmsgId
encodeCmsg :: a -> Cmsg
decodeCmsg :: Cmsg -> Maybe a

encodeCmsg :: forall a. ControlMessage a => a -> Cmsg
encodeCmsg x = unsafeDupablePerformIO $ do
encodeStorableCmsg :: forall a. (ControlMessage a, Storable a) => a -> Cmsg
encodeStorableCmsg x = unsafeDupablePerformIO $ do
bs <- create siz $ \p0 -> do
let p = castPtr p0
poke p x
Expand All @@ -104,8 +107,8 @@ encodeCmsg x = unsafeDupablePerformIO $ do
where
siz = sizeOf x

decodeCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeCmsg (Cmsg cmsid (PS fptr off len))
decodeStorableCmsg :: forall a . (ControlMessage a, Storable a) => Cmsg -> Maybe a
decodeStorableCmsg (Cmsg cmsid (PS fptr off len))
| cid /= cmsid = Nothing
| len < siz = Nothing
| otherwise = unsafeDupablePerformIO $ withForeignPtr fptr $ \p0 -> do
Expand All @@ -122,6 +125,8 @@ newtype IPv4TTL = IPv4TTL DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TTL where
controlMessageId = CmsgIdIPv4TTL
decodeCmsg = decodeStorableCmsg
encodeCmsg = encodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -130,6 +135,8 @@ newtype IPv6HopLimit = IPv6HopLimit DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6HopLimit where
controlMessageId = CmsgIdIPv6HopLimit
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -138,6 +145,8 @@ newtype IPv4TOS = IPv4TOS DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv4TOS where
controlMessageId = CmsgIdIPv4TOS
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -146,6 +155,8 @@ newtype IPv6TClass = IPv6TClass DWORD deriving (Eq, Show, Storable)

instance ControlMessage IPv6TClass where
controlMessageId = CmsgIdIPv6TClass
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

----------------------------------------------------------------

Expand All @@ -158,6 +169,8 @@ instance Show IPv4PktInfo where

instance ControlMessage IPv4PktInfo where
controlMessageId = CmsgIdIPv4PktInfo
encodeCmsg = encodeStorableCmsg
decodeCmsg = decodeStorableCmsg

instance Storable IPv4PktInfo where
sizeOf ~_ = #{size IN_PKTINFO}
Expand All @@ -180,6 +193,8 @@ instance Show IPv6PktInfo where

instance ControlMessage IPv6PktInfo where
controlMessageId = CmsgIdIPv6PktInfo
decodeCmsg = decodeStorableCmsg
encodeCmsg = encodeStorableCmsg

instance Storable IPv6PktInfo where
sizeOf ~_ = #{size IN6_PKTINFO}
Expand All @@ -192,8 +207,14 @@ instance Storable IPv6PktInfo where
n :: ULONG <- (#peek IN6_PKTINFO, ipi6_ifindex) p
return $ IPv6PktInfo (fromIntegral n) ha6

instance ControlMessage Fd where
controlMessageId = CmsgIdFd
----------------------------------------------------------------

instance ControlMessage [Fd] where
controlMessageId = CmsgIdFds
encodeCmsg = \_ -> Cmsg CmsgIdFds ""
decodeCmsg = \_ -> Just []

----------------------------------------------------------------

cmsgIdBijection :: Bijection CmsgId String
cmsgIdBijection =
Expand All @@ -204,7 +225,7 @@ cmsgIdBijection =
, (CmsgIdIPv6TClass, "CmsgIdIPv6TClass")
, (CmsgIdIPv4PktInfo, "CmsgIdIPv4PktInfo")
, (CmsgIdIPv6PktInfo, "CmsgIdIPv6PktInfo")
, (CmsgIdFd, "CmsgIdFd")
, (CmsgIdFds, "CmsgIdFds")
]

instance Show CmsgId where
Expand Down
3 changes: 3 additions & 0 deletions network.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,6 @@ test-suite spec

if impl(ghc >=8)
default-extensions: Strict StrictData

if os(windows)
cpp-options: -D_WIN32
2 changes: 2 additions & 0 deletions tests/Network/SocketSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,11 @@ spec = do
let msgid = CmsgId (-300) (-300) in
show msgid `shouldBe` "CmsgId (-300) (-300)"

#if !defined(_WIN32)
describe "bijective encodeCmsg-decodeCmsg roundtrip equality" $ do
it "holds for [Fd]" $ forAll genFds $
\x -> (decodeCmsg . encodeCmsg $ x) == Just (x :: [Fd])
#endif

describe "bijective read-show roundtrip equality" $ do
it "holds for Family" $ forAll familyGen $
Expand Down

0 comments on commit 2ccb4aa

Please sign in to comment.