Skip to content

Commit

Permalink
feat(starknet_l1_provider): add height arg (#3045)
Browse files Browse the repository at this point in the history
To ensure soundness. Otherwise bugs in batcher/consensus that try to
calculate blocks with future heights may cause double-proposing of a tx,
which causes a reorg for l1handler txs.

Co-authored-by: Gilad Chase <[email protected]>
  • Loading branch information
giladchase and Gilad Chase authored Jan 7, 2025
1 parent 6e57a80 commit 7e1a97b
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 41 deletions.
2 changes: 2 additions & 0 deletions crates/starknet_batcher/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ impl Batcher {
self.mempool_client.clone(),
self.l1_provider_client.clone(),
self.config.max_l1_handler_txs_per_block_proposal,
propose_block_input.block_info.block_number,
);

// A channel to receive the transactions included in the proposed block.
Expand Down Expand Up @@ -213,6 +214,7 @@ impl Batcher {
let tx_provider = ValidateTransactionProvider {
tx_receiver: input_tx_receiver,
l1_provider_client: self.l1_provider_client.clone(),
height: validate_block_input.block_info.block_number,
};

let (block_builder, abort_signal_sender) = self
Expand Down
10 changes: 8 additions & 2 deletions crates/starknet_batcher/src/transaction_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::vec;
use async_trait::async_trait;
#[cfg(test)]
use mockall::automock;
use starknet_api::block::BlockNumber;
use starknet_api::executable_transaction::Transaction;
use starknet_api::transaction::TransactionHash;
use starknet_l1_provider_types::errors::L1ProviderClientError;
Expand Down Expand Up @@ -40,6 +41,7 @@ pub struct ProposeTransactionProvider {
pub mempool_client: SharedMempoolClient,
pub l1_provider_client: SharedL1ProviderClient,
pub max_l1_handler_txs_per_block: usize,
pub height: BlockNumber,
phase: TxProviderPhase,
n_l1handler_txs_so_far: usize,
}
Expand All @@ -56,11 +58,13 @@ impl ProposeTransactionProvider {
mempool_client: SharedMempoolClient,
l1_provider_client: SharedL1ProviderClient,
max_l1_handler_txs_per_block: usize,
height: BlockNumber,
) -> Self {
Self {
mempool_client,
l1_provider_client,
max_l1_handler_txs_per_block,
height,
phase: TxProviderPhase::L1,
n_l1handler_txs_so_far: 0,
}
Expand All @@ -72,7 +76,7 @@ impl ProposeTransactionProvider {
) -> TransactionProviderResult<Vec<Transaction>> {
Ok(self
.l1_provider_client
.get_txs(n_txs)
.get_txs(n_txs, self.height)
.await?
.into_iter()
.map(Transaction::L1Handler)
Expand Down Expand Up @@ -127,6 +131,7 @@ impl TransactionProvider for ProposeTransactionProvider {
pub struct ValidateTransactionProvider {
pub tx_receiver: tokio::sync::mpsc::Receiver<Transaction>,
pub l1_provider_client: SharedL1ProviderClient,
pub height: BlockNumber,
}

#[async_trait]
Expand All @@ -142,7 +147,8 @@ impl TransactionProvider for ValidateTransactionProvider {
}
for tx in &buffer {
if let Transaction::L1Handler(tx) = tx {
let l1_validation_status = self.l1_provider_client.validate(tx.tx_hash).await?;
let l1_validation_status =
self.l1_provider_client.validate(tx.tx_hash, self.height).await?;
if l1_validation_status != L1ValidationStatus::Validated {
// TODO: add the validation status into the error.
return Err(TransactionProviderError::L1HandlerTransactionValidationFailed(
Expand Down
12 changes: 8 additions & 4 deletions crates/starknet_batcher/src/transaction_provider_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use assert_matches::assert_matches;
use mockall::predicate::eq;
use rstest::{fixture, rstest};
use starknet_api::block::BlockNumber;
use starknet_api::executable_transaction::{L1HandlerTransaction, Transaction};
use starknet_api::test_utils::invoke::{executable_invoke_tx, InvokeTxArgs};
use starknet_api::tx_hash;
Expand All @@ -18,6 +19,7 @@ use crate::transaction_provider::{
};

const MAX_L1_HANDLER_TXS_PER_BLOCK: usize = 15;
const HEIGHT: BlockNumber = BlockNumber(1);
const MAX_TXS_PER_FETCH: usize = 10;
const VALIDATE_BUFFER_SIZE: usize = 30;

Expand All @@ -32,8 +34,8 @@ impl MockDependencies {
fn expect_get_l1_handler_txs(&mut self, n_to_request: usize, n_to_return: usize) {
self.l1_provider_client
.expect_get_txs()
.with(eq(n_to_request))
.returning(move |_| Ok(vec![L1HandlerTransaction::default(); n_to_return]));
.with(eq(n_to_request), eq(HEIGHT))
.returning(move |_, _| Ok(vec![L1HandlerTransaction::default(); n_to_return]));
}

fn expect_get_mempool_txs(&mut self, n_to_request: usize) {
Expand All @@ -45,8 +47,8 @@ impl MockDependencies {
fn expect_validate_l1handler(&mut self, tx: L1HandlerTransaction, result: L1ValidationStatus) {
self.l1_provider_client
.expect_validate()
.withf(move |tx_arg| tx_arg == &tx.tx_hash)
.returning(move |_| Ok(result));
.withf(move |tx_arg, height| tx_arg == &tx.tx_hash && *height == HEIGHT)
.returning(move |_, _| Ok(result));
}

async fn simulate_input_txs(&mut self, txs: Vec<Transaction>) {
Expand All @@ -60,13 +62,15 @@ impl MockDependencies {
Arc::new(self.mempool_client),
Arc::new(self.l1_provider_client),
MAX_L1_HANDLER_TXS_PER_BLOCK,
HEIGHT,
)
}

fn validate_tx_provider(self) -> ValidateTransactionProvider {
ValidateTransactionProvider {
tx_receiver: self.tx_receiver,
l1_provider_client: Arc::new(self.l1_provider_client),
height: HEIGHT,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/starknet_l1_provider/src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ impl ComponentRequestHandler<L1ProviderRequest, L1ProviderResponse> for L1Provid
#[instrument(skip(self))]
async fn handle_request(&mut self, request: L1ProviderRequest) -> L1ProviderResponse {
match request {
L1ProviderRequest::GetTransactions(n_txs) => {
L1ProviderResponse::GetTransactions(self.get_txs(n_txs))
L1ProviderRequest::GetTransactions { n_txs, height } => {
L1ProviderResponse::GetTransactions(self.get_txs(n_txs, height))
}
}
}
Expand Down
53 changes: 33 additions & 20 deletions crates/starknet_l1_provider/src/l1_provider_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use starknet_api::block::BlockNumber;
use starknet_api::test_utils::l1_handler::executable_l1_handler_tx;
use starknet_api::transaction::TransactionHash;
use starknet_api::{l1_handler_tx_args, tx_hash};
Expand Down Expand Up @@ -30,27 +31,39 @@ fn get_txs_happy_flow() {
.build_into_l1_provider();

// Test.
assert_eq!(l1_provider.get_txs(0).unwrap(), []);
assert_eq!(l1_provider.get_txs(1).unwrap(), [txs[0].clone()]);
assert_eq!(l1_provider.get_txs(3).unwrap(), txs[1..=2]);
assert_eq!(l1_provider.get_txs(1).unwrap(), []);
assert_eq!(l1_provider.get_txs(0, BlockNumber(1)).unwrap(), []);
assert_eq!(l1_provider.get_txs(1, BlockNumber(1)).unwrap(), [txs[0].clone()]);
assert_eq!(l1_provider.get_txs(3, BlockNumber(1)).unwrap(), txs[1..=2]);
assert_eq!(l1_provider.get_txs(1, BlockNumber(1)).unwrap(), []);
}

#[test]
fn validate_happy_flow() {
// Setup.
let l1_provider = L1ProviderContentBuilder::new()
let mut l1_provider = L1ProviderContentBuilder::new()
.with_txs([tx!(tx_hash: 1)])
.with_on_l2_awaiting_l1_consumption([tx_hash!(2)])
.with_state(Validate)
.build_into_l1_provider();

// Test.
assert_eq!(l1_provider.validate(tx_hash!(1)).unwrap(), ValidationStatus::Validated);
assert_eq!(l1_provider.validate(tx_hash!(2)).unwrap(), ValidationStatus::AlreadyIncludedOnL2);
assert_eq!(l1_provider.validate(tx_hash!(3)).unwrap(), ValidationStatus::ConsumedOnL1OrUnknown);
assert_eq!(
l1_provider.validate(tx_hash!(1), BlockNumber(1)).unwrap(),
ValidationStatus::Validated
);
assert_eq!(
l1_provider.validate(tx_hash!(2), BlockNumber(1)).unwrap(),
ValidationStatus::AlreadyIncludedOnL2
);
assert_eq!(
l1_provider.validate(tx_hash!(3), BlockNumber(1)).unwrap(),
ValidationStatus::ConsumedOnL1OrUnknown
);
// Transaction wasn't deleted after the validation.
assert_eq!(l1_provider.validate(tx_hash!(1)).unwrap(), ValidationStatus::Validated);
assert_eq!(
l1_provider.validate(tx_hash!(1), BlockNumber(1)).unwrap(),
ValidationStatus::Validated
);
}

#[test]
Expand All @@ -63,12 +76,12 @@ fn pending_state_errors() {

// Test.
assert_matches!(
l1_provider.get_txs(1).unwrap_err(),
l1_provider.get_txs(1, BlockNumber(1)).unwrap_err(),
L1ProviderError::GetTransactionsInPendingState
);

assert_matches!(
l1_provider.validate(tx_hash!(1)).unwrap_err(),
l1_provider.validate(tx_hash!(1), BlockNumber(1)).unwrap_err(),
L1ProviderError::ValidateInPendingState
);
}
Expand All @@ -79,16 +92,16 @@ fn uninitialized_get_txs() {
let mut uninitialized_l1_provider = L1Provider::default();
assert_eq!(uninitialized_l1_provider.state, Uninitialized);

uninitialized_l1_provider.get_txs(1).unwrap();
uninitialized_l1_provider.get_txs(1, BlockNumber(1)).unwrap();
}

#[test]
#[should_panic(expected = "Uninitialized L1 provider")]
fn uninitialized_validate() {
let uninitialized_l1_provider = L1Provider::default();
let mut uninitialized_l1_provider = L1Provider::default();
assert_eq!(uninitialized_l1_provider.state, Uninitialized);

uninitialized_l1_provider.validate(TransactionHash::default()).unwrap();
uninitialized_l1_provider.validate(TransactionHash::default(), BlockNumber(1)).unwrap();
}

#[test]
Expand All @@ -97,14 +110,14 @@ fn proposal_start_errors() {
let mut l1_provider =
L1ProviderContentBuilder::new().with_state(Pending).build_into_l1_provider();
// Test.
l1_provider.proposal_start().unwrap();
l1_provider.proposal_start(BlockNumber(1)).unwrap();

assert_eq!(
l1_provider.proposal_start().unwrap_err(),
l1_provider.proposal_start(BlockNumber(1)).unwrap_err(),
L1ProviderError::unexpected_transition(Propose, Propose)
);
assert_eq!(
l1_provider.validation_start().unwrap_err(),
l1_provider.validation_start(BlockNumber(1)).unwrap_err(),
L1ProviderError::unexpected_transition(Propose, Validate)
);
}
Expand All @@ -116,14 +129,14 @@ fn validation_start_errors() {
L1ProviderContentBuilder::new().with_state(Pending).build_into_l1_provider();

// Test.
l1_provider.validation_start().unwrap();
l1_provider.validation_start(BlockNumber(1)).unwrap();

assert_eq!(
l1_provider.validation_start().unwrap_err(),
l1_provider.validation_start(BlockNumber(1)).unwrap_err(),
L1ProviderError::unexpected_transition(Validate, Validate)
);
assert_eq!(
l1_provider.proposal_start().unwrap_err(),
l1_provider.proposal_start(BlockNumber(1)).unwrap_err(),
L1ProviderError::unexpected_transition(Validate, Propose)
);
}
36 changes: 29 additions & 7 deletions crates/starknet_l1_provider/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use papyrus_config::converters::deserialize_milliseconds_to_duration;
use papyrus_config::dumping::{ser_param, SerializeConfig};
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::{Deserialize, Serialize};
use starknet_api::block::BlockNumber;
use starknet_api::executable_transaction::L1HandlerTransaction;
use starknet_api::transaction::TransactionHash;
use starknet_l1_provider_types::errors::L1ProviderError;
Expand All @@ -32,6 +33,7 @@ pub struct L1Provider {
// TODO(Gilad): consider transitioning to a generic phantom state once the infra is stabilized
// and we see how well it handles consuming the L1Provider when moving between states.
state: ProviderState,
current_height: BlockNumber,
}

impl L1Provider {
Expand All @@ -40,7 +42,14 @@ impl L1Provider {
}

/// Retrieves up to `n_txs` transactions that have yet to be proposed or accepted on L2.
pub fn get_txs(&mut self, n_txs: usize) -> L1ProviderResult<Vec<L1HandlerTransaction>> {
pub fn get_txs(
&mut self,
n_txs: usize,
height: BlockNumber,
) -> L1ProviderResult<Vec<L1HandlerTransaction>> {
// Reenable once `commit_block` is implemented so that height can be updated.
let _disabled = self.validate_height(height);

match self.state {
ProviderState::Propose => Ok(self.tx_manager.get_txs(n_txs)),
ProviderState::Pending => Err(L1ProviderError::GetTransactionsInPendingState),
Expand All @@ -51,7 +60,12 @@ impl L1Provider {

/// Returns true if and only if the given transaction is both not included in an L2 block, and
/// unconsumed on L1.
pub fn validate(&self, tx_hash: TransactionHash) -> L1ProviderResult<ValidationStatus> {
pub fn validate(
&mut self,
tx_hash: TransactionHash,
height: BlockNumber,
) -> L1ProviderResult<ValidationStatus> {
self.validate_height(height)?;
match self.state {
ProviderState::Validate => Ok(self.tx_manager.tx_status(tx_hash)),
ProviderState::Propose => Err(L1ProviderError::ValidateTransactionConsensusBug),
Expand All @@ -62,21 +76,21 @@ impl L1Provider {

// TODO: when deciding on consensus, if possible, have commit_block also tell the node if it's
// about to [optimistically-]propose or validate the next block.
pub fn commit_block(&mut self, _commited_txs: &[TransactionHash]) {
pub fn commit_block(&mut self, _commited_txs: &[TransactionHash], _height: BlockNumber) {
todo!(
"Purges txs from internal buffers, if was proposer clear staging buffer,
reset state to Pending until we get proposing/validating notice from consensus."
)
}

// TODO: pending formal consensus API, guessing the API here to keep things moving.
// TODO: consider adding block number, it isn't strictly necessary, but will help debugging.
pub fn validation_start(&mut self) -> L1ProviderResult<()> {
pub fn validation_start(&mut self, height: BlockNumber) -> L1ProviderResult<()> {
self.validate_height(height)?;
self.state = self.state.transition_to_validate()?;
Ok(())
}

pub fn proposal_start(&mut self) -> L1ProviderResult<()> {
pub fn proposal_start(&mut self, height: BlockNumber) -> L1ProviderResult<()> {
self.validate_height(height)?;
self.state = self.state.transition_to_propose()?;
Ok(())
}
Expand All @@ -95,6 +109,14 @@ impl L1Provider {
Then, transition to Pending."
);
}

fn validate_height(&mut self, height: BlockNumber) -> L1ProviderResult<()> {
let next_height = self.current_height.unchecked_next();
if height != next_height {
return Err(L1ProviderError::UnexpectedHeight { expected: next_height, got: height });
}
Ok(())
}
}

impl ComponentStarter for L1Provider {}
Expand Down
4 changes: 4 additions & 0 deletions crates/starknet_l1_provider/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use indexmap::{IndexMap, IndexSet};
use starknet_api::block::BlockNumber;
use starknet_api::executable_transaction::L1HandlerTransaction;
use starknet_api::transaction::TransactionHash;

Expand All @@ -10,6 +11,7 @@ use crate::{L1Provider, ProviderState, TransactionManager};
pub struct L1ProviderContent {
tx_manager_content: Option<TransactionManagerContent>,
state: Option<ProviderState>,
current_height: BlockNumber,
}

impl From<L1ProviderContent> for L1Provider {
Expand All @@ -20,6 +22,7 @@ impl From<L1ProviderContent> for L1Provider {
.map(|tm_content| tm_content.complete_to_tx_manager())
.unwrap_or_default(),
state: content.state.unwrap_or_default(),
current_height: content.current_height,
}
}
}
Expand Down Expand Up @@ -58,6 +61,7 @@ impl L1ProviderContentBuilder {
L1ProviderContent {
tx_manager_content: self.tx_manager_content_builder.build(),
state: self.state,
..Default::default()
}
}

Expand Down
3 changes: 3 additions & 0 deletions crates/starknet_l1_provider_types/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;

use serde::{Deserialize, Serialize};
use starknet_api::block::BlockNumber;
use starknet_sequencer_infra::component_client::ClientError;
use thiserror::Error;

Expand All @@ -12,6 +13,8 @@ pub enum L1ProviderError {
GetTransactionsInPendingState,
#[error("`get_txs` while in validate state")]
GetTransactionConsensusBug,
#[error("Unexpected height: expected {expected}, got {got}")]
UnexpectedHeight { expected: BlockNumber, got: BlockNumber },
#[error("Cannot transition from {from} to {to}")]
UnexpectedProviderStateTransition { from: String, to: String },
#[error(
Expand Down
Loading

0 comments on commit 7e1a97b

Please sign in to comment.