This repository has been archived by the owner on Dec 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(network): add get_blocks::Handler with simple test
- Loading branch information
1 parent
ec2365a
commit a5ddf73
Showing
6 changed files
with
378 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
#[cfg(test)] | ||
#[path = "handler_test.rs"] | ||
mod handler_test; | ||
|
||
use std::collections::{HashMap, VecDeque}; | ||
use std::io; | ||
use std::task::{Context, Poll}; | ||
use std::time::Duration; | ||
|
||
use futures::channel::mpsc::{TrySendError, UnboundedReceiver}; | ||
use futures::StreamExt; | ||
use libp2p::swarm::handler::{ConnectionEvent, DialUpgradeError, FullyNegotiatedOutbound}; | ||
use libp2p::swarm::{ | ||
ConnectionHandler, | ||
ConnectionHandlerEvent, | ||
KeepAlive, | ||
StreamUpgradeError, | ||
SubstreamProtocol, | ||
}; | ||
|
||
use super::protocol::{RequestProtocol, RequestProtocolError, ResponseProtocol, PROTOCOL_NAME}; | ||
use super::RequestId; | ||
use crate::messages::block::{GetBlocks, GetBlocksResponse}; | ||
|
||
// TODO(shahak): Add a FromBehaviour event for cancelling an existing request. | ||
#[derive(Debug)] | ||
pub struct NewRequestEvent { | ||
pub request: GetBlocks, | ||
pub request_id: RequestId, | ||
} | ||
|
||
#[derive(thiserror::Error, Debug)] | ||
pub enum RequestError { | ||
#[error("Connection timed out after {} seconds.", substream_timeout.as_secs())] | ||
Timeout { substream_timeout: Duration }, | ||
#[error(transparent)] | ||
IOError(#[from] io::Error), | ||
#[error(transparent)] | ||
ResponseSendError(#[from] TrySendError<GetBlocksResponse>), | ||
#[error("Remote peer doesn't support the {PROTOCOL_NAME} protocol.")] | ||
RemoteDoesntSupportProtocol, | ||
} | ||
|
||
#[derive(thiserror::Error, Debug)] | ||
#[error("Remote peer doesn't support the {PROTOCOL_NAME} protocol.")] | ||
pub struct RemoteDoesntSupportProtocolError; | ||
|
||
#[derive(Debug)] | ||
pub enum RequestProgressEvent { | ||
ReceivedResponse { request_id: RequestId, response: GetBlocksResponse }, | ||
RequestFinished { request_id: RequestId }, | ||
RequestFailed { request_id: RequestId, error: RequestError }, | ||
} | ||
|
||
type HandlerEvent<H> = ConnectionHandlerEvent< | ||
<H as ConnectionHandler>::OutboundProtocol, | ||
<H as ConnectionHandler>::OutboundOpenInfo, | ||
<H as ConnectionHandler>::ToBehaviour, | ||
<H as ConnectionHandler>::Error, | ||
>; | ||
|
||
pub struct Handler { | ||
substream_timeout: Duration, | ||
request_to_responses_receiver: HashMap<RequestId, UnboundedReceiver<GetBlocksResponse>>, | ||
pending_events: VecDeque<HandlerEvent<Self>>, | ||
ready_requests: VecDeque<(RequestId, GetBlocksResponse)>, | ||
} | ||
|
||
impl Handler { | ||
// TODO(shahak) If we'll add more parameters, consider creating a HandlerConfig struct. | ||
pub fn new(substream_timeout: Duration) -> Self { | ||
Self { | ||
substream_timeout, | ||
request_to_responses_receiver: Default::default(), | ||
pending_events: Default::default(), | ||
ready_requests: Default::default(), | ||
} | ||
} | ||
|
||
fn convert_upgrade_error( | ||
&self, | ||
error: StreamUpgradeError<RequestProtocolError>, | ||
) -> RequestError { | ||
match error { | ||
StreamUpgradeError::Timeout => { | ||
RequestError::Timeout { substream_timeout: self.substream_timeout } | ||
} | ||
StreamUpgradeError::Apply(request_protocol_error) => match request_protocol_error { | ||
RequestProtocolError::IOError(error) => RequestError::IOError(error), | ||
RequestProtocolError::ResponseSendError(error) => { | ||
RequestError::ResponseSendError(error) | ||
} | ||
}, | ||
StreamUpgradeError::NegotiationFailed => RequestError::RemoteDoesntSupportProtocol, | ||
StreamUpgradeError::Io(error) => RequestError::IOError(error), | ||
} | ||
} | ||
|
||
fn clear_pending_events_related_to_request(&mut self, request_id: RequestId) { | ||
self.pending_events.retain(|event| match event { | ||
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::ReceivedResponse { | ||
request_id: other_request_id, | ||
.. | ||
}) => request_id != *other_request_id, | ||
_ => true, | ||
}) | ||
} | ||
} | ||
|
||
impl ConnectionHandler for Handler { | ||
type FromBehaviour = NewRequestEvent; | ||
type ToBehaviour = RequestProgressEvent; | ||
type Error = RemoteDoesntSupportProtocolError; | ||
type InboundProtocol = ResponseProtocol; | ||
type OutboundProtocol = RequestProtocol; | ||
type InboundOpenInfo = (); | ||
type OutboundOpenInfo = RequestId; | ||
|
||
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> { | ||
SubstreamProtocol::new(ResponseProtocol {}, ()).with_timeout(self.substream_timeout) | ||
} | ||
|
||
fn connection_keep_alive(&self) -> KeepAlive { | ||
// TODO(shahak): Implement keep alive logic. | ||
KeepAlive::Yes | ||
} | ||
|
||
fn poll( | ||
&mut self, | ||
cx: &mut Context<'_>, | ||
) -> Poll< | ||
ConnectionHandlerEvent< | ||
Self::OutboundProtocol, | ||
Self::OutboundOpenInfo, | ||
Self::ToBehaviour, | ||
Self::Error, | ||
>, | ||
> { | ||
// TODO(shahak): Consider handling incoming messages interleaved with handling pending | ||
// events to avoid starvation. | ||
if let Some(event) = self.pending_events.pop_front() { | ||
return Poll::Ready(event); | ||
} | ||
|
||
// Handle incoming messages. | ||
for (request_id, responses_receiver) in &mut self.request_to_responses_receiver { | ||
if let Poll::Ready(Some(response)) = responses_receiver.poll_next_unpin(cx) { | ||
// Collect all ready responses to avoid starvation of the request ids at the end. | ||
self.ready_requests.push_back((*request_id, response)); | ||
} | ||
} | ||
if let Some((request_id, response)) = self.ready_requests.pop_front() { | ||
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( | ||
RequestProgressEvent::ReceivedResponse { request_id, response }, | ||
)); | ||
} | ||
|
||
Poll::Pending | ||
} | ||
|
||
fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { | ||
// There's only one type of event so we can unpack it without matching. | ||
let NewRequestEvent { request, request_id } = event; | ||
let (request_protocol, responses_receiver) = RequestProtocol::new(request); | ||
let insert_result = | ||
self.request_to_responses_receiver.insert(request_id, responses_receiver); | ||
if insert_result.is_some() { | ||
panic!("Multiple requests exist with the same ID {}", request_id); | ||
} | ||
self.pending_events.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { | ||
protocol: SubstreamProtocol::new(request_protocol, request_id) | ||
.with_timeout(self.substream_timeout), | ||
}); | ||
} | ||
|
||
fn on_connection_event( | ||
&mut self, | ||
event: ConnectionEvent< | ||
'_, | ||
Self::InboundProtocol, | ||
Self::OutboundProtocol, | ||
Self::InboundOpenInfo, | ||
Self::OutboundOpenInfo, | ||
>, | ||
) { | ||
match event { | ||
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { | ||
protocol: _, | ||
info: request_id, | ||
}) => { | ||
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour( | ||
RequestProgressEvent::RequestFinished { request_id }, | ||
)); | ||
self.request_to_responses_receiver.remove(&request_id); | ||
} | ||
ConnectionEvent::DialUpgradeError(DialUpgradeError { info: request_id, error }) => { | ||
let error = self.convert_upgrade_error(error); | ||
if matches!(error, RequestError::RemoteDoesntSupportProtocol) { | ||
// This error will happen on all future connections to the peer, so we'll close | ||
// the handle after reporting to the behaviour. | ||
self.pending_events.clear(); | ||
self.pending_events.push_front(ConnectionHandlerEvent::NotifyBehaviour( | ||
RequestProgressEvent::RequestFailed { request_id, error }, | ||
)); | ||
self.pending_events | ||
.push_back(ConnectionHandlerEvent::Close(RemoteDoesntSupportProtocolError)); | ||
} else { | ||
self.clear_pending_events_related_to_request(request_id); | ||
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour( | ||
RequestProgressEvent::RequestFailed { request_id, error }, | ||
)); | ||
} | ||
self.request_to_responses_receiver.remove(&request_id); | ||
} | ||
ConnectionEvent::FullyNegotiatedInbound(_) | ||
| ConnectionEvent::ListenUpgradeError(_) | ||
| ConnectionEvent::AddressChange(_) | ||
| ConnectionEvent::LocalProtocolsChange(_) | ||
| ConnectionEvent::RemoteProtocolsChange(_) => {} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
use std::iter::zip; | ||
use std::pin::Pin; | ||
use std::time::Duration; | ||
|
||
use assert_matches::assert_matches; | ||
use futures::channel::mpsc::UnboundedSender; | ||
use futures::task::{Context, Poll}; | ||
use futures::{Stream, StreamExt}; | ||
use libp2p::swarm::handler::{ConnectionEvent, FullyNegotiatedOutbound}; | ||
use libp2p::swarm::{ConnectionHandler, ConnectionHandlerEvent}; | ||
|
||
use super::super::RequestId; | ||
use super::{Handler, HandlerEvent, NewRequestEvent, RequestProgressEvent}; | ||
use crate::messages::block::{BlockHeader, GetBlocks, GetBlocksResponse}; | ||
use crate::messages::common::BlockId; | ||
use crate::messages::proto::p2p::proto::get_blocks_response::Response; | ||
|
||
impl Unpin for Handler {} | ||
|
||
impl Stream for Handler { | ||
type Item = HandlerEvent<Handler>; | ||
|
||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | ||
match Pin::into_inner(self).poll(cx) { | ||
Poll::Pending => Poll::Pending, | ||
Poll::Ready(event) => Poll::Ready(Some(event)), | ||
} | ||
} | ||
} | ||
|
||
const SUBSTREAM_TIMEOUT: Duration = Duration::MAX; | ||
|
||
async fn start_request_and_validate_event( | ||
handler: &mut Handler, | ||
request: &GetBlocks, | ||
request_id: RequestId, | ||
) -> UnboundedSender<GetBlocksResponse> { | ||
handler.on_behaviour_event(NewRequestEvent { request: request.clone(), request_id }); | ||
let event = handler.next().await.unwrap(); | ||
let ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } = event else { | ||
panic!("Got unexpected event"); | ||
}; | ||
assert_eq!(*request, *protocol.upgrade().request()); | ||
assert_eq!(SUBSTREAM_TIMEOUT, *protocol.timeout()); | ||
protocol.upgrade().responses_sender().clone() | ||
} | ||
|
||
async fn send_response_and_validate_event( | ||
handler: &mut Handler, | ||
response: &GetBlocksResponse, | ||
request_id: RequestId, | ||
responses_sender: &UnboundedSender<GetBlocksResponse>, | ||
) { | ||
responses_sender.unbounded_send(response.clone()).unwrap(); | ||
let event = handler.next().await.unwrap(); | ||
assert_matches!( | ||
event, | ||
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::ReceivedResponse{ | ||
request_id: event_request_id, response: event_response | ||
}) if event_request_id == request_id && event_response == *response | ||
); | ||
} | ||
|
||
async fn finish_request_and_validate_event(handler: &mut Handler, request_id: RequestId) { | ||
handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound( | ||
FullyNegotiatedOutbound { protocol: (), info: request_id }, | ||
)); | ||
let event = handler.next().await.unwrap(); | ||
assert_matches!( | ||
event, | ||
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::RequestFinished{ | ||
request_id: event_request_id | ||
}) if event_request_id == request_id | ||
); | ||
} | ||
|
||
#[tokio::test] | ||
async fn process_request() { | ||
let mut handler = Handler::new(SUBSTREAM_TIMEOUT); | ||
|
||
let request = GetBlocks::default(); | ||
let request_id = RequestId::default(); | ||
let response = GetBlocksResponse { | ||
response: Some(Response::Header(BlockHeader { | ||
parent_block: Some(BlockId { hash: None, height: 1 }), | ||
..Default::default() | ||
})), | ||
}; | ||
|
||
let responses_sender = | ||
start_request_and_validate_event(&mut handler, &request, request_id).await; | ||
|
||
send_response_and_validate_event(&mut handler, &response, request_id, &responses_sender).await; | ||
finish_request_and_validate_event(&mut handler, request_id).await; | ||
} | ||
|
||
#[tokio::test] | ||
async fn process_multiple_requests_simultaneously() { | ||
let mut handler = Handler::new(SUBSTREAM_TIMEOUT); | ||
|
||
const N_REQUESTS: usize = 20; | ||
let request_ids = (0..N_REQUESTS).map(RequestId).collect::<Vec<_>>(); | ||
let requests = (0..N_REQUESTS) | ||
.map(|i| GetBlocks { skip: i as u64, ..Default::default() }) | ||
.collect::<Vec<_>>(); | ||
let responses = (0..N_REQUESTS) | ||
.map(|i| GetBlocksResponse { | ||
response: Some(Response::Header(BlockHeader { | ||
parent_block: Some(BlockId { hash: None, height: i as u64 }), | ||
..Default::default() | ||
})), | ||
}) | ||
.collect::<Vec<_>>(); | ||
|
||
for ((request, request_id), response) in zip(zip(requests, request_ids), responses.iter()) { | ||
let responses_sender = | ||
start_request_and_validate_event(&mut handler, &request, request_id).await; | ||
responses_sender.unbounded_send(response.clone()).unwrap(); | ||
} | ||
|
||
let mut request_id_found = [false; N_REQUESTS]; | ||
for event in handler.take(N_REQUESTS).collect::<Vec<_>>().await { | ||
match event { | ||
ConnectionHandlerEvent::NotifyBehaviour(RequestProgressEvent::ReceivedResponse { | ||
request_id: RequestId(i), | ||
response: event_response, | ||
}) => { | ||
assert_eq!(responses[i], event_response); | ||
assert!(!request_id_found[i]); | ||
request_id_found[i] = true; | ||
} | ||
_ => { | ||
panic!("Got unexpected event"); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
pub mod handler; | ||
pub mod protocol; | ||
|
||
use derive_more::Display; | ||
|
||
#[derive(Clone, Copy, Debug, Default, Display, Eq, Hash, PartialEq)] | ||
pub struct RequestId(pub usize); |
Oops, something went wrong.