diff --git a/crates/sequencing/papyrus_consensus/src/lib.rs b/crates/sequencing/papyrus_consensus/src/lib.rs index 2d076720ce..274c53d046 100644 --- a/crates/sequencing/papyrus_consensus/src/lib.rs +++ b/crates/sequencing/papyrus_consensus/src/lib.rs @@ -36,7 +36,7 @@ use futures::StreamExt; #[instrument(skip(context, validator_id, network_receiver, cached_messages), level = "info")] #[allow(missing_docs)] async fn run_height>( - context: &ContextT, + context: &mut ContextT, height: BlockNumber, validator_id: ValidatorId, network_receiver: &mut BroadcastSubscriberReceiver, @@ -103,7 +103,7 @@ where #[instrument(skip(context, start_height, network_receiver), level = "info")] #[allow(missing_docs)] pub async fn run_consensus>( - context: ContextT, + mut context: ContextT, start_height: BlockNumber, validator_id: ValidatorId, mut network_receiver: BroadcastSubscriberReceiver, @@ -116,7 +116,7 @@ where let mut future_messages = Vec::new(); loop { let decision = run_height( - &context, + &mut context, current_height, validator_id, &mut network_receiver, diff --git a/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs b/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs index 8d419694fc..d4e4d22807 100644 --- a/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs +++ b/crates/sequencing/papyrus_consensus/src/papyrus_consensus_context.rs @@ -3,7 +3,6 @@ mod papyrus_consensus_context_test; use core::panic; -use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; @@ -18,7 +17,6 @@ use papyrus_storage::{StorageError, StorageReader}; use starknet_api::block::{BlockHash, BlockNumber}; use starknet_api::core::ContractAddress; use starknet_api::transaction::Transaction; -use tokio::sync::Mutex; use tracing::debug; use crate::types::{ConsensusBlock, ConsensusContext, ConsensusError, ProposalInit, ValidatorId}; @@ -47,7 +45,7 @@ impl ConsensusBlock for PapyrusConsensusBlock { pub struct PapyrusConsensusContext { storage_reader: StorageReader, - broadcast_sender: Arc>>, + broadcast_sender: BroadcastSubscriberSender, validators: Vec, } @@ -61,7 +59,7 @@ impl PapyrusConsensusContext { ) -> Self { Self { storage_reader, - broadcast_sender: Arc::new(Mutex::new(broadcast_sender)), + broadcast_sender, validators: (0..num_validators).map(ContractAddress::from).collect(), } } @@ -172,9 +170,9 @@ impl ConsensusContext for PapyrusConsensusContext { *self.validators.first().expect("validators should have at least 2 validators") } - async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError> { + async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError> { debug!("Broadcasting message: {message:?}"); - self.broadcast_sender.lock().await.send(message).await?; + self.broadcast_sender.send(message).await?; Ok(()) } @@ -184,7 +182,7 @@ impl ConsensusContext for PapyrusConsensusContext { mut content_receiver: mpsc::Receiver, fin_receiver: oneshot::Receiver, ) -> Result<(), ConsensusError> { - let broadcast_sender = self.broadcast_sender.clone(); + let mut broadcast_sender = self.broadcast_sender.clone(); tokio::spawn(async move { let mut transactions = Vec::new(); @@ -209,8 +207,6 @@ impl ConsensusContext for PapyrusConsensusContext { ); broadcast_sender - .lock() - .await .send(ConsensusMessage::Proposal(proposal)) .await .expect("Failed to send proposal"); diff --git a/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs b/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs index 5eca5ee6cb..4b33524ce2 100644 --- a/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs +++ b/crates/sequencing/papyrus_consensus/src/single_height_consensus.rs @@ -54,7 +54,7 @@ impl SingleHeightConsensus { #[instrument(skip_all, fields(height=self.height.0), level = "debug")] pub(crate) async fn start>( &mut self, - context: &ContextT, + context: &mut ContextT, ) -> Result>, ConsensusError> { info!("Starting consensus with validators {:?}", self.validators); let events = self.state_machine.start(); @@ -70,7 +70,7 @@ impl SingleHeightConsensus { )] pub(crate) async fn handle_proposal>( &mut self, - context: &ContextT, + context: &mut ContextT, init: ProposalInit, p2p_messages_receiver: mpsc::Receiver<::ProposalChunk>, fin_receiver: oneshot::Receiver, @@ -126,7 +126,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] pub(crate) async fn handle_message>( &mut self, - context: &ContextT, + context: &mut ContextT, message: ConsensusMessage, ) -> Result>, ConsensusError> { debug!("Received message: {:?}", message); @@ -141,7 +141,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] async fn handle_vote>( &mut self, - context: &ContextT, + context: &mut ContextT, vote: Vote, ) -> Result>, ConsensusError> { let (votes, sm_vote) = match vote.vote_type { @@ -174,7 +174,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] async fn handle_state_machine_events>( &mut self, - context: &ContextT, + context: &mut ContextT, mut events: VecDeque, ) -> Result>, ConsensusError> { while let Some(event) = events.pop_front() { @@ -210,7 +210,7 @@ impl SingleHeightConsensus { #[instrument(skip(self, context), level = "debug")] async fn handle_state_machine_start_round>( &mut self, - context: &ContextT, + context: &mut ContextT, block_hash: Option, round: Round, ) -> VecDeque { @@ -249,7 +249,7 @@ impl SingleHeightConsensus { #[instrument(skip_all)] async fn handle_state_machine_vote>( &mut self, - context: &ContextT, + context: &mut ContextT, block_hash: BlockHash, round: Round, vote_type: VoteType, diff --git a/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs b/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs index a94f714fa3..38fedb418d 100644 --- a/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs +++ b/crates/sequencing/papyrus_consensus/src/single_height_consensus_test.rs @@ -54,15 +54,21 @@ async fn proposer() { .withf(move |msg: &ConsensusMessage| msg == &prevote(block_id, 0, node_id)) .returning(move |_| Ok(())); // Sends proposal and prevote. - assert!(matches!(shc.start(&context).await, Ok(None))); + assert!(matches!(shc.start(&mut context).await, Ok(None))); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 2_u32.into())).await, + Ok(None) + ); // 3 of 4 Prevotes is enough to send a Precommit. context .expect_broadcast() .withf(move |msg: &ConsensusMessage| msg == &precommit(block_id, 0, node_id)) .returning(move |_| Ok(())); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 3_u32.into())).await, + Ok(None) + ); let precommits = vec![ precommit(block.id(), 0, 1_u32.into()), @@ -70,9 +76,9 @@ async fn proposer() { precommit(block.id(), 0, 2_u32.into()), precommit(block.id(), 0, 3_u32.into()), ]; - assert_eq!(shc.handle_message(&context, precommits[1].clone()).await, Ok(None)); - assert_eq!(shc.handle_message(&context, precommits[2].clone()).await, Ok(None)); - let decision = shc.handle_message(&context, precommits[3].clone()).await.unwrap().unwrap(); + assert_eq!(shc.handle_message(&mut context, precommits[1].clone()).await, Ok(None)); + assert_eq!(shc.handle_message(&mut context, precommits[2].clone()).await, Ok(None)); + let decision = shc.handle_message(&mut context, precommits[3].clone()).await.unwrap().unwrap(); assert_eq!(decision.block, block); assert!( decision @@ -119,7 +125,7 @@ async fn validator() { .returning(move |_| Ok(())); let res = shc .handle_proposal( - &context, + &mut context, ProposalInit { height: BlockNumber(0), proposer }, mpsc::channel(1).1, // content - ignored by SHC. fin_receiver, @@ -127,21 +133,27 @@ async fn validator() { .await; assert_eq!(res, Ok(None)); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 2_u32.into())).await, + Ok(None) + ); // 3 of 4 Prevotes is enough to send a Precommit. context .expect_broadcast() .withf(move |msg: &ConsensusMessage| msg == &precommit(block_id, 0, node_id)) .returning(move |_| Ok(())); - assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None)); + assert_eq!( + shc.handle_message(&mut context, prevote(block.id(), 0, 3_u32.into())).await, + Ok(None) + ); let precommits = vec![ precommit(block.id(), 0, 2_u32.into()), precommit(block.id(), 0, 3_u32.into()), precommit(block.id(), 0, node_id), ]; - assert_eq!(shc.handle_message(&context, precommits[0].clone()).await, Ok(None)); - let decision = shc.handle_message(&context, precommits[1].clone()).await.unwrap().unwrap(); + assert_eq!(shc.handle_message(&mut context, precommits[0].clone()).await, Ok(None)); + let decision = shc.handle_message(&mut context, precommits[1].clone()).await.unwrap().unwrap(); assert_eq!(decision.block, block); assert!( decision diff --git a/crates/sequencing/papyrus_consensus/src/test_utils.rs b/crates/sequencing/papyrus_consensus/src/test_utils.rs index 042bc8c052..098dc26298 100644 --- a/crates/sequencing/papyrus_consensus/src/test_utils.rs +++ b/crates/sequencing/papyrus_consensus/src/test_utils.rs @@ -49,7 +49,7 @@ mock! { fn proposer(&self, validators: &[ValidatorId], height: BlockNumber) -> ValidatorId; - async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError>; + async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError>; async fn propose( &self, diff --git a/crates/sequencing/papyrus_consensus/src/types.rs b/crates/sequencing/papyrus_consensus/src/types.rs index b2e2c55a30..4ac435d6a5 100644 --- a/crates/sequencing/papyrus_consensus/src/types.rs +++ b/crates/sequencing/papyrus_consensus/src/types.rs @@ -123,7 +123,7 @@ pub trait ConsensusContext { /// Calculates the ID of the Proposer based on the inputs. fn proposer(&self, validators: &[ValidatorId], height: BlockNumber) -> ValidatorId; - async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError>; + async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError>; /// This should be non-blocking. Meaning it returns immediately and waits to receive from the /// input channels in parallel (ie on a separate task).