From e0d154892c5e0394c9c37222581fe7eee5021eec Mon Sep 17 00:00:00 2001 From: Matan Markind Date: Mon, 22 Jul 2024 11:38:32 +0300 Subject: [PATCH] refactor(consensus): change the broadcast fn in context to take 'ref mut self' This allows us to remove the Arc and also avoid cloning the sender on every call. The cost is that we leak the fact that futures Sender requires . --- .../sequencing/papyrus_consensus/src/lib.rs | 6 ++-- .../src/papyrus_consensus_context.rs | 14 +++----- .../src/single_height_consensus.rs | 14 ++++---- .../src/single_height_consensus_test.rs | 34 +++++++++++++------ .../papyrus_consensus/src/test_utils.rs | 2 +- .../sequencing/papyrus_consensus/src/types.rs | 2 +- 6 files changed, 40 insertions(+), 32 deletions(-) diff --git a/crates/sequencing/papyrus_consensus/src/lib.rs b/crates/sequencing/papyrus_consensus/src/lib.rs index 2d076720ce..274c53d046 100644 --- a/crates/sequencing/papyrus_consensus/src/lib.rs +++ b/crates/sequencing/papyrus_consensus/src/lib.rs @@ -36,7 +36,7 @@ use futures::StreamExt; #[instrument(skip(context, validator_id, network_receiver, cached_messages), level = "info")] #[allow(missing_docs)] async fn run_height>( - context: &ContextT, + context: &mut ContextT, height: BlockNumber, validator_id: ValidatorId, network_receiver: &mut BroadcastSubscriberReceiver, @@ -103,7 +103,7 @@ where #[instrument(skip(context, start_height, network_receiver), level = "info")] #[allow(missing_docs)] pub async fn run_consensus>( - context: ContextT, + mut context: ContextT, start_height: BlockNumber, validator_id: ValidatorId, mut network_receiver: BroadcastSubscriberReceiver, @@ -116,7 +116,7 @@ where let mut future_messages = Vec::new(); loop { let decision = run_height( - &context, + &mut context, current_height, validator_id, &mut network_receiver, diff --git a/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs b/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs index 8d419694fc..d4e4d22807 100644 --- a/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs @@ -3,7 +3,6 @@ mod papyrus_consensus_context_test; use core::panic; -use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; @@ -18,7 +17,6 @@ use papyrus_storage::{StorageError, StorageReader}; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; use starknet_api::transaction::Transaction; -use tokio::sync::Mutex; use tracing::debug; use crate::types::{ConsensusBlock, ConsensusContext, ConsensusError, ProposalInit, ValidatorId}; @@ -47,7 +45,7 @@ impl ConsensusBlock for PapyrusConsensusBlock { pub struct PapyrusConsensusContext { storage_reader: StorageReader, - broadcast_sender: Arc>>, + broadcast_sender: BroadcastSubscriberSender, validators: Vec, } @@ -61,7 +59,7 @@ impl PapyrusConsensusContext { ) -> Self { Self { storage_reader, - broadcast_sender: Arc::new(Mutex::new(broadcast_sender)), + broadcast_sender, validators: (0..num_validators).map(ContractAddress::from).collect(), } } @@ -172,9 +170,9 @@ impl ConsensusContext for PapyrusConsensusContext { *self.validators.first().expect("validators should have at least 2 validators") } - async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError> { + async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError> { debug!("Broadcasting message: {message:?}"); - self.broadcast_sender.lock().await.send(message).await?; + self.broadcast_sender.send(message).await?; Ok(()) } @@ -184,7 +182,7 @@ impl ConsensusContext for PapyrusConsensusContext { mut content_receiver: mpsc::Receiver, fin_receiver: oneshot::Receiver, ) -> Result<(), ConsensusError> { - let broadcast_sender = self.broadcast_sender.clone(); + let mut broadcast_sender = self.broadcast_sender.clone(); tokio::spawn(async move { let mut transactions = Vec::new(); @@ -209,8 +207,6 @@ impl ConsensusContext for PapyrusConsensusContext { ); broadcast_sender - .lock() - .await .send(ConsensusMessage::Proposal(proposal)) .await .expect("Failed to send proposal"); diff --git a/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs b/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs index 5eca5ee6cb..4b33524ce2 100644 --- a/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs +++ b/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs @@ -54,7 +54,7 @@ impl SingleHeightConsensus { #[instrument(skip_all, fields(height=self.height.0), level = "debug")] pub(crate) async fn start>( &mut self, - context: &ContextT, + context: &mut ContextT, ) -> Result>, ConsensusError> { info!("Starting consensus with validators {:?}", self.validators); let events = self.state_machine.start(); @@ -70,7 +70,7 @@ impl SingleHeightConsensus { )] pub(crate) async fn handle_proposal>( &mut self, - context: &ContextT, + context: &mut ContextT, init: ProposalInit, p2p_messages_receiver: mpsc::Receiver<::ProposalChunk>, fin_receiver: oneshot::Receiver, @@ -126,7 +126,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] pub(crate) async fn handle_message>( &mut self, - context: &ContextT, + context: &mut ContextT, message: ConsensusMessage, ) -> Result>, ConsensusError> { debug!("Received message: {:?}", message); @@ -141,7 +141,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] async fn handle_vote>( &mut self, - context: &ContextT, + context: &mut ContextT, vote: Vote, ) -> Result>, ConsensusError> { let (votes, sm_vote) = match vote.vote_type { @@ -174,7 +174,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] async fn handle_state_machine_events>( &mut self, - context: &ContextT, + context: &mut ContextT, mut events: VecDeque, ) -> Result>, ConsensusError> { while let Some(event) = events.pop_front() { @@ -210,7 +210,7 @@ impl SingleHeightConsensus { #[instrument(skip(self, context), level = "debug")] async fn handle_state_machine_start_round>( &mut self, - context: &ContextT, + context: &mut ContextT, block_hash: Option, round: Round, ) -> VecDeque { @@ -249,7 +249,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] async fn handle_state_machine_vote>( &mut self, - context: &ContextT, + context: &mut ContextT, block_hash: BlockHash, round: Round, vote_type: VoteType, diff --git a/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs b/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs index a94f714fa3..38fedb418d 100644 --- a/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs +++ b/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs @@ -54,15 +54,21 @@ async fn proposer() { .withf(move |msg: &ConsensusMessage| msg == &prevote(block_id, 0, node_id)) .returning(move |_| Ok(())); // Sends proposal and prevote. - assert!(matches!(shc.start(&context).await, Ok(None))); + assert!(matches!(shc.start(&mut context).await, Ok(None))); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 2_u32.into())).await, + Ok(None) + ); // 3 of 4 Prevotes is enough to send a Precommit. context .expect_broadcast() .withf(move |msg: &ConsensusMessage| msg == &precommit(block_id, 0, node_id)) .returning(move |_| Ok(())); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 3_u32.into())).await, + Ok(None) + ); let precommits = vec![ precommit(block.id(), 0, 1_u32.into()), @@ -70,9 +76,9 @@ async fn proposer() { precommit(block.id(), 0, 2_u32.into()), precommit(block.id(), 0, 3_u32.into()), ]; - assert_eq!(shc.handle_message(&context, precommits[1].clone()).await, Ok(None)); - assert_eq!(shc.handle_message(&context, precommits[2].clone()).await, Ok(None)); - let decision = shc.handle_message(&context, precommits[3].clone()).await.unwrap().unwrap(); + assert_eq!(shc.handle_message(&mut context, precommits[1].clone()).await, Ok(None)); + assert_eq!(shc.handle_message(&mut context, precommits[2].clone()).await, Ok(None)); + let decision = shc.handle_message(&mut context, precommits[3].clone()).await.unwrap().unwrap(); assert_eq!(decision.block, block); assert!( decision @@ -119,7 +125,7 @@ async fn validator() { .returning(move |_| Ok(())); let res = shc .handle_proposal( - &context, + &mut context, ProposalInit { height: BlockNumber(0), proposer }, mpsc::channel(1).1, // content - ignored by SHC. fin_receiver, @@ -127,21 +133,27 @@ async fn validator() { .await; assert_eq!(res, Ok(None)); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 2_u32.into())).await, + Ok(None) + ); // 3 of 4 Prevotes is enough to send a Precommit. context .expect_broadcast() .withf(move |msg: &ConsensusMessage| msg == &precommit(block_id, 0, node_id)) .returning(move |_| Ok(())); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 3_u32.into())).await, + Ok(None) + ); let precommits = vec![ precommit(block.id(), 0, 2_u32.into()), precommit(block.id(), 0, 3_u32.into()), precommit(block.id(), 0, node_id), ]; - assert_eq!(shc.handle_message(&context, precommits[0].clone()).await, Ok(None)); - let decision = shc.handle_message(&context, precommits[1].clone()).await.unwrap().unwrap(); + assert_eq!(shc.handle_message(&mut context, precommits[0].clone()).await, Ok(None)); + let decision = shc.handle_message(&mut context, precommits[1].clone()).await.unwrap().unwrap(); assert_eq!(decision.block, block); assert!( decision diff --git a/crates/sequencing/papyrus_consensus/src/test_utils.rs b/crates/sequencing/papyrus_consensus/src/test_utils.rs index 042bc8c052..098dc26298 100644 --- a/crates/sequencing/papyrus_consensus/src/test_utils.rs +++ b/crates/sequencing/papyrus_consensus/src/test_utils.rs @@ -49,7 +49,7 @@ mock! { fn proposer(&self, validators: &[ValidatorId], height: BlockNumber) -> ValidatorId; - async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError>; + async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError>; async fn propose( &self, diff --git a/crates/sequencing/papyrus_consensus/src/types.rs b/crates/sequencing/papyrus_consensus/src/types.rs index b2e2c55a30..4ac435d6a5 100644 --- a/crates/sequencing/papyrus_consensus/src/types.rs +++ b/crates/sequencing/papyrus_consensus/src/types.rs @@ -123,7 +123,7 @@ pub trait ConsensusContext { /// Calculates the ID of the Proposer based on the inputs. fn proposer(&self, validators: &[ValidatorId], height: BlockNumber) -> ValidatorId; - async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError>; + async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError>; /// This should be non-blocking. Meaning it returns immediately and waits to receive from the /// input channels in parallel (ie on a separate task).