Skip to content

Commit

Permalink
WIP: Annotated decoder
Browse files Browse the repository at this point in the history
This is an experiment to provide `runAnnotatedPeer`, which is like
`runPeer' but allows us to run a decoder which has access to bytes used
when decoding a message.   This allows one to record ByteString from
which a piece of data was decoded, e.g. for each `tx` inside
`MsgReplyTxs`.

The `Codec` type in `typed-protocols` was generalised for this purpose.
The core functionality is implemented in
`runAnnotatedDecoderWithChannel` which runs `AnnotatedCodec` against
a `Channel` which does incremental decoding & recording bytes used so
far.  We also expose `runAnnotatedPeer` which runs a `Peer` against
`Channel` using an `AnnotatedCodec` (using `annotatedDriverSimple`).

TODO:

* `runAnnotatedPeerWithLimits`
* `runAnnotatedPipelinedPeerWithLimits`

It's actually the last one that we will need in `tx-submission`.

* Add
```
data WithBytes a {
  encoded :: ByteString,
  decoded :: a
```
  and generalise `codecTxSubmission2` so that it can be used to used
  with annotator and without it - it might require two separate
  function, but I think it can be generated from one more general
  function (so we don't need to maintain two codecs).

TODO: design & implement quickcheck properties
  • Loading branch information
coot committed Aug 21, 2024
1 parent ed11046 commit 4b47e98
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 30 deletions.
9 changes: 9 additions & 0 deletions cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ package network-mux
package ouroboros-network
flags: +asserts +cddl


source-repository-package
type: git
location: https://github.com/input-output-hk/typed-protocols
tag: d0c0668048be5b9878917180d7a0641861216bec
subdir: typed-protocols
typed-protocols-cborg
allow-newer: typed-protocols:io-classes

154 changes: 136 additions & 18 deletions ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
-- @UndecidableInstances@ extensions is required for defining @Show@ instance
-- of @'TraceSendRecv'@.
Expand All @@ -19,10 +18,12 @@ module Ouroboros.Network.Driver.Simple
-- $intro
-- * Normal peers
runPeer
, runAnnotatedPeer
, TraceSendRecv (..)
, DecoderFailure (..)
-- * Pipelined peers
, runPipelinedPeer
, runPipelinedAnnotatedPeer
-- * Connected peers
-- TODO: move these to a test lib
, Role (..)
Expand All @@ -43,6 +44,9 @@ import Ouroboros.Network.Channel
import Control.Monad.Class.MonadAsync
import Control.Monad.Class.MonadThrow
import Control.Tracer (Tracer (..), contramap, traceWith)
import Data.Maybe (fromMaybe)
import Data.Functor.Identity (Identity)
import Control.Monad.Identity (Identity(..))


-- $intro
Expand Down Expand Up @@ -107,18 +111,31 @@ instance Show DecoderFailure where
instance Exception DecoderFailure where


driverSimple :: forall ps failure bytes m.
( MonadThrow m
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m
driverSimple tracer Codec{encode, decode} channel@Channel{send} =
mkSimpleDriver :: forall ps failure bytes m f annotator.
( MonadThrow m
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> (forall a.
Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m (f a)
-> m (Either failure (a, Maybe bytes))
)
-- ^ run incremental decoder against a channel

-> (forall st. annotator st -> f (SomeMessage st))
-- ^ transform annotator to a container holding the decoded
-- message

-> Tracer m (TraceSendRecv ps)
-> Codec' ps failure m annotator bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m

mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{send} =
Driver { sendMessage, recvMessage, startDState = Nothing }
where
sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps).
Expand All @@ -135,7 +152,7 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
-> m (SomeMessage st, Maybe bytes)
recvMessage stok trailing = do
decoder <- decode stok
result <- runDecoderWithChannel channel trailing decoder
result <- runDecodeSteps channel trailing (nat <$> decoder)
case result of
Right x@(SomeMessage msg, _trailing') -> do
traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg))
Expand All @@ -144,6 +161,36 @@ driverSimple tracer Codec{encode, decode} channel@Channel{send} =
throwIO (DecoderFailure stok failure)


simpleDriver :: forall ps failure bytes m.
( MonadThrow m
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> Codec ps failure m bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m
simpleDriver = mkSimpleDriver runDecoderWithChannel Identity


annotatedSimpleDriver
:: forall ps failure bytes m.
( MonadThrow m
, Monoid bytes
, Show failure
, forall (st :: ps). Show (ClientHasAgency st)
, forall (st :: ps). Show (ServerHasAgency st)
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> AnnotatedCodec ps failure m bytes
-> Channel m bytes
-> Driver ps (Maybe bytes) m
annotatedSimpleDriver = mkSimpleDriver runAnnotatedDecoderWithChannel runAnnotator


-- | Run a peer with the given channel via the given codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
Expand All @@ -164,7 +211,31 @@ runPeer
runPeer tracer codec channel peer =
runPeerWithDriver driver peer (startDState driver)
where
driver = driverSimple tracer codec channel
driver = simpleDriver tracer codec channel


-- | Run a peer with the given channel via the given annotated codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
--
runAnnotatedPeer
:: forall ps (st :: ps) pr failure bytes m a .
( MonadThrow m
, Monoid bytes
, Show failure
, forall (st' :: ps). Show (ClientHasAgency st')
, forall (st' :: ps). Show (ServerHasAgency st')
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> AnnotatedCodec ps failure m bytes
-> Channel m bytes
-> Peer ps pr st m a
-> m (a, Maybe bytes)
runAnnotatedPeer tracer codec channel peer =
runPeerWithDriver driver peer (startDState driver)
where
driver = annotatedSimpleDriver tracer codec channel


-- | Run a pipelined peer with the given channel via the given codec.
Expand All @@ -191,7 +262,35 @@ runPipelinedPeer
runPipelinedPeer tracer codec channel peer =
runPipelinedPeerWithDriver driver peer (startDState driver)
where
driver = driverSimple tracer codec channel
driver = simpleDriver tracer codec channel


-- | Run a pipelined peer with the given channel via the given annotated codec.
--
-- This runs the peer to completion (if the protocol allows for termination).
--
-- Unlike normal peers, running pipelined peers rely on concurrency, hence the
-- 'MonadAsync' constraint.
--
runPipelinedAnnotatedPeer
:: forall ps (st :: ps) pr failure bytes m a.
( MonadAsync m
, MonadThrow m
, Monoid bytes
, Show failure
, forall (st' :: ps). Show (ClientHasAgency st')
, forall (st' :: ps). Show (ServerHasAgency st')
, ShowProxy ps
)
=> Tracer m (TraceSendRecv ps)
-> AnnotatedCodec ps failure m bytes
-> Channel m bytes
-> PeerPipelined ps pr st m a
-> m (a, Maybe bytes)
runPipelinedAnnotatedPeer tracer codec channel peer =
runPipelinedPeerWithDriver driver peer (startDState driver)
where
driver = annotatedSimpleDriver tracer codec channel


--
Expand All @@ -204,17 +303,36 @@ runPipelinedPeer tracer codec channel peer =
runDecoderWithChannel :: Monad m
=> Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m a
-> DecodeStep bytes failure m (Identity a)
-> m (Either failure (a, Maybe bytes))

runDecoderWithChannel Channel{recv} = go
where
go _ (DecodeDone x trailing) = return (Right (x, trailing))
go _ (DecodeDone (Identity x) trailing) = return (Right (x, trailing))
go _ (DecodeFail failure) = return (Left failure)
go Nothing (DecodePartial k) = recv >>= k >>= go Nothing
go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing


runAnnotatedDecoderWithChannel
:: forall m bytes failure a.
( Monad m
, Monoid bytes
)
=> Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure m (bytes -> a)
-> m (Either failure (a, Maybe bytes))

runAnnotatedDecoderWithChannel Channel{recv} bs0 = go (fromMaybe mempty bs0) bs0
where
go :: bytes -> Maybe bytes -> DecodeStep bytes failure m (bytes -> a) -> m (Either failure (a, Maybe bytes))
go bytes _ (DecodeDone f trailing) = return $ Right (f bytes, trailing)
go _bytes _ (DecodeFail failure) = return (Left failure)
go bytes Nothing (DecodePartial k) = recv >>= \bs -> k bs >>= go (bytes <> fromMaybe mempty bs) Nothing
go bytes (Just trailing) (DecodePartial k) = k (Just trailing) >>= go (bytes <> trailing) Nothing


data Role = Client | Server

-- | Run two 'Peer's via a pair of connected 'Channel's and a common 'Codec'.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ codecTxSubmission2
-> (forall s . CBOR.Decoder s txid)
-> (tx -> CBOR.Encoding)
-> (forall s . CBOR.Decoder s tx)
-> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString
-> AnnotatedCodec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString
codecTxSubmission2 encodeTxId decodeTxId
encodeTx decodeTx =
mkCodecCborLazyBS
Expand All @@ -79,7 +79,7 @@ codecTxSubmission2 encodeTxId decodeTxId
where
decode :: forall (pr :: PeerRole) (st :: TxSubmission2 txid tx).
PeerHasAgency pr st
-> forall s. CBOR.Decoder s (SomeMessage st)
-> forall s. CBOR.Decoder s (Annotator ByteString st)
decode stok = do
len <- CBOR.decodeListLen
key <- CBOR.decodeWord
Expand Down Expand Up @@ -156,26 +156,26 @@ decodeTxSubmission2
PeerHasAgency pr st
-> Int
-> Word
-> CBOR.Decoder s (SomeMessage st))
-> CBOR.Decoder s (Annotator ByteString st))
decodeTxSubmission2 decodeTxId decodeTx = decode
where
decode :: forall (pr :: PeerRole) s (st :: TxSubmission2 txid tx).
PeerHasAgency pr st
-> Int
-> Word
-> CBOR.Decoder s (SomeMessage st)
-> CBOR.Decoder s (Annotator ByteString st)
decode stok len key = do
case (stok, len, key) of
(ClientAgency TokInit, 1, 6) ->
return (SomeMessage MsgInit)
return (Annotator $ \_ -> SomeMessage MsgInit)
(ServerAgency TokIdle, 4, 0) -> do
blocking <- CBOR.decodeBool
ackNo <- NumTxIdsToAck <$> CBOR.decodeWord16
reqNo <- NumTxIdsToReq <$> CBOR.decodeWord16
return $!
if blocking
then SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo)
else SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo)
then Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo)
else Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo)

(ClientAgency (TokTxIds b), 2, 1) -> do
CBOR.decodeListLenIndef
Expand All @@ -187,11 +187,11 @@ decodeTxSubmission2 decodeTxId decodeTx = decode
return (txid, SizeInBytes sz))
case (b, txids) of
(TokBlocking, t:ts) ->
return $
return $ Annotator $ \_ ->
SomeMessage (MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts)))

(TokNonBlocking, ts) ->
return $
return $ Annotator $ \_ ->
SomeMessage (MsgReplyTxIds (NonBlockingReply ts))

(TokBlocking, []) ->
Expand All @@ -201,15 +201,26 @@ decodeTxSubmission2 decodeTxId decodeTx = decode
(ServerAgency TokIdle, 2, 2) -> do
CBOR.decodeListLenIndef
txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTxId
return (SomeMessage (MsgRequestTxs txids))
return (Annotator $ \_ -> SomeMessage (MsgRequestTxs txids))

(ClientAgency TokTxs, 2, 3) -> do
CBOR.decodeListLenIndef
txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTx
return (SomeMessage (MsgReplyTxs txids))
-- ^ TODO: `txids -> txs` :grin:
return (Annotator $
-- TODO: here we have access to bytes from which the message was decoded.
-- we can use `Codec.CBOR.Decoding.decodeWithByteSpan`
-- around each `tx` and wrap each `tx` in `WithBytes`.
--
-- `decodeTxSubmission2` can be polymorphic by adding an
-- extra argument of type
-- `ByteString -> ByteOffSet -> ByteOffset -> tx -> a`
-- this way we could wrap `tx` in `WithBytes` or just
-- return `tx`.
\_bytes -> SomeMessage (MsgReplyTxs txids))

(ClientAgency (TokTxIds TokBlocking), 1, 4) ->
return (SomeMessage MsgDone)
return (Annotator $ \_ -> SomeMessage MsgDone)

--
-- failures per protocol state
Expand Down

0 comments on commit 4b47e98

Please sign in to comment.