Skip to content

Commit

Permalink
refactor(consensus): the Context is passed as a param instead of bein…
Browse files Browse the repository at this point in the history
…g held as a field by SHC (#2238)
  • Loading branch information
matan-starkware authored Jul 22, 2024
1 parent 537e11e commit 2db883b
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 80 deletions.
2 changes: 1 addition & 1 deletion crates/papyrus_node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn run_consensus(
let start_height = config.start_height;

Ok(tokio::spawn(papyrus_consensus::run_consensus(
Arc::new(context),
context,
start_height,
validator_id,
consensus_channels.broadcasted_messages_receiver,
Expand Down
21 changes: 9 additions & 12 deletions crates/sequencing/papyrus_consensus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// TODO(Matan): fix #[allow(missing_docs)].
//! A consensus implementation for a [`Starknet`](https://www.starknet.io/) node.
use std::sync::Arc;

use futures::channel::{mpsc, oneshot};
use papyrus_common::metrics as papyrus_metrics;
use papyrus_network::network_manager::BroadcastSubscriberReceiver;
Expand Down Expand Up @@ -37,8 +35,8 @@ use futures::StreamExt;

#[instrument(skip(context, validator_id, network_receiver, cached_messages), level = "info")]
#[allow(missing_docs)]
async fn run_height<BlockT: ConsensusBlock>(
context: Arc<dyn ConsensusContext<Block = BlockT>>,
async fn run_height<BlockT: ConsensusBlock, ContextT: ConsensusContext<Block = BlockT>>(
context: &ContextT,
height: BlockNumber,
validator_id: ValidatorId,
network_receiver: &mut BroadcastSubscriberReceiver<ConsensusMessage>,
Expand All @@ -49,10 +47,9 @@ where
Into<(ProposalInit, mpsc::Receiver<BlockT::ProposalChunk>, oneshot::Receiver<BlockHash>)>,
{
let validators = context.validators(height).await;
let mut shc =
SingleHeightConsensus::new(Arc::clone(&context), height, validator_id, validators);
let mut shc = SingleHeightConsensus::new(height, validator_id, validators);

if let Some(decision) = shc.start().await? {
if let Some(decision) = shc.start(context).await? {
return Ok(decision);
}

Expand Down Expand Up @@ -91,9 +88,9 @@ where
// Special case due to fake streaming.
let (proposal_init, content_receiver, fin_receiver) =
ProposalWrapper(proposal).into();
shc.handle_proposal(proposal_init, content_receiver, fin_receiver).await?
shc.handle_proposal(context, proposal_init, content_receiver, fin_receiver).await?
}
_ => shc.handle_message(message).await?,
_ => shc.handle_message(context, message).await?,
};

if let Some(decision) = maybe_decision {
Expand All @@ -105,8 +102,8 @@ where
// TODO(dvir): add test for this.
#[instrument(skip(context, start_height, network_receiver), level = "info")]
#[allow(missing_docs)]
pub async fn run_consensus<BlockT: ConsensusBlock>(
context: Arc<dyn ConsensusContext<Block = BlockT>>,
pub async fn run_consensus<BlockT: ConsensusBlock, ContextT: ConsensusContext<Block = BlockT>>(
context: ContextT,
start_height: BlockNumber,
validator_id: ValidatorId,
mut network_receiver: BroadcastSubscriberReceiver<ConsensusMessage>,
Expand All @@ -119,7 +116,7 @@ where
let mut future_messages = Vec::new();
loop {
let decision = run_height(
Arc::clone(&context),
&context,
current_height,
validator_id,
&mut network_receiver,
Expand Down
71 changes: 37 additions & 34 deletions crates/sequencing/papyrus_consensus/src/single_height_consensus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod single_height_consensus_test;

use std::collections::{HashMap, VecDeque};
use std::sync::Arc;

use futures::channel::{mpsc, oneshot};
use papyrus_protobuf::consensus::{ConsensusMessage, Vote, VoteType};
Expand All @@ -29,7 +28,6 @@ const ROUND_ZERO: Round = 0;
/// out messages "directly" to the network, and returning a decision to the caller.
pub(crate) struct SingleHeightConsensus<BlockT: ConsensusBlock> {
height: BlockNumber,
context: Arc<dyn ConsensusContext<Block = BlockT>>,
validators: Vec<ValidatorId>,
id: ValidatorId,
state_machine: StateMachine,
Expand All @@ -39,17 +37,11 @@ pub(crate) struct SingleHeightConsensus<BlockT: ConsensusBlock> {
}

impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
pub(crate) fn new(
context: Arc<dyn ConsensusContext<Block = BlockT>>,
height: BlockNumber,
id: ValidatorId,
validators: Vec<ValidatorId>,
) -> Self {
pub(crate) fn new(height: BlockNumber, id: ValidatorId, validators: Vec<ValidatorId>) -> Self {
// TODO(matan): Use actual weights, not just `len`.
let state_machine = StateMachine::new(validators.len() as u32);
Self {
height,
context,
validators,
id,
state_machine,
Expand All @@ -59,22 +51,26 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
}
}

#[instrument(skip(self), fields(height=self.height.0), level = "debug")]
pub(crate) async fn start(&mut self) -> Result<Option<Decision<BlockT>>, ConsensusError> {
#[instrument(skip_all, fields(height=self.height.0), level = "debug")]
pub(crate) async fn start<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
info!("Starting consensus with validators {:?}", self.validators);
let events = self.state_machine.start();
self.handle_state_machine_events(events).await
self.handle_state_machine_events(context, events).await
}

/// Receive a proposal from a peer node. Returns only once the proposal has been fully received
/// and processed.
#[instrument(
skip(self, init, p2p_messages_receiver, fin_receiver),
skip_all,
fields(height = %self.height),
level = "debug",
)]
pub(crate) async fn handle_proposal(
pub(crate) async fn handle_proposal<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
init: ProposalInit,
p2p_messages_receiver: mpsc::Receiver<<BlockT as ConsensusBlock>::ProposalChunk>,
fin_receiver: oneshot::Receiver<BlockHash>,
Expand All @@ -83,7 +79,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
"Received proposal: proposal_height={}, proposer={:?}",
init.height.0, init.proposer
);
let proposer_id = self.context.proposer(&self.validators, self.height);
let proposer_id = context.proposer(&self.validators, self.height);
if init.height != self.height {
let msg = format!("invalid height: expected {:?}, got {:?}", self.height, init.height);
return Err(ConsensusError::InvalidProposal(proposer_id, self.height, msg));
Expand All @@ -94,8 +90,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
return Err(ConsensusError::InvalidProposal(proposer_id, self.height, msg));
}

let block_receiver =
self.context.validate_proposal(self.height, p2p_messages_receiver).await;
let block_receiver = context.validate_proposal(self.height, p2p_messages_receiver).await;
// TODO(matan): Actual Tendermint should handle invalid proposals.
let block = block_receiver.await.map_err(|_| {
ConsensusError::InvalidProposal(
Expand Down Expand Up @@ -124,27 +119,29 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
// TODO(matan): Handle multiple rounds.
self.proposals.insert(ROUND_ZERO, block);
let sm_events = self.state_machine.handle_event(sm_proposal);
self.handle_state_machine_events(sm_events).await
self.handle_state_machine_events(context, sm_events).await
}

/// Handle messages from peer nodes.
#[instrument(skip_all)]
pub(crate) async fn handle_message(
pub(crate) async fn handle_message<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
message: ConsensusMessage,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
debug!("Received message: {:?}", message);
match message {
ConsensusMessage::Proposal(_) => {
unimplemented!("Proposals should use `handle_proposal` due to fake streaming")
}
ConsensusMessage::Vote(vote) => self.handle_vote(vote).await,
ConsensusMessage::Vote(vote) => self.handle_vote(context, vote).await,
}
}

#[instrument(skip_all)]
async fn handle_vote(
async fn handle_vote<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
vote: Vote,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
let (votes, sm_vote) = match vote.vote_type {
Expand All @@ -170,21 +167,24 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {

votes.insert((ROUND_ZERO, vote.voter), vote);
let sm_events = self.state_machine.handle_event(sm_vote);
self.handle_state_machine_events(sm_events).await
self.handle_state_machine_events(context, sm_events).await
}

// Handle events output by the state machine.
#[instrument(skip_all)]
async fn handle_state_machine_events(
async fn handle_state_machine_events<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
mut events: VecDeque<StateMachineEvent>,
) -> Result<Option<Decision<BlockT>>, ConsensusError> {
while let Some(event) = events.pop_front() {
trace!("Handling event: {:?}", event);
match event {
StateMachineEvent::StartRound(block_hash, round) => {
events.append(
&mut self.handle_state_machine_start_round(block_hash, round).await,
&mut self
.handle_state_machine_start_round(context, block_hash, round)
.await,
);
}
StateMachineEvent::Proposal(_, _) => {
Expand All @@ -195,37 +195,39 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
return self.handle_state_machine_decision(block_hash, round).await;
}
StateMachineEvent::Prevote(block_hash, round) => {
self.handle_state_machine_vote(block_hash, round, VoteType::Prevote).await?;
self.handle_state_machine_vote(context, block_hash, round, VoteType::Prevote)
.await?;
}
StateMachineEvent::Precommit(block_hash, round) => {
self.handle_state_machine_vote(block_hash, round, VoteType::Precommit).await?;
self.handle_state_machine_vote(context, block_hash, round, VoteType::Precommit)
.await?;
}
}
}
Ok(None)
}

#[instrument(skip(self), level = "debug")]
async fn handle_state_machine_start_round(
#[instrument(skip(self, context), level = "debug")]
async fn handle_state_machine_start_round<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
block_hash: Option<BlockHash>,
round: Round,
) -> VecDeque<StateMachineEvent> {
// TODO(matan): Support re-proposing validValue.
assert!(block_hash.is_none(), "Reproposing is not yet supported");
let proposer_id = self.context.proposer(&self.validators, self.height);
let proposer_id = context.proposer(&self.validators, self.height);
if proposer_id != self.id {
debug!("Validator");
return self.state_machine.handle_event(StateMachineEvent::StartRound(None, round));
}
debug!("Proposer");

let (p2p_messages_receiver, block_receiver) =
self.context.build_proposal(self.height).await;
let (p2p_messages_receiver, block_receiver) = context.build_proposal(self.height).await;
let (fin_sender, fin_receiver) = oneshot::channel();
let init = ProposalInit { height: self.height, proposer: self.id };
// Peering is a permanent component, so if sending to it fails we cannot continue.
self.context
context
.propose(init, p2p_messages_receiver, fin_receiver)
.await
.expect("Failed sending Proposal to Peering");
Expand All @@ -245,8 +247,9 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
}

#[instrument(skip_all)]
async fn handle_state_machine_vote(
async fn handle_state_machine_vote<ContextT: ConsensusContext<Block = BlockT>>(
&mut self,
context: &ContextT,
block_hash: BlockHash,
round: Round,
vote_type: VoteType,
Expand All @@ -260,7 +263,7 @@ impl<BlockT: ConsensusBlock> SingleHeightConsensus<BlockT> {
// TODO(matan): Consider refactoring not to panic, rather log and return the error.
panic!("State machine should not send repeat votes: old={:?}, new={:?}", old, vote);
}
self.context.broadcast(ConsensusMessage::Vote(vote)).await?;
context.broadcast(ConsensusMessage::Vote(vote)).await?;
Ok(None)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,26 @@ async fn proposer() {
.returning(move |_| Ok(()));

let mut shc = SingleHeightConsensus::new(
Arc::new(context),
BlockNumber(0),
node_id,
vec![node_id, 2_u32.into(), 3_u32.into(), 4_u32.into()],
);

// Sends proposal and prevote.
assert!(matches!(shc.start().await, Ok(None)));
assert!(matches!(shc.start(&context).await, Ok(None)));

assert_eq!(shc.handle_message(prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(prevote(block.id(), 0, 3_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None));

let precommits = vec![
precommit(block.id(), 0, 1_u32.into()),
precommit(BlockHash(Felt::TWO), 0, 4_u32.into()), // Ignores since disagrees.
precommit(block.id(), 0, 2_u32.into()),
precommit(block.id(), 0, 3_u32.into()),
];
assert_eq!(shc.handle_message(precommits[1].clone()).await, Ok(None));
assert_eq!(shc.handle_message(precommits[2].clone()).await, Ok(None));
let decision = shc.handle_message(precommits[3].clone()).await.unwrap().unwrap();
assert_eq!(shc.handle_message(&context, precommits[1].clone()).await, Ok(None));
assert_eq!(shc.handle_message(&context, precommits[2].clone()).await, Ok(None));
let decision = shc.handle_message(&context, precommits[3].clone()).await.unwrap().unwrap();
assert_eq!(decision.block, block);
assert!(
decision
Expand Down Expand Up @@ -116,7 +115,6 @@ async fn validator() {

// Creation calls to `context.validators`.
let mut shc = SingleHeightConsensus::new(
Arc::new(context),
BlockNumber(0),
node_id,
vec![node_id, proposer, 3_u32.into(), 4_u32.into()],
Expand All @@ -128,23 +126,24 @@ async fn validator() {

let res = shc
.handle_proposal(
&context,
ProposalInit { height: BlockNumber(0), proposer },
mpsc::channel(1).1, // content - ignored by SHC.
fin_receiver,
)
.await;
assert_eq!(res, Ok(None));

assert_eq!(shc.handle_message(prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(prevote(block.id(), 0, 3_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 2_u32.into())).await, Ok(None));
assert_eq!(shc.handle_message(&context, prevote(block.id(), 0, 3_u32.into())).await, Ok(None));

let precommits = vec![
precommit(block.id(), 0, 2_u32.into()),
precommit(block.id(), 0, 3_u32.into()),
precommit(block.id(), 0, node_id),
];
assert_eq!(shc.handle_message(precommits[0].clone()).await, Ok(None));
let decision = shc.handle_message(precommits[1].clone()).await.unwrap().unwrap();
assert_eq!(shc.handle_message(&context, precommits[0].clone()).await, Ok(None));
let decision = shc.handle_message(&context, precommits[1].clone()).await.unwrap().unwrap();
assert_eq!(decision.block, block);
assert!(
decision
Expand Down
12 changes: 1 addition & 11 deletions crates/sequencing/papyrus_consensus/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#[cfg(test)]
#[path = "types_test.rs"]
mod types_test;

use std::fmt::Debug;

use async_trait::async_trait;
Expand Down Expand Up @@ -68,14 +64,8 @@ pub trait ConsensusBlock: Send {
}

/// Interface for consensus to call out to the node.
// Why `Send + Sync`?
// 1. We expect multiple components within consensus to concurrently access the context.
// 2. The other option is for each component to have its own copy (i.e. clone) of the context, but
// this is object unsafe (Clone requires Sized).
// 3. Given that we see the context as basically a connector to other components in the node, the
// limitation of Sync to keep functions `&self` shouldn't be a problem.
#[async_trait]
pub trait ConsensusContext: Send + Sync {
pub trait ConsensusContext {
/// The [block](`ConsensusBlock`) type built by `ConsensusContext` from a proposal.
// We use an associated type since consensus is indifferent to the actual content of a proposal,
// but we cannot use generics due to object safety.
Expand Down
10 changes: 0 additions & 10 deletions crates/sequencing/papyrus_consensus/src/types_test.rs

This file was deleted.

0 comments on commit 2db883b

Please sign in to comment.