Skip to content
This repository has been archived by the owner on Dec 26, 2024. It is now read-only.

Commit

Permalink
feat(network): add get blocks inbound response logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nagmo-starkware committed Sep 4, 2023
1 parent b71e367 commit f291cee
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 50 deletions.
19 changes: 10 additions & 9 deletions crates/papyrus_network/src/get_blocks/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use libp2p::swarm::{
SubstreamProtocol,
};

use super::protocol::{RequestProtocol, RequestProtocolError, ResponseProtocol, PROTOCOL_NAME};
use super::protocol::{InboundProtocol, OutboundProtocol, OutboundProtocolError, PROTOCOL_NAME};
use super::RequestId;
use crate::messages::block::{GetBlocks, GetBlocksResponse};

Expand Down Expand Up @@ -79,15 +79,15 @@ impl Handler {

fn convert_upgrade_error(
&self,
error: StreamUpgradeError<RequestProtocolError>,
error: StreamUpgradeError<OutboundProtocolError>,
) -> RequestError {
match error {
StreamUpgradeError::Timeout => {
RequestError::Timeout { substream_timeout: self.substream_timeout }
}
StreamUpgradeError::Apply(request_protocol_error) => match request_protocol_error {
RequestProtocolError::IOError(error) => RequestError::IOError(error),
RequestProtocolError::ResponseSendError(error) => {
OutboundProtocolError::IOError(error) => RequestError::IOError(error),
OutboundProtocolError::ResponseSendError(error) => {
RequestError::ResponseSendError(error)
}
},
Expand All @@ -111,13 +111,14 @@ impl ConnectionHandler for Handler {
type FromBehaviour = NewRequestEvent;
type ToBehaviour = RequestProgressEvent;
type Error = RemoteDoesntSupportProtocolError;
type InboundProtocol = ResponseProtocol;
type OutboundProtocol = RequestProtocol;
type InboundProtocol = InboundProtocol;
type OutboundProtocol = OutboundProtocol;
type InboundOpenInfo = ();
type OutboundOpenInfo = RequestId;

fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(ResponseProtocol {}, ()).with_timeout(self.substream_timeout)
let (inbound_protocol, _) = InboundProtocol::new();
SubstreamProtocol::new(inbound_protocol, ()).with_timeout(self.substream_timeout)
}

fn connection_keep_alive(&self) -> KeepAlive {
Expand Down Expand Up @@ -157,14 +158,14 @@ impl ConnectionHandler for Handler {

fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
let NewRequestEvent { request, request_id } = event;
let (request_protocol, responses_receiver) = RequestProtocol::new(request);
let (outbound_protocol, responses_receiver) = OutboundProtocol::new(request);
let insert_result =
self.request_to_responses_receiver.insert(request_id, responses_receiver);
if insert_result.is_some() {
panic!("Multiple requests exist with the same ID {}", request_id);
}
self.pending_events.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(request_protocol, request_id)
protocol: SubstreamProtocol::new(outbound_protocol, request_id)
.with_timeout(self.substream_timeout),
});
}
Expand Down
87 changes: 67 additions & 20 deletions crates/papyrus_network/src/get_blocks/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{io, iter};

use futures::channel::mpsc::{unbounded, TrySendError, UnboundedReceiver, UnboundedSender};
use futures::future::BoxFuture;
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt};
use futures::{AsyncRead, AsyncWrite, AsyncWriteExt, FutureExt, StreamExt};
use libp2p::core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use libp2p::swarm::StreamProtocol;

Expand All @@ -18,9 +18,32 @@ pub const PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/get_blocks/1.0.0
/// Substream upgrade protocol for sending data on blocks.
///
/// Receives a request to get a range of blocks and sends a stream of data on the blocks.
pub struct ResponseProtocol;
pub struct InboundProtocol {
request_relay_sender: UnboundedSender<GetBlocks>,
response_relay_receiver: UnboundedReceiver<Option<GetBlocksResponse>>,
}

impl InboundProtocol {
pub fn new()
-> (Self, (UnboundedReceiver<GetBlocks>, UnboundedSender<Option<GetBlocksResponse>>)) {
let (request_relay_sender, request_relay_receiver) = unbounded();
let (response_relay_sender, response_relay_receiver) = unbounded();
(
Self { request_relay_sender, response_relay_receiver },
(request_relay_receiver, response_relay_sender),
)
}
}

impl UpgradeInfo for ResponseProtocol {
#[derive(thiserror::Error, Debug)]
pub enum InboundProtocolError {
#[error(transparent)]
IOError(#[from] io::Error),
#[error(transparent)]
RequestSendError(#[from] TrySendError<GetBlocks>),
}

impl UpgradeInfo for InboundProtocol {
type Info = StreamProtocol;
type InfoIter = iter::Once<Self::Info>;

Expand All @@ -29,37 +52,61 @@ impl UpgradeInfo for ResponseProtocol {
}
}

impl<Stream> InboundUpgrade<Stream> for ResponseProtocol
impl<Stream> InboundUpgrade<Stream> for InboundProtocol
where
Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = ();
type Error = io::Error;
type Error = InboundProtocolError;
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;

fn upgrade_inbound(self, mut io: Stream, _: Self::Info) -> Self::Future {
async move {
read_message::<GetBlocks, _>(&mut io).await?;
for response in hardcoded_responses() {
write_message(response, &mut io).await?;
fn upgrade_inbound(mut self, mut io: Stream, _: Self::Info) -> Self::Future {
Box::pin(
async move {
if let Ok(get_blocks_msg) = read_message::<GetBlocks, _>(&mut io).await {
self.request_relay_sender.unbounded_send(get_blocks_msg)?;
}
let mut expect_end_of_stream = false;
loop {
match self.response_relay_receiver.next().await {
Some(response) => match response {
Some(res) => write_message(res, &mut io).await?,
None => {
expect_end_of_stream = true;
write_message(
GetBlocksResponse { response: Some(Response::Fin(Fin {})) },
&mut io,
)
.await?;
}
},
None => {
if expect_end_of_stream {
return Ok(());
}
return Err(InboundProtocolError::IOError(io::Error::new(
io::ErrorKind::UnexpectedEof,
"Unexpected end of stream",
)));
}
};
}
}
io.close().await?;
Ok(())
}
.boxed()
.boxed(),
)
}
}

/// Substream upgrade protocol for requesting data on blocks.
///
/// Sends a request to get a range of blocks and receives a stream of data on the blocks.
#[derive(Debug)]
pub struct RequestProtocol {
pub struct OutboundProtocol {
request: GetBlocks,
responses_sender: UnboundedSender<GetBlocksResponse>,
}

impl RequestProtocol {
impl OutboundProtocol {
pub fn new(request: GetBlocks) -> (Self, UnboundedReceiver<GetBlocksResponse>) {
let (responses_sender, responses_receiver) = unbounded();
(Self { request, responses_sender }, responses_receiver)
Expand All @@ -77,14 +124,14 @@ impl RequestProtocol {
}

#[derive(thiserror::Error, Debug)]
pub enum RequestProtocolError {
pub enum OutboundProtocolError {
#[error(transparent)]
IOError(#[from] io::Error),
#[error(transparent)]
ResponseSendError(#[from] TrySendError<GetBlocksResponse>),
}

impl UpgradeInfo for RequestProtocol {
impl UpgradeInfo for OutboundProtocol {
type Info = StreamProtocol;
type InfoIter = iter::Once<Self::Info>;

Expand All @@ -93,12 +140,12 @@ impl UpgradeInfo for RequestProtocol {
}
}

impl<Stream> OutboundUpgrade<Stream> for RequestProtocol
impl<Stream> OutboundUpgrade<Stream> for OutboundProtocol
where
Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = ();
type Error = RequestProtocolError;
type Error = OutboundProtocolError;
type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;

fn upgrade_outbound(self, mut io: Stream, _: Self::Info) -> Self::Future {
Expand Down
101 changes: 80 additions & 21 deletions crates/papyrus_network/src/get_blocks/protocol_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use assert_matches::assert_matches;
use futures::{AsyncRead, AsyncWrite, Future, StreamExt};
use futures::{AsyncRead, AsyncWrite, Future, SinkExt, StreamExt};
use libp2p::core::multiaddr::multiaddr;
use libp2p::core::transport::memory::MemoryTransport;
use libp2p::core::transport::{ListenerId, Transport};
Expand All @@ -9,19 +9,20 @@ use pretty_assertions::assert_eq;

use super::{
hardcoded_responses,
RequestProtocol,
RequestProtocolError,
ResponseProtocol,
InboundProtocol,
OutboundProtocol,
OutboundProtocolError,
PROTOCOL_NAME,
};
use crate::get_blocks::protocol::InboundProtocolError;
use crate::messages::block::{GetSignatures, NewBlock};
use crate::messages::common::BlockId;
use crate::messages::write_message;

#[test]
fn both_protocols_have_same_info() {
let (outbound_protocol, _) = RequestProtocol::new(Default::default());
let inbound_protocol = ResponseProtocol;
let (outbound_protocol, _) = OutboundProtocol::new(Default::default());
let (inbound_protocol, _) = InboundProtocol::new();
assert_eq!(
outbound_protocol.protocol_info().collect::<Vec<_>>(),
inbound_protocol.protocol_info().collect::<Vec<_>>()
Expand Down Expand Up @@ -55,8 +56,9 @@ async fn get_connected_io_futures() -> (
async fn positive_flow() {
let (inbound_io_future, outbound_io_future) = get_connected_io_futures().await;

let (outbound_protocol, mut responses_receiver) = RequestProtocol::new(Default::default());
let inbound_protocol = ResponseProtocol;
let (outbound_protocol, mut responses_receiver) = OutboundProtocol::new(Default::default());
let (inbound_protocol, (mut request_relay_receiver, response_relay_sender)) =
InboundProtocol::new();

tokio::join!(
async move {
Expand All @@ -68,14 +70,25 @@ async fn positive_flow() {
.await
.unwrap();
},
// plays the role of the network and DB handlers
async move {
// ignore block query for now, just send hardcoded responses
let _blocks_query = request_relay_receiver.next().await.unwrap();
request_relay_receiver.close();
for expected_response in hardcoded_responses() {
let result = responses_receiver.next().await;
if expected_response.is_fin() {
assert!(result.is_none());
break;
} else {
assert_eq!(result.unwrap(), expected_response);
let msg =
if expected_response.is_fin() { None } else { Some(expected_response.clone()) };
match response_relay_sender.unbounded_send(msg) {
Ok(_) => {
let result = responses_receiver.next().await;
if expected_response.is_fin() {
assert!(result.is_none());
break;
} else {
assert_eq!(result.unwrap(), expected_response);
}
}
Err(err) => panic!("Failed to send response with err: {err}"),
}
}
}
Expand All @@ -86,7 +99,7 @@ async fn positive_flow() {
async fn inbound_sends_invalid_response() {
let (inbound_io_future, outbound_io_future) = get_connected_io_futures().await;

let (outbound_protocol, mut responses_receiver) = RequestProtocol::new(Default::default());
let (outbound_protocol, mut responses_receiver) = OutboundProtocol::new(Default::default());

tokio::join!(
async move {
Expand All @@ -103,7 +116,7 @@ async fn inbound_sends_invalid_response() {
.upgrade_outbound(outbound_io_future.await, PROTOCOL_NAME)
.await
.unwrap_err();
assert_matches!(err, RequestProtocolError::IOError(_));
assert_matches!(err, OutboundProtocolError::IOError(_));
},
async move { assert!(responses_receiver.next().await.is_none()) }
);
Expand All @@ -112,7 +125,7 @@ async fn inbound_sends_invalid_response() {
#[tokio::test]
async fn outbound_sends_invalid_request() {
let (inbound_io_future, outbound_io_future) = get_connected_io_futures().await;
let inbound_protocol = ResponseProtocol;
let (inbound_protocol, _) = InboundProtocol::new();

tokio::join!(
async move {
Expand All @@ -137,20 +150,66 @@ async fn outbound_sends_invalid_request() {
async fn outbound_receiver_closed() {
let (inbound_io_future, outbound_io_future) = get_connected_io_futures().await;

let (outbound_protocol, mut responses_receiver) = RequestProtocol::new(Default::default());
let inbound_protocol = ResponseProtocol;
let (outbound_protocol, mut responses_receiver) = OutboundProtocol::new(Default::default());
let (inbound_protocol, _) = InboundProtocol::new();
responses_receiver.close();

tokio::join!(
async move {
inbound_protocol.upgrade_inbound(inbound_io_future.await, PROTOCOL_NAME).await.unwrap();
inbound_protocol
.upgrade_inbound(inbound_io_future.await, PROTOCOL_NAME)
.await
.unwrap_err();
},
async move {
let err = outbound_protocol
.upgrade_outbound(outbound_io_future.await, PROTOCOL_NAME)
.await
.unwrap_err();
assert_matches!(err, RequestProtocolError::ResponseSendError(_));
assert_matches!(err, OutboundProtocolError::ResponseSendError(_));
},
);
}

#[tokio::test]
async fn response_relay_stops_unexpectedly() {
let (inbound_io_future, outbound_io_future) = get_connected_io_futures().await;

let (outbound_protocol, mut responses_receiver) = OutboundProtocol::new(Default::default());
let (inbound_protocol, (mut request_relay_receiver, mut response_relay_sender)) =
InboundProtocol::new();

tokio::join!(
async move {
match inbound_protocol
.upgrade_inbound(inbound_io_future.await, PROTOCOL_NAME)
.await
.unwrap_err()
{
InboundProtocolError::IOError(_) => {}
err => panic!("Unexpected error: {:?}", err),
};
},
async move {
outbound_protocol
.upgrade_outbound(outbound_io_future.await, PROTOCOL_NAME)
.await
.unwrap_err();
},
// plays the role of the network handler
async move {
let _blocks_query = request_relay_receiver.next().await.unwrap();
request_relay_receiver.close();
let responses = hardcoded_responses();
match response_relay_sender.unbounded_send(Some(responses[0].clone())) {
Ok(_) => {
let result = responses_receiver.next().await;
assert_eq!(result.unwrap(), responses[0]);
}
Err(err) => panic!("Failed to send response with err: {err}"),
}
let _res = response_relay_sender.close().await;
assert!(responses_receiver.try_next().is_err());
}
);
}

0 comments on commit f291cee

Please sign in to comment.