Skip to content
This repository has been archived by the owner on Dec 26, 2024. It is now read-only.

Commit

Permalink
feat(network): add get_blocks::Handler with simple test
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama committed Sep 4, 2023
1 parent ec2365a commit a5ddf73
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/papyrus_network/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ license-file.workspace = true

[dependencies]
bytes.workspace = true
derive_more.workspace = true
futures.workspace = true
libp2p.workspace = true
prost.workspace = true
Expand Down
222 changes: 222 additions & 0 deletions crates/papyrus_network/src/get_blocks/handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#[cfg(test)]
#[path = "handler_test.rs"]
mod handler_test;

use std::collections::{HashMap, VecDeque};
use std::io;
use std::task::{Context, Poll};
use std::time::Duration;

use futures::channel::mpsc::{TrySendError, UnboundedReceiver};
use futures::StreamExt;
use libp2p::swarm::handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedOutbound};
use libp2p::swarm::{
ConnectionHandler,
ConnectionHandlerEvent,
KeepAlive,
StreamUpgradeError,
SubstreamProtocol,
};

use super::protocol::{RequestProtocol, RequestProtocolError, ResponseProtocol, PROTOCOL_NAME};
use super::RequestId;
use crate::messages::block::{GetBlocks, GetBlocksResponse};

// TODO(shahak): Add a FromBehaviour event for cancelling an existing request.
#[derive(Debug)]
pub struct NewRequestEvent {
pub request: GetBlocks,
pub request_id: RequestId,
}

#[derive(thiserror::Error, Debug)]
pub enum RequestError {
#[error("Connection timed out after {} seconds.", substream_timeout.as_secs())]
Timeout { substream_timeout: Duration },
#[error(transparent)]
IOError(#[from] io::Error),
#[error(transparent)]
ResponseSendError(#[from] TrySendError<GetBlocksResponse>),
#[error("Remote peer doesn't support the {PROTOCOL_NAME} protocol.")]
RemoteDoesntSupportProtocol,
}

#[derive(thiserror::Error, Debug)]
#[error("Remote peer doesn't support the {PROTOCOL_NAME} protocol.")]
pub struct RemoteDoesntSupportProtocolError;

#[derive(Debug)]
pub enum RequestProgressEvent {
ReceivedResponse { request_id: RequestId, response: GetBlocksResponse },
RequestFinished { request_id: RequestId },
RequestFailed { request_id: RequestId, error: RequestError },
}

type HandlerEvent<H> = ConnectionHandlerEvent<
<H as ConnectionHandler>::OutboundProtocol,
<H as ConnectionHandler>::OutboundOpenInfo,
<H as ConnectionHandler>::ToBehaviour,
<H as ConnectionHandler>::Error,
>;

pub struct Handler {
substream_timeout: Duration,
request_to_responses_receiver: HashMap<RequestId, UnboundedReceiver<GetBlocksResponse>>,
pending_events: VecDeque<HandlerEvent<Self>>,
ready_requests: VecDeque<(RequestId, GetBlocksResponse)>,
}

impl Handler {
// TODO(shahak) If we'll add more parameters, consider creating a HandlerConfig struct.
pub fn new(substream_timeout: Duration) -> Self {
Self {
substream_timeout,
request_to_responses_receiver: Default::default(),
pending_events: Default::default(),
ready_requests: Default::default(),
}
}

fn convert_upgrade_error(
&self,
error: StreamUpgradeError<RequestProtocolError>,
) -> RequestError {
match error {
StreamUpgradeError::Timeout => {
RequestError::Timeout { substream_timeout: self.substream_timeout }
}
StreamUpgradeError::Apply(request_protocol_error) => match request_protocol_error {
RequestProtocolError::IOError(error) => RequestError::IOError(error),
RequestProtocolError::ResponseSendError(error) => {
RequestError::ResponseSendError(error)
}
},
StreamUpgradeError::NegotiationFailed => RequestError::RemoteDoesntSupportProtocol,
StreamUpgradeError::Io(error) => RequestError::IOError(error),
}
}

fn clear_pending_events_related_to_request(&mut self, request_id: RequestId) {
self.pending_events.retain(|event| match event {
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::ReceivedResponse {
request_id: other_request_id,
..
}) => request_id != *other_request_id,
_ => true,
})
}
}

impl ConnectionHandler for Handler {
type FromBehaviour = NewRequestEvent;
type ToBehaviour = RequestProgressEvent;
type Error = RemoteDoesntSupportProtocolError;
type InboundProtocol = ResponseProtocol;
type OutboundProtocol = RequestProtocol;
type InboundOpenInfo = ();
type OutboundOpenInfo = RequestId;

fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(ResponseProtocol {}, ()).with_timeout(self.substream_timeout)
}

fn connection_keep_alive(&self) -> KeepAlive {
// TODO(shahak): Implement keep alive logic.
KeepAlive::Yes
}

fn poll(
&mut self,
cx: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::ToBehaviour,
Self::Error,
>,
> {
// TODO(shahak): Consider handling incoming messages interleaved with handling pending
// events to avoid starvation.
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(event);
}

// Handle incoming messages.
for (request_id, responses_receiver) in &mut self.request_to_responses_receiver {
if let Poll::Ready(Some(response)) = responses_receiver.poll_next_unpin(cx) {
// Collect all ready responses to avoid starvation of the request ids at the end.
self.ready_requests.push_back((*request_id, response));
}
}
if let Some((request_id, response)) = self.ready_requests.pop_front() {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
RequestProgressEvent::ReceivedResponse { request_id, response },
));
}

Poll::Pending
}

fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
// There's only one type of event so we can unpack it without matching.
let NewRequestEvent { request, request_id } = event;
let (request_protocol, responses_receiver) = RequestProtocol::new(request);
let insert_result =
self.request_to_responses_receiver.insert(request_id, responses_receiver);
if insert_result.is_some() {
panic!("Multiple requests exist with the same ID {}", request_id);
}
self.pending_events.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(request_protocol, request_id)
.with_timeout(self.substream_timeout),
});
}

fn on_connection_event(
&mut self,
event: ConnectionEvent<
'_,
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
protocol: _,
info: request_id,
}) => {
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestProgressEvent::RequestFinished { request_id },
));
self.request_to_responses_receiver.remove(&request_id);
}
ConnectionEvent::DialUpgradeError(DialUpgradeError { info: request_id, error }) => {
let error = self.convert_upgrade_error(error);
if matches!(error, RequestError::RemoteDoesntSupportProtocol) {
// This error will happen on all future connections to the peer, so we'll close
// the handle after reporting to the behaviour.
self.pending_events.clear();
self.pending_events.push_front(ConnectionHandlerEvent::NotifyBehaviour(
RequestProgressEvent::RequestFailed { request_id, error },
));
self.pending_events
.push_back(ConnectionHandlerEvent::Close(RemoteDoesntSupportProtocolError));
} else {
self.clear_pending_events_related_to_request(request_id);
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestProgressEvent::RequestFailed { request_id, error },
));
}
self.request_to_responses_receiver.remove(&request_id);
}
ConnectionEvent::FullyNegotiatedInbound(_)
| ConnectionEvent::ListenUpgradeError(_)
| ConnectionEvent::AddressChange(_)
| ConnectionEvent::LocalProtocolsChange(_)
| ConnectionEvent::RemoteProtocolsChange(_) => {}
}
}
}
137 changes: 137 additions & 0 deletions crates/papyrus_network/src/get_blocks/handler_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
use std::iter::zip;
use std::pin::Pin;
use std::time::Duration;

use assert_matches::assert_matches;
use futures::channel::mpsc::UnboundedSender;
use futures::task::{Context, Poll};
use futures::{Stream, StreamExt};
use libp2p::swarm::handler::{ConnectionEvent, FullyNegotiatedOutbound};
use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent};

use super::super::RequestId;
use super::{Handler, HandlerEvent, NewRequestEvent, RequestProgressEvent};
use crate::messages::block::{BlockHeader, GetBlocks, GetBlocksResponse};
use crate::messages::common::BlockId;
use crate::messages::proto::p2p::proto::get_blocks_response::Response;

impl Unpin for Handler {}

impl Stream for Handler {
type Item = HandlerEvent<Handler>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::into_inner(self).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(event) => Poll::Ready(Some(event)),
}
}
}

const SUBSTREAM_TIMEOUT: Duration = Duration::MAX;

async fn start_request_and_validate_event(
handler: &mut Handler,
request: &GetBlocks,
request_id: RequestId,
) -> UnboundedSender<GetBlocksResponse> {
handler.on_behaviour_event(NewRequestEvent { request: request.clone(), request_id });
let event = handler.next().await.unwrap();
let ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } = event else {
panic!("Got unexpected event");
};
assert_eq!(*request, *protocol.upgrade().request());
assert_eq!(SUBSTREAM_TIMEOUT, *protocol.timeout());
protocol.upgrade().responses_sender().clone()
}

async fn send_response_and_validate_event(
handler: &mut Handler,
response: &GetBlocksResponse,
request_id: RequestId,
responses_sender: &UnboundedSender<GetBlocksResponse>,
) {
responses_sender.unbounded_send(response.clone()).unwrap();
let event = handler.next().await.unwrap();
assert_matches!(
event,
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::ReceivedResponse{
request_id: event_request_id, response: event_response
}) if event_request_id == request_id && event_response == *response
);
}

async fn finish_request_and_validate_event(handler: &mut Handler, request_id: RequestId) {
handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
FullyNegotiatedOutbound { protocol: (), info: request_id },
));
let event = handler.next().await.unwrap();
assert_matches!(
event,
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::RequestFinished{
request_id: event_request_id
}) if event_request_id == request_id
);
}

#[tokio::test]
async fn process_request() {
let mut handler = Handler::new(SUBSTREAM_TIMEOUT);

let request = GetBlocks::default();
let request_id = RequestId::default();
let response = GetBlocksResponse {
response: Some(Response::Header(BlockHeader {
parent_block: Some(BlockId { hash: None, height: 1 }),
..Default::default()
})),
};

let responses_sender =
start_request_and_validate_event(&mut handler, &request, request_id).await;

send_response_and_validate_event(&mut handler, &response, request_id, &responses_sender).await;
finish_request_and_validate_event(&mut handler, request_id).await;
}

#[tokio::test]
async fn process_multiple_requests_simultaneously() {
let mut handler = Handler::new(SUBSTREAM_TIMEOUT);

const N_REQUESTS: usize = 20;
let request_ids = (0..N_REQUESTS).map(RequestId).collect::<Vec<_>>();
let requests = (0..N_REQUESTS)
.map(|i| GetBlocks { skip: i as u64, ..Default::default() })
.collect::<Vec<_>>();
let responses = (0..N_REQUESTS)
.map(|i| GetBlocksResponse {
response: Some(Response::Header(BlockHeader {
parent_block: Some(BlockId { hash: None, height: i as u64 }),
..Default::default()
})),
})
.collect::<Vec<_>>();

for ((request, request_id), response) in zip(zip(requests, request_ids), responses.iter()) {
let responses_sender =
start_request_and_validate_event(&mut handler, &request, request_id).await;
responses_sender.unbounded_send(response.clone()).unwrap();
}

let mut request_id_found = [false; N_REQUESTS];
for event in handler.take(N_REQUESTS).collect::<Vec<_>>().await {
match event {
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::ReceivedResponse {
request_id: RequestId(i),
response: event_response,
}) => {
assert_eq!(responses[i], event_response);
assert!(!request_id_found[i]);
request_id_found[i] = true;
}
_ => {
panic!("Got unexpected event");
}
}
}
}
6 changes: 6 additions & 0 deletions crates/papyrus_network/src/get_blocks/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
pub mod handler;
pub mod protocol;

use derive_more::Display;

#[derive(Clone, Copy, Debug, Default, Display, Eq, Hash, PartialEq)]
pub struct RequestId(pub usize);
Loading

0 comments on commit a5ddf73

Please sign in to comment.