diff --git a/Cargo.lock b/Cargo.lock index 1e640eae..cb826b05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -686,6 +686,19 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9" +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -826,6 +839,21 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -842,6 +870,23 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + [[package]] name = "futures-sink" version = "0.3.31" @@ -860,9 +905,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -953,6 +1001,8 @@ dependencies = [ "polynomials", "rand", "rand_chacha", + "rayon", + "serial_test", "sha2", "sumcheck", "thiserror", @@ -1057,6 +1107,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.1" @@ -1597,8 +1653,8 @@ name = "mpi_config" version = "0.1.0" dependencies = [ "arith", - "mersenne31", - "mpi", + "rayon", + "serial_test", ] [[package]] @@ -1813,6 +1869,8 @@ dependencies = [ "mpi_config", "polynomials", "rand", + "rayon", + "serial_test", "transcript", ] @@ -2077,6 +2135,31 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e56dd856803e253c8f298af3f4d7eb0ae5e23a737252cd90bb4f3b435033b2d" +dependencies = [ + "dashmap", + "futures", + "lazy_static", + "log", + "parking_lot", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d129178576168c589c9ec973feedf7d3126c01ac2bf08795109aa35b69fb8f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "sha1" version = "0.10.6" diff --git a/Cargo.toml b/Cargo.toml index 9b11d7ff..cf9465f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,8 +11,9 @@ members = [ "circuit", "config", # gkr_field_config + pcs_config + transcript_config "config/gkr_field_config", # definitions of all field types used in gkr and pcs - "config/mpi_config", # definitions of mpi communication toolkit + # "config/mpi_config", # definitions of mpi communication toolkit "config/config_macros", # proc macros used to declare a new config, this has to a separate crate due to rust compilation issues + "config/mpi_config", # definitions of mpi communication toolkit via rayon backend "gkr", "poly_commit", "sumcheck", @@ -40,6 +41,7 @@ mpi = "0.8.0" rand = "0.8.5" rayon = "1.10" sha2 = "0.10.8" +serial_test = "2.0" tiny-keccak = { version = "2.0.2", features = [ "sha3", "keccak" ] } tokio = { version = "1.38.0", features = ["full"] } tynm = { version = "0.1.6", default-features = false } diff --git a/arith/polynomials/src/ref_mle.rs b/arith/polynomials/src/ref_mle.rs index fab69eab..4bb13eea 100644 --- a/arith/polynomials/src/ref_mle.rs +++ b/arith/polynomials/src/ref_mle.rs @@ -9,6 +9,7 @@ pub trait MultilinearExtension: Index { fn num_vars(&self) -> usize; + #[inline] fn hypercube_size(&self) -> usize { 1 << self.num_vars() } @@ -18,6 +19,11 @@ pub trait MultilinearExtension: Index { fn hypercube_basis_ref(&self) -> &Vec; fn interpolate_over_hypercube(&self) -> Vec; + + #[inline] + fn serialized_size(&self) -> usize { + self.hypercube_size() * F::SIZE + } } #[derive(Debug, Clone)] diff --git a/config/mpi_config/Cargo.toml b/config/mpi_config/Cargo.toml index 41deab80..f7ef9742 100644 --- a/config/mpi_config/Cargo.toml +++ b/config/mpi_config/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" edition = "2021" [dependencies] +rayon.workspace = true + arith = { path = "../../arith" } -mpi.workspace = true [dev-dependencies] -mersenne31 = { path = "../../arith/mersenne31"} \ No newline at end of file +serial_test.workspace = true \ No newline at end of file diff --git a/config/mpi_config/src/atomic_vec.rs b/config/mpi_config/src/atomic_vec.rs new file mode 100644 index 00000000..a726f65f --- /dev/null +++ b/config/mpi_config/src/atomic_vec.rs @@ -0,0 +1,74 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// A lock-free append-only vector implementation +/// credit: Claude +#[derive(Debug)] +pub struct AtomicVec { + // The actual data storage + data: Vec, + // Current length of valid data + len: AtomicUsize, +} + +impl AtomicVec { + #[inline] + pub fn new(capacity: usize) -> Self { + let mut data = Vec::with_capacity(capacity); + // Pre-fill with default values to avoid reallocation + data.resize_with(capacity, || unsafe { std::mem::zeroed() }); + Self { + data, + len: AtomicUsize::new(0), + } + } + + /// Append data to the vector + /// Returns the start index where data was appended + pub fn append(&self, items: &[T]) -> Option { + let old_len = self.len.fetch_add(items.len(), Ordering::AcqRel); + if old_len + items.len() > self.data.capacity() { + // Restore the length if we would exceed capacity + self.len.fetch_sub(items.len(), Ordering::Release); + return None; + } + + // Safe because: + // 1. We've pre-allocated the space + // 2. Each thread writes to its own section + // 3. The atomic len ensures no overlapping writes + unsafe { + let ptr = self.data.as_ptr().add(old_len) as *mut T; + for (i, item) in items.iter().enumerate() { + std::ptr::write(ptr.add(i), item.clone()); + } + } + + Some(old_len) + } + + /// Read a slice of data + #[inline] + pub fn get_slice(&self, start: usize, end: usize) -> Option<&[T]> { + let current_len = self.len.load(Ordering::Acquire); + if start >= current_len || end > current_len || start > end { + return None; + } + + // Safe because: + // 1. We've checked the bounds + // 2. No data is ever modified after being written + Some(unsafe { std::slice::from_raw_parts(self.data.as_ptr().add(start), end - start) }) + } + + /// Get current length + #[inline] + pub fn len(&self) -> usize { + self.len.load(Ordering::Acquire) + } + + /// Check if empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} diff --git a/config/mpi_config/src/lib.rs b/config/mpi_config/src/lib.rs index 2fae9994..72d16c4f 100644 --- a/config/mpi_config/src/lib.rs +++ b/config/mpi_config/src/lib.rs @@ -1,12 +1,37 @@ -use std::{cmp, fmt::Debug}; +//! This module implements a synchronized MPI configuration for Rayon. +//! +//! Assumptions +//! 1. There will NOT be a root process that collects data from all other processes and broadcast +//! it. +//! 2. Each thread writes to its own local memory. +//! 3. Each thread reads from all other threads' local memory. +//! 4. All threads have access to a same global, read-only memory. This global memory is initialized +//! before the threads start and will remain invariant during the threads' execution. +//! 5. IMPORTANT!!! The threads are synchronized by the caller; within each period of time, all +//! threads write a same amount of data -use arith::Field; -use mpi::{ - environment::Universe, - ffi, - topology::{Process, SimpleCommunicator}, - traits::*, -}; +mod atomic_vec; +pub use atomic_vec::AtomicVec; + +mod mpi_config; +pub use mpi_config::MPIConfig; + +mod thread_config; +pub use thread_config::ThreadConfig; + +/// Max number of std::hint::spin_loop() we will do before panicking +#[cfg(target_arch = "aarch64")] +const MAX_WAIT_CYCLES: usize = 1000000 * 140; // Multiply by 140 since ARM yield is ~1-2 cycles vs x86 PAUSE ~140 cycles + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +const MAX_WAIT_CYCLES: usize = 100000000; + +// Fallback for other architectures +#[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))] +const MAX_WAIT_CYCLES: usize = 1000000; + +#[cfg(test)] +mod tests; #[macro_export] macro_rules! root_println { @@ -17,275 +42,10 @@ macro_rules! root_println { }; } -static mut UNIVERSE: Option = None; -static mut WORLD: Option = None; - -#[derive(Clone)] -pub struct MPIConfig { - pub universe: Option<&'static mpi::environment::Universe>, - pub world: Option<&'static SimpleCommunicator>, - pub world_size: i32, - pub world_rank: i32, -} - -impl Default for MPIConfig { - fn default() -> Self { - Self { - universe: None, - world: None, - world_size: 1, - world_rank: 0, - } - } -} - -impl Debug for MPIConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let universe_fmt = if self.universe.is_none() { - Option::::None - } else { - Some(self.universe.unwrap().buffer_size()) - }; - - let world_fmt = if self.world.is_none() { - Option::::None - } else { - Some(0usize) - }; - - f.debug_struct("MPIConfig") - .field("universe", &universe_fmt) - .field("world", &world_fmt) - .field("world_size", &self.world_size) - .field("world_rank", &self.world_rank) - .finish() - } -} - -// Note: may not be correct -impl PartialEq for MPIConfig { - fn eq(&self, other: &Self) -> bool { - self.world_rank == other.world_rank && self.world_size == other.world_size - } -} - -/// MPI toolkit: -impl MPIConfig { - const ROOT_RANK: i32 = 0; - - /// The communication limit for MPI is 2^30. Save 10 bits for #parties here. - const CHUNK_SIZE: usize = 1usize << 20; - - // OK if already initialized, mpi::initialize() will return None - #[allow(static_mut_refs)] - pub fn init() { - unsafe { - let universe = mpi::initialize(); - if universe.is_some() { - UNIVERSE = universe; - WORLD = Some(UNIVERSE.as_ref().unwrap().world()); - } - } - } - - #[inline] - pub fn finalize() { - unsafe { ffi::MPI_Finalize() }; - } - - #[allow(static_mut_refs)] - pub fn new() -> Self { - Self::init(); - let universe = unsafe { UNIVERSE.as_ref() }; - let world = unsafe { WORLD.as_ref() }; - let world_size = if let Some(world) = world { - world.size() - } else { - 1 - }; - let world_rank = if let Some(world) = world { - world.rank() - } else { - 0 - }; - Self { - universe, - world, - world_size, - world_rank, - } - } - - #[inline] - pub fn new_for_verifier(world_size: i32) -> Self { - Self { - universe: None, - world: None, - world_size, - world_rank: 0, - } - } - - /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. - #[inline] - unsafe fn elem_to_u8_bytes(elem: &V, byte_size: usize) -> Vec { - Vec::::from_raw_parts((elem as *const V) as *mut u8, byte_size, byte_size) - } - - /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. - #[inline] - unsafe fn vec_to_u8_bytes(vec: &Vec) -> Vec { - Vec::::from_raw_parts( - vec.as_ptr() as *mut u8, - vec.len() * F::SIZE, - vec.capacity() * F::SIZE, - ) - } - - #[allow(clippy::collapsible_else_if)] - pub fn gather_vec(&self, local_vec: &Vec, global_vec: &mut Vec) { - unsafe { - if self.world_size == 1 { - *global_vec = local_vec.clone() - } else { - assert!(!self.is_root() || global_vec.len() == local_vec.len() * self.world_size()); - - let local_vec_u8 = Self::vec_to_u8_bytes(local_vec); - let local_n_bytes = local_vec_u8.len(); - let n_chunks = (local_n_bytes + Self::CHUNK_SIZE - 1) / Self::CHUNK_SIZE; - if n_chunks == 1 { - if self.world_rank == Self::ROOT_RANK { - let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); - self.root_process() - .gather_into_root(&local_vec_u8, &mut global_vec_u8); - global_vec_u8.leak(); // discard control of the memory - } else { - self.root_process().gather_into(&local_vec_u8); - } - } else { - if self.world_rank == Self::ROOT_RANK { - let mut chunk_buffer_u8 = vec![0u8; Self::CHUNK_SIZE * self.world_size()]; - let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); - for i in 0..n_chunks { - let local_start = i * Self::CHUNK_SIZE; - let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); - self.root_process().gather_into_root( - &local_vec_u8[local_start..local_end], - &mut chunk_buffer_u8, - ); - - // distribute the data to where they belong to in global vec - let actual_chunk_size = local_end - local_start; - for j in 0..self.world_size() { - let global_start = j * local_n_bytes + local_start; - let global_end = global_start + actual_chunk_size; - global_vec_u8[global_start..global_end].copy_from_slice( - &chunk_buffer_u8[j * Self::CHUNK_SIZE - ..j * Self::CHUNK_SIZE + actual_chunk_size], - ); - } - } - global_vec_u8.leak(); // discard control of the memory - } else { - for i in 0..n_chunks { - let local_start = i * Self::CHUNK_SIZE; - let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); - self.root_process() - .gather_into(&local_vec_u8[local_start..local_end]); - } - } - } - local_vec_u8.leak(); // discard control of the memory - } - } - } - - /// Root process broadcase a value f into all the processes - #[inline] - pub fn root_broadcast_f(&self, f: &mut F) { - unsafe { - if self.world_size == 1 { - } else { - let mut vec_u8 = Self::elem_to_u8_bytes(f, F::SIZE); - self.root_process().broadcast_into(&mut vec_u8); - vec_u8.leak(); - } - } - } - - #[inline] - pub fn root_broadcast_bytes(&self, bytes: &mut Vec) { - self.root_process().broadcast_into(bytes); - } - - /// sum up all local values - #[inline] - pub fn sum_vec(&self, local_vec: &Vec) -> Vec { - if self.world_size == 1 { - local_vec.clone() - } else if self.world_rank == Self::ROOT_RANK { - let mut global_vec = vec![F::ZERO; local_vec.len() * (self.world_size as usize)]; - self.gather_vec(local_vec, &mut global_vec); - for i in 0..local_vec.len() { - for j in 1..(self.world_size as usize) { - global_vec[i] = global_vec[i] + global_vec[j * local_vec.len() + i]; - } - } - global_vec.truncate(local_vec.len()); - global_vec - } else { - self.gather_vec(local_vec, &mut vec![]); - vec![] - } - } - - /// coef has a length of mpi_world_size - #[inline] - pub fn coef_combine_vec(&self, local_vec: &Vec, coef: &[F]) -> Vec { - if self.world_size == 1 { - // Warning: literally, it should be coef[0] * local_vec - // but coef[0] is always one in our use case of self.world_size = 1 - local_vec.clone() - } else if self.world_rank == Self::ROOT_RANK { - let mut global_vec = vec![F::ZERO; local_vec.len() * (self.world_size as usize)]; - let mut ret = vec![F::ZERO; local_vec.len()]; - self.gather_vec(local_vec, &mut global_vec); - for i in 0..local_vec.len() { - for j in 0..(self.world_size as usize) { - ret[i] += global_vec[j * local_vec.len() + i] * coef[j]; - } - } - ret - } else { - self.gather_vec(local_vec, &mut vec![]); - vec![] - } - } - - #[inline(always)] - pub fn world_size(&self) -> usize { - self.world_size as usize - } - - #[inline(always)] - pub fn world_rank(&self) -> usize { - self.world_rank as usize - } - - #[inline(always)] - pub fn is_root(&self) -> bool { - self.world_rank == Self::ROOT_RANK - } - - #[inline(always)] - pub fn root_process(&self) -> Process { - self.world.unwrap().process_at_rank(Self::ROOT_RANK) - } - - #[inline(always)] - pub fn barrier(&self) { - self.world.unwrap().barrier(); - } +#[macro_export] +macro_rules! thread_println { + ($config: expr, $($arg:tt)*) => { + print!("[Thread {}] ", $config.current_thread().world_rank); + println!($($arg)*); + }; } - -unsafe impl Send for MPIConfig {} diff --git a/config/mpi_config/src/mpi_config.rs b/config/mpi_config/src/mpi_config.rs new file mode 100644 index 00000000..5b89f025 --- /dev/null +++ b/config/mpi_config/src/mpi_config.rs @@ -0,0 +1,274 @@ +use std::sync::Arc; + +use arith::Field; + +use crate::{ThreadConfig, MAX_WAIT_CYCLES}; + +/// Configuration for MPI +/// Assumptions +/// 1. Each thread writes to its own local memory +/// 2. Each thread reads from all other threads' local memory +/// 3. All threads have the same global memory +/// 4. IMPORTANT!!! The threads are synchronized by the caller; within each period of time, all +/// threads write a same amount of data +/// +/// The config struct only uses pointers so we avoid cloning of all data +#[derive(Debug, Clone)] +pub struct MPIConfig { + pub world_size: i32, // Number of threads + pub global_memory: Arc<[u8]>, // Global memory shared by all threads + pub threads: Vec, // Local memory for each thread +} + +impl Default for MPIConfig { + #[inline] + fn default() -> Self { + Self { + world_size: 1, + global_memory: Arc::from(vec![]), + threads: vec![], + } + } +} + +impl PartialEq for MPIConfig { + #[inline] + fn eq(&self, other: &Self) -> bool { + // equality is based on size + // it doesn't check the memory are consistent + self.world_size == other.world_size + } +} + +impl MPIConfig { + #[inline] + pub fn single_thread() -> Self { + Self { + world_size: 1, + global_memory: Arc::from(vec![]), + threads: vec![ThreadConfig::new(0, 1024 * 1024)], + } + } + + #[inline] + pub fn new(world_size: i32, global_data: Arc<[u8]>, buffer_size: usize) -> Self { + Self { + world_size, + global_memory: global_data, + threads: (0..world_size) + .map(|rank| ThreadConfig::new(rank, buffer_size)) + .collect(), + } + } + + #[inline] + pub fn world_size(&self) -> i32 { + self.world_size + } + + #[inline] + /// check the caller's thread is the root thread + pub fn is_root(&self) -> bool { + rayon::current_thread_index().unwrap() == 0 + } + + #[inline] + /// Get the current thread + pub fn current_thread(&self) -> &ThreadConfig { + let index = rayon::current_thread_index().unwrap(); + &self.threads[index] + } + + #[inline] + /// Get the current thread + pub fn current_thread_mut(&mut self) -> &mut ThreadConfig { + let index = rayon::current_thread_index().unwrap(); + &mut self.threads[index] + } + + #[inline] + /// Get the size of the current local memory + pub fn current_size(&self) -> usize { + self.current_thread().size() + } + + // #[inline] + // /// Check if the current thread is synced + // pub fn is_current_thread_synced(&self) -> bool { + // self.current_thread().is_synced() + // } + + // #[inline] + // /// Check if all threads are synced + // pub fn are_all_threads_synced(&self) -> bool { + // self.threads.iter().all(|t| t.is_synced()) + // } + + // #[inline] + // /// Sync up the current thread + // /// Returns a vector of slices, one for each thread's new data + // pub fn sync_up(&mut self) -> Vec<&[u8]> { + // if self.is_current_thread_synced() { + // return vec![]; + // } + + // let start = self.current_thread().last_synced; + // let end = self.current_thread().size(); + // // update the pointer to the latest index + // self.current_thread_mut().last_synced = end; + // let result = self.read_all(start, end); + + // result + // } + + /// Read all threads' local memory by waiting until there is new data to read from all + /// threads. + /// Returns a vector of slices, one for each thread's new data + /// Update the sync pointer of the current thread + /// + /// The threads are synchronized by the caller; within each period of time, all + /// threads write a same amount of data + pub fn read_all(&self, start: usize, end: usize) -> Vec<&[u8]> { + let total = self.threads.len(); + let mut pending = (0..total).collect::>(); + let mut results: Vec<&[u8]> = vec![&[]; total]; + let mut wait_cycles = 0; + + // Keep going until we've read from all threads + while !pending.is_empty() { + // Use retain to avoid re-checking already synced threads + pending.retain(|&i| { + let len = self.threads[i].size(); + if len >= end { + results[i] = self.threads[i].read(start, end); + // println!( + // "[Thread {}] read from thread {}: {} to {}", + // rayon::current_thread_index().unwrap(), + // i, + // start, + // end + // ); + false // Remove from pending + } else { + true // Keep in pending + } + }); + + if !pending.is_empty() { + // Claude suggest to use the following approach for waiting + // + // Simple spin - Rayon manages the thread pool efficiently + // Hint to the CPU that we're spinning (reduces power consumption) + // - For AMD/Intel it delays for 140 cycles + // - For ARM it is 1~2 cycles (We may need to manually adjust MAX_WAIT_CYCLES) + std::hint::spin_loop(); + wait_cycles += 1; + if wait_cycles > MAX_WAIT_CYCLES { + panic!( + "[Thread {}] exceeded max wait cycles\nPending list: {:?}", + rayon::current_thread_index().unwrap(), + pending + ); + } + } + } + results + } + + #[inline] + // todo: add a field buffer to the thread config so we can avoid field (de)serialization + pub fn read_all_field(&self, start: usize, end: usize) -> Vec> { + let data = self.read_all(start, end); + data.iter() + .map(|x| { + x.chunks(F::SIZE) + .map(|y| F::deserialize_from(y).unwrap()) + .collect() + }) + .collect() + } + + #[inline] + // todo: add a field buffer to the thread config so we can avoid field (de)serialization + pub fn read_all_field_flat(&self, start: usize, end: usize) -> Vec { + let data = self.read_all(start, end); + data.iter() + .flat_map(|x| x.chunks(F::SIZE).map(|y| F::deserialize_from(y).unwrap())) + .collect() + } + + #[inline] + /// Append data to the current thread's local memory + pub fn append_local(&self, data: &[u8]) { + let thread = self.current_thread(); + thread.append(data).expect("Failed to append"); + } + + #[inline] + /// Append data to the current thread's local memory + pub fn append_local_field(&self, f: &F) { + let mut data = vec![]; + f.serialize_into(&mut data).unwrap(); + self.append_local(&data); + } + + #[inline] + /// Append data to the current thread's local memory + pub fn append_local_fields(&self, f: &[F]) { + let data = f + .iter() + .flat_map(|x| { + let mut buf = vec![]; + x.serialize_into(&mut buf).unwrap(); + buf + }) + .collect::>(); + self.append_local(&data); + } + + /// coefficient has a length of mpi_world_size + #[inline] + pub fn coef_combine_vec(&self, local_vec: &Vec, coefficient: &[F]) -> Vec { + if self.world_size == 1 { + // Warning: literally, it should be coefficient[0] * local_vec + // but coefficient[0] is always one in our use case of self.world_size = 1 + local_vec.clone() + } else { + // write local vector to the buffer, then sync up all threads + let start = self.current_thread().size(); + let data = local_vec + .iter() + .flat_map(|&x| { + let mut buf = vec![]; + x.serialize_into(&mut buf).unwrap(); + buf + }) + .collect::>(); + self.append_local(&data); + let end = self.current_thread().size(); + let all_fields = self.read_all_field::(start, end); + + // build the result via linear combination + let mut result = vec![F::zero(); local_vec.len()]; + for i in 0..local_vec.len() { + for j in 0..(self.world_size as usize) { + result[i] += all_fields[j][i] * coefficient[j]; + } + } + + result + } + } + + #[inline] + /// Finalize function does nothing except for a minimal sanity check + /// that all threads have the same amount of data + pub fn finalize(&self) { + let len = self.threads[0].size(); + self.threads.iter().skip(1).for_each(|t| { + assert_eq!(t.size(), len); + }); + + // do nothing + } +} diff --git a/config/mpi_config/src/tests.rs b/config/mpi_config/src/tests.rs new file mode 100644 index 00000000..6d6b875b --- /dev/null +++ b/config/mpi_config/src/tests.rs @@ -0,0 +1,167 @@ +use std::sync::Arc; + +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use serial_test::serial; + +use crate::{MPIConfig, ThreadConfig}; + +// Example usage +#[test] +#[serial] +fn test_single_thread() { + let config = ThreadConfig::new(0, 1024 * 1024); + + // Append some data + let pos = config.append(&[1, 2, 3, 4]).unwrap(); + println!("Appended at position: {}", pos); + + // Read it back + let data = config.read(pos, pos + 4); + println!("Read back: {:?}", data); +} + +#[test] +#[serial] +// Assuming we have the MPIConfig and AtomicVec from previous example +fn test_parallel_processing() { + // Create some test data for global memory + let global_data: Arc<[u8]> = Arc::from((0..1024).map(|i| i as u8).collect::>()); + let num_threads = rayon::current_num_threads(); + + // Create configs for all threads + let mpi_config = MPIConfig::new(num_threads as i32, global_data, 1024 * 1024); + + // Process in parallel using rayon + (0..num_threads).into_par_iter().for_each(|rank| { + let thread = &mpi_config.threads[rank]; + + // Simulate some work: read from global memory and write to local + // Each thread reads a different section of global memory + let chunk_size = mpi_config.global_memory.len() / num_threads; + let start = rank * chunk_size; + let end = if rank == num_threads - 1 { + mpi_config.global_memory.len() + } else { + start + chunk_size + }; + + // Read from global memory + if let Some(global_chunk) = mpi_config.global_memory.get(start..end) { + // Process the data (example: multiply each byte by rank + 1) + let processed: Vec = global_chunk + .iter() + .map(|&x| x.wrapping_mul((rank + 1) as u8)) + .collect(); + + // Write to local memory + match thread.append(&processed) { + Ok(pos) => println!( + "Thread {} wrote {} bytes at position {}", + rank, + processed.len(), + pos + ), + Err(e) => eprintln!("Thread {} failed to write: {}", rank, e), + } + } + }); + + // Verify results + for rank in 0..num_threads { + let thread = &mpi_config.threads[rank]; + let data = thread.local_memory.get_slice(0, thread.local_memory.len()); + + if let Some(local_data) = data { + println!( + "Thread {} final local memory size: {}", + rank, + local_data.len() + ); + // Print first few bytes for verification + if !local_data.is_empty() { + println!( + "Thread {} first few bytes: {:?}", + rank, + &local_data[..local_data.len().min(4)] + ); + } + } + } +} + +#[test] +#[serial] +fn test_cross_thread_communication() { + // Create global data + let global_data: Arc<[u8]> = Arc::from((0..16).map(|i| i as u8).collect::>()); + let num_threads = rayon::current_num_threads(); + let data_len = 4; + + // Create configs for all threads + let mpi_config = MPIConfig::new(num_threads as i32, global_data, 1024 * 1024); + + let expected_result = (0..num_threads) + .map(|i| vec![i as u8 + 1; data_len]) + .collect::>(); + + // write to its own memory, and read from all others + (0..num_threads).into_par_iter().for_each(|rank| { + let thread = &mpi_config.threads[rank]; + + let data = vec![rank as u8 + 1; data_len]; + let start = thread.size(); + let end = start + data_len; + + thread.append(&data).expect("Failed to append"); + + let results = mpi_config.read_all(start, end); + assert_eq!(results.len(), num_threads as usize); + + for (i, result) in results.iter().enumerate() { + assert_eq!(result.len(), data_len); + assert_eq!(result, &expected_result[i]); + } + }); +} + +#[test] +#[serial] +fn test_incremental_updates() { + let global_data = Arc::<[u8]>::from(vec![0u8; 64]); + let num_threads = rayon::current_num_threads(); + let data_len = 4; + + let expected_result = (0..num_threads) + .map(|i| vec![i as u8 + 1; data_len]) + .collect::>(); + + // Create configs for all threads + let mpi_config = MPIConfig::new(num_threads as i32, global_data, 1024 * 1024); + + // write to its own memory, and read from all others + (0..num_threads).into_par_iter().for_each(|rank| { + // 10 interactions among the threads; without spawning and killing new threads + // during each interaction, a fixed amount of data will be written to each thead's local + // memory + for i in 0..10 { + let thread = &mpi_config.threads[rank]; + let data = vec![((rank + 1) * (i + 1)) as u8; data_len]; + let start = thread.size(); + let end = start + data_len; + + thread.append(&data).expect("Failed to append"); + + let results = mpi_config.read_all(start, end); + assert_eq!(results.len(), num_threads as usize); + + println!("Thread {} iteration {}: {:?}", rank, i, results); + + for (j, result) in results.iter().enumerate() { + assert_eq!(result.len(), data_len as usize); + result.iter().zip(&expected_result[j]).for_each(|(&a, &b)| { + assert_eq!(a, b * (i + 1) as u8); + }); + } + } + }); +} diff --git a/config/mpi_config/src/thread_config.rs b/config/mpi_config/src/thread_config.rs new file mode 100644 index 00000000..4226d1cc --- /dev/null +++ b/config/mpi_config/src/thread_config.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use crate::AtomicVec; + +/// Configuration for MPI +/// Assumptions +/// 1. Each thread writes to its own local memory +/// 2. Each thread reads from all other threads' local memory +/// 3. All threads have the same global memory +/// 4. IMPORTANT!!! The threads are synchronized by the caller; within each period of time, all +/// threads write a same amount of data +/// +/// A further optimization (TODO): +/// - we can have a buffer for bytes, and a buffer for field elements, this should avoid the need of +/// field (de)serializations between threads +#[derive(Debug, Clone)] +pub struct ThreadConfig { + pub world_rank: i32, // indexer for the thread + pub local_memory: Arc>, // local memory for the thread + pub last_synced: usize, // last synced index +} + +impl Default for ThreadConfig { + #[inline] + fn default() -> Self { + Self { + world_rank: 0, + local_memory: Arc::new(AtomicVec::new(0)), + last_synced: 0, + } + } +} + +impl PartialEq for ThreadConfig { + #[inline] + fn eq(&self, other: &Self) -> bool { + // equality is based on rank and size + // it doesn't check the memory are consistent + self.world_rank == other.world_rank + } +} + +impl ThreadConfig { + #[inline] + pub fn new(world_rank: i32, buffer_size: usize) -> Self { + Self { + world_rank, + local_memory: Arc::new(AtomicVec::new(buffer_size)), + last_synced: 0, + } + } + + #[inline] + pub fn is_root(&self) -> bool { + self.world_rank == 0 + } + + #[inline] + pub fn is_synced(&self) -> bool { + self.last_synced == self.local_memory.len() + } + + #[inline] + pub fn append(&self, data: &[u8]) -> Result { + self.local_memory + .append(data) + .ok_or("Failed to append: insufficient capacity") + } + + #[inline] + pub fn read(&self, start: usize, end: usize) -> &[u8] { + self.local_memory + .get_slice(start, end) + .ok_or(format!( + "failed to read between {start} and {end} for slice of length {}", + self.local_memory.len() + )) + .unwrap() + } + + #[inline] + /// Get the length of local memory + pub fn size(&self) -> usize { + self.local_memory.len() + } +} diff --git a/config/mpi_config2/Cargo.toml b/config/mpi_config2/Cargo.toml new file mode 100644 index 00000000..41deab80 --- /dev/null +++ b/config/mpi_config2/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "mpi_config" +version = "0.1.0" +edition = "2021" + +[dependencies] +arith = { path = "../../arith" } +mpi.workspace = true + +[dev-dependencies] +mersenne31 = { path = "../../arith/mersenne31"} \ No newline at end of file diff --git a/config/mpi_config2/src/lib.rs b/config/mpi_config2/src/lib.rs new file mode 100644 index 00000000..2fae9994 --- /dev/null +++ b/config/mpi_config2/src/lib.rs @@ -0,0 +1,291 @@ +use std::{cmp, fmt::Debug}; + +use arith::Field; +use mpi::{ + environment::Universe, + ffi, + topology::{Process, SimpleCommunicator}, + traits::*, +}; + +#[macro_export] +macro_rules! root_println { + ($config: expr, $($arg:tt)*) => { + if $config.is_root() { + println!($($arg)*); + } + }; +} + +static mut UNIVERSE: Option = None; +static mut WORLD: Option = None; + +#[derive(Clone)] +pub struct MPIConfig { + pub universe: Option<&'static mpi::environment::Universe>, + pub world: Option<&'static SimpleCommunicator>, + pub world_size: i32, + pub world_rank: i32, +} + +impl Default for MPIConfig { + fn default() -> Self { + Self { + universe: None, + world: None, + world_size: 1, + world_rank: 0, + } + } +} + +impl Debug for MPIConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let universe_fmt = if self.universe.is_none() { + Option::::None + } else { + Some(self.universe.unwrap().buffer_size()) + }; + + let world_fmt = if self.world.is_none() { + Option::::None + } else { + Some(0usize) + }; + + f.debug_struct("MPIConfig") + .field("universe", &universe_fmt) + .field("world", &world_fmt) + .field("world_size", &self.world_size) + .field("world_rank", &self.world_rank) + .finish() + } +} + +// Note: may not be correct +impl PartialEq for MPIConfig { + fn eq(&self, other: &Self) -> bool { + self.world_rank == other.world_rank && self.world_size == other.world_size + } +} + +/// MPI toolkit: +impl MPIConfig { + const ROOT_RANK: i32 = 0; + + /// The communication limit for MPI is 2^30. Save 10 bits for #parties here. + const CHUNK_SIZE: usize = 1usize << 20; + + // OK if already initialized, mpi::initialize() will return None + #[allow(static_mut_refs)] + pub fn init() { + unsafe { + let universe = mpi::initialize(); + if universe.is_some() { + UNIVERSE = universe; + WORLD = Some(UNIVERSE.as_ref().unwrap().world()); + } + } + } + + #[inline] + pub fn finalize() { + unsafe { ffi::MPI_Finalize() }; + } + + #[allow(static_mut_refs)] + pub fn new() -> Self { + Self::init(); + let universe = unsafe { UNIVERSE.as_ref() }; + let world = unsafe { WORLD.as_ref() }; + let world_size = if let Some(world) = world { + world.size() + } else { + 1 + }; + let world_rank = if let Some(world) = world { + world.rank() + } else { + 0 + }; + Self { + universe, + world, + world_size, + world_rank, + } + } + + #[inline] + pub fn new_for_verifier(world_size: i32) -> Self { + Self { + universe: None, + world: None, + world_size, + world_rank: 0, + } + } + + /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. + #[inline] + unsafe fn elem_to_u8_bytes(elem: &V, byte_size: usize) -> Vec { + Vec::::from_raw_parts((elem as *const V) as *mut u8, byte_size, byte_size) + } + + /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. + #[inline] + unsafe fn vec_to_u8_bytes(vec: &Vec) -> Vec { + Vec::::from_raw_parts( + vec.as_ptr() as *mut u8, + vec.len() * F::SIZE, + vec.capacity() * F::SIZE, + ) + } + + #[allow(clippy::collapsible_else_if)] + pub fn gather_vec(&self, local_vec: &Vec, global_vec: &mut Vec) { + unsafe { + if self.world_size == 1 { + *global_vec = local_vec.clone() + } else { + assert!(!self.is_root() || global_vec.len() == local_vec.len() * self.world_size()); + + let local_vec_u8 = Self::vec_to_u8_bytes(local_vec); + let local_n_bytes = local_vec_u8.len(); + let n_chunks = (local_n_bytes + Self::CHUNK_SIZE - 1) / Self::CHUNK_SIZE; + if n_chunks == 1 { + if self.world_rank == Self::ROOT_RANK { + let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); + self.root_process() + .gather_into_root(&local_vec_u8, &mut global_vec_u8); + global_vec_u8.leak(); // discard control of the memory + } else { + self.root_process().gather_into(&local_vec_u8); + } + } else { + if self.world_rank == Self::ROOT_RANK { + let mut chunk_buffer_u8 = vec![0u8; Self::CHUNK_SIZE * self.world_size()]; + let mut global_vec_u8 = Self::vec_to_u8_bytes(global_vec); + for i in 0..n_chunks { + let local_start = i * Self::CHUNK_SIZE; + let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); + self.root_process().gather_into_root( + &local_vec_u8[local_start..local_end], + &mut chunk_buffer_u8, + ); + + // distribute the data to where they belong to in global vec + let actual_chunk_size = local_end - local_start; + for j in 0..self.world_size() { + let global_start = j * local_n_bytes + local_start; + let global_end = global_start + actual_chunk_size; + global_vec_u8[global_start..global_end].copy_from_slice( + &chunk_buffer_u8[j * Self::CHUNK_SIZE + ..j * Self::CHUNK_SIZE + actual_chunk_size], + ); + } + } + global_vec_u8.leak(); // discard control of the memory + } else { + for i in 0..n_chunks { + let local_start = i * Self::CHUNK_SIZE; + let local_end = cmp::min(local_start + Self::CHUNK_SIZE, local_n_bytes); + self.root_process() + .gather_into(&local_vec_u8[local_start..local_end]); + } + } + } + local_vec_u8.leak(); // discard control of the memory + } + } + } + + /// Root process broadcase a value f into all the processes + #[inline] + pub fn root_broadcast_f(&self, f: &mut F) { + unsafe { + if self.world_size == 1 { + } else { + let mut vec_u8 = Self::elem_to_u8_bytes(f, F::SIZE); + self.root_process().broadcast_into(&mut vec_u8); + vec_u8.leak(); + } + } + } + + #[inline] + pub fn root_broadcast_bytes(&self, bytes: &mut Vec) { + self.root_process().broadcast_into(bytes); + } + + /// sum up all local values + #[inline] + pub fn sum_vec(&self, local_vec: &Vec) -> Vec { + if self.world_size == 1 { + local_vec.clone() + } else if self.world_rank == Self::ROOT_RANK { + let mut global_vec = vec![F::ZERO; local_vec.len() * (self.world_size as usize)]; + self.gather_vec(local_vec, &mut global_vec); + for i in 0..local_vec.len() { + for j in 1..(self.world_size as usize) { + global_vec[i] = global_vec[i] + global_vec[j * local_vec.len() + i]; + } + } + global_vec.truncate(local_vec.len()); + global_vec + } else { + self.gather_vec(local_vec, &mut vec![]); + vec![] + } + } + + /// coef has a length of mpi_world_size + #[inline] + pub fn coef_combine_vec(&self, local_vec: &Vec, coef: &[F]) -> Vec { + if self.world_size == 1 { + // Warning: literally, it should be coef[0] * local_vec + // but coef[0] is always one in our use case of self.world_size = 1 + local_vec.clone() + } else if self.world_rank == Self::ROOT_RANK { + let mut global_vec = vec![F::ZERO; local_vec.len() * (self.world_size as usize)]; + let mut ret = vec![F::ZERO; local_vec.len()]; + self.gather_vec(local_vec, &mut global_vec); + for i in 0..local_vec.len() { + for j in 0..(self.world_size as usize) { + ret[i] += global_vec[j * local_vec.len() + i] * coef[j]; + } + } + ret + } else { + self.gather_vec(local_vec, &mut vec![]); + vec![] + } + } + + #[inline(always)] + pub fn world_size(&self) -> usize { + self.world_size as usize + } + + #[inline(always)] + pub fn world_rank(&self) -> usize { + self.world_rank as usize + } + + #[inline(always)] + pub fn is_root(&self) -> bool { + self.world_rank == Self::ROOT_RANK + } + + #[inline(always)] + pub fn root_process(&self) -> Process { + self.world.unwrap().process_at_rank(Self::ROOT_RANK) + } + + #[inline(always)] + pub fn barrier(&self) { + self.world.unwrap().barrier(); + } +} + +unsafe impl Send for MPIConfig {} diff --git a/config/mpi_config/tests/gather_vec.rs b/config/mpi_config2/tests/gather_vec.rs similarity index 100% rename from config/mpi_config/tests/gather_vec.rs rename to config/mpi_config2/tests/gather_vec.rs diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 1c62fbfb..f51eacab 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -22,14 +22,15 @@ transcript = { path = "../transcript" } ark-std.workspace = true clap.workspace = true env_logger.workspace = true +ethnum.workspace = true +halo2curves.workspace = true log.workspace = true mpi.workspace = true rand.workspace = true +rand_chacha.workspace = true +rayon.workspace = true sha2.workspace = true -halo2curves.workspace = true thiserror.workspace = true -ethnum.workspace = true -rand_chacha.workspace = true # for the server bytes.workspace = true @@ -42,6 +43,8 @@ tiny-keccak.workspace = true [dev-dependencies] criterion = "0.5.1" +serial_test.workspace = true + [[bin]] name = "gkr-mpi" path = "src/main_mpi.rs" diff --git a/gkr/src/exec.rs b/gkr/src/exec.rs index 99307ab5..3497fce1 100644 --- a/gkr/src/exec.rs +++ b/gkr/src/exec.rs @@ -1,77 +1,79 @@ -use config::{Config, GKRScheme}; -use mpi_config::MPIConfig; +// use config::{Config, GKRScheme}; +// use mpi_config::MPIConfig; -use log::debug; +// use log::debug; -#[allow(unused_imports)] // The FiatShamirHashType import is used in the macro expansion -use config::FiatShamirHashType; -#[allow(unused_imports)] // The FieldType import is used in the macro expansion -use gkr_field_config::FieldType; +// #[allow(unused_imports)] // The FiatShamirHashType import is used in the macro expansion +// use config::FiatShamirHashType; +// #[allow(unused_imports)] // The FieldType import is used in the macro expansion +// use gkr_field_config::FieldType; -use gkr::executor::*; +// use gkr::executor::*; -#[tokio::main] -async fn main() { - // examples: - // expander-exec prove - // expander-exec verify - // expander-exec serve - let mut mpi_config = MPIConfig::new(); +// #[tokio::main] +// async fn main() { +// // examples: +// // expander-exec prove +// // expander-exec verify +// // expander-exec serve +// let mut mpi_config = MPIConfig::new(); - let args = std::env::args().collect::>(); - if args.len() < 5 { - println!( - "Usage: expander-exec prove " - ); - println!( - "Usage: expander-exec verify " - ); - println!("Usage: expander-exec serve "); - return; - } - let command = &args[1]; - if command != "prove" && command != "verify" && command != "serve" { - println!("Invalid command."); - return; - } +// let args = std::env::args().collect::>(); +// if args.len() < 5 { +// println!( +// "Usage: expander-exec prove " +// ); +// println!( +// "Usage: expander-exec verify +// " ); +// println!("Usage: expander-exec serve "); +// return; +// } +// let command = &args[1]; +// if command != "prove" && command != "verify" && command != "serve" { +// println!("Invalid command."); +// return; +// } - if command == "verify" && args.len() > 5 { - assert!(mpi_config.world_size == 1); // verifier should not be run with mpiexec - mpi_config.world_size = args[5].parse::().expect("Parsing mpi size fails"); - } +// if command == "verify" && args.len() > 5 { +// assert!(mpi_config.world_size == 1); // verifier should not be run with mpiexec +// mpi_config.world_size = args[5].parse::().expect("Parsing mpi size fails"); +// } - let circuit_file = &args[2]; - let field_type = detect_field_type_from_circuit_file(circuit_file); - debug!("field type: {:?}", field_type); - match field_type { - FieldType::M31 => { - run_command::( - command, - circuit_file, - Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - &args, - ) - .await; - } - FieldType::BN254 => { - run_command::( - command, - circuit_file, - Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - &args, - ) - .await; - } - FieldType::GF2 => { - run_command::( - command, - circuit_file, - Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - &args, - ) - .await - } - } +// let circuit_file = &args[2]; +// let field_type = detect_field_type_from_circuit_file(circuit_file); +// debug!("field type: {:?}", field_type); +// match field_type { +// FieldType::M31 => { +// run_command::( +// command, +// circuit_file, +// Config::::new(GKRScheme::Vanilla, mpi_config.clone()), +// &args, +// ) +// .await; +// } +// FieldType::BN254 => { +// run_command::( +// command, +// circuit_file, +// Config::::new(GKRScheme::Vanilla, mpi_config.clone()), +// &args, +// ) +// .await; +// } +// FieldType::GF2 => { +// run_command::( +// command, +// circuit_file, +// Config::::new(GKRScheme::Vanilla, mpi_config.clone()), +// &args, +// ) +// .await +// } +// } - MPIConfig::finalize(); -} +// MPIConfig::finalize(); +// } + +fn main() {} diff --git a/gkr/src/main.rs b/gkr/src/main.rs index 2c11ef5b..486a9ae9 100644 --- a/gkr/src/main.rs +++ b/gkr/src/main.rs @@ -54,7 +54,7 @@ fn main() { let args = Args::parse(); print_info(&args); - let mpi_config = MPIConfig::new(); + let mpi_config = MPIConfig::single_thread(); declare_gkr_config!( M31ExtConfigSha2, @@ -112,7 +112,7 @@ fn main() { _ => unreachable!(), }; - MPIConfig::finalize(); + mpi_config.finalize(); } const PCS_TESTING_SEED_U64: u64 = 114514; diff --git a/gkr/src/main_mpi.rs b/gkr/src/main_mpi.rs index 2a03a453..69d6d5eb 100644 --- a/gkr/src/main_mpi.rs +++ b/gkr/src/main_mpi.rs @@ -1,209 +1,211 @@ -use circuit::Circuit; -use clap::Parser; -use config::{Config, GKRConfig, GKRScheme}; -use config_macros::declare_gkr_config; -use mpi_config::MPIConfig; - -use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; -use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR}; -use rand::SeedableRng; -use rand_chacha::ChaCha12Rng; -use transcript::{BytesHashTranscript, SHA256hasher}; - -use gkr::{ - utils::{ - KECCAK_BN254_CIRCUIT, KECCAK_BN254_WITNESS, KECCAK_GF2_CIRCUIT, KECCAK_GF2_WITNESS, - KECCAK_M31_CIRCUIT, KECCAK_M31_WITNESS, POSEIDON_M31_CIRCUIT, POSEIDON_M31_WITNESS, - }, - Prover, -}; - -#[allow(unused_imports)] -// The FiatShamirHashType and PolynomialCommitmentType import is used in the macro expansion -use config::{FiatShamirHashType, PolynomialCommitmentType}; -#[allow(unused_imports)] // The FieldType import is used in the macro expansion -use gkr_field_config::FieldType; - -/// ... -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Field Identifier: fr, m31, m31ext3 - #[arg(short, long,default_value_t = String::from("m31ext3"))] - field: String, - - // scheme: keccak, poseidon - #[arg(short, long, default_value_t = String::from("keccak"))] - scheme: String, - - /// number of repeat - #[arg(short, long, default_value_t = 1)] - repeats: usize, -} - -fn main() { - let args = Args::parse(); - print_info(&args); - - let mpi_config = MPIConfig::new(); - - declare_gkr_config!( - M31ExtConfigSha2, - FieldType::M31, - FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw - ); - declare_gkr_config!( - BN254ConfigSha2, - FieldType::BN254, - FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw - ); - declare_gkr_config!( - GF2ExtConfigSha2, - FieldType::GF2, - FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw - ); - - match args.field.as_str() { - "m31ext3" => match args.scheme.as_str() { - "keccak" => run_benchmark::( - &args, - Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - ), - "poseidon" => run_benchmark::( - &args, - Config::::new(GKRScheme::GkrSquare, mpi_config.clone()), - ), - _ => unreachable!(), - }, - "fr" => match args.scheme.as_str() { - "keccak" => run_benchmark::( - &args, - Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - ), - "poseidon" => run_benchmark::( - &args, - Config::::new(GKRScheme::GkrSquare, mpi_config.clone()), - ), - _ => unreachable!(), - }, - "gf2ext128" => match args.scheme.as_str() { - "keccak" => run_benchmark::( - &args, - Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - ), - "poseidon" => run_benchmark::( - &args, - Config::::new(GKRScheme::GkrSquare, mpi_config.clone()), - ), - _ => unreachable!(), - }, - _ => unreachable!(), - }; - - MPIConfig::finalize(); -} - -const PCS_TESTING_SEED_U64: u64 = 114514; - -fn run_benchmark(args: &Args, config: Config) { - let pack_size = Cfg::FieldConfig::get_field_pack_size(); - - // load circuit - let mut circuit = match args.scheme.as_str() { - "keccak" => match Cfg::FieldConfig::FIELD_TYPE { - FieldType::GF2 => Circuit::::load_circuit(KECCAK_GF2_CIRCUIT), - FieldType::M31 => Circuit::::load_circuit(KECCAK_M31_CIRCUIT), - FieldType::BN254 => Circuit::::load_circuit(KECCAK_BN254_CIRCUIT), - }, - "poseidon" => match Cfg::FieldConfig::FIELD_TYPE { - FieldType::M31 => Circuit::::load_circuit(POSEIDON_M31_CIRCUIT), - _ => unreachable!(), - }, - _ => unreachable!(), - }; - - let witness_path = match args.scheme.as_str() { - "keccak" => match Cfg::FieldConfig::FIELD_TYPE { - FieldType::GF2 => KECCAK_GF2_WITNESS, - FieldType::M31 => KECCAK_M31_WITNESS, - FieldType::BN254 => KECCAK_BN254_WITNESS, - }, - "poseidon" => match Cfg::FieldConfig::FIELD_TYPE { - FieldType::M31 => POSEIDON_M31_WITNESS, - _ => unreachable!("not supported"), - }, - _ => unreachable!(), - }; - - match args.scheme.as_str() { - "keccak" => circuit.load_witness_file(witness_path), - "poseidon" => match Cfg::FieldConfig::FIELD_TYPE { - FieldType::M31 => circuit.load_non_simd_witness_file(witness_path), - _ => unreachable!("not supported"), - }, - - _ => unreachable!(), - }; - - let circuit_copy_size: usize = match (Cfg::FieldConfig::FIELD_TYPE, args.scheme.as_str()) { - (FieldType::GF2, "keccak") => 1, - (FieldType::M31, "keccak") => 2, - (FieldType::BN254, "keccak") => 2, - (FieldType::M31, "poseidon") => 120, - _ => unreachable!(), - }; - - let mut prover = Prover::new(&config); - prover.prepare_mem(&circuit); - - let mut rng = ChaCha12Rng::seed_from_u64(PCS_TESTING_SEED_U64); - let (pcs_params, pcs_proving_key, _pcs_verification_key, mut pcs_scratch) = - expander_pcs_init_testing_only::( - circuit.log_input_size(), - &config.mpi_config, - &mut rng, - ); - - const N_PROOF: usize = 1000; - - println!("We are now calculating average throughput, please wait until {N_PROOF} proofs are computed"); - for i in 0..args.repeats { - config.mpi_config.barrier(); // wait until everyone is here - let start_time = std::time::Instant::now(); - for _j in 0..N_PROOF { - prover.prove( - &mut circuit, - &pcs_params, - &pcs_proving_key, - &mut pcs_scratch, - ); - } - let stop_time = std::time::Instant::now(); - let duration = stop_time.duration_since(start_time); - let throughput = (N_PROOF * circuit_copy_size * pack_size * config.mpi_config.world_size()) - as f64 - / duration.as_secs_f64(); - println!("{}-bench: throughput: {} hashes/s", i, throughput.round()); - } -} - -fn print_info(args: &Args) { - let prover = match args.scheme.as_str() { - "keccak" => "GKR", - "poseidon" => "GKR^2", - _ => unreachable!(), - }; - - println!("==============================="); - println!( - "benchmarking {} with {} over {}", - args.scheme, prover, args.field - ); - println!("field: {}", args.field); - println!("#bench repeats: {}", args.repeats); - println!("hash scheme: {}", args.scheme); - println!("===============================") -} +// use circuit::Circuit; +// use clap::Parser; +// use config::{Config, GKRConfig, GKRScheme}; +// use config_macros::declare_gkr_config; +// use mpi_config::MPIConfig; + +// use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; +// use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR}; +// use rand::SeedableRng; +// use rand_chacha::ChaCha12Rng; +// use transcript::{BytesHashTranscript, SHA256hasher}; + +// use gkr::{ +// utils::{ +// KECCAK_BN254_CIRCUIT, KECCAK_BN254_WITNESS, KECCAK_GF2_CIRCUIT, KECCAK_GF2_WITNESS, +// KECCAK_M31_CIRCUIT, KECCAK_M31_WITNESS, POSEIDON_M31_CIRCUIT, POSEIDON_M31_WITNESS, +// }, +// Prover, +// }; + +// #[allow(unused_imports)] +// // The FiatShamirHashType and PolynomialCommitmentType import is used in the macro expansion +// use config::{FiatShamirHashType, PolynomialCommitmentType}; +// #[allow(unused_imports)] // The FieldType import is used in the macro expansion +// use gkr_field_config::FieldType; + +// /// ... +// #[derive(Parser, Debug)] +// #[command(author, version, about, long_about = None)] +// struct Args { +// /// Field Identifier: fr, m31, m31ext3 +// #[arg(short, long,default_value_t = String::from("m31ext3"))] +// field: String, + +// // scheme: keccak, poseidon +// #[arg(short, long, default_value_t = String::from("keccak"))] +// scheme: String, + +// /// number of repeat +// #[arg(short, long, default_value_t = 1)] +// repeats: usize, +// } + +// fn main() { +// let args = Args::parse(); +// print_info(&args); + +// let mpi_config = MPIConfig::new(); + +// declare_gkr_config!( +// M31ExtConfigSha2, +// FieldType::M31, +// FiatShamirHashType::SHA256, +// PolynomialCommitmentType::Raw +// ); +// declare_gkr_config!( +// BN254ConfigSha2, +// FieldType::BN254, +// FiatShamirHashType::SHA256, +// PolynomialCommitmentType::Raw +// ); +// declare_gkr_config!( +// GF2ExtConfigSha2, +// FieldType::GF2, +// FiatShamirHashType::SHA256, +// PolynomialCommitmentType::Raw +// ); + +// match args.field.as_str() { +// "m31ext3" => match args.scheme.as_str() { +// "keccak" => run_benchmark::( +// &args, +// Config::::new(GKRScheme::Vanilla, mpi_config.clone()), +// ), +// "poseidon" => run_benchmark::( +// &args, +// Config::::new(GKRScheme::GkrSquare, mpi_config.clone()), +// ), +// _ => unreachable!(), +// }, +// "fr" => match args.scheme.as_str() { +// "keccak" => run_benchmark::( +// &args, +// Config::::new(GKRScheme::Vanilla, mpi_config.clone()), +// ), +// "poseidon" => run_benchmark::( +// &args, +// Config::::new(GKRScheme::GkrSquare, mpi_config.clone()), +// ), +// _ => unreachable!(), +// }, +// "gf2ext128" => match args.scheme.as_str() { +// "keccak" => run_benchmark::( +// &args, +// Config::::new(GKRScheme::Vanilla, mpi_config.clone()), +// ), +// "poseidon" => run_benchmark::( +// &args, +// Config::::new(GKRScheme::GkrSquare, mpi_config.clone()), +// ), +// _ => unreachable!(), +// }, +// _ => unreachable!(), +// }; + +// MPIConfig::finalize(); +// } + +// const PCS_TESTING_SEED_U64: u64 = 114514; + +// fn run_benchmark(args: &Args, config: Config) { +// let pack_size = Cfg::FieldConfig::get_field_pack_size(); + +// // load circuit +// let mut circuit = match args.scheme.as_str() { +// "keccak" => match Cfg::FieldConfig::FIELD_TYPE { +// FieldType::GF2 => Circuit::::load_circuit(KECCAK_GF2_CIRCUIT), +// FieldType::M31 => Circuit::::load_circuit(KECCAK_M31_CIRCUIT), +// FieldType::BN254 => Circuit::::load_circuit(KECCAK_BN254_CIRCUIT), +// }, +// "poseidon" => match Cfg::FieldConfig::FIELD_TYPE { +// FieldType::M31 => Circuit::::load_circuit(POSEIDON_M31_CIRCUIT), +// _ => unreachable!(), +// }, +// _ => unreachable!(), +// }; + +// let witness_path = match args.scheme.as_str() { +// "keccak" => match Cfg::FieldConfig::FIELD_TYPE { +// FieldType::GF2 => KECCAK_GF2_WITNESS, +// FieldType::M31 => KECCAK_M31_WITNESS, +// FieldType::BN254 => KECCAK_BN254_WITNESS, +// }, +// "poseidon" => match Cfg::FieldConfig::FIELD_TYPE { +// FieldType::M31 => POSEIDON_M31_WITNESS, +// _ => unreachable!("not supported"), +// }, +// _ => unreachable!(), +// }; + +// match args.scheme.as_str() { +// "keccak" => circuit.load_witness_file(witness_path), +// "poseidon" => match Cfg::FieldConfig::FIELD_TYPE { +// FieldType::M31 => circuit.load_non_simd_witness_file(witness_path), +// _ => unreachable!("not supported"), +// }, + +// _ => unreachable!(), +// }; + +// let circuit_copy_size: usize = match (Cfg::FieldConfig::FIELD_TYPE, args.scheme.as_str()) { +// (FieldType::GF2, "keccak") => 1, +// (FieldType::M31, "keccak") => 2, +// (FieldType::BN254, "keccak") => 2, +// (FieldType::M31, "poseidon") => 120, +// _ => unreachable!(), +// }; + +// let mut prover = Prover::new(&config); +// prover.prepare_mem(&circuit); + +// let mut rng = ChaCha12Rng::seed_from_u64(PCS_TESTING_SEED_U64); +// let (pcs_params, pcs_proving_key, _pcs_verification_key, mut pcs_scratch) = +// expander_pcs_init_testing_only::( +// circuit.log_input_size(), +// &config.mpi_config, +// &mut rng, +// ); + +// const N_PROOF: usize = 1000; + +// println!("We are now calculating average throughput, please wait until {N_PROOF} proofs are +// computed"); for i in 0..args.repeats { +// config.mpi_config.barrier(); // wait until everyone is here +// let start_time = std::time::Instant::now(); +// for _j in 0..N_PROOF { +// prover.prove( +// &mut circuit, +// &pcs_params, +// &pcs_proving_key, +// &mut pcs_scratch, +// ); +// } +// let stop_time = std::time::Instant::now(); +// let duration = stop_time.duration_since(start_time); +// let throughput = (N_PROOF * circuit_copy_size * pack_size * +// config.mpi_config.world_size()) as f64 +// / duration.as_secs_f64(); +// println!("{}-bench: throughput: {} hashes/s", i, throughput.round()); +// } +// } + +// fn print_info(args: &Args) { +// let prover = match args.scheme.as_str() { +// "keccak" => "GKR", +// "poseidon" => "GKR^2", +// _ => unreachable!(), +// }; + +// println!("==============================="); +// println!( +// "benchmarking {} with {} over {}", +// args.scheme, prover, args.field +// ); +// println!("field: {}", args.field); +// println!("#bench repeats: {}", args.repeats); +// println!("hash scheme: {}", args.scheme); +// println!("===============================") +// } + +fn main() {} diff --git a/gkr/src/prover/gkr.rs b/gkr/src/prover/gkr.rs index 296ab614..98f2b3bd 100644 --- a/gkr/src/prover/gkr.rs +++ b/gkr/src/prover/gkr.rs @@ -53,19 +53,29 @@ pub fn gkr_prove>( &mut sp.eq_evals_at_r_simd0, ); - let claimed_v = if mpi_config.is_root() { - let mut claimed_v_gathering_buffer = - vec![C::ChallengeField::zero(); mpi_config.world_size()]; - mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer); - MultiLinearPoly::evaluate_with_buffer( - &claimed_v_gathering_buffer, - &r_mpi, - &mut sp.eq_evals_at_r_mpi0, - ) - } else { - mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]); - C::ChallengeField::zero() - }; + let start = mpi_config.current_size(); + let end = start + C::ChallengeField::SIZE; + mpi_config.append_local_field(&claimed_v_local); + let claimed_v_gathering_buffer = mpi_config.read_all_field_flat(start, end); + let claimed_v = MultiLinearPoly::evaluate_with_buffer( + &claimed_v_gathering_buffer, + &r_mpi, + &mut sp.eq_evals_at_r_mpi0, + ); + + // let claimed_v = if mpi_config.is_root() { + // let mut claimed_v_gathering_buffer = + // vec![C::ChallengeField::zero(); mpi_config.world_size() as usize]; + // mpi_config.gather_vec(&vec![claimed_v_local], &mut claimed_v_gathering_buffer); + // MultiLinearPoly::evaluate_with_buffer( + // &claimed_v_gathering_buffer, + // &r_mpi, + // &mut sp.eq_evals_at_r_mpi0, + // ) + // } else { + // mpi_config.gather_vec(&vec![claimed_v_local], &mut vec![]); + // C::ChallengeField::zero() + // }; for i in (0..layer_num).rev() { (rz0, rz1, r_simd, r_mpi) = sumcheck_prove_gkr_layer( @@ -84,7 +94,7 @@ pub fn gkr_prove>( if rz1.is_some() { // TODO: try broadcast beta.unwrap directly let mut tmp = transcript.generate_challenge_field_element(); - mpi_config.root_broadcast_f(&mut tmp); + // mpi_config.root_broadcast_f(&mut tmp); alpha = Some(tmp) } else { alpha = None; diff --git a/gkr/src/prover/linear_gkr.rs b/gkr/src/prover/linear_gkr.rs index 907ae30e..e6128e8c 100644 --- a/gkr/src/prover/linear_gkr.rs +++ b/gkr/src/prover/linear_gkr.rs @@ -5,10 +5,11 @@ use ark_std::{end_timer, start_timer}; use circuit::Circuit; use config::{Config, GKRConfig, GKRScheme}; use gkr_field_config::GKRFieldConfig; +use mpi_config::root_println; use poly_commit::{ExpanderGKRChallenge, PCSForExpanderGKR, StructuredReferenceString}; use polynomials::{MultilinearExtension, RefMultiLinearPoly}; use sumcheck::ProverScratchPad; -use transcript::{transcript_root_broadcast, Proof, Transcript}; +use transcript::{Proof, Transcript}; use crate::{gkr_prove, gkr_square_prove}; @@ -72,7 +73,7 @@ impl Prover { self.sp = ProverScratchPad::::new( max_num_input_var, max_num_output_var, - self.config.mpi_config.world_size(), + self.config.mpi_config.world_size() as usize, ); } @@ -97,11 +98,13 @@ impl Prover { pcs_scratch, ) }; + root_println!(self.config.mpi_config, "PC commit done"); + let mut buffer = vec![]; commitment.serialize_into(&mut buffer).unwrap(); // TODO: error propagation transcript.append_u8_slice(&buffer); - transcript_root_broadcast(&mut transcript, &self.config.mpi_config); + // transcript_root_broadcast(&mut transcript, &self.config.mpi_config); #[cfg(feature = "grinding")] grind::(&mut transcript, &self.config); diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index a6daed0d..93eac670 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -1,5 +1,6 @@ use std::io::Write; use std::panic::AssertUnwindSafe; +use std::sync::Arc; use std::time::Instant; use std::{fs, panic}; @@ -15,6 +16,8 @@ use poly_commit::expander_pcs_init_testing_only; use poly_commit::raw::RawExpanderGKR; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha12Rng; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use serial_test::serial; use sha2::Digest; use transcript::{BytesHashTranscript, FieldHashTranscript, Keccak256hasher, SHA256hasher}; @@ -22,9 +25,15 @@ use crate::{utils::*, Prover, Verifier}; const PCS_TESTING_SEED_U64: u64 = 114514; +const NUM_THREADS: usize = 32; + #[test] +#[serial] fn test_gkr_correctness() { - let mpi_config = MPIConfig::new(); + // Create global data + let global_data: Arc<[u8]> = Arc::from((0..16).map(|i| i as u8).collect::>()); + let mpi_config = MPIConfig::new(NUM_THREADS as i32, global_data, 1024 * 1024); + declare_gkr_config!( C0, FieldType::GF2, @@ -78,184 +87,198 @@ fn test_gkr_correctness() { &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), None, ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - None, - ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - None, - ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - None, - ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - None, - ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - None, - ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - Some("../data/gkr_proof.txt"), - ); - test_gkr_correctness_helper( - &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), - None, - ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // None, + // ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // None, + // ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // None, + // ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // None, + // ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // None, + // ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // Some("../data/gkr_proof.txt"), + // ); + // test_gkr_correctness_helper( + // &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + // None, + // ); - MPIConfig::finalize(); + mpi_config.finalize(); } #[allow(unreachable_patterns)] fn test_gkr_correctness_helper(config: &Config, write_proof_to: Option<&str>) { - root_println!(config.mpi_config, "============== start ==============="); - root_println!( - config.mpi_config, - "Field Type: {:?}", - ::FIELD_TYPE - ); - let circuit_copy_size: usize = match ::FIELD_TYPE { - FieldType::GF2 => 1, - FieldType::M31 => 2, - FieldType::BN254 => 2, - _ => unreachable!(), - }; - root_println!( - config.mpi_config, - "Proving {} keccak instances at once.", - circuit_copy_size * ::get_field_pack_size() - ); - root_println!(config.mpi_config, "Config created."); + (0..NUM_THREADS).into_par_iter().for_each(|_| { + root_println!(config.mpi_config, "============== start ==============="); + root_println!( + config.mpi_config, + "Field Type: {:?}", + ::FIELD_TYPE + ); + let circuit_copy_size: usize = match ::FIELD_TYPE { + FieldType::GF2 => 1, + FieldType::M31 => 2, + FieldType::BN254 => 2, + _ => unreachable!(), + }; + root_println!( + config.mpi_config, + "Proving {} keccak instances at once.", + circuit_copy_size * ::get_field_pack_size() + ); + root_println!(config.mpi_config, "Config created."); + + let circuit_path = match ::FIELD_TYPE { + FieldType::GF2 => "../".to_owned() + KECCAK_GF2_CIRCUIT, + FieldType::M31 => "../".to_owned() + KECCAK_M31_CIRCUIT, + FieldType::BN254 => "../".to_owned() + KECCAK_BN254_CIRCUIT, + _ => unreachable!(), + }; + // todo: move circuit into shared memory + let mut circuit = Circuit::::load_circuit(&circuit_path); + root_println!(config.mpi_config, "Circuit loaded."); - let circuit_path = match ::FIELD_TYPE { - FieldType::GF2 => "../".to_owned() + KECCAK_GF2_CIRCUIT, - FieldType::M31 => "../".to_owned() + KECCAK_M31_CIRCUIT, - FieldType::BN254 => "../".to_owned() + KECCAK_BN254_CIRCUIT, - _ => unreachable!(), - }; - let mut circuit = Circuit::::load_circuit(&circuit_path); - root_println!(config.mpi_config, "Circuit loaded."); + let witness_path = match ::FIELD_TYPE { + FieldType::GF2 => "../".to_owned() + KECCAK_GF2_WITNESS, + FieldType::M31 => "../".to_owned() + KECCAK_M31_WITNESS, + FieldType::BN254 => "../".to_owned() + KECCAK_BN254_WITNESS, + _ => unreachable!(), + }; + circuit.load_witness_file(&witness_path); + root_println!(config.mpi_config, "Witness loaded."); - let witness_path = match ::FIELD_TYPE { - FieldType::GF2 => "../".to_owned() + KECCAK_GF2_WITNESS, - FieldType::M31 => "../".to_owned() + KECCAK_M31_WITNESS, - FieldType::BN254 => "../".to_owned() + KECCAK_BN254_WITNESS, - _ => unreachable!(), - }; - circuit.load_witness_file(&witness_path); - root_println!(config.mpi_config, "Witness loaded."); + circuit.evaluate(); - circuit.evaluate(); - let output = &circuit.layers.last().unwrap().output_vals; - assert!(output[..circuit.expected_num_output_zeros] - .iter() - .all(|f| f.is_zero())); + let output = &circuit.layers.last().unwrap().output_vals; - let mut prover = Prover::new(config); - prover.prepare_mem(&circuit); + assert!(output[..circuit.expected_num_output_zeros] + .iter() + .all(|f| f.is_zero())); - let mut rng = ChaCha12Rng::seed_from_u64(PCS_TESTING_SEED_U64); - let (pcs_params, pcs_proving_key, pcs_verification_key, mut pcs_scratch) = - expander_pcs_init_testing_only::( - circuit.log_input_size(), - &config.mpi_config, - &mut rng, + let mut prover = Prover::new(config); + prover.prepare_mem(&circuit); + + let mut rng = ChaCha12Rng::seed_from_u64(PCS_TESTING_SEED_U64); + let (pcs_params, pcs_proving_key, pcs_verification_key, mut pcs_scratch) = + expander_pcs_init_testing_only::( + circuit.log_input_size(), + &config.mpi_config, + &mut rng, + ); + + root_println!(config.mpi_config, "start proving."); + let proving_start = Instant::now(); + let (claimed_v, proof) = prover.prove( + &mut circuit, + &pcs_params, + &pcs_proving_key, + &mut pcs_scratch, + ); + root_println!( + config.mpi_config, + "Proving time: {} μs", + proving_start.elapsed().as_micros() ); - let proving_start = Instant::now(); - let (claimed_v, proof) = prover.prove( - &mut circuit, - &pcs_params, - &pcs_proving_key, - &mut pcs_scratch, - ); - root_println!( - config.mpi_config, - "Proving time: {} μs", - proving_start.elapsed().as_micros() - ); + root_println!( + config.mpi_config, + "Proof generated. Size: {} bytes", + proof.bytes.len() + ); + root_println!(config.mpi_config,); - root_println!( - config.mpi_config, - "Proof generated. Size: {} bytes", - proof.bytes.len() - ); - root_println!(config.mpi_config,); + root_println!(config.mpi_config, "Proof hash: "); + sha2::Sha256::digest(&proof.bytes) + .iter() + .for_each(|b| print!("{} ", b)); + root_println!(config.mpi_config,); - root_println!(config.mpi_config, "Proof hash: "); - sha2::Sha256::digest(&proof.bytes) - .iter() - .for_each(|b| print!("{} ", b)); - root_println!(config.mpi_config,); + // let mut public_input_gathered = if config.mpi_config.is_root() { + // vec![ + // ::SimdCircuitField::ZERO; + // circuit.public_input.len() * config.mpi_config.world_size() as usize + // ] + // } else { + // vec![] + // }; + // config + // .mpi_config + // .gather_vec(&circuit.public_input, &mut public_input_gathered); - let mut public_input_gathered = if config.mpi_config.is_root() { - vec![ - ::SimdCircuitField::ZERO; - circuit.public_input.len() * config.mpi_config.world_size() - ] - } else { - vec![] - }; - config - .mpi_config - .gather_vec(&circuit.public_input, &mut public_input_gathered); + let start = config.mpi_config.current_thread().size(); + let end =start + circuit.public_input.len() + * ::SimdCircuitField::SERIALIZED_SIZE; + config.mpi_config.append_local_fields(&circuit.public_input); - // Verify - if config.mpi_config.is_root() { - if let Some(str) = write_proof_to { - let mut file = fs::OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(str) - .unwrap(); + let public_input_gathered: Vec<::SimdCircuitField> = + config.mpi_config.read_all_field_flat(start, end); - let mut buf = vec![]; - proof.serialize_into(&mut buf).unwrap(); - file.write_all(&buf).unwrap(); - } - let verifier = Verifier::new(config); - println!("Verifier created."); - let verification_start = Instant::now(); - assert!(verifier.verify( - &mut circuit, - &public_input_gathered, - &claimed_v, - &pcs_params, - &pcs_verification_key, - &proof - )); - println!( - "Verification time: {} μs", - verification_start.elapsed().as_micros() - ); - println!("Correct proof verified."); - let mut bad_proof = proof.clone(); - let rng = &mut rand::thread_rng(); - let random_idx = rng.gen_range(0..bad_proof.bytes.len()); - let random_change = rng.gen_range(1..256) as u8; - bad_proof.bytes[random_idx] ^= random_change; + // Verify + if config.mpi_config.is_root() { + if let Some(str) = write_proof_to { + let mut file = fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(str) + .unwrap(); - // Catch the panic and treat it as returning `false` - let result = panic::catch_unwind(AssertUnwindSafe(|| { - verifier.verify( + let mut buf = vec![]; + proof.serialize_into(&mut buf).unwrap(); + file.write_all(&buf).unwrap(); + } + let verifier = Verifier::new(config); + println!("Verifier created."); + let verification_start = Instant::now(); + assert!(verifier.verify( &mut circuit, &public_input_gathered, &claimed_v, &pcs_params, &pcs_verification_key, - &bad_proof, - ) - })); + &proof + )); + println!( + "Verification time: {} μs", + verification_start.elapsed().as_micros() + ); + println!("Correct proof verified."); + let mut bad_proof = proof.clone(); + let rng = &mut rand::thread_rng(); + let random_idx = rng.gen_range(0..bad_proof.bytes.len()); + let random_change = rng.gen_range(1..256) as u8; + bad_proof.bytes[random_idx] ^= random_change; - let final_result = result.unwrap_or_default(); + // Catch the panic and treat it as returning `false` + let result = panic::catch_unwind(AssertUnwindSafe(|| { + verifier.verify( + &mut circuit, + &public_input_gathered, + &claimed_v, + &pcs_params, + &pcs_verification_key, + &bad_proof, + ) + })); - assert!(!final_result,); - println!("Bad proof rejected."); - println!("============== end ==============="); - } + let final_result = result.unwrap_or_default(); + + assert!(!final_result,); + println!("Bad proof rejected."); + println!("============== end ==============="); + } + }); } diff --git a/gkr/src/verifier.rs b/gkr/src/verifier.rs index 77f6a3e0..5a75cd17 100644 --- a/gkr/src/verifier.rs +++ b/gkr/src/verifier.rs @@ -177,7 +177,7 @@ pub fn gkr_verify>( Option, ) { let timer = start_timer!(|| "gkr verify"); - let mut sp = VerifierScratchPad::::new(circuit, mpi_config.world_size()); + let mut sp = VerifierScratchPad::::new(circuit, mpi_config.world_size() as usize); let layer_num = circuit.layers.len(); let mut rz0 = vec![]; diff --git a/poly_commit/Cargo.toml b/poly_commit/Cargo.toml index a17872d3..58306dec 100644 --- a/poly_commit/Cargo.toml +++ b/poly_commit/Cargo.toml @@ -10,5 +10,10 @@ mpi_config = { path = "../config/mpi_config" } polynomials = { path = "../arith/polynomials"} transcript = { path = "../transcript" } -rand.workspace = true ethnum.workspace = true +rand.workspace = true +rayon.workspace = true + + +[dev-dependencies] +serial_test.workspace = true \ No newline at end of file diff --git a/poly_commit/src/raw.rs b/poly_commit/src/raw.rs index 7d68bfd5..4cdd5a28 100644 --- a/poly_commit/src/raw.rs +++ b/poly_commit/src/raw.rs @@ -6,7 +6,7 @@ use crate::{ use arith::{BN254Fr, Field, FieldForECC, FieldSerde, FieldSerdeResult, SimdField}; use ethnum::U256; use gkr_field_config::GKRFieldConfig; -use mpi_config::MPIConfig; +use mpi_config::{thread_println, MPIConfig}; use polynomials::{MultiLinearPoly, MultilinearExtension}; use rand::RngCore; use transcript::Transcript; @@ -196,14 +196,15 @@ impl> PCSForExpanderGKR> PCSForExpanderGKR::VKey, commitment: &Self::Commitment, x: &ExpanderGKRChallenge, @@ -231,7 +232,6 @@ impl> PCSForExpanderGKR bool { - assert!(mpi_config.is_root()); // Only the root will verify let ExpanderGKRChallenge:: { x, x_simd, x_mpi } = x; Self::eval(&commitment.evals, x, x_simd, x_mpi) == v } diff --git a/poly_commit/src/traits.rs b/poly_commit/src/traits.rs index 507ff32b..a698597a 100644 --- a/poly_commit/src/traits.rs +++ b/poly_commit/src/traits.rs @@ -7,8 +7,8 @@ use std::fmt::Debug; use transcript::Transcript; pub trait StructuredReferenceString { - type PKey: Clone + Debug + FieldSerde + Send; - type VKey: Clone + Debug + FieldSerde + Send; + type PKey: Clone + Debug + FieldSerde + Send + Sync; + type VKey: Clone + Debug + FieldSerde + Send + Sync; /// Convert the SRS into proving and verifying keys. /// Comsuming self by default. @@ -72,7 +72,7 @@ pub struct ExpanderGKRChallenge { pub trait PCSForExpanderGKR> { const NAME: &'static str; - type Params: Clone + Debug + Default + Send; + type Params: Clone + Debug + Default + Send + Sync; type ScratchPad: Clone + Debug + Default + Send; type SRS: Clone + Debug + Default + FieldSerde + StructuredReferenceString; diff --git a/poly_commit/tests/common.rs b/poly_commit/tests/common.rs index ada8f1da..9e80aa1b 100644 --- a/poly_commit/tests/common.rs +++ b/poly_commit/tests/common.rs @@ -5,8 +5,9 @@ use poly_commit::raw::RawExpanderGKR; use poly_commit::{ ExpanderGKRChallenge, PCSForExpanderGKR, PolynomialCommitmentScheme, StructuredReferenceString, }; -use polynomials::MultilinearExtension; +use polynomials::{MultilinearExtension, RefMultiLinearPoly}; use rand::thread_rng; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use transcript::Transcript; pub fn test_pcs>( @@ -39,55 +40,79 @@ pub fn test_gkr_pcs< T: Transcript, P: PCSForExpanderGKR, >( - params: &P::Params, mpi_config: &MPIConfig, - transcript: &mut T, - poly: &impl MultilinearExtension, - xs: &[ExpanderGKRChallenge], + n_local_vars: usize, ) { let mut rng = thread_rng(); - let srs = P::gen_srs_for_testing(params, mpi_config, &mut rng); + let params = P::gen_params(n_local_vars); + + let srs = P::gen_srs_for_testing(¶ms, mpi_config, &mut rng); let (proving_key, verification_key) = srs.into_keys(); - let mut scratch_pad = P::init_scratch_pad(params, mpi_config); - let commitment = P::commit(params, mpi_config, &proving_key, poly, &mut scratch_pad); + let num_threads = rayon::current_num_threads(); - // PCSForExpanderGKR does not require an evaluation value for the opening function - // We use RawExpanderGKR as the golden standard for the evaluation value - // Note this test will almost always pass for RawExpanderGKR, so make sure it is correct - let mut coeffs_gathered = if mpi_config.is_root() { - vec![C::SimdCircuitField::ZERO; poly.hypercube_size() * mpi_config.world_size()] - } else { - vec![] - }; - mpi_config.gather_vec(poly.hypercube_basis_ref(), &mut coeffs_gathered); + (0..num_threads).into_par_iter().for_each(|_| { + let mut rng = thread_rng(); + let hypercube_basis = (0..(1 << n_local_vars)) + .map(|_| C::SimdCircuitField::random_unsafe(&mut rng)) + .collect(); + let poly = RefMultiLinearPoly::from_ref(&hypercube_basis); - for xx in xs { - let ExpanderGKRChallenge { x, x_simd, x_mpi } = xx; - let opening = P::open( - params, - mpi_config, - &proving_key, - poly, - xx, - transcript, - &mut scratch_pad, - ); + let xs = (0..2) + .map(|_| ExpanderGKRChallenge:: { + x: (0..n_local_vars) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect::>(), + x_simd: (0..C::get_field_pack_size().trailing_zeros()) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect::>(), + x_mpi: (0..mpi_config.world_size().trailing_zeros()) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect::>(), + }) + .collect::>>(); + + let mut scratch_pad = P::init_scratch_pad(¶ms, mpi_config); + + let commitment = P::commit(¶ms, mpi_config, &proving_key, &poly, &mut scratch_pad); + let mut transcript = T::new(); + + // PCSForExpanderGKR does not require an evaluation value for the opening function + // We use RawExpanderGKR as the golden standard for the evaluation value + // Note this test will almost always pass for RawExpanderGKR, so make sure it is correct + let start = mpi_config.current_size(); + let end = start + poly.serialized_size(); + + poly.hypercube_basis_ref() + .iter() + .for_each(|f| mpi_config.append_local_field(f)); + let coeffs_gathered = mpi_config.read_all_field_flat(start, end); + + for xx in xs { + let ExpanderGKRChallenge { x, x_simd, x_mpi } = &xx; + let opening = P::open( + ¶ms, + mpi_config, + &proving_key, + &poly, + &xx, + &mut transcript, + &mut scratch_pad, + ); - if mpi_config.is_root() { // this will always pass for RawExpanderGKR, so make sure it is correct - let v = RawExpanderGKR::::eval(&coeffs_gathered, x, x_simd, x_mpi); + let v = RawExpanderGKR::::eval(&coeffs_gathered, &x, &x_simd, &x_mpi); assert!(P::verify( - params, + ¶ms, mpi_config, &verification_key, &commitment, - xx, + &xx, v, - transcript, + &mut transcript, &opening )); } - } + }); } diff --git a/poly_commit/tests/test_raw.rs b/poly_commit/tests/test_raw.rs index 6818c351..ee1b8c0c 100644 --- a/poly_commit/tests/test_raw.rs +++ b/poly_commit/tests/test_raw.rs @@ -1,15 +1,16 @@ mod common; +use std::sync::Arc; + use arith::{BN254Fr, Field}; +use common::test_gkr_pcs; use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mpi_config::MPIConfig; -use poly_commit::{ - raw::{RawExpanderGKR, RawExpanderGKRParams, RawMultiLinearPCS, RawMultiLinearParams}, - ExpanderGKRChallenge, -}; -use polynomials::{MultiLinearPoly, RefMultiLinearPoly}; +use poly_commit::raw::{RawExpanderGKR, RawMultiLinearPCS, RawMultiLinearParams}; +use polynomials::MultiLinearPoly; use rand::thread_rng; -use transcript::{BytesHashTranscript, SHA256hasher, Transcript}; +use serial_test::serial; +use transcript::{BytesHashTranscript, SHA256hasher}; #[test] fn test_raw() { @@ -27,45 +28,24 @@ fn test_raw() { common::test_pcs::(¶ms, &poly, &xs); } -fn test_raw_gkr_helper>( - mpi_config: &MPIConfig, - transcript: &mut T, -) { - let params = RawExpanderGKRParams { n_local_vars: 8 }; - let mut rng = thread_rng(); - let hypercube_basis = (0..(1 << params.n_local_vars)) - .map(|_| C::SimdCircuitField::random_unsafe(&mut rng)) - .collect(); - let poly = RefMultiLinearPoly::from_ref(&hypercube_basis); - let xs = (0..100) - .map(|_| ExpanderGKRChallenge:: { - x: (0..params.n_local_vars) - .map(|_| C::ChallengeField::random_unsafe(&mut rng)) - .collect::>(), - x_simd: (0..C::get_field_pack_size().trailing_zeros()) - .map(|_| C::ChallengeField::random_unsafe(&mut rng)) - .collect::>(), - x_mpi: (0..mpi_config.world_size().trailing_zeros()) - .map(|_| C::ChallengeField::random_unsafe(&mut rng)) - .collect::>(), - }) - .collect::>>(); - common::test_gkr_pcs::>(¶ms, mpi_config, transcript, &poly, &xs); -} - #[test] +#[serial] fn test_raw_gkr() { - let mpi_config = MPIConfig::new(); + let global_data: Arc<[u8]> = Arc::from((0..1024).map(|i| i as u8).collect::>()); + let num_threads = rayon::current_num_threads(); + + // Create configs for all threads + let mpi_config = MPIConfig::new(num_threads as i32, global_data, 1024 * 1024); type TM31 = BytesHashTranscript<::ChallengeField, SHA256hasher>; - test_raw_gkr_helper::(&mpi_config, &mut TM31::new()); + test_gkr_pcs::>(&mpi_config, 8); type TGF2 = BytesHashTranscript<::ChallengeField, SHA256hasher>; - test_raw_gkr_helper::(&mpi_config, &mut TGF2::new()); + test_gkr_pcs::>(&mpi_config, 8); type TBN254 = BytesHashTranscript<::ChallengeField, SHA256hasher>; - test_raw_gkr_helper::(&mpi_config, &mut TBN254::new()); + test_gkr_pcs::>(&mpi_config, 8); - MPIConfig::finalize(); + mpi_config.finalize(); } diff --git a/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs b/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs index 0de24734..cdcfe7a9 100644 --- a/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs +++ b/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs @@ -328,20 +328,29 @@ impl<'a, C: GKRFieldConfig> SumcheckGkrVanillaHelper<'a, C> { #[inline] pub(crate) fn prepare_mpi_var_vals(&mut self) { - self.mpi_config.gather_vec( - &vec![self.sp.simd_var_v_evals[0]], - &mut self.sp.mpi_var_v_evals, - ); - self.mpi_config.gather_vec( - &vec![self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0]], - &mut self.sp.mpi_var_hg_evals, - ); + // assert!(self.mpi_config.is_current_thread_synced()); + + let start = self.mpi_config.current_size(); + self.mpi_config + .append_local_field(&self.sp.simd_var_v_evals[0]); + let end = self.mpi_config.current_size(); + self.sp.mpi_var_v_evals = self + .mpi_config + .read_all_field_flat::(start, end); + + let start = self.mpi_config.current_size(); + self.mpi_config + .append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0])); + let end = self.mpi_config.current_size(); + self.sp.mpi_var_hg_evals = self + .mpi_config + .read_all_field_flat::(start, end); } #[inline] pub(crate) fn prepare_y_vals(&mut self) { - let mut v_rx_rsimd_rw = self.sp.mpi_var_v_evals[0]; - self.mpi_config.root_broadcast_f(&mut v_rx_rsimd_rw); + let v_rx_rsimd_rw = self.sp.mpi_var_v_evals[0]; + // self.mpi_config.root_broadcast_f(&mut v_rx_rsimd_rw); let mul = &self.layer.mul; let eq_evals_at_rz0 = &self.sp.eq_evals_at_rz0; diff --git a/sumcheck/src/sumcheck.rs b/sumcheck/src/sumcheck.rs index 9ad374c8..f81389ea 100644 --- a/sumcheck/src/sumcheck.rs +++ b/sumcheck/src/sumcheck.rs @@ -49,21 +49,30 @@ pub fn sumcheck_prove_gkr_layer(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_rx(i_var, r); } helper.prepare_simd_var_vals(); for i_var in 0..helper.simd_var_num { let evals = helper.poly_evals_at_r_simd_var(i_var, 3); - let r = transcript_io::(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_r_simd_var(i_var, r); } helper.prepare_mpi_var_vals(); for i_var in 0..mpi_config.world_size().trailing_zeros() as usize { let evals = helper.poly_evals_at_r_mpi_var(i_var, 3); - let r = transcript_io::(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_r_mpi_var(i_var, r); } @@ -75,7 +84,10 @@ pub fn sumcheck_prove_gkr_layer(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_ry(i_var, r); } let vy_claim = helper.vy_claim(); diff --git a/sumcheck/src/utils.rs b/sumcheck/src/utils.rs index 37398a6d..e4d530ea 100644 --- a/sumcheck/src/utils.rs +++ b/sumcheck/src/utils.rs @@ -1,5 +1,4 @@ use arith::{Field, SimdField}; -use mpi_config::MPIConfig; use transcript::Transcript; // #[inline(always)] @@ -24,8 +23,15 @@ pub fn unpack_and_combine(p: &F, coef: &[F::Scalar]) -> F::Scalar } /// Transcript IO between sumcheck steps +/// +/// The thread will push the generated challenge field element to its local memory. +/// The caller is responsible for syncing up this field element. #[inline] -pub fn transcript_io(mpi_config: &MPIConfig, ps: &[F], transcript: &mut T) -> F +pub fn transcript_io( + // mpi_config: &MPIConfig, + ps: &[F], + transcript: &mut T, +) -> F where F: Field, T: Transcript, @@ -34,7 +40,5 @@ where for p in ps { transcript.append_field_element(p); } - let mut r = transcript.generate_challenge_field_element(); - mpi_config.root_broadcast_f(&mut r); - r + transcript.generate_challenge_field_element() } diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index dcbdfaa9..0091ea3b 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -4,8 +4,8 @@ pub use fiat_shamir_hash::{FiatShamirBytesHash, Keccak256hasher, SHA256hasher}; mod transcript; pub use transcript::{BytesHashTranscript, FieldHashTranscript, Transcript}; -mod transcript_utils; -pub use transcript_utils::transcript_root_broadcast; +// mod transcript_utils; +// pub use transcript_utils::transcript_sync_up; mod proof; pub use proof::Proof; diff --git a/transcript/src/transcript_utils.rs b/transcript/src/transcript_utils.rs index bee14e21..39d7a907 100644 --- a/transcript/src/transcript_utils.rs +++ b/transcript/src/transcript_utils.rs @@ -1,16 +1,18 @@ -use crate::Transcript; -use arith::Field; -use mpi_config::MPIConfig; +// use crate::Transcript; +// use arith::Field; +// use mpi_config::MPIConfig; -/// broadcast root transcript state. incurs an additional hash if self.world_size > 1 -pub fn transcript_root_broadcast(transcript: &mut T, mpi_config: &MPIConfig) -where - F: Field, - T: Transcript, -{ - if mpi_config.world_size > 1 { - let mut state = transcript.hash_and_return_state(); - mpi_config.root_broadcast_bytes(&mut state); - transcript.set_state(&state); - } -} +// /// broadcast root transcript state. incurs an additional hash if self.world_size > 1 +// pub fn transcript_sync_up(transcript: &mut T, mpi_config: &MPIConfig) +// where +// F: Field, +// T: Transcript, +// { +// if mpi_config.world_size > 1 { +// let payload = MPIConfig::sync + +// let mut state = transcript.hash_and_return_state(); +// mpi_config.root_broadcast_bytes(&mut state); +// transcript.set_state(&state); +// } +// }