Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Jan 7, 2025
1 parent 8352719 commit 72279a5
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 22 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions config/mpi_config/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ edition = "2021"
[dependencies]
rayon.workspace = true

arith = { path = "../../arith" }

[dev-dependencies]
serial_test.workspace = true
92 changes: 90 additions & 2 deletions config/mpi_config/src/mpi_config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::sync::Arc;

use arith::Field;

use crate::{ThreadConfig, MAX_WAIT_CYCLES};

/// Configuration for MPI
Expand Down Expand Up @@ -61,13 +63,26 @@ impl MPIConfig {
rayon::current_thread_index().unwrap() == 0
}

/// Sync with all threads' local memory by waiting until there is new data to read from all
#[inline]
/// Get the current thread
pub fn current_thread(&self) -> &ThreadConfig {
let index = rayon::current_thread_index().unwrap();
&self.threads[index]
}

#[inline]
/// Get the size of the current local memory
pub fn current_size(&self) -> usize {
self.current_thread().size()
}

/// Read all threads' local memory by waiting until there is new data to read from all
/// threads.
/// Returns a vector of slices, one for each thread's new data
///
/// The threads are synchronized by the caller; within each period of time, all
/// threads write a same amount of data
pub fn sync(&self, start: usize, end: usize) -> Vec<&[u8]> {
pub fn read_all(&self, start: usize, end: usize) -> Vec<&[u8]> {
let total = self.threads.len();
let mut pending = (0..total).collect::<Vec<_>>();
let mut results: Vec<&[u8]> = vec![&[]; total];
Expand Down Expand Up @@ -102,4 +117,77 @@ impl MPIConfig {
}
results
}

#[inline]
// todo: add a field buffer to the thread config so we can avoid field (de)serialization
pub fn read_all_field<F: Field>(&self, start: usize, end: usize) -> Vec<Vec<F>> {
let data = self.read_all(start, end);
data.iter()
.map(|x| {
x.chunks(F::SIZE)
.map(|y| F::deserialize_from(y).unwrap())
.collect()
})
.collect()
}

#[inline]
pub fn read_all_field_flat<F: Field>(&self, start: usize, end: usize) -> Vec<F> {
let data = self.read_all(start, end);
data.iter()
.flat_map(|x| {
x.chunks(F::SIZE)
.map(|y| F::deserialize_from(y).unwrap())
})
.collect()
}

#[inline]
/// Append data to the current thread's local memory
pub fn append_local(&self, data: &[u8]) {
let thread = self.current_thread();
thread.append(data).expect("Failed to append");
}

#[inline]
/// Append data to the current thread's local memory
pub fn append_local_field<F: Field>(&self, f: &F) {
let mut data = vec![];
f.serialize_into(&mut data).unwrap();
self.append_local(&data);
}

/// coefficient has a length of mpi_world_size
#[inline]
pub fn coef_combine_vec<F: Field>(&self, local_vec: &Vec<F>, coefficient: &[F]) -> Vec<F> {
if self.world_size == 1 {
// Warning: literally, it should be coefficient[0] * local_vec
// but coefficient[0] is always one in our use case of self.world_size = 1
local_vec.clone()
} else {
// write local vector to the buffer, then sync up all threads
let start = self.current_thread().size();
let data = local_vec
.iter()
.flat_map(|&x| {
let mut buf = vec![];
x.serialize_into(&mut buf).unwrap();
buf
})
.collect::<Vec<u8>>();
self.append_local(&data);
let end = self.current_thread().size();
let all_fields = self.read_all_field::<F>(start, end);

// build the result via linear combination
let mut result = vec![F::zero(); local_vec.len()];
for i in 0..local_vec.len() {
for j in 0..(self.world_size as usize) {
result[i] += all_fields[j][i] * coefficient[j];
}
}

result
}
}
}
4 changes: 2 additions & 2 deletions config/mpi_config/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ fn test_cross_thread_communication() {

thread.append(&data).expect("Failed to append");

let results = mpi_config.sync(start, end);
let results = mpi_config.read_all(start, end);
assert_eq!(results.len(), num_threads as usize);

for (i, result) in results.iter().enumerate() {
Expand Down Expand Up @@ -151,7 +151,7 @@ fn test_incremental_updates() {

thread.append(&data).expect("Failed to append");

let results = mpi_config.sync(start, end);
let results = mpi_config.read_all(start, end);
assert_eq!(results.len(), num_threads as usize);

println!("Thread {} iteration {}: {:?}", rank, i, results);
Expand Down
4 changes: 4 additions & 0 deletions config/mpi_config/src/thread_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use crate::AtomicVec;
/// 3. All threads have the same global memory
/// 4. IMPORTANT!!! The threads are synchronized by the caller; within each period of time, all
/// threads write a same amount of data
///
/// A further optimization (TODO):
/// - we can have a buffer for bytes, and a buffer for field elements, this should avoid the need of
/// field (de)serializations between threads
#[derive(Debug, Clone)]
pub struct ThreadConfig {
pub world_rank: i32, // indexer for the thread
Expand Down
2 changes: 1 addition & 1 deletion poly_commit/src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ impl<C: GKRFieldConfig, T: Transcript<C::ChallengeField>> PCSForExpanderGKR<C, T
// read the last poly.hypercube_size() from each thread
let end = mpi_config.threads[0].size();
let start = end - poly.hypercube_size();
let payloads = mpi_config.sync(start, end); // read #thread payloads
let payloads = mpi_config.read_all(start, end); // read #thread payloads

payloads
.iter()
Expand Down
19 changes: 10 additions & 9 deletions sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,21 @@ impl<'a, C: GKRFieldConfig> SumcheckGkrVanillaHelper<'a, C> {

#[inline]
pub(crate) fn prepare_mpi_var_vals(&mut self) {
self.mpi_config.gather_vec(
&vec![self.sp.simd_var_v_evals[0]],
&mut self.sp.mpi_var_v_evals,
);
self.mpi_config.gather_vec(
&vec![self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0]],
&mut self.sp.mpi_var_hg_evals,
);
let start = self.mpi_config.current_size();
self.mpi_config.append_local_field(&self.sp.simd_var_v_evals[0]);
let end = self.mpi_config.current_size();
self.sp.mpi_var_v_evals = self.mpi_config.read_all_field_flat::<C::ChallengeField>(start, end);

let start = self.mpi_config.current_size();
self.mpi_config.append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0]));
let end = self.mpi_config.current_size();
self.sp.mpi_var_hg_evals = self.mpi_config.read_all_field_flat::<C::ChallengeField>(start, end);
}

#[inline]
pub(crate) fn prepare_y_vals(&mut self) {
let mut v_rx_rsimd_rw = self.sp.mpi_var_v_evals[0];
self.mpi_config.root_broadcast_f(&mut v_rx_rsimd_rw);
// self.mpi_config.root_broadcast_f(&mut v_rx_rsimd_rw);

let mul = &self.layer.mul;
let eq_evals_at_rz0 = &self.sp.eq_evals_at_rz0;
Expand Down
20 changes: 16 additions & 4 deletions sumcheck/src/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,30 @@ pub fn sumcheck_prove_gkr_layer<C: GKRFieldConfig, T: Transcript<C::ChallengeFie
helper.prepare_x_vals();
for i_var in 0..helper.input_var_num {
let evals = helper.poly_evals_at_rx(i_var, 2);
let r = transcript_io::<C::ChallengeField, T>(mpi_config, &evals, transcript);
let r = transcript_io::<C::ChallengeField, T>(
//mpi_config,
&evals, transcript,
);
helper.receive_rx(i_var, r);
}

helper.prepare_simd_var_vals();
for i_var in 0..helper.simd_var_num {
let evals = helper.poly_evals_at_r_simd_var(i_var, 3);
let r = transcript_io::<C::ChallengeField, T>(mpi_config, &evals, transcript);
let r = transcript_io::<C::ChallengeField, T>(
//mpi_config,
&evals, transcript,
);
helper.receive_r_simd_var(i_var, r);
}

helper.prepare_mpi_var_vals();
for i_var in 0..mpi_config.world_size().trailing_zeros() as usize {
let evals = helper.poly_evals_at_r_mpi_var(i_var, 3);
let r = transcript_io::<C::ChallengeField, T>(mpi_config, &evals, transcript);
let r = transcript_io::<C::ChallengeField, T>(
//mpi_config,
&evals, transcript,
);
helper.receive_r_mpi_var(i_var, r);
}

Expand All @@ -75,7 +84,10 @@ pub fn sumcheck_prove_gkr_layer<C: GKRFieldConfig, T: Transcript<C::ChallengeFie
helper.prepare_y_vals();
for i_var in 0..helper.input_var_num {
let evals = helper.poly_evals_at_ry(i_var, 2);
let r = transcript_io::<C::ChallengeField, T>(mpi_config, &evals, transcript);
let r = transcript_io::<C::ChallengeField, T>(
//mpi_config,
&evals, transcript,
);
helper.receive_ry(i_var, r);
}
let vy_claim = helper.vy_claim();
Expand Down
16 changes: 12 additions & 4 deletions sumcheck/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::vec;

use arith::{Field, SimdField};
use mpi_config::MPIConfig;
use transcript::Transcript;
Expand Down Expand Up @@ -25,7 +27,11 @@ pub fn unpack_and_combine<F: SimdField>(p: &F, coef: &[F::Scalar]) -> F::Scalar

/// Transcript IO between sumcheck steps
#[inline]
pub fn transcript_io<F, T>(mpi_config: &MPIConfig, ps: &[F], transcript: &mut T) -> F
pub fn transcript_io<F, T>(
// mpi_config: &MPIConfig,
ps: &[F],
transcript: &mut T,
) -> F
where
F: Field,
T: Transcript<F>,
Expand All @@ -34,7 +40,9 @@ where
for p in ps {
transcript.append_field_element(p);
}
let mut r = transcript.generate_challenge_field_element();
mpi_config.root_broadcast_f(&mut r);
r
// let mut r =
transcript.generate_challenge_field_element()

// mpi_config.root_broadcast_f(&mut r);
// r
}

0 comments on commit 72279a5

Please sign in to comment.