Skip to content

Commit

Permalink
refactor(starknet_batcher): refactor batcher test mock dependencies (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dafnamatsry authored Dec 2, 2024
1 parent 3fb46f2 commit d1ad687
Showing 1 changed file with 93 additions and 109 deletions.
202 changes: 93 additions & 109 deletions crates/starknet_batcher/src/batcher_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use futures::future::BoxFuture;
use futures::FutureExt;
use mockall::automock;
use mockall::predicate::{always, eq};
use rstest::{fixture, rstest};
use rstest::rstest;
use starknet_api::block::BlockNumber;
use starknet_api::core::{ContractAddress, Nonce, StateDiffCommitment};
use starknet_api::executable_transaction::Transaction;
Expand Down Expand Up @@ -72,64 +72,39 @@ fn deadline() -> chrono::DateTime<Utc> {
chrono::Utc::now() + BLOCK_GENERATION_TIMEOUT
}

#[fixture]
fn storage_reader() -> MockBatcherStorageReaderTrait {
let mut storage = MockBatcherStorageReaderTrait::new();
storage.expect_height().returning(|| Ok(INITIAL_HEIGHT));
storage
}

#[fixture]
fn storage_writer() -> MockBatcherStorageWriterTrait {
MockBatcherStorageWriterTrait::new()
}

#[fixture]
fn batcher_config() -> BatcherConfig {
BatcherConfig { outstream_content_buffer_size: STREAMING_CHUNK_SIZE, ..Default::default() }
struct MockDependencies {
storage_reader: MockBatcherStorageReaderTrait,
storage_writer: MockBatcherStorageWriterTrait,
mempool_client: MockMempoolClient,
proposal_manager: MockProposalManagerTraitWrapper,
block_builder_factory: MockBlockBuilderFactoryTrait,
}

#[fixture]
fn mempool_client() -> MockMempoolClient {
MockMempoolClient::new()
impl Default for MockDependencies {
fn default() -> Self {
let mut storage_reader = MockBatcherStorageReaderTrait::new();
storage_reader.expect_height().returning(|| Ok(INITIAL_HEIGHT));
Self {
storage_reader,
storage_writer: MockBatcherStorageWriterTrait::new(),
mempool_client: MockMempoolClient::new(),
proposal_manager: MockProposalManagerTraitWrapper::new(),
block_builder_factory: MockBlockBuilderFactoryTrait::new(),
}
}
}

fn batcher(proposal_manager: MockProposalManagerTraitWrapper) -> Batcher {
fn create_batcher(mock_dependencies: MockDependencies) -> Batcher {
Batcher::new(
batcher_config(),
Arc::new(storage_reader()),
Box::new(storage_writer()),
Arc::new(mempool_client()),
Box::new(MockBlockBuilderFactoryTrait::new()),
Box::new(proposal_manager),
BatcherConfig { outstream_content_buffer_size: STREAMING_CHUNK_SIZE, ..Default::default() },
Arc::new(mock_dependencies.storage_reader),
Box::new(mock_dependencies.storage_writer),
Arc::new(mock_dependencies.mempool_client),
Box::new(mock_dependencies.block_builder_factory),
Box::new(mock_dependencies.proposal_manager),
)
}

fn create_batcher(
proposal_manager: MockProposalManagerTraitWrapper,
block_builder_factory: MockBlockBuilderFactoryTrait,
) -> Batcher {
Batcher::new(
batcher_config(),
Arc::new(storage_reader()),
Box::new(storage_writer()),
Arc::new(mempool_client()),
Box::new(block_builder_factory),
Box::new(proposal_manager),
)
}

fn mock_proposal_manager_common_expectations(
proposal_manager: &mut MockProposalManagerTraitWrapper,
) {
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());
proposal_manager
.expect_wrap_await_proposal_commitment()
.times(1)
.with(eq(PROPOSAL_ID))
.return_once(move |_| { async move { Ok(proposal_commitment()) } }.boxed());
}

fn abort_signal_sender() -> AbortSignalSender {
tokio::sync::oneshot::channel().0
}
Expand Down Expand Up @@ -168,6 +143,17 @@ fn mock_create_builder_for_propose_block(
block_builder_factory
}

fn mock_proposal_manager_common_expectations(
proposal_manager: &mut MockProposalManagerTraitWrapper,
) {
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());
proposal_manager
.expect_wrap_await_proposal_commitment()
.times(1)
.with(eq(PROPOSAL_ID))
.return_once(move |_| { async move { Ok(proposal_commitment()) } }.boxed());
}

fn mock_proposal_manager_validate_flow() -> MockProposalManagerTraitWrapper {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
mock_proposal_manager_common_expectations(&mut proposal_manager);
Expand All @@ -190,7 +176,7 @@ async fn start_height_success() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });
assert_eq!(batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await, Ok(()));
}

Expand All @@ -214,7 +200,7 @@ async fn start_height_fail(#[case] height: BlockNumber, #[case] expected_error:
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_reset().never();

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });
assert_eq!(batcher.start_height(StartHeightInput { height }).await, Err(expected_error));
}

Expand All @@ -224,7 +210,7 @@ async fn duplicate_start_height() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_reset().times(1).return_once(|| async {}.boxed());

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });

let initial_height = StartHeightInput { height: INITIAL_HEIGHT };
assert_eq!(batcher.start_height(initial_height.clone()).await, Ok(()));
Expand All @@ -235,7 +221,7 @@ async fn duplicate_start_height() {
#[tokio::test]
async fn no_active_height() {
let proposal_manager = MockProposalManagerTraitWrapper::new();
let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });

// Calling `propose_block` and `validate_block` without starting a height should fail.

Expand Down Expand Up @@ -263,7 +249,11 @@ async fn no_active_height() {
async fn validate_block_full_flow() {
let block_builder_factory = mock_create_builder_for_validate_block();
let proposal_manager = mock_proposal_manager_validate_flow();
let mut batcher = create_batcher(proposal_manager, block_builder_factory);
let mut batcher = create_batcher(MockDependencies {
proposal_manager,
block_builder_factory,
..Default::default()
});

batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();

Expand Down Expand Up @@ -304,7 +294,7 @@ async fn send_content_after_proposal_already_finished() {
.times(1)
.returning(|_| async move { InternalProposalStatus::Finished }.boxed());

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });

// Send transactions after the proposal has finished.
let send_proposal_input_txs = SendProposalContentInput {
Expand All @@ -325,7 +315,7 @@ async fn send_content_to_unknown_proposal() {
.with(eq(PROPOSAL_ID))
.return_once(move |_| async move { InternalProposalStatus::NotFound }.boxed());

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });

// Send transactions to an unknown proposal.
let send_proposal_input_txs = SendProposalContentInput {
Expand All @@ -352,7 +342,7 @@ async fn send_txs_to_an_invalid_proposal() {
.with(eq(PROPOSAL_ID))
.return_once(move |_| async move { InternalProposalStatus::Failed }.boxed());

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });

let send_proposal_input_txs = SendProposalContentInput {
proposal_id: PROPOSAL_ID,
Expand Down Expand Up @@ -383,7 +373,11 @@ async fn send_finish_to_an_invalid_proposal() {
.with(eq(PROPOSAL_ID))
.return_once(move |_| { async move { Err(proposal_error) } }.boxed());

let mut batcher = create_batcher(proposal_manager, block_builder_factory);
let mut batcher = create_batcher(MockDependencies {
proposal_manager,
block_builder_factory,
..Default::default()
});
batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();

let validate_block_input = ValidateBlockInput {
Expand Down Expand Up @@ -414,7 +408,11 @@ async fn propose_block_full_flow() {
.times(1)
.return_once(|_, _, _| { async move { Ok(()) } }.boxed());

let mut batcher = create_batcher(proposal_manager, block_builder_factory);
let mut batcher = create_batcher(MockDependencies {
proposal_manager,
block_builder_factory,
..Default::default()
});

batcher.start_height(StartHeightInput { height: INITIAL_HEIGHT }).await.unwrap();
batcher
Expand Down Expand Up @@ -454,6 +452,7 @@ async fn propose_block_full_flow() {
assert_matches!(exhausted, Err(BatcherError::ProposalNotFound { .. }));
}

#[rstest]
#[tokio::test]
async fn propose_block_without_retrospective_block_hash() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
Expand All @@ -464,14 +463,8 @@ async fn propose_block_without_retrospective_block_hash() {
.expect_height()
.returning(|| Ok(BlockNumber(constants::STORED_BLOCK_HASH_BUFFER)));

let mut batcher = Batcher::new(
batcher_config(),
Arc::new(storage_reader),
Box::new(storage_writer()),
Arc::new(mempool_client()),
Box::new(MockBlockBuilderFactoryTrait::new()),
Box::new(proposal_manager),
);
let mut batcher =
create_batcher(MockDependencies { proposal_manager, storage_reader, ..Default::default() });

batcher
.start_height(StartHeightInput { height: BlockNumber(constants::STORED_BLOCK_HASH_BUFFER) })
Expand All @@ -494,7 +487,7 @@ async fn get_content_from_unknown_proposal() {
let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_await_proposal_commitment().times(0);

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });

let get_proposal_content_input = GetProposalContentInput { proposal_id: PROPOSAL_ID };
let result = batcher.get_proposal_content(get_proposal_content_input).await;
Expand All @@ -503,52 +496,43 @@ async fn get_content_from_unknown_proposal() {

#[rstest]
#[tokio::test]
async fn decision_reached(
batcher_config: BatcherConfig,
storage_reader: MockBatcherStorageReaderTrait,
mut storage_writer: MockBatcherStorageWriterTrait,
mut mempool_client: MockMempoolClient,
) {
let expected_state_diff = ThinStateDiff::default();
let state_diff_clone = expected_state_diff.clone();
let expected_proposal_commitment = ProposalCommitment::default();
let tx_hashes = test_tx_hashes(0..5);
let tx_hashes_clone = tx_hashes.clone();
let address_to_nonce = test_contract_nonces(0..3);
let nonces_clone = address_to_nonce.clone();
async fn decision_reached() {
let mut mock_dependencies = MockDependencies::default();

let mut proposal_manager = MockProposalManagerTraitWrapper::new();
proposal_manager.expect_wrap_take_proposal_result().times(1).with(eq(PROPOSAL_ID)).return_once(
move |_| {
mock_dependencies
.proposal_manager
.expect_wrap_take_proposal_result()
.times(1)
.with(eq(PROPOSAL_ID))
.return_once(move |_| {
async move {
Ok(ProposalOutput {
state_diff: state_diff_clone,
commitment: expected_proposal_commitment,
tx_hashes: tx_hashes_clone,
nonces: nonces_clone,
state_diff: ThinStateDiff::default(),
commitment: ProposalCommitment::default(),
tx_hashes: test_tx_hashes(),
nonces: test_contract_nonces(),
})
}
.boxed()
},
);
mempool_client
});

mock_dependencies
.mempool_client
.expect_commit_block()
.with(eq(CommitBlockArgs { address_to_nonce, tx_hashes }))
.with(eq(CommitBlockArgs {
address_to_nonce: test_contract_nonces(),
tx_hashes: test_tx_hashes(),
}))
.returning(|_| Ok(()));

storage_writer
mock_dependencies
.storage_writer
.expect_commit_proposal()
.with(eq(INITIAL_HEIGHT), eq(expected_state_diff))
.with(eq(INITIAL_HEIGHT), eq(ThinStateDiff::default()))
.returning(|_, _| Ok(()));

let mut batcher = Batcher::new(
batcher_config,
Arc::new(storage_reader),
Box::new(storage_writer),
Arc::new(mempool_client),
Box::new(MockBlockBuilderFactoryTrait::new()),
Box::new(proposal_manager),
);
let mut batcher = create_batcher(mock_dependencies);

batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await.unwrap();
}

Expand All @@ -564,7 +548,7 @@ async fn decision_reached_no_executed_proposal() {
},
);

let mut batcher = batcher(proposal_manager);
let mut batcher = create_batcher(MockDependencies { proposal_manager, ..Default::default() });
let decision_reached_result =
batcher.decision_reached(DecisionReachedInput { proposal_id: PROPOSAL_ID }).await;
assert_eq!(decision_reached_result, Err(expected_error));
Expand Down Expand Up @@ -638,10 +622,10 @@ impl<T: ProposalManagerTraitWrapper> ProposalManagerTrait for T {
}
}

fn test_tx_hashes(range: std::ops::Range<u128>) -> HashSet<TransactionHash> {
range.map(|i| tx_hash!(i)).collect()
fn test_tx_hashes() -> HashSet<TransactionHash> {
(0..5u8).map(|i| tx_hash!(i + 12)).collect()
}

fn test_contract_nonces(range: std::ops::Range<u128>) -> HashMap<ContractAddress, Nonce> {
HashMap::from_iter(range.map(|i| (contract_address!(i), nonce!(i))))
fn test_contract_nonces() -> HashMap<ContractAddress, Nonce> {
HashMap::from_iter((0..3u8).map(|i| (contract_address!(i + 33), nonce!(i + 9))))
}

0 comments on commit d1ad687

Please sign in to comment.