Skip to content

Commit

Permalink
Enable 'rng.gen()'ing field types (#591)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored May 1, 2024
1 parent 0533122 commit 89b4fcd
Show file tree
Hide file tree
Showing 17 changed files with 97 additions and 273 deletions.
38 changes: 0 additions & 38 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ itertools.workspace = true
num-traits.workspace = true
thiserror.workspace = true
bytemuck = { workspace = true, features = ["derive"] }
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
tracing.workspace = true

[dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
rand = { version = "0.8.5", features = ["small_rng"] }
test-log = { version = "0.2.15", features = ["trace"] }
tracing-subscriber = "0.3.18"

Expand Down
36 changes: 8 additions & 28 deletions crates/prover/benches/eval_at_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use criterion::{black_box, Criterion};

#[cfg(target_arch = "x86_64")]
pub fn cpu_eval_at_secure_point(c: &mut criterion::Criterion) {
use rand::rngs::StdRng;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::CPUBackend;
use stwo_prover::core::circle::CirclePoint;
Expand All @@ -11,7 +11,7 @@ pub fn cpu_eval_at_secure_point(c: &mut criterion::Criterion) {
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use stwo_prover::core::poly::NaturalOrder;
let log_size = 20;
let rng = &mut StdRng::seed_from_u64(0);
let mut rng = SmallRng::seed_from_u64(0);

let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<CPUBackend, _, NaturalOrder>::new(
Expand All @@ -21,18 +21,8 @@ pub fn cpu_eval_at_secure_point(c: &mut criterion::Criterion) {
.collect(),
);
let poly = evaluation.bit_reverse().interpolate();
let x = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);
let y = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);
let x: QM31 = rng.gen();
let y: QM31 = rng.gen();

let point = CirclePoint { x, y };
c.bench_function("cpu eval_at_secure_field_point 2^20", |b| {
Expand All @@ -44,7 +34,7 @@ pub fn cpu_eval_at_secure_point(c: &mut criterion::Criterion) {

#[cfg(target_arch = "x86_64")]
pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
use rand::rngs::StdRng;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::avx512::AVX512Backend;
use stwo_prover::core::circle::CirclePoint;
Expand All @@ -53,7 +43,7 @@ pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use stwo_prover::core::poly::NaturalOrder;
let log_size = 20;
let rng = &mut StdRng::seed_from_u64(0);
let mut rng = SmallRng::seed_from_u64(0);

let domain = CanonicCoset::new(log_size as u32).circle_domain();
let evaluation = CircleEvaluation::<AVX512Backend, BaseField, NaturalOrder>::new(
Expand All @@ -63,18 +53,8 @@ pub fn avx512_eval_at_secure_point(c: &mut criterion::Criterion) {
.collect(),
);
let poly = evaluation.bit_reverse().interpolate();
let x = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);
let y = QM31::from_u32_unchecked(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
);
let x: QM31 = rng.gen();
let y: QM31 = rng.gen();

let point = CirclePoint { x, y };
c.bench_function("avx eval_at_secure_field_point 2^20", |b| {
Expand Down
63 changes: 15 additions & 48 deletions crates/prover/benches/field.rs
Original file line number Diff line number Diff line change
@@ -1,38 +1,17 @@
use criterion::Criterion;
use rand::rngs::ThreadRng;
use rand::Rng;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::fields::cm31::CM31;
use stwo_prover::core::fields::m31::{M31, P};
use stwo_prover::core::fields::m31::M31;
use stwo_prover::core::fields::qm31::SecureField;

pub const N_ELEMENTS: usize = 1 << 16;
pub const N_STATE_ELEMENTS: usize = 8;

pub fn get_random_m31_element(rng: &mut ThreadRng) -> M31 {
M31::from_u32_unchecked(rng.gen::<u32>() % P)
}

pub fn get_random_cm31_element(rng: &mut ThreadRng) -> CM31 {
CM31::from_m31(get_random_m31_element(rng), get_random_m31_element(rng))
}

pub fn get_random_qm31_element(rng: &mut ThreadRng) -> SecureField {
SecureField::from_m31(
get_random_m31_element(rng),
get_random_m31_element(rng),
get_random_m31_element(rng),
get_random_m31_element(rng),
)
}

pub fn m31_operations_bench(c: &mut criterion::Criterion) {
let mut rng = rand::thread_rng();
let mut elements: Vec<M31> = Vec::new();
let mut state: [M31; N_STATE_ELEMENTS] =
[(); N_STATE_ELEMENTS].map(|_| get_random_m31_element(&mut rng));

for _ in 0..(N_ELEMENTS) {
elements.push(get_random_m31_element(&mut rng));
}
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<M31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [M31; N_STATE_ELEMENTS] = rng.gen();

c.bench_function("M31 mul", |b| {
b.iter(|| {
Expand Down Expand Up @@ -60,14 +39,9 @@ pub fn m31_operations_bench(c: &mut criterion::Criterion) {
}

pub fn cm31_operations_bench(c: &mut criterion::Criterion) {
let mut rng = rand::thread_rng();
let mut elements: Vec<CM31> = Vec::new();
let mut state: [CM31; N_STATE_ELEMENTS] =
[(); N_STATE_ELEMENTS].map(|_| get_random_cm31_element(&mut rng));

for _ in 0..(N_ELEMENTS) {
elements.push(get_random_cm31_element(&mut rng));
}
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<CM31> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen();

c.bench_function("CM31 mul", |b| {
b.iter(|| {
Expand Down Expand Up @@ -95,14 +69,9 @@ pub fn cm31_operations_bench(c: &mut criterion::Criterion) {
}

pub fn qm31_operations_bench(c: &mut criterion::Criterion) {
let mut rng = rand::thread_rng();
let mut elements: Vec<SecureField> = Vec::new();
let mut state: [SecureField; N_STATE_ELEMENTS] =
[(); N_STATE_ELEMENTS].map(|_| get_random_qm31_element(&mut rng));

for _ in 0..(N_ELEMENTS) {
elements.push(get_random_qm31_element(&mut rng));
}
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<SecureField> = (0..N_ELEMENTS).map(|_| rng.gen()).collect();
let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen();

c.bench_function("SecureField mul", |b| {
b.iter(|| {
Expand Down Expand Up @@ -138,15 +107,13 @@ pub fn avx512_m31_operations_bench(c: &mut criterion::Criterion) {
return;
}

let mut rng = rand::thread_rng();
let mut rng = SmallRng::seed_from_u64(0);
let mut elements: Vec<PackedBaseField> = Vec::new();
let mut states: Vec<PackedBaseField> =
vec![PackedBaseField::from_array([1.into(); K_BLOCK_SIZE]); N_STATE_ELEMENTS];

for _ in 0..(N_ELEMENTS / K_BLOCK_SIZE) {
elements.push(PackedBaseField::from_array(
[get_random_m31_element(&mut rng); K_BLOCK_SIZE],
));
elements.push(PackedBaseField::from_array(rng.gen()));
}

c.bench_function("mul_avx512", |b| {
Expand Down
19 changes: 6 additions & 13 deletions crates/prover/benches/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::Rng;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::fields::m31::{M31, P};
use stwo_prover::core::fields::qm31::QM31;
use stwo_prover::math::matrix::{RowMajorMatrix, SquareMatrix};
Expand All @@ -9,30 +10,22 @@ const QM31_MATRIX_SIZE: usize = 6;

// TODO(ShaharS): Share code with other benchmarks.
fn row_major_matrix_multiplication_bench(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let mut rng = SmallRng::seed_from_u64(0);

let matrix_m31 = RowMajorMatrix::<M31, MATRIX_SIZE>::new(
(0..MATRIX_SIZE.pow(2))
.map(|_| M31::from_u32_unchecked(rng.gen::<u32>() % P))
.map(|_| rng.gen())
.collect::<Vec<M31>>(),
);

let matrix_qm31 = RowMajorMatrix::<QM31, QM31_MATRIX_SIZE>::new(
(0..QM31_MATRIX_SIZE.pow(2))
.map(|_| {
QM31::from_u32_unchecked(
rng.gen::<u32>() % P,
rng.gen::<u32>() % P,
rng.gen::<u32>() % P,
rng.gen::<u32>() % P,
)
})
.map(|_| rng.gen())
.collect::<Vec<QM31>>(),
);

// Create vector M31.
let vec: [M31; MATRIX_SIZE] =
[(); MATRIX_SIZE].map(|_| M31::from_u32_unchecked(rng.gen::<u32>() % P));
let vec: [M31; MATRIX_SIZE] = rng.gen();

// Create vector QM31.
let vec_qm31: [QM31; QM31_MATRIX_SIZE] = [(); QM31_MATRIX_SIZE].map(|_| {
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ mod tests {
use std::array;

use num_traits::Zero;
use rand::rngs::StdRng;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::*;
Expand All @@ -171,7 +171,7 @@ mod tests {
#[test]
fn test_point_evaluation_accumulator() {
// Generate a vector of random sizes with a constant seed.
let rng = &mut StdRng::seed_from_u64(0);
let mut rng = SmallRng::seed_from_u64(0);
const MAX_LOG_SIZE: u32 = 10;
const MASK: u32 = P;
let log_sizes = (0..100)
Expand Down Expand Up @@ -204,7 +204,7 @@ mod tests {
#[test]
fn test_domain_evaluation_accumulator() {
// Generate a vector of random sizes with a constant seed.
let rng = &mut StdRng::seed_from_u64(0);
let mut rng = SmallRng::seed_from_u64(0);
const LOG_SIZE_MIN: u32 = 4;
const LOG_SIZE_BOUND: u32 = 10;
const MASK: u32 = P;
Expand Down
20 changes: 5 additions & 15 deletions crates/prover/src/core/backend/avx512/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ fn slow_eval_at_point(
#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use rand::rngs::StdRng;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use crate::core::backend::avx512::circle::slow_eval_at_point;
Expand All @@ -341,9 +341,9 @@ mod tests {
use crate::core::backend::Column;
use crate::core::circle::CirclePoint;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps};
use crate::core::poly::{BitReversedOrder, NaturalOrder};
use crate::qm31;

#[test]
fn test_interpolate_and_eval() {
Expand Down Expand Up @@ -426,7 +426,7 @@ mod tests {
#[test]
fn test_eval_securefield() {
use crate::core::backend::avx512::fft::MIN_FFT_LOG_SIZE;
let rng = &mut StdRng::seed_from_u64(0);
let mut rng = SmallRng::seed_from_u64(0);

for log_size in MIN_FFT_LOG_SIZE..(CACHED_FFT_LOG_SIZE + 2) {
let domain = CanonicCoset::new(log_size as u32).circle_domain();
Expand All @@ -438,18 +438,8 @@ mod tests {
);
let poly = evaluation.bit_reverse().interpolate();

let x = qm31!(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>()
);
let y = qm31!(
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>(),
rng.gen::<u32>()
);
let x: SecureField = rng.gen();
let y: SecureField = rng.gen();

let p = CirclePoint { x, y };

Expand Down
Loading

0 comments on commit 89b4fcd

Please sign in to comment.