diff --git a/cabal.project b/cabal.project index 11db548b8c0..69095eb9a0d 100644 --- a/cabal.project +++ b/cabal.project @@ -54,3 +54,12 @@ package network-mux package ouroboros-network flags: +asserts +cddl + +source-repository-package + type: git + location: https://github.com/input-output-hk/typed-protocols + tag: d0c0668048be5b9878917180d7a0641861216bec + subdir: typed-protocols + typed-protocols-cborg +allow-newer: typed-protocols:io-classes + diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs index c7c163f3465..868d93dcee8 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs @@ -6,7 +6,6 @@ {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} -- @UndecidableInstances@ extensions is required for defining @Show@ instance -- of @'TraceSendRecv'@. @@ -19,10 +18,12 @@ module Ouroboros.Network.Driver.Simple -- $intro -- * Normal peers runPeer + , runAnnotatedPeer , TraceSendRecv (..) , DecoderFailure (..) -- * Pipelined peers , runPipelinedPeer + , runPipelinedAnnotatedPeer -- * Connected peers -- TODO: move these to a test lib , Role (..) @@ -43,6 +44,9 @@ import Ouroboros.Network.Channel import Control.Monad.Class.MonadAsync import Control.Monad.Class.MonadThrow import Control.Tracer (Tracer (..), contramap, traceWith) +import Data.Maybe (fromMaybe) +import Data.Functor.Identity (Identity) +import Control.Monad.Identity (Identity(..)) -- $intro @@ -107,18 +111,31 @@ instance Show DecoderFailure where instance Exception DecoderFailure where -driverSimple :: forall ps failure bytes m. - ( MonadThrow m - , Show failure - , forall (st :: ps). Show (ClientHasAgency st) - , forall (st :: ps). Show (ServerHasAgency st) - , ShowProxy ps - ) - => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes - -> Channel m bytes - -> Driver ps (Maybe bytes) m -driverSimple tracer Codec{encode, decode} channel@Channel{send} = +mkSimpleDriver :: forall ps failure bytes m f annotator. + ( MonadThrow m + , Show failure + , forall (st :: ps). Show (ClientHasAgency st) + , forall (st :: ps). Show (ServerHasAgency st) + , ShowProxy ps + ) + => (forall a. + Channel m bytes + -> Maybe bytes + -> DecodeStep bytes failure m (f a) + -> m (Either failure (a, Maybe bytes)) + ) + -- ^ run incremental decoder against a channel + + -> (forall st. annotator st -> f (SomeMessage st)) + -- ^ transform annotator to a container holding the decoded + -- message + + -> Tracer m (TraceSendRecv ps) + -> Codec' ps failure m annotator bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m + +mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{send} = Driver { sendMessage, recvMessage, startDState = Nothing } where sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps). @@ -135,7 +152,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} = -> m (SomeMessage st, Maybe bytes) recvMessage stok trailing = do decoder <- decode stok - result <- runDecoderWithChannel channel trailing decoder + result <- runDecodeSteps channel trailing (nat <$> decoder) case result of Right x@(SomeMessage msg, _trailing') -> do traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg)) @@ -144,6 +161,36 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} = throwIO (DecoderFailure stok failure) +simpleDriver :: forall ps failure bytes m. + ( MonadThrow m + , Show failure + , forall (st :: ps). Show (ClientHasAgency st) + , forall (st :: ps). Show (ServerHasAgency st) + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> Codec ps failure m bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m +simpleDriver = mkSimpleDriver runDecoderWithChannel Identity + + +annotatedSimpleDriver + :: forall ps failure bytes m. + ( MonadThrow m + , Monoid bytes + , Show failure + , forall (st :: ps). Show (ClientHasAgency st) + , forall (st :: ps). Show (ServerHasAgency st) + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> AnnotatedCodec ps failure m bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m +annotatedSimpleDriver = mkSimpleDriver runAnnotatedDecoderWithChannel runAnnotator + + -- | Run a peer with the given channel via the given codec. -- -- This runs the peer to completion (if the protocol allows for termination). @@ -164,7 +211,31 @@ runPeer runPeer tracer codec channel peer = runPeerWithDriver driver peer (startDState driver) where - driver = driverSimple tracer codec channel + driver = simpleDriver tracer codec channel + + +-- | Run a peer with the given channel via the given annotated codec. +-- +-- This runs the peer to completion (if the protocol allows for termination). +-- +runAnnotatedPeer + :: forall ps (st :: ps) pr failure bytes m a . + ( MonadThrow m + , Monoid bytes + , Show failure + , forall (st' :: ps). Show (ClientHasAgency st') + , forall (st' :: ps). Show (ServerHasAgency st') + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> AnnotatedCodec ps failure m bytes + -> Channel m bytes + -> Peer ps pr st m a + -> m (a, Maybe bytes) +runAnnotatedPeer tracer codec channel peer = + runPeerWithDriver driver peer (startDState driver) + where + driver = annotatedSimpleDriver tracer codec channel -- | Run a pipelined peer with the given channel via the given codec. @@ -191,7 +262,35 @@ runPipelinedPeer runPipelinedPeer tracer codec channel peer = runPipelinedPeerWithDriver driver peer (startDState driver) where - driver = driverSimple tracer codec channel + driver = simpleDriver tracer codec channel + + +-- | Run a pipelined peer with the given channel via the given annotated codec. +-- +-- This runs the peer to completion (if the protocol allows for termination). +-- +-- Unlike normal peers, running pipelined peers rely on concurrency, hence the +-- 'MonadAsync' constraint. +-- +runPipelinedAnnotatedPeer + :: forall ps (st :: ps) pr failure bytes m a. + ( MonadAsync m + , MonadThrow m + , Monoid bytes + , Show failure + , forall (st' :: ps). Show (ClientHasAgency st') + , forall (st' :: ps). Show (ServerHasAgency st') + , ShowProxy ps + ) + => Tracer m (TraceSendRecv ps) + -> AnnotatedCodec ps failure m bytes + -> Channel m bytes + -> PeerPipelined ps pr st m a + -> m (a, Maybe bytes) +runPipelinedAnnotatedPeer tracer codec channel peer = + runPipelinedPeerWithDriver driver peer (startDState driver) + where + driver = annotatedSimpleDriver tracer codec channel -- @@ -204,17 +303,36 @@ runPipelinedPeer tracer codec channel peer = runDecoderWithChannel :: Monad m => Channel m bytes -> Maybe bytes - -> DecodeStep bytes failure m a + -> DecodeStep bytes failure m (Identity a) -> m (Either failure (a, Maybe bytes)) runDecoderWithChannel Channel{recv} = go where - go _ (DecodeDone x trailing) = return (Right (x, trailing)) + go _ (DecodeDone (Identity x) trailing) = return (Right (x, trailing)) go _ (DecodeFail failure) = return (Left failure) go Nothing (DecodePartial k) = recv >>= k >>= go Nothing go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing +runAnnotatedDecoderWithChannel + :: forall m bytes failure a. + ( Monad m + , Monoid bytes + ) + => Channel m bytes + -> Maybe bytes + -> DecodeStep bytes failure m (bytes -> a) + -> m (Either failure (a, Maybe bytes)) + +runAnnotatedDecoderWithChannel Channel{recv} bs0 = go (fromMaybe mempty bs0) bs0 + where + go :: bytes -> Maybe bytes -> DecodeStep bytes failure m (bytes -> a) -> m (Either failure (a, Maybe bytes)) + go bytes _ (DecodeDone f trailing) = return $ Right (f bytes, trailing) + go _bytes _ (DecodeFail failure) = return (Left failure) + go bytes Nothing (DecodePartial k) = recv >>= \bs -> k bs >>= go (bytes <> fromMaybe mempty bs) Nothing + go bytes (Just trailing) (DecodePartial k) = k (Just trailing) >>= go (bytes <> trailing) Nothing + + data Role = Client | Server -- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'. diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs index 7bfa2f0f806..4a91658de5c 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs @@ -70,7 +70,7 @@ codecTxSubmission2 -> (forall s . CBOR.Decoder s txid) -> (tx -> CBOR.Encoding) -> (forall s . CBOR.Decoder s tx) - -> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString + -> AnnotatedCodec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString codecTxSubmission2 encodeTxId decodeTxId encodeTx decodeTx = mkCodecCborLazyBS @@ -79,7 +79,7 @@ codecTxSubmission2 encodeTxId decodeTxId where decode :: forall (pr :: PeerRole) (st :: TxSubmission2 txid tx). PeerHasAgency pr st - -> forall s. CBOR.Decoder s (SomeMessage st) + -> forall s. CBOR.Decoder s (Annotator ByteString st) decode stok = do len <- CBOR.decodeListLen key <- CBOR.decodeWord @@ -156,26 +156,26 @@ decodeTxSubmission2 PeerHasAgency pr st -> Int -> Word - -> CBOR.Decoder s (SomeMessage st)) + -> CBOR.Decoder s (Annotator ByteString st)) decodeTxSubmission2 decodeTxId decodeTx = decode where decode :: forall (pr :: PeerRole) s (st :: TxSubmission2 txid tx). PeerHasAgency pr st -> Int -> Word - -> CBOR.Decoder s (SomeMessage st) + -> CBOR.Decoder s (Annotator ByteString st) decode stok len key = do case (stok, len, key) of (ClientAgency TokInit, 1, 6) -> - return (SomeMessage MsgInit) + return (Annotator $ \_ -> SomeMessage MsgInit) (ServerAgency TokIdle, 4, 0) -> do blocking <- CBOR.decodeBool ackNo <- NumTxIdsToAck <$> CBOR.decodeWord16 reqNo <- NumTxIdsToReq <$> CBOR.decodeWord16 return $! if blocking - then SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo) - else SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo) + then Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo) + else Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo) (ClientAgency (TokTxIds b), 2, 1) -> do CBOR.decodeListLenIndef @@ -187,11 +187,11 @@ decodeTxSubmission2 decodeTxId decodeTx = decode return (txid, SizeInBytes sz)) case (b, txids) of (TokBlocking, t:ts) -> - return $ + return $ Annotator $ \_ -> SomeMessage (MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts))) (TokNonBlocking, ts) -> - return $ + return $ Annotator $ \_ -> SomeMessage (MsgReplyTxIds (NonBlockingReply ts)) (TokBlocking, []) -> @@ -201,15 +201,26 @@ decodeTxSubmission2 decodeTxId decodeTx = decode (ServerAgency TokIdle, 2, 2) -> do CBOR.decodeListLenIndef txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTxId - return (SomeMessage (MsgRequestTxs txids)) + return (Annotator $ \_ -> SomeMessage (MsgRequestTxs txids)) (ClientAgency TokTxs, 2, 3) -> do CBOR.decodeListLenIndef txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTx - return (SomeMessage (MsgReplyTxs txids)) + -- ^ TODO: `txids -> txs` :grin: + return (Annotator $ + -- TODO: here we have access to bytes from which the message was decoded. + -- we can use `Codec.CBOR.Decoding.decodeWithByteSpan` + -- around each `tx` and wrap each `tx` in `WithBytes`. + -- + -- `decodeTxSubmission2` can be polymorphic by adding an + -- extra argument of type + -- `ByteString -> ByteOffSet -> ByteOffset -> tx -> a` + -- this way we could wrap `tx` in `WithBytes` or just + -- return `tx`. + \_bytes -> SomeMessage (MsgReplyTxs txids)) (ClientAgency (TokTxIds TokBlocking), 1, 4) -> - return (SomeMessage MsgDone) + return (Annotator $ \_ -> SomeMessage MsgDone) -- -- failures per protocol state