diff --git a/Cargo.lock b/Cargo.lock index ad285dd9..092d816e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1650,6 +1650,7 @@ dependencies = [ name = "mpi_config" version = "0.1.0" dependencies = [ + "arith", "rayon", "serial_test", ] diff --git a/config/mpi_config/Cargo.toml b/config/mpi_config/Cargo.toml index 7646c664..f7ef9742 100644 --- a/config/mpi_config/Cargo.toml +++ b/config/mpi_config/Cargo.toml @@ -6,5 +6,7 @@ edition = "2021" [dependencies] rayon.workspace = true +arith = { path = "../../arith" } + [dev-dependencies] serial_test.workspace = true \ No newline at end of file diff --git a/config/mpi_config/src/mpi_config.rs b/config/mpi_config/src/mpi_config.rs index e9b726a2..e4bc7182 100644 --- a/config/mpi_config/src/mpi_config.rs +++ b/config/mpi_config/src/mpi_config.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use arith::Field; + use crate::{ThreadConfig, MAX_WAIT_CYCLES}; /// Configuration for MPI @@ -61,13 +63,26 @@ impl MPIConfig { rayon::current_thread_index().unwrap() == 0 } - /// Sync with all threads' local memory by waiting until there is new data to read from all + #[inline] + /// Get the current thread + pub fn current_thread(&self) -> &ThreadConfig { + let index = rayon::current_thread_index().unwrap(); + &self.threads[index] + } + + #[inline] + /// Get the size of the current local memory + pub fn current_size(&self) -> usize { + self.current_thread().size() + } + + /// Read all threads' local memory by waiting until there is new data to read from all /// threads. /// Returns a vector of slices, one for each thread's new data /// /// The threads are synchronized by the caller; within each period of time, all /// threads write a same amount of data - pub fn sync(&self, start: usize, end: usize) -> Vec<&[u8]> { + pub fn read_all(&self, start: usize, end: usize) -> Vec<&[u8]> { let total = self.threads.len(); let mut pending = (0..total).collect::>(); let mut results: Vec<&[u8]> = vec![&[]; total]; @@ -102,4 +117,77 @@ impl MPIConfig { } results } + + #[inline] + // todo: add a field buffer to the thread config so we can avoid field (de)serialization + pub fn read_all_field(&self, start: usize, end: usize) -> Vec> { + let data = self.read_all(start, end); + data.iter() + .map(|x| { + x.chunks(F::SIZE) + .map(|y| F::deserialize_from(y).unwrap()) + .collect() + }) + .collect() + } + + #[inline] + pub fn read_all_field_flat(&self, start: usize, end: usize) -> Vec { + let data = self.read_all(start, end); + data.iter() + .flat_map(|x| { + x.chunks(F::SIZE) + .map(|y| F::deserialize_from(y).unwrap()) + }) + .collect() + } + + #[inline] + /// Append data to the current thread's local memory + pub fn append_local(&self, data: &[u8]) { + let thread = self.current_thread(); + thread.append(data).expect("Failed to append"); + } + + #[inline] + /// Append data to the current thread's local memory + pub fn append_local_field(&self, f: &F) { + let mut data = vec![]; + f.serialize_into(&mut data).unwrap(); + self.append_local(&data); + } + + /// coefficient has a length of mpi_world_size + #[inline] + pub fn coef_combine_vec(&self, local_vec: &Vec, coefficient: &[F]) -> Vec { + if self.world_size == 1 { + // Warning: literally, it should be coefficient[0] * local_vec + // but coefficient[0] is always one in our use case of self.world_size = 1 + local_vec.clone() + } else { + // write local vector to the buffer, then sync up all threads + let start = self.current_thread().size(); + let data = local_vec + .iter() + .flat_map(|&x| { + let mut buf = vec![]; + x.serialize_into(&mut buf).unwrap(); + buf + }) + .collect::>(); + self.append_local(&data); + let end = self.current_thread().size(); + let all_fields = self.read_all_field::(start, end); + + // build the result via linear combination + let mut result = vec![F::zero(); local_vec.len()]; + for i in 0..local_vec.len() { + for j in 0..(self.world_size as usize) { + result[i] += all_fields[j][i] * coefficient[j]; + } + } + + result + } + } } diff --git a/config/mpi_config/src/tests.rs b/config/mpi_config/src/tests.rs index 00a1cbd7..6d6b875b 100644 --- a/config/mpi_config/src/tests.rs +++ b/config/mpi_config/src/tests.rs @@ -114,7 +114,7 @@ fn test_cross_thread_communication() { thread.append(&data).expect("Failed to append"); - let results = mpi_config.sync(start, end); + let results = mpi_config.read_all(start, end); assert_eq!(results.len(), num_threads as usize); for (i, result) in results.iter().enumerate() { @@ -151,7 +151,7 @@ fn test_incremental_updates() { thread.append(&data).expect("Failed to append"); - let results = mpi_config.sync(start, end); + let results = mpi_config.read_all(start, end); assert_eq!(results.len(), num_threads as usize); println!("Thread {} iteration {}: {:?}", rank, i, results); diff --git a/config/mpi_config/src/thread_config.rs b/config/mpi_config/src/thread_config.rs index 64b94ebb..50226948 100644 --- a/config/mpi_config/src/thread_config.rs +++ b/config/mpi_config/src/thread_config.rs @@ -9,6 +9,10 @@ use crate::AtomicVec; /// 3. All threads have the same global memory /// 4. IMPORTANT!!! The threads are synchronized by the caller; within each period of time, all /// threads write a same amount of data +/// +/// A further optimization (TODO): +/// - we can have a buffer for bytes, and a buffer for field elements, this should avoid the need of +/// field (de)serializations between threads #[derive(Debug, Clone)] pub struct ThreadConfig { pub world_rank: i32, // indexer for the thread diff --git a/poly_commit/src/raw.rs b/poly_commit/src/raw.rs index 80d5ccaf..4332aa91 100644 --- a/poly_commit/src/raw.rs +++ b/poly_commit/src/raw.rs @@ -199,7 +199,7 @@ impl> PCSForExpanderGKR SumcheckGkrVanillaHelper<'a, C> { #[inline] pub(crate) fn prepare_mpi_var_vals(&mut self) { - self.mpi_config.gather_vec( - &vec![self.sp.simd_var_v_evals[0]], - &mut self.sp.mpi_var_v_evals, - ); - self.mpi_config.gather_vec( - &vec![self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0]], - &mut self.sp.mpi_var_hg_evals, - ); + let start = self.mpi_config.current_size(); + self.mpi_config.append_local_field(&self.sp.simd_var_v_evals[0]); + let end = self.mpi_config.current_size(); + self.sp.mpi_var_v_evals = self.mpi_config.read_all_field_flat::(start, end); + + let start = self.mpi_config.current_size(); + self.mpi_config.append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0])); + let end = self.mpi_config.current_size(); + self.sp.mpi_var_hg_evals = self.mpi_config.read_all_field_flat::(start, end); } #[inline] pub(crate) fn prepare_y_vals(&mut self) { let mut v_rx_rsimd_rw = self.sp.mpi_var_v_evals[0]; - self.mpi_config.root_broadcast_f(&mut v_rx_rsimd_rw); + // self.mpi_config.root_broadcast_f(&mut v_rx_rsimd_rw); let mul = &self.layer.mul; let eq_evals_at_rz0 = &self.sp.eq_evals_at_rz0; diff --git a/sumcheck/src/sumcheck.rs b/sumcheck/src/sumcheck.rs index 9ad374c8..f81389ea 100644 --- a/sumcheck/src/sumcheck.rs +++ b/sumcheck/src/sumcheck.rs @@ -49,21 +49,30 @@ pub fn sumcheck_prove_gkr_layer(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_rx(i_var, r); } helper.prepare_simd_var_vals(); for i_var in 0..helper.simd_var_num { let evals = helper.poly_evals_at_r_simd_var(i_var, 3); - let r = transcript_io::(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_r_simd_var(i_var, r); } helper.prepare_mpi_var_vals(); for i_var in 0..mpi_config.world_size().trailing_zeros() as usize { let evals = helper.poly_evals_at_r_mpi_var(i_var, 3); - let r = transcript_io::(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_r_mpi_var(i_var, r); } @@ -75,7 +84,10 @@ pub fn sumcheck_prove_gkr_layer(mpi_config, &evals, transcript); + let r = transcript_io::( + //mpi_config, + &evals, transcript, + ); helper.receive_ry(i_var, r); } let vy_claim = helper.vy_claim(); diff --git a/sumcheck/src/utils.rs b/sumcheck/src/utils.rs index 37398a6d..8f69a7d9 100644 --- a/sumcheck/src/utils.rs +++ b/sumcheck/src/utils.rs @@ -1,3 +1,5 @@ +use std::vec; + use arith::{Field, SimdField}; use mpi_config::MPIConfig; use transcript::Transcript; @@ -25,7 +27,11 @@ pub fn unpack_and_combine(p: &F, coef: &[F::Scalar]) -> F::Scalar /// Transcript IO between sumcheck steps #[inline] -pub fn transcript_io(mpi_config: &MPIConfig, ps: &[F], transcript: &mut T) -> F +pub fn transcript_io( + // mpi_config: &MPIConfig, + ps: &[F], + transcript: &mut T, +) -> F where F: Field, T: Transcript, @@ -34,7 +40,9 @@ where for p in ps { transcript.append_field_element(p); } - let mut r = transcript.generate_challenge_field_element(); - mpi_config.root_broadcast_f(&mut r); - r + // let mut r = + transcript.generate_challenge_field_element() + + // mpi_config.root_broadcast_f(&mut r); + // r }