Skip to content

Commit

Permalink
refactor(network): move responses sender into query sender
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanm-starkware committed Jul 17, 2024
1 parent 4e3c1bd commit 750d790
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 227 deletions.
2 changes: 0 additions & 2 deletions crates/papyrus_network/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::{Deserialize, Serialize};
use validator::Validate;

pub use crate::network_manager::SqmrSubscriberChannels;

// TODO: add peer manager config to the network config
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Validate)]
pub struct NetworkConfig {
Expand Down
201 changes: 98 additions & 103 deletions crates/papyrus_network/src/network_manager/mod.rs

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions crates/papyrus_network/src/network_manager/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand All @@ -21,9 +22,10 @@ use tokio::sync::Mutex;
use tokio::time::sleep;

use super::swarm_trait::{Event, SwarmTrait};
use super::{GenericNetworkManager, SqmrSubscriberChannels};
use super::GenericNetworkManager;
use crate::gossipsub_impl::{self, Topic};
use crate::mixed_behaviour;
use crate::network_manager::SqmrClientPayload;
use crate::sqmr::behaviour::{PeerNotConnected, SessionIdNotFoundError};
use crate::sqmr::{Bytes, GenericEvent, InboundSessionId, OutboundSessionId};

Expand Down Expand Up @@ -219,14 +221,16 @@ async fn register_sqmr_protocol_client_and_use_channels() {
let mut network_manager = GenericNetworkManager::generic_new(mock_swarm);

// register subscriber and send query
let SqmrSubscriberChannels { mut query_sender, response_receiver } = network_manager
.register_sqmr_protocol_client::<Vec<u8>, Vec<u8>>(
SIGNED_BLOCK_HEADER_PROTOCOL.to_string(),
BUFFER_SIZE,
);
let mut query_sender = network_manager.register_sqmr_protocol_client::<Vec<u8>, Vec<u8>>(
SIGNED_BLOCK_HEADER_PROTOCOL.to_string(),
BUFFER_SIZE,
);

let response_receiver_length = Arc::new(Mutex::new(0));
let cloned_response_receiver_length = Arc::clone(&response_receiver_length);
let (responses_sender, response_receiver) =
futures::channel::mpsc::channel::<Result<Vec<u8>, Infallible>>(BUFFER_SIZE);
let responses_sender = Box::new(responses_sender);
let response_receiver_collector = response_receiver
.enumerate()
.take(VEC1.len())
Expand All @@ -237,11 +241,11 @@ async fn register_sqmr_protocol_client_and_use_channels() {
result
})
.collect::<Vec<_>>();
let (_report_callback, report_receiver) = oneshot::channel::<()>();
let (_report_sender, report_receiver) = oneshot::channel::<()>();
tokio::select! {
_ = network_manager.run() => panic!("network manager ended"),
_ = poll_fn(|cx| event_listner.poll_unpin(cx)).then(|_| async move {
query_sender.send((VEC1.clone(), report_receiver)).await.unwrap()})
query_sender.send(SqmrClientPayload{query : VEC1.clone(), report_receiver, responses_sender}).await.unwrap()})
.then(|_| async move {
*cloned_response_receiver_length.lock().await = response_receiver_collector.await.len();
}) => {},
Expand Down Expand Up @@ -364,7 +368,7 @@ async fn receive_broadcasted_message_and_report_it() {
.then(|result| {
let (message_result, report_callback) = result.unwrap().unwrap();
assert_eq!(message, message_result.unwrap());
report_callback();
report_callback.send(()).unwrap();
tokio::time::timeout(TIMEOUT, reported_peer_receiver.next())
}) => {
assert_eq!(originated_peer_id, reported_peer_result.unwrap().unwrap());
Expand Down
15 changes: 6 additions & 9 deletions crates/papyrus_node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ fn run_network(
};
let mut network_manager = network_manager::NetworkManager::new(network_config.clone());
let local_peer_id = network_manager.get_local_peer_id();
let header_client_channels = network_manager
let header_client_sender = network_manager
.register_sqmr_protocol_client(Protocol::SignedBlockHeader.into(), BUFFER_SIZE);
let state_diff_client_channels =
let state_diff_client_sender =
network_manager.register_sqmr_protocol_client(Protocol::StateDiff.into(), BUFFER_SIZE);
let transaction_client_channels =
let transaction_client_sender =
network_manager.register_sqmr_protocol_client(Protocol::Transaction.into(), BUFFER_SIZE);

let header_server_channel = network_manager
Expand All @@ -381,12 +381,9 @@ fn run_network(
None => None,
};
let p2p_sync_channels = P2PSyncClientChannels {
header_query_sender: Box::new(header_client_channels.query_sender),
header_response_receiver: Box::new(header_client_channels.response_receiver),
state_diff_query_sender: Box::new(state_diff_client_channels.query_sender),
state_diff_response_receiver: Box::new(state_diff_client_channels.response_receiver),
transaction_query_sender: Box::new(transaction_client_channels.query_sender),
transaction_response_receiver: Box::new(transaction_client_channels.response_receiver),
header_payload_sender: header_client_sender,
state_diff_payload_sender: state_diff_client_sender,
transaction_payload_sender: transaction_client_sender,
};

Ok((
Expand Down
38 changes: 23 additions & 15 deletions crates/papyrus_p2p_sync/src/client/header_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures::{SinkExt, StreamExt};
use papyrus_network::network_manager::SqmrClientPayload;
use papyrus_protobuf::sync::{
BlockHashOrNumber,
DataOrFin,
Expand Down Expand Up @@ -27,11 +28,9 @@ async fn signed_headers_basic_flow() {
let TestArgs {
p2p_sync,
storage_reader,
mut header_query_receiver,
mut headers_sender,
mut header_payload_receiver,
// The test will fail if we drop these
state_diff_query_receiver: _state_diff_query_receiver,
state_diffs_sender: _state_diffs_sender,
state_diff_payload_receiver: _state_diff_query_receiver,
..
} = setup();
let block_hashes_and_signatures =
Expand All @@ -44,7 +43,11 @@ async fn signed_headers_basic_flow() {
let end_block_number = (query_index + 1) * HEADER_QUERY_LENGTH;

// Receive query and validate it.
let (query, _report_receiver) = header_query_receiver.next().await.unwrap();
let SqmrClientPayload {
query,
report_receiver: _report_receiver,
responses_sender: mut headers_sender,
} = header_payload_receiver.next().await.unwrap();
assert_eq!(
query,
HeaderQuery(Query {
Expand Down Expand Up @@ -110,18 +113,20 @@ async fn sync_sends_new_header_query_if_it_got_partial_responses() {

let TestArgs {
p2p_sync,
mut header_query_receiver,
mut headers_sender,
mut header_payload_receiver,
// The test will fail if we drop these
state_diff_query_receiver: _state_diff_query_receiver,
state_diffs_sender: _state_diffs_sender,
state_diff_payload_receiver: _state_diff_query_receiver,
..
} = setup();
let block_hashes_and_signatures = create_block_hashes_and_signatures(NUM_ACTUAL_RESPONSES);

// Create a future that will receive a query, send partial responses and receive the next query.
let parse_queries_future = async move {
let _query = header_query_receiver.next().await.unwrap();
let SqmrClientPayload {
query: _query,
report_receiver: _report_receiver,
responses_sender: mut headers_sender,
} = header_payload_receiver.next().await.unwrap();

for (i, (block_hash, signature)) in block_hashes_and_signatures.into_iter().enumerate() {
headers_sender
Expand All @@ -140,11 +145,14 @@ async fn sync_sends_new_header_query_if_it_got_partial_responses() {
headers_sender.send(Ok(DataOrFin(None))).await.unwrap();

// First unwrap is for the timeout. Second unwrap is for the Option returned from Stream.
let (query, _report_receiver) =
timeout(TIMEOUT_FOR_NEW_QUERY_AFTER_PARTIAL_RESPONSE, header_query_receiver.next())
.await
.unwrap()
.unwrap();
let SqmrClientPayload {
query,
report_receiver: _report_receiver,
responses_sender: _responses_sender,
} = timeout(TIMEOUT_FOR_NEW_QUERY_AFTER_PARTIAL_RESPONSE, header_payload_receiver.next())
.await
.unwrap()
.unwrap();

assert_eq!(
query,
Expand Down
69 changes: 38 additions & 31 deletions crates/papyrus_p2p_sync/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ use std::time::Duration;
use futures::channel::mpsc::SendError;
use futures::future::{ready, Ready};
use futures::sink::With;
use futures::{Sink, SinkExt, Stream};
use futures::{SinkExt, Stream};
use header::HeaderStreamBuilder;
use papyrus_config::converters::deserialize_seconds_to_duration;
use papyrus_config::dumping::{ser_optional_param, ser_param, SerializeConfig};
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use papyrus_network::network_manager::ReportReceiver;
use papyrus_network::network_manager::{SqmrClientPayload, SqmrClientSender};
use papyrus_protobuf::converters::ProtobufConversionError;
use papyrus_protobuf::sync::{
DataOrFin,
Expand Down Expand Up @@ -158,33 +158,29 @@ pub enum P2PSyncError {
#[error(transparent)]
SendError(#[from] SendError),
}
type SyncResponse<T> = Result<DataOrFin<T>, ProtobufConversionError>;

type Response<T> = Result<DataOrFin<T>, ProtobufConversionError>;
// TODO(Eitan): Use SqmrSubscriberChannels once there is a utility function for testing
type QuerySender<T> =
Box<dyn Sink<(T, ReportReceiver), Error = SendError> + Unpin + Send + 'static>;
type WithQuerySender<T> = With<
QuerySender<T>,
(T, ReportReceiver),
(Query, ReportReceiver),
Ready<Result<(T, ReportReceiver), SendError>>,
fn((Query, ReportReceiver)) -> Ready<Result<(T, ReportReceiver), SendError>>,

type WithPayloadSender<TQuery, Response> = With<
SqmrClientSender<TQuery, Response>,
SqmrClientPayload<TQuery, Response>,
SqmrClientPayload<Query, Response>,
Ready<Result<SqmrClientPayload<TQuery, Response>, SendError>>,
fn(
SqmrClientPayload<Query, Response>,
) -> Ready<Result<SqmrClientPayload<TQuery, Response>, SendError>>,
>;
type ResponseReceiver<T> = Box<dyn Stream<Item = Response<T>> + Unpin + Send + 'static>;
type HeaderQuerySender = QuerySender<HeaderQuery>;
type HeaderResponseReceiver = ResponseReceiver<SignedBlockHeader>;
type StateDiffQuerySender = QuerySender<StateDiffQuery>;
type StateDiffResponseReceiver = ResponseReceiver<StateDiffChunk>;
type TransactionQuerySender = QuerySender<TransactionQuery>;
type TransactionResponseReceiver = ResponseReceiver<(Transaction, TransactionOutput)>;
type ResponseReceiver<T> = Box<dyn Stream<Item = SyncResponse<T>> + Unpin + Send>;
type HeaderPayloadSender = SqmrClientSender<HeaderQuery, DataOrFin<SignedBlockHeader>>;
type StateDiffPayloadSender = SqmrClientSender<StateDiffQuery, DataOrFin<StateDiffChunk>>;
type TransactionPayloadSender =
SqmrClientSender<TransactionQuery, DataOrFin<(Transaction, TransactionOutput)>>;

pub struct P2PSyncClientChannels {
pub header_query_sender: HeaderQuerySender,
pub header_response_receiver: HeaderResponseReceiver,
pub state_diff_query_sender: StateDiffQuerySender,
pub state_diff_response_receiver: StateDiffResponseReceiver,
pub transaction_query_sender: TransactionQuerySender,
pub transaction_response_receiver: TransactionResponseReceiver,
pub header_payload_sender: HeaderPayloadSender,
pub state_diff_payload_sender: StateDiffPayloadSender,
pub transaction_payload_sender: TransactionPayloadSender,
}

impl P2PSyncClientChannels {
Expand All @@ -194,20 +190,31 @@ impl P2PSyncClientChannels {
config: P2PSyncClientConfig,
) -> impl Stream<Item = DataStreamResult> + Send + 'static {
let header_stream = HeaderStreamBuilder::create_stream(
self.header_query_sender
.with(|(query, report_receiver)| ready(Ok((HeaderQuery(query), report_receiver)))),
self.header_response_receiver,
self.header_payload_sender.with(
|SqmrClientPayload { query, report_receiver, responses_sender }| {
ready(Ok(SqmrClientPayload {
query: HeaderQuery(query),
report_receiver,
responses_sender,
}))
},
),
storage_reader.clone(),
config.wait_period_for_new_data,
config.num_headers_per_query,
config.stop_sync_at_block_number,
);

let state_diff_stream = StateDiffStreamBuilder::create_stream(
self.state_diff_query_sender.with(|(query, report_receiver)| {
ready(Ok((StateDiffQuery(query), report_receiver)))
}),
self.state_diff_response_receiver,
self.state_diff_payload_sender.with(
|SqmrClientPayload { query, report_receiver, responses_sender }| {
ready(Ok(SqmrClientPayload {
query: StateDiffQuery(query),
report_receiver,
responses_sender,
}))
},
),
storage_reader.clone(),
config.wait_period_for_new_data,
config.num_block_state_diffs_per_query,
Expand Down
47 changes: 28 additions & 19 deletions crates/papyrus_p2p_sync/src/client/state_diff_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::time::Duration;
use assert_matches::assert_matches;
use futures::{FutureExt, SinkExt, StreamExt};
use indexmap::indexmap;
use papyrus_network::network_manager::SqmrClientPayload;
use papyrus_protobuf::sync::{
BlockHashOrNumber,
ContractDiff,
Expand Down Expand Up @@ -46,13 +47,8 @@ async fn state_diff_basic_flow() {
let TestArgs {
p2p_sync,
storage_reader,
mut state_diff_query_receiver,
mut headers_sender,
mut state_diffs_sender,
// The test will fail if we drop this.
// We don't need to read the header query in order to know which headers to send, and we
// already validate the header query in a different test.
header_query_receiver: _header_query_receiver,
mut state_diff_payload_receiver,
mut header_payload_receiver,
..
} = setup();

Expand All @@ -71,7 +67,12 @@ async fn state_diff_basic_flow() {
tokio::time::sleep(SLEEP_DURATION_TO_LET_SYNC_ADVANCE).await;

// Check that before we send headers there is no state diff query.
assert!(state_diff_query_receiver.next().now_or_never().is_none());
assert!(state_diff_payload_receiver.next().now_or_never().is_none());
let SqmrClientPayload {
query: _query,
report_receiver: _report_receiver,
responses_sender: mut headers_sender,
} = header_payload_receiver.next().await.unwrap();

// Send headers for entire query.
for (i, ((block_hash, block_signature), state_diff)) in
Expand All @@ -96,7 +97,11 @@ async fn state_diff_basic_flow() {
(STATE_DIFF_QUERY_LENGTH, HEADER_QUERY_LENGTH - STATE_DIFF_QUERY_LENGTH),
] {
// Get a state diff query and validate it
let (query, _report_receiver) = state_diff_query_receiver.next().await.unwrap();
let SqmrClientPayload {
query,
report_receiver: _report_receiver,
responses_sender: mut state_diff_sender,
} = state_diff_payload_receiver.next().await.unwrap();
assert_eq!(
query,
StateDiffQuery(Query {
Expand All @@ -116,7 +121,7 @@ async fn state_diff_basic_flow() {
let txn = storage_reader.begin_ro_txn().unwrap();
assert_eq!(block_number, txn.get_state_marker().unwrap());

state_diffs_sender
state_diff_sender
.send(Ok(DataOrFin(Some(state_diff_chunk.clone()))))
.await
.unwrap();
Expand Down Expand Up @@ -164,7 +169,7 @@ async fn state_diff_basic_flow() {
};
assert_eq!(state_diff, expected_state_diff);
}
state_diffs_sender.send(Ok(DataOrFin(None))).await.unwrap();
state_diff_sender.send(Ok(DataOrFin(None))).await.unwrap();
}
};

Expand Down Expand Up @@ -307,13 +312,8 @@ async fn validate_state_diff_fails(
let TestArgs {
p2p_sync,
storage_reader,
mut state_diff_query_receiver,
mut headers_sender,
mut state_diffs_sender,
// The test will fail if we drop this.
// We don't need to read the header query in order to know which headers to send, and we
// already validate the header query in a different test.
header_query_receiver: _header_query_receiver,
mut state_diff_payload_receiver,
mut header_payload_receiver,
..
} = setup();

Expand All @@ -322,6 +322,11 @@ async fn validate_state_diff_fails(
// Create a future that will receive queries, send responses and validate the results.
let parse_queries_future = async move {
// Send a single header. There's no need to fill the entire query.
let SqmrClientPayload {
query: _query,
report_receiver: _report_receiver,
responses_sender: mut headers_sender,
} = header_payload_receiver.next().await.unwrap();
headers_sender
.send(Ok(DataOrFin(Some(SignedBlockHeader {
block_header: BlockHeader {
Expand All @@ -336,7 +341,11 @@ async fn validate_state_diff_fails(
.unwrap();

// Get a state diff query and validate it
let (query, _report_reciever) = state_diff_query_receiver.next().await.unwrap();
let SqmrClientPayload {
query,
report_receiver: _report_reciever,
responses_sender: mut state_diffs_sender,
} = state_diff_payload_receiver.next().await.unwrap();
assert_eq!(
query,
StateDiffQuery(Query {
Expand Down
Loading

0 comments on commit 750d790

Please sign in to comment.