diff --git a/Cargo.lock b/Cargo.lock index 29da347..971566e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3263,6 +3263,7 @@ dependencies = [ "anyhow", "futures-util", "ipfs-api", + "tempfile", "thiserror 2.0.9", "tokio", "tracing", @@ -4296,10 +4297,12 @@ dependencies = [ "ethereum", "garaga_rs", "guest-types", + "hasher", "ipfs-utils", "methods", "mmr", "mmr-utils", + "mockall", "risc0-ethereum-contracts", "risc0-zkvm", "serde", diff --git a/crates/guest-types/src/lib.rs b/crates/guest-types/src/lib.rs index 60f944f..867702d 100644 --- a/crates/guest-types/src/lib.rs +++ b/crates/guest-types/src/lib.rs @@ -296,3 +296,90 @@ impl BlocksValidityInput { &self.mmr_input } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_append_result() { + let result = AppendResult::new(10, 15, 5, "test_hash".to_string()); + + assert_eq!(result.leaves_count(), 10); + assert_eq!(result.last_element_idx(), 15); + assert_eq!(result.element_index(), 5); + assert_eq!(result.value(), "test_hash"); + } + + #[test] + fn test_guest_output() { + let output = GuestOutput::new( + 1, + 100, + "block_hash".to_string(), + "root_hash".to_string(), + 50, + ); + + assert_eq!(output.batch_index(), 1); + assert_eq!(output.latest_mmr_block(), 100); + assert_eq!(output.latest_mmr_block_hash(), "block_hash"); + assert_eq!(output.root_hash(), "root_hash"); + assert_eq!(output.leaves_count(), 50); + } + + #[test] + fn test_combined_input() { + let mmr_input = MMRInput::new(vec!["peak1".to_string()], 10, 5, vec!["elem1".to_string()]); + + let input = CombinedInput::new( + 1, + 100, + Vec::new(), + mmr_input.clone(), + Some("batch_link".to_string()), + Some("next_link".to_string()), + false, + ); + + assert_eq!(input.chain_id(), 1); + assert_eq!(input.batch_size(), 100); + assert!(input.headers().is_empty()); + assert_eq!(input.batch_link(), Some("batch_link")); + assert_eq!(input.next_batch_link(), Some("next_link")); + assert!(!input.skip_proof_verification()); + + // Test MMRInput getters + assert_eq!(input.mmr_input().elements_count(), 10); + assert_eq!(input.mmr_input().leaves_count(), 5); + assert_eq!(input.mmr_input().initial_peaks(), vec!["peak1"]); + } + + #[test] + fn test_final_hash() { + let hash = FinalHash::new("test_hash".to_string(), 42); + + assert_eq!(hash.hash(), "test_hash"); + assert_eq!(hash.index(), 42); + } + + #[test] + fn test_blocks_validity_input() { + let mmr_input = MMRInput::new(vec!["peak1".to_string()], 10, 5, vec!["elem1".to_string()]); + + let guest_proof = GuestProof { + element_index: 1, + element_hash: "hash".to_string(), + siblings_hashes: vec!["sibling".to_string()], + peaks_hashes: vec!["peak".to_string()], + elements_count: 10, + }; + + let input = BlocksValidityInput::new(1, Vec::new(), mmr_input, vec![guest_proof]); + + assert_eq!(input.chain_id(), 1); + assert!(input.headers().is_empty()); + assert_eq!(input.proofs().len(), 1); + assert_eq!(input.mmr_input().elements_count(), 10); + } +} diff --git a/crates/ipfs-utils/Cargo.toml b/crates/ipfs-utils/Cargo.toml index 6c17a85..1181eb3 100644 --- a/crates/ipfs-utils/Cargo.toml +++ b/crates/ipfs-utils/Cargo.toml @@ -14,5 +14,5 @@ tokio = { version = "1.28", features = ["full"] } anyhow = "1.0" futures-util = "0.3" -[features] -ipfs-integration-tests = [] # Empty feature flag for integration tests \ No newline at end of file +[dev-dependencies] +tempfile = "3.5" \ No newline at end of file diff --git a/crates/ipfs-utils/src/lib.rs b/crates/ipfs-utils/src/lib.rs index 7dc10ba..dc6b08f 100644 --- a/crates/ipfs-utils/src/lib.rs +++ b/crates/ipfs-utils/src/lib.rs @@ -154,3 +154,154 @@ impl IpfsManager { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use tempfile; + use tokio::sync::Mutex; + // Define test-specific trait + trait TestIpfsApi { + async fn add_file(&self, data: Vec) -> Result; + async fn cat_file(&self, hash: &str) -> Result, ipfs_api::Error>; + async fn get_version(&self) -> Result<(), ipfs_api::Error>; + } + + #[derive(Clone)] + struct MockIpfsClient { + stored_data: Arc>>, + } + + impl MockIpfsClient { + fn new() -> Self { + Self { + stored_data: Arc::new(Mutex::new(Vec::new())), + } + } + } + + impl TestIpfsApi for MockIpfsClient { + async fn add_file(&self, data: Vec) -> Result { + *self.stored_data.lock().await = data; + Ok("QmTestHash".to_string()) + } + + async fn cat_file(&self, _: &str) -> Result, ipfs_api::Error> { + Ok(self.stored_data.lock().await.clone()) + } + + async fn get_version(&self) -> Result<(), ipfs_api::Error> { + Ok(()) + } + } + + #[allow(dead_code)] + struct TestIpfsManager { + client: MockIpfsClient, + max_file_size: usize, + } + + impl TestIpfsManager { + fn new() -> Self { + Self { + client: MockIpfsClient::new(), + max_file_size: 1024 * 1024, // 1MB limit + } + } + + async fn upload_db(&self, file_path: &Path) -> Result { + let data = std::fs::read(file_path)?; + + // Check file size + if data.len() > self.max_file_size { + return Err(anyhow::anyhow!("File size exceeds maximum allowed size")); + } + + Ok(self.client.add_file(data).await?) + } + + async fn fetch_db(&self, hash: &str, output_path: &Path) -> Result<()> { + // Basic hash validation like the real implementation + if !hash.starts_with("Qm") { + return Err(IpfsError::InvalidHash(hash.to_string()).into()); + } + + let data = self.client.cat_file(hash).await?; + std::fs::write(output_path, data)?; + Ok(()) + } + } + + #[tokio::test] + async fn test_upload_and_fetch() { + let temp_dir = tempfile::tempdir().unwrap(); + let source_path = temp_dir.path().join("source.db"); + let dest_path = temp_dir.path().join("dest.db"); + + let test_data = b"test database content"; + std::fs::write(&source_path, test_data).unwrap(); + + let manager = TestIpfsManager::new(); + + // Test upload + let hash = manager.upload_db(&source_path).await.unwrap(); + assert_eq!(hash, "QmTestHash"); + + // Test fetch + manager.fetch_db(&hash, &dest_path).await.unwrap(); + + // Verify content + let fetched_data = std::fs::read(&dest_path).unwrap(); + assert_eq!(fetched_data, test_data); + } + + #[tokio::test] + async fn test_connection_check() { + let manager = TestIpfsManager::new(); + let result = manager.client.get_version().await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_file_size_limit() { + let temp_dir = tempfile::tempdir().unwrap(); + let large_file = temp_dir.path().join("large.db"); + + // Create file larger than max size + let large_data = vec![0u8; 2 * 1024 * 1024]; // 2MB + std::fs::write(&large_file, large_data).unwrap(); + + let manager = TestIpfsManager::new(); + let result = manager.upload_db(&large_file).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_invalid_file_path() { + let manager = TestIpfsManager::new(); + let result = manager.upload_db(Path::new("/nonexistent/path")).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_invalid_hash() { + let temp_dir = tempfile::tempdir().unwrap(); + let output_path = temp_dir.path().join("output.db"); + + let manager = TestIpfsManager::new(); + let result = manager.fetch_db("invalid-hash", &output_path).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_empty_file() { + let temp_dir = tempfile::tempdir().unwrap(); + let empty_file = temp_dir.path().join("empty.db"); + std::fs::write(&empty_file, b"").unwrap(); + + let manager = TestIpfsManager::new(); + let result = manager.upload_db(&empty_file).await; + assert!(result.is_ok()); + } +} diff --git a/crates/publisher/Cargo.toml b/crates/publisher/Cargo.toml index e995587..786cb16 100644 --- a/crates/publisher/Cargo.toml +++ b/crates/publisher/Cargo.toml @@ -45,4 +45,10 @@ starknet-types-core = "0.1.7" serde = "1.0" risc0-ethereum-contracts = { git = "https://github.com/risc0/risc0-ethereum", tag = "v1.2.0" } - +[dev-dependencies] +mockall = "0.13" +tokio = { version = "1.0", features = ["full"] } +starknet-handler = { path = "../starknet-handler" } +hasher = { git = "https://github.com/ametel01/rust-accumulators.git", branch = "feat/sha2-hasher", features = ["sha256"] } +store = { git = "https://github.com/ametel01/rust-accumulators.git", branch = "feat/sha2-hasher" } +# jsonrpc-core-client = { version = "18.0.0", features = ["http"] } \ No newline at end of file diff --git a/crates/publisher/src/core/accumulator.rs b/crates/publisher/src/core/accumulator.rs index 7d7a9c2..9212462 100644 --- a/crates/publisher/src/core/accumulator.rs +++ b/crates/publisher/src/core/accumulator.rs @@ -431,3 +431,182 @@ impl<'a> AccumulatorBuilder<'a> { .await } } + +#[cfg(test)] +mod tests { + use super::*; + use mockall::mock; + use mockall::predicate::*; + use starknet::core::types::U256; + use starknet::providers::jsonrpc::HttpTransport; + use starknet::providers::JsonRpcClient; + use starknet::providers::Url; + use starknet_handler::account::StarknetAccount; + use starknet_handler::MmrState; + use std::sync::Arc; + + mock! { + #[derive(Clone)] + pub StarknetAccount { + fn verify_mmr_proof(&self, verifier_address: &str, calldata: Vec, ipfs_hash: String) -> Result<(), AccumulatorError>; + } + } + + // Add conversion impl + impl From for StarknetAccount { + fn from(_mock: MockStarknetAccount) -> Self { + // Create a new StarknetAccount for testing + let transport = HttpTransport::new(Url::parse("http://localhost:8545").unwrap()); + let provider = Arc::new(JsonRpcClient::new(transport)); + + StarknetAccount::new(provider, "0x123", "0x456").unwrap() + } + } + + #[tokio::test] + async fn test_accumulator_builder_new() { + let account = MockStarknetAccount::new(); + // Create longer-lived String values + let rpc_url = "http://localhost:8545".to_string(); + let verifier_addr = "0x123".to_string(); + let store_addr = "0x456".to_string(); + + let result = AccumulatorBuilder::new( + &rpc_url, + 1, + &verifier_addr, + &store_addr, + account.into(), + 100, + false, + ) + .await; + + assert!(result.is_ok()); + let builder = result.unwrap(); + assert_eq!(builder.chain_id, 1); + assert_eq!(builder.current_batch, 0); + assert_eq!(builder.total_batches, 0); + } + + #[tokio::test] + async fn test_accumulator_builder_new_invalid_inputs() { + let account = MockStarknetAccount::new(); + let rpc_url = "http://localhost:8545".to_string(); + let store_addr = "0x456".to_string(); + + // Test empty verifier address + let binding = "".to_string(); + let result = AccumulatorBuilder::new( + &rpc_url, + 1, + &binding, + &store_addr, + MockStarknetAccount::new().into(), // Create new instance instead of cloning + 100, + false, + ) + .await; + assert!(matches!(result, Err(AccumulatorError::InvalidInput(_)))); + + // Test zero batch size + let verifier_addr = "0x123".to_string(); + let result = AccumulatorBuilder::new( + &rpc_url, + 1, + &verifier_addr, + &store_addr, + account.into(), + 0, + false, + ) + .await; + assert!(matches!(result, Err(AccumulatorError::InvalidInput(_)))); + } + + #[tokio::test] + async fn test_build_with_num_batches_invalid_input() { + let account = MockStarknetAccount::new(); + let rpc_url = "http://localhost:8545".to_string(); + let verifier_addr = "0x123".to_string(); + let store_addr = "0x456".to_string(); + + let mut builder = AccumulatorBuilder::new( + &rpc_url, + 1, + &verifier_addr, + &store_addr, + account.into(), + 100, + false, + ) + .await + .unwrap(); + + let result = builder.build_with_num_batches(0).await; + assert!(matches!(result, Err(AccumulatorError::InvalidInput(_)))); + } + + #[tokio::test] + async fn test_update_mmr_with_new_headers_invalid_input() { + let account = MockStarknetAccount::new(); + let rpc_url = "http://localhost:8545".to_string(); + let verifier_addr = "0x123".to_string(); + let store_addr = "0x456".to_string(); + + let mut builder = AccumulatorBuilder::new( + &rpc_url, + 1, + &verifier_addr, + &store_addr, + account.into(), + 100, + false, + ) + .await + .unwrap(); + + let result = builder.update_mmr_with_new_headers(100, 50).await; + assert!(matches!(result, Err(AccumulatorError::InvalidInput(_)))); + } + + #[tokio::test] + async fn test_handle_batch_result_skip_verification() { + let account = MockStarknetAccount::new(); + let rpc_url = "http://localhost:8545".to_string(); + let verifier_addr = "0x123".to_string(); + let store_addr = "0x456".to_string(); + + let builder = AccumulatorBuilder::new( + &rpc_url, + 1, + &verifier_addr, + &store_addr, + account.into(), + 100, + true, + ) + .await + .unwrap(); + + // Create BatchResult with all required parameters + let mmr_state = MmrState::new( + 100, // size + U256::from(0_u64), // root_hash + U256::from(0_u64), // prev_root + 0, // last_pos + None, // last_leaf + ); + + let batch_result = BatchResult::new( + 100, // start_block + 200, // end_block + mmr_state, // mmr_state + None, // proof + "test_hash".to_string(), // ipfs_hash + ); + + let result = builder.handle_batch_result(&batch_result).await; + assert!(result.is_ok()); + } +} diff --git a/crates/publisher/src/core/batch_processor.rs b/crates/publisher/src/core/batch_processor.rs index 1118df8..b5c330a 100644 --- a/crates/publisher/src/core/batch_processor.rs +++ b/crates/publisher/src/core/batch_processor.rs @@ -341,3 +341,218 @@ impl BatchRange { Ok(Self { start, end }) } } + +#[cfg(test)] +mod tests { + use super::*; + use mockall::automock; + use serde::Serialize; + use starknet::{ + core::types::U256, + providers::{jsonrpc::HttpTransport, JsonRpcClient, Url}, + }; + use starknet_handler::account::StarknetAccount; + use starknet_handler::MmrState; + use std::sync::Arc; + + // Create traits that match the structs we want to mock + #[automock] + #[allow(dead_code)] + pub trait ProofGeneratorTrait { + fn generate_groth16_proof( + &self, + input: CombinedInput, + ) -> Result; + fn decode_journal(&self, proof: &mmr::Proof) -> Result; + } + + // Mock implementation that doesn't need real ELF data + #[allow(dead_code)] + struct MockProofGen; + impl ProofGenerator { + fn mock() -> Self { + // Use a static array instead of vec for 'static lifetime + let method_elf: &'static [u8] = &[1, 2, 3, 4]; // Non-empty ELF data + let method_id = [1u32; 8]; // Non-zero method ID + + ProofGenerator::new(method_elf, method_id) + .expect("Failed to create mock ProofGenerator") + } + } + + // Create a trait without lifetime parameter for automock + #[automock] + #[allow(dead_code)] + pub trait MMRStateManagerTrait { + fn update_state<'a>( + &self, + store_manager: mmr_utils::StoreManager, + mmr: &mut mmr::MMR, + pool: &sqlx::Pool, + end_block: u64, + guest_output: Option<&'a GuestOutput>, + new_headers: &Vec, + ) -> Result; + } + + // Mock implementation that doesn't need real Starknet connection + impl<'a> MMRStateManager<'a> { + fn mock() -> Self { + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new( + Url::parse("http://localhost:5050").expect("Invalid URL"), + ))); + let account = StarknetAccount::new( + provider, "0x0", "0x0", // private key as &str + ) + .expect("Failed to create StarknetAccount"); + + MMRStateManager::new( + account, "0x0", // store_address as &str + ) + } + } + + // Helper function to create test instances + fn create_test_processor() -> BatchProcessor<'static> { + let proof_gen = ProofGenerator::mock(); + let mmr_state_mgr = MMRStateManager::mock(); + + BatchProcessor::new(100, proof_gen, false, mmr_state_mgr).unwrap() + } + + #[tokio::test] + async fn test_calculate_batch_bounds() { + let processor = create_test_processor(); + + // Test normal case + let (start, end) = processor.calculate_batch_bounds(1).unwrap(); + assert_eq!(start, 100); + assert_eq!(end, 199); + + // Test batch 0 + let (start, end) = processor.calculate_batch_bounds(0).unwrap(); + assert_eq!(start, 0); + assert_eq!(end, 99); + } + + #[tokio::test] + async fn test_calculate_batch_range() { + let processor = create_test_processor(); + + // Test normal case + let range = processor.calculate_batch_range(150, 100).unwrap(); + assert_eq!(range.start, 100); + assert_eq!(range.end, 150); + + // Test when current_end is at batch boundary + let range = processor.calculate_batch_range(200, 150).unwrap(); + assert_eq!(range.start, 200); + assert_eq!(range.end, 200); + + // Test error case: current_end < start_block + let result = processor.calculate_batch_range(100, 150); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_calculate_start_block() { + let processor = create_test_processor(); + + // Test normal case + let start = processor.calculate_start_block(150).unwrap(); + assert_eq!(start, 100); + + // Test at batch boundary + let start = processor.calculate_start_block(200).unwrap(); + assert_eq!(start, 200); + + // Test error case: current_end = 0 + let result = processor.calculate_start_block(0); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_batch_processor_new() { + let proof_gen = ProofGenerator::mock(); + let mmr_state_mgr = MMRStateManager::mock(); + + // Test valid creation + let result = BatchProcessor::new(100, proof_gen, false, mmr_state_mgr); + assert!(result.is_ok()); + + // Test invalid batch size + let result = BatchProcessor::new(0, ProofGenerator::mock(), false, MMRStateManager::mock()); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_batch_range_new() { + // Test valid range + let result = BatchRange::new(100, 200); + assert!(result.is_ok()); + let range = result.unwrap(); + assert_eq!(range.start, 100); + assert_eq!(range.end, 200); + + // Test invalid range + let result = BatchRange::new(200, 100); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_process_batch_invalid_inputs() { + let processor = create_test_processor(); + + // Test end_block < start_block + let result = processor.process_batch(1, 150, 100).await; + assert!(result.is_err()); + + // Test start_block before batch start + let result = processor.process_batch(1, 50, 199).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_getters() { + let processor = create_test_processor(); + + assert_eq!(processor.batch_size(), 100); + assert!(!processor.skip_proof_verification()); + } + + // Add test that uses the mock traits to satisfy dead code warnings + #[test] + fn test_mock_traits() { + let mut mock_proof_gen = MockProofGeneratorTrait::new(); + let mut mock_mmr_mgr = MockMMRStateManagerTrait::new(); + + // Set up expectations + mock_proof_gen + .expect_generate_groth16_proof() + .returning(|_| { + Ok(mmr::Proof { + element_index: 0, + element_hash: "".to_string(), + siblings_hashes: vec!["".to_string()], + peaks_hashes: vec!["".to_string()], + elements_count: 0, + }) + }); + + mock_mmr_mgr + .expect_update_state() + .returning(|_, _, _, _, _, _| { + Ok(MmrState::new( + 0, + U256::from(0_u64), + U256::from(0_u64), + 0, + None, + )) + }); + + // Verify mocks exist + mock_proof_gen.checkpoint(); + mock_mmr_mgr.checkpoint(); + } +} diff --git a/crates/publisher/src/core/mmr_state_manager.rs b/crates/publisher/src/core/mmr_state_manager.rs index cd2a0e0..064bf95 100644 --- a/crates/publisher/src/core/mmr_state_manager.rs +++ b/crates/publisher/src/core/mmr_state_manager.rs @@ -6,6 +6,8 @@ use mmr_utils::StoreManager; use starknet_handler::{account::StarknetAccount, u256_from_hex, MmrState}; use store::SqlitePool; use tracing::{debug, error, info}; +// use jsonrpc_client::{JsonRpcClient, HttpTransport, Url}; +// use std::sync::Arc; pub struct MMRStateManager<'a> { account: StarknetAccount, @@ -43,7 +45,8 @@ impl<'a> MMRStateManager<'a> { )); } - info!("Updating MMR state..."); + info!("Updating MMR state with {} headers...", headers.len()); + debug!("Headers: {:?}", headers); Self::append_headers(store_manager, mmr, pool, headers) .await @@ -70,45 +73,62 @@ impl<'a> MMRStateManager<'a> { info!("MMR state updated successfully"); Ok(new_mmr_state) } else { - // When no guest output, create state directly from MMR + debug!("No guest output provided, creating state from MMR directly"); let bag = mmr.bag_the_peaks(None).await.map_err(|e| { error!(error = %e, "Failed to bag the peaks"); e })?; + let elements_count = mmr.elements_count.get().await.map_err(|e| { error!(error = %e, "Failed to get elements count"); e })?; + debug!("Elements count: {}", elements_count); + let root_hash = mmr.calculate_root_hash(&bag, elements_count).map_err(|e| { error!(error = %e, "Failed to calculate root hash"); e })?; - let leaves_count = mmr.leaves_count.get().await.map_err(|e| { - error!(error = %e, "Failed to get leaves count"); + debug!("Raw root hash: {}", root_hash); + + let root_hash_hex = if !root_hash.starts_with("0x") { + format!("0x{}", root_hash) + } else { + root_hash + }; + debug!("Formatted root hash: {}", root_hash_hex); + + let root_hash_u256 = u256_from_hex(&root_hash_hex).map_err(|e| { + error!(error = %e, "Failed to convert root hash to U256"); e })?; - let latest_mmr_block_hash = u256_from_hex( - headers - .last() - .ok_or(AccumulatorError::InvalidInput("Headers list is empty"))?, - ) - .map_err(|e| { - error!(error = %e, "Failed to convert root hash from hex"); + + let latest_header = headers.last().unwrap(); + debug!("Latest header: {}", latest_header); + + let latest_mmr_block_hash = u256_from_hex(latest_header).map_err(|e| { + error!(error = %e, "Failed to convert latest header to U256"); + e + })?; + + let leaves_count = mmr.leaves_count.get().await.map_err(|e| { + error!(error = %e, "Failed to get leaves count"); e })?; + debug!("Leaves count: {}", leaves_count); let new_mmr_state = MmrState::new( latest_block_number, latest_mmr_block_hash, - u256_from_hex(&root_hash.trim_start_matches("0x")).map_err(|e| { - error!(error = %e, "Failed to convert root hash from hex"); - e - })?, + root_hash_u256, leaves_count as u64, None, ); - info!("No verification option selected, MMR state not updated onchain"); + info!( + "Created MMR state: latest_block={}, leaves={}", + latest_block_number, leaves_count + ); Ok(new_mmr_state) } } @@ -216,3 +236,127 @@ impl<'a> MMRStateManager<'a> { Ok(new_state) } } + +#[cfg(test)] +mod tests { + use super::*; + use mmr_utils::StoreManager; + use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient, Url}; + use starknet_handler::account::StarknetAccount; + use std::sync::Arc; + use store::memory::InMemoryStore; + + // Helper function to create test dependencies + async fn setup_test() -> (MMRStateManager<'static>, StoreManager, MMR, SqlitePool) { + let account = StarknetAccount::new( + Arc::new(JsonRpcClient::new(HttpTransport::new( + Url::parse("http://localhost:5050").expect("Invalid URL"), + ))), + "0x1234567890abcdef", // Valid hex address + "0x1234567890abcdef", // Valid hex private key + ) + .expect("Failed to create StarknetAccount"); + + let store_address = "0x1234567890abcdef"; // Valid hex store address + let mmr_state_manager = MMRStateManager::new(account, store_address); + + let memory_store = Arc::new(InMemoryStore::new(None)); + let pool = SqlitePool::connect("sqlite::memory:") + .await + .expect("Failed to create in-memory SQLite database"); + + // Create the required table + sqlx::query( + "CREATE TABLE IF NOT EXISTS value_index_map ( + value TEXT PRIMARY KEY, + element_index INTEGER NOT NULL + )", + ) + .execute(&pool) + .await + .expect("Failed to create value_index_map table"); + + let store_manager = StoreManager::new("sqlite::memory:") + .await + .expect("Failed to create StoreManager"); + + let mmr = MMR::new( + memory_store.clone(), + Arc::new(hasher::hashers::sha2::Sha2Hasher::new()), + None, + ); + + debug!("Test dependencies created successfully"); + (mmr_state_manager, store_manager, mmr, pool) + } + + #[tokio::test] + async fn test_update_state_without_guest_output() { + let (manager, store_manager, mut mmr, pool) = setup_test().await; + + let headers = vec![ + "0x0000000000000000000000000000000000000000000000001234567890abcdef".to_string(), + "0x0000000000000000000000000000000000000000000000000deadbeefcafe000".to_string(), + ]; + + // MMR is already initialized by MMR::new() + let result = manager + .update_state(store_manager, &mut mmr, &pool, 100, None, &headers) + .await; + + match &result { + Ok(_) => debug!("Update state succeeded"), + Err(e) => error!("Update state failed: {:?}", e), + } + + assert!( + result.is_ok(), + "Update state failed: {:?}", + result.err().unwrap() + ); + let state = result.unwrap(); + assert_eq!(state.latest_mmr_block(), 100); + assert_eq!(state.leaves_count(), 2); + } + + #[tokio::test] + async fn test_update_state_with_empty_headers() { + let (manager, store_manager, mut mmr, pool) = setup_test().await; + + let result = manager + .update_state(store_manager, &mut mmr, &pool, 100, None, &vec![]) + .await; + + assert!(matches!(result, Err(AccumulatorError::InvalidInput(_)))); + } + + #[tokio::test] + async fn test_append_headers_with_empty_hash() { + let (_, store_manager, mut mmr, pool) = setup_test().await; + + let headers = vec!["".to_string()]; + + let result = + MMRStateManager::append_headers(store_manager, &mut mmr, &pool, &headers).await; + + assert!(matches!(result, Err(AccumulatorError::InvalidInput(_)))); + } + + #[tokio::test] + async fn test_create_new_state() { + let guest_output = GuestOutput::new( + 1, // batch_index + 100, // latest_mmr_block + "0x0000000000000000000000000000000000000000000000001234567890abcdef".to_string(), // 64 chars hex + "0x0000000000000000000000000000000000000000000000001234567890abcdef".to_string(), // 64 chars hex + 10, // leaves_count + ); + + let result = MMRStateManager::create_new_state(100, &guest_output).await; + + assert!(result.is_ok()); + let state = result.unwrap(); + assert_eq!(state.latest_mmr_block(), 100); + assert_eq!(state.leaves_count(), 10); + } +} diff --git a/crates/publisher/src/core/proof_generator.rs b/crates/publisher/src/core/proof_generator.rs index 71dbd0a..b1cd5d4 100644 --- a/crates/publisher/src/core/proof_generator.rs +++ b/crates/publisher/src/core/proof_generator.rs @@ -15,6 +15,7 @@ use crate::{ utils::{Groth16, Stark}, }; +#[derive(Debug)] pub struct ProofGenerator { method_elf: &'static [u8], method_id: [u32; 8], @@ -200,3 +201,97 @@ where Ok(receipt.journal.decode()?) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + + // Mock data structure for testing + #[derive(Debug, Clone, Serialize, Deserialize)] + struct TestInput { + value: u32, + } + + const TEST_METHOD_ELF: &[u8] = &[1, 2, 3, 4]; // Mock ELF data + const TEST_METHOD_ID: [u32; 8] = [1, 0, 0, 0, 0, 0, 0, 0]; + + #[test] + fn test_new_proof_generator() { + // Test successful creation + let result = ProofGenerator::::new(TEST_METHOD_ELF, TEST_METHOD_ID); + assert!(result.is_ok()); + + // Test empty ELF + let result = ProofGenerator::::new(&[], TEST_METHOD_ID); + assert!(matches!( + result.unwrap_err(), + ProofGeneratorError::InvalidInput(_) + )); + + // Test zero method ID + let result = ProofGenerator::::new(TEST_METHOD_ELF, [0; 8]); + assert!(matches!( + result.unwrap_err(), + ProofGeneratorError::InvalidInput(_) + )); + } + + #[tokio::test] + async fn test_generate_stark_proof_invalid_input() { + let proof_generator = + ProofGenerator::>::new(TEST_METHOD_ELF, TEST_METHOD_ID).unwrap(); + let result = proof_generator.generate_stark_proof(vec![]).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_generate_groth16_proof_invalid_input() { + let proof_generator = + ProofGenerator::>::new(TEST_METHOD_ELF, TEST_METHOD_ID).unwrap(); + let result = proof_generator.generate_groth16_proof(vec![]).await; + assert!(result.is_err()); + } + + // Note: Testing the actual proof generation would require mock implementations + // of the RISC Zero prover and related components. Here's a sketch of how that + // might look with proper mocking: + + /* + #[tokio::test] + async fn test_generate_stark_proof_success() { + // Would need to mock: + // - ExecutorEnv + // - default_prover + // - compute_image_id + + let generator = ProofGenerator::::new(TEST_METHOD_ELF, TEST_METHOD_ID).unwrap(); + let input = TestInput { value: 42 }; + let result = generator.generate_stark_proof(input).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_generate_groth16_proof_success() { + // Would need to mock: + // - ExecutorEnv + // - default_prover + // - compute_image_id + // - encode_seal + // - Groth16Proof conversion + // - get_groth16_calldata_felt + + let generator = ProofGenerator::::new(TEST_METHOD_ELF, TEST_METHOD_ID).unwrap(); + let input = TestInput { value: 42 }; + let result = generator.generate_groth16_proof(input).await; + assert!(result.is_ok()); + } + */ + + #[test] + fn test_decode_journal() { + // Would need mock Groth16 proof with valid journal data + // This test would verify that journal decoding works correctly + // and handles errors appropriately + } +} diff --git a/crates/publisher/src/utils/utils.rs b/crates/publisher/src/utils/utils.rs index 6546316..7b77651 100644 --- a/crates/publisher/src/utils/utils.rs +++ b/crates/publisher/src/utils/utils.rs @@ -1,23 +1,59 @@ use crate::errors::AccumulatorError; /// Validates that a hex string represents a valid U256 (256-bit unsigned integer) -pub fn validate_u256_hex(hex_str: &str) -> Result<(), AccumulatorError> { - // Check if it's a valid hex string with '0x' prefix - if !hex_str.starts_with("0x") { - return Err(AccumulatorError::InvalidU256Hex(hex_str.to_string()).into()); +pub fn validate_u256_hex(hex: &str) -> Result<(), AccumulatorError> { + if !hex.starts_with("0x") || hex.len() <= 2 { + // Check for "0x" prefix and ensure there's data after it + return Err(AccumulatorError::InvalidU256Hex(hex.to_string())); } // Remove '0x' prefix and check if remaining string is valid hex - let hex_value = &hex_str[2..]; + let hex_value = &hex[2..]; if !hex_value.chars().all(|c| c.is_ascii_hexdigit()) { - return Err(AccumulatorError::InvalidU256Hex(hex_str.to_string()).into()); + return Err(AccumulatorError::InvalidU256Hex(hex.to_string())); } // Check length - maximum 64 hex chars (256 bits = 64 hex digits) - // Note: we allow shorter values as they're valid smaller numbers if hex_value.len() > 64 { - return Err(AccumulatorError::InvalidU256Hex(hex_str.to_string()).into()); + return Err(AccumulatorError::InvalidU256Hex(hex.to_string())); } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_valid_u256_hex() { + // Basic valid cases + assert!(validate_u256_hex("0x123").is_ok()); + assert!(validate_u256_hex("0xabc").is_ok()); + assert!(validate_u256_hex("0xABC123").is_ok()); + + // Edge cases - valid + assert!(validate_u256_hex("0x0").is_ok()); + assert!(validate_u256_hex(&("0x".to_owned() + &"f".repeat(64))).is_ok()); // Max length + assert!(validate_u256_hex("0xdeadbeef").is_ok()); + } + + #[test] + fn test_invalid_u256_hex() { + assert!(validate_u256_hex("0x").is_err()); + assert!(validate_u256_hex("0").is_err()); + assert!(validate_u256_hex("").is_err()); + assert!(validate_u256_hex("invalid").is_err()); + } + + #[test] + fn test_error_message() { + let result = validate_u256_hex("invalid"); + match result { + Err(AccumulatorError::InvalidU256Hex(msg)) => { + assert_eq!(msg, "invalid"); + } + _ => panic!("Expected InvalidU256Hex error"), + } + } +}