Skip to content

Commit

Permalink
refactor(consensus): change the broadcast fn in context to take 'ref …
Browse files Browse the repository at this point in the history
…mut self' (#2241)

This allows us to remove the Arc<Mutex> and also avoid cloning the sender on every call.
The cost is that we leak the fact that futures Sender requires .
  • Loading branch information
matan-starkware authored and dan-starkware committed Jul 23, 2024
1 parent afd54d1 commit b0ddf06
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 32 deletions.
6 changes: 3 additions & 3 deletions crates/sequencing/papyrus_consensus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use futures::StreamExt;
#[instrument(skip(context, validator_id, network_receiver, cached_messages), level = "info")]
#[allow(missing_docs)]
async fn run_height<BlockT: ConsensusBlock, ContextT: ConsensusContext<Block = BlockT>>(
context: &ContextT,
context: &mut ContextT,
height: BlockNumber,
validator_id: ValidatorId,
network_receiver: &mut BroadcastSubscriberReceiver<ConsensusMessage>,
Expand Down Expand Up @@ -103,7 +103,7 @@ where
#[instrument(skip(context, start_height, network_receiver), level = "info")]
#[allow(missing_docs)]
pub async fn run_consensus<BlockT: ConsensusBlock, ContextT: ConsensusContext<Block = BlockT>>(
context: ContextT,
mut context: ContextT,
start_height: BlockNumber,
validator_id: ValidatorId,
mut network_receiver: BroadcastSubscriberReceiver<ConsensusMessage>,
Expand All @@ -116,7 +116,7 @@ where
let mut future_messages = Vec::new();
loop {
let decision = run_height(
&context,
&mut context,
current_height,
validator_id,
&mut network_receiver,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod papyrus_consensus_context_test;

use core::panic;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
Expand All @@ -18,7 +17,6 @@ use papyrus_storage::{StorageError, StorageReader};
use starknet_api::block::{BlockHash, BlockNumber};
use starknet_api::core::ContractAddress;
use starknet_api::transaction::Transaction;
use tokio::sync::Mutex;
use tracing::debug;

use crate::types::{ConsensusBlock, ConsensusContext, ConsensusError, ProposalInit, ValidatorId};
Expand Down Expand Up @@ -47,7 +45,7 @@ impl ConsensusBlock for PapyrusConsensusBlock {

pub struct PapyrusConsensusContext {
storage_reader: StorageReader,
broadcast_sender: Arc<Mutex<BroadcastSubscriberSender<ConsensusMessage>>>,
broadcast_sender: BroadcastSubscriberSender<ConsensusMessage>,
validators: Vec<ValidatorId>,
}

Expand All @@ -61,7 +59,7 @@ impl PapyrusConsensusContext {
) -> Self {
Self {
storage_reader,
broadcast_sender: Arc::new(Mutex::new(broadcast_sender)),
broadcast_sender,
validators: (0..num_validators).map(ContractAddress::from).collect(),
}
}
Expand Down Expand Up @@ -172,9 +170,9 @@ impl ConsensusContext for PapyrusConsensusContext {
*self.validators.first().expect("validators should have at least 2 validators")
}

async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError> {
async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError> {
debug!("Broadcasting message: {message:?}");
self.broadcast_sender.lock().await.send(message).await?;
self.broadcast_sender.send(message).await?;
Ok(())
}

Expand All @@ -184,7 +182,7 @@ impl ConsensusContext for PapyrusConsensusContext {
mut content_receiver: mpsc::Receiver<Transaction>,
fin_receiver: oneshot::Receiver<BlockHash>,
) -> Result<(), ConsensusError> {
let broadcast_sender = self.broadcast_sender.clone();
let mut broadcast_sender = self.broadcast_sender.clone();

tokio::spawn(async move {
let mut transactions = Vec::new();
Expand All @@ -209,8 +207,6 @@ impl ConsensusContext for PapyrusConsensusContext {
);

broadcast_sender
.lock()
.await
.send(ConsensusMessage::Proposal(proposal))
.await
.expect("Failed to send proposal");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
#[instrument(skip_all, fields(height=self.height.0), level = "debug")]
pub(crate) async fn start<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
info!("Starting consensus with validators {:?}", self.validators);
let events = self.state_machine.start();
Expand All @@ -70,7 +70,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
)]
pub(crate) async fn handle_proposal<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
init: ProposalInit,
p2p_messages_receiver: mpsc::Receiver<<BlockT as ConsensusBlock>::ProposalChunk>,
fin_receiver: oneshot::Receiver<BlockHash>,
Expand Down Expand Up @@ -126,7 +126,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
#[instrument(skip_all)]
pub(crate) async fn handle_message<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
message: ConsensusMessage,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
debug!("Received message: {:?}", message);
Expand All @@ -141,7 +141,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
#[instrument(skip_all)]
async fn handle_vote<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
vote: Vote,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
let (votes, sm_vote) = match vote.vote_type {
Expand Down Expand Up @@ -174,7 +174,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
#[instrument(skip_all)]
async fn handle_state_machine_events<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
mut events: VecDeque<StateMachineEvent>,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
while let Some(event) = events.pop_front() {
Expand Down Expand Up @@ -210,7 +210,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
#[instrument(skip(self, context), level = "debug")]
async fn handle_state_machine_start_round<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
block_hash: Option<BlockHash>,
round: Round,
) -> VecDeque<StateMachineEvent> {
Expand Down Expand Up @@ -249,7 +249,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
#[instrument(skip_all)]
async fn handle_state_machine_vote<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
context: &mut ContextT,
block_hash: BlockHash,
round: Round,
vote_type: VoteType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,31 @@ async fn proposer() {
.withf(move |msg: &ConsensusMessage| msg == &prevote(block_id, 0, node_id))
.returning(move |_| Ok(()));
// Sends proposal and prevote.
assert!(matches!(shc.start(&context).await, Ok(None)));
assert!(matches!(shc.start(&mut context).await, Ok(None)));

assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(
shc.handle_message(&mut context, prevote(block.id(), 0, 2_u32.into())).await,
Ok(None)
);
// 3 of 4 Prevotes is enough to send a Precommit.
context
.expect_broadcast()
.withf(move |msg: &ConsensusMessage| msg == &precommit(block_id, 0, node_id))
.returning(move |_| Ok(()));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None));
assert_eq!(
shc.handle_message(&mut context, prevote(block.id(), 0, 3_u32.into())).await,
Ok(None)
);

let precommits = vec![
precommit(block.id(), 0, 1_u32.into()),
precommit(BlockHash(Felt::TWO), 0, 4_u32.into()), // Ignores since disagrees.
precommit(block.id(), 0, 2_u32.into()),
precommit(block.id(), 0, 3_u32.into()),
];
assert_eq!(shc.handle_message(&context, precommits[1].clone()).await, Ok(None));
assert_eq!(shc.handle_message(&context, precommits[2].clone()).await, Ok(None));
let decision = shc.handle_message(&context, precommits[3].clone()).await.unwrap().unwrap();
assert_eq!(shc.handle_message(&mut context, precommits[1].clone()).await, Ok(None));
assert_eq!(shc.handle_message(&mut context, precommits[2].clone()).await, Ok(None));
let decision = shc.handle_message(&mut context, precommits[3].clone()).await.unwrap().unwrap();
assert_eq!(decision.block, block);
assert!(
decision
Expand Down Expand Up @@ -119,29 +125,35 @@ async fn validator() {
.returning(move |_| Ok(()));
let res = shc
.handle_proposal(
&context,
&mut context,
ProposalInit { height: BlockNumber(0), proposer },
mpsc::channel(1).1, // content - ignored by SHC.
fin_receiver,
)
.await;
assert_eq!(res, Ok(None));

assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(
shc.handle_message(&mut context, prevote(block.id(), 0, 2_u32.into())).await,
Ok(None)
);
// 3 of 4 Prevotes is enough to send a Precommit.
context
.expect_broadcast()
.withf(move |msg: &ConsensusMessage| msg == &precommit(block_id, 0, node_id))
.returning(move |_| Ok(()));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None));
assert_eq!(
shc.handle_message(&mut context, prevote(block.id(), 0, 3_u32.into())).await,
Ok(None)
);

let precommits = vec![
precommit(block.id(), 0, 2_u32.into()),
precommit(block.id(), 0, 3_u32.into()),
precommit(block.id(), 0, node_id),
];
assert_eq!(shc.handle_message(&context, precommits[0].clone()).await, Ok(None));
let decision = shc.handle_message(&context, precommits[1].clone()).await.unwrap().unwrap();
assert_eq!(shc.handle_message(&mut context, precommits[0].clone()).await, Ok(None));
let decision = shc.handle_message(&mut context, precommits[1].clone()).await.unwrap().unwrap();
assert_eq!(decision.block, block);
assert!(
decision
Expand Down
2 changes: 1 addition & 1 deletion crates/sequencing/papyrus_consensus/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ mock! {

fn proposer(&self, validators: &[ValidatorId], height: BlockNumber) -> ValidatorId;

async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError>;
async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError>;

async fn propose(
&self,
Expand Down
2 changes: 1 addition & 1 deletion crates/sequencing/papyrus_consensus/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub trait ConsensusContext {
/// Calculates the ID of the Proposer based on the inputs.
fn proposer(&self, validators: &[ValidatorId], height: BlockNumber) -> ValidatorId;

async fn broadcast(&self, message: ConsensusMessage) -> Result<(), ConsensusError>;
async fn broadcast(&mut self, message: ConsensusMessage) -> Result<(), ConsensusError>;

/// This should be non-blocking. Meaning it returns immediately and waits to receive from the
/// input channels in parallel (ie on a separate task).
Expand Down

0 comments on commit b0ddf06

Please sign in to comment.