diff --git a/Cargo.lock b/Cargo.lock index 6e0fc4d6b..c63918921 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "aligned" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "377e4c0ba83e4431b10df45c1d4666f178ea9c552cac93e60c3a88bf32785923" +dependencies = [ + "as-slice", +] + [[package]] name = "anes" version = "0.1.6" @@ -35,6 +44,15 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +[[package]] +name = "as-slice" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516b6b4f0e40d50dcda9365d53964ec74560ad4284da2e7fc97122cd83174516" +dependencies = [ + "stable_deref_trait", +] + [[package]] name = "autocfg" version = "1.2.0" @@ -636,13 +654,21 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "stwo-prover" version = "0.1.1" dependencies = [ + "aligned", "blake2", "blake3", "bytemuck", + "cfg-if", "criterion", "derivative", "hex", diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 3a9be22b6..fed9ad7ce 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -8,16 +8,18 @@ edition.workspace = true [dependencies] blake2.workspace = true blake3.workspace = true +bytemuck = { workspace = true, features = ["derive"] } +cfg-if = "1.0.0" derivative.workspace = true hex.workspace = true itertools.workspace = true num-traits.workspace = true -thiserror.workspace = true -bytemuck = { workspace = true, features = ["derive"] } rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } +thiserror.workspace = true tracing.workspace = true [dev-dependencies] +aligned = "0.4.2" criterion = { version = "0.5.1", features = ["html_reports"] } test-log = { version = "0.2.15", features = ["trace"] } tracing-subscriber = "0.3.18" diff --git a/crates/prover/src/core/backend/mod.rs b/crates/prover/src/core/backend/mod.rs index d7be7dab5..b8bd2f534 100644 --- a/crates/prover/src/core/backend/mod.rs +++ b/crates/prover/src/core/backend/mod.rs @@ -13,6 +13,7 @@ use super::poly::circle::PolyOps; #[cfg(target_arch = "x86_64")] pub mod avx512; pub mod cpu; +pub mod simd; pub trait Backend: Copy diff --git a/crates/prover/src/core/backend/simd/cm31.rs b/crates/prover/src/core/backend/simd/cm31.rs new file mode 100644 index 000000000..fa3b86489 --- /dev/null +++ b/crates/prover/src/core/backend/simd/cm31.rs @@ -0,0 +1,221 @@ +use std::array; +use std::ops::{Add, Mul, MulAssign, Neg, Sub}; + +use num_traits::{One, Zero}; + +use super::m31::{PackedM31, N_LANES}; +use crate::core::fields::cm31::CM31; +use crate::core::fields::FieldExpOps; + +/// SIMD implementation of [`CM31`]. +#[derive(Copy, Clone, Debug)] +pub struct PackedCM31(pub [PackedM31; 2]); + +impl PackedCM31 { + /// Constructs a new instance with all vector elements set to `value`. + pub fn broadcast(value: CM31) -> Self { + Self([PackedM31::broadcast(value.0), PackedM31::broadcast(value.1)]) + } + + /// Returns all `a` values such that each vector element is represented as `a + bi`. + pub fn a(&self) -> PackedM31 { + self.0[0] + } + + /// Returns all `b` values such that each vector element is represented as `a + bi`. + pub fn b(&self) -> PackedM31 { + self.0[1] + } + + pub fn to_array(&self) -> [CM31; N_LANES] { + let a = self.a().to_array(); + let b = self.b().to_array(); + array::from_fn(|i| CM31(a[i], b[i])) + } + + pub fn from_array(values: [CM31; N_LANES]) -> Self { + Self([ + PackedM31::from_array(values.map(|v| v.0)), + PackedM31::from_array(values.map(|v| v.1)), + ]) + } + + /// Interleaves two vectors. + pub fn interleave(self, other: Self) -> (Self, Self) { + let Self([a_evens, b_evens]) = self; + let Self([a_odds, b_odds]) = other; + let (a_lhs, a_rhs) = a_evens.interleave(a_odds); + let (b_lhs, b_rhs) = b_evens.interleave(b_odds); + (Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) + } + + /// Deinterleaves two vectors. + pub fn deinterleave(self, other: Self) -> (Self, Self) { + let Self([a_self, b_self]) = self; + let Self([a_other, b_other]) = other; + let (a_evens, a_odds) = a_self.deinterleave(a_other); + let (b_evens, b_odds) = b_self.deinterleave(b_other); + (Self([a_evens, b_evens]), Self([a_odds, b_odds])) + } + + /// Doubles each element in the vector. + pub fn double(self) -> Self { + let Self([a, b]) = self; + Self([a.double(), b.double()]) + } +} + +impl Add for PackedCM31 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self([self.a() + rhs.a(), self.b() + rhs.b()]) + } +} + +impl Sub for PackedCM31 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self([self.a() - rhs.a(), self.b() - rhs.b()]) + } +} + +impl Mul for PackedCM31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + // Compute using Karatsuba. + let ac = self.a() * rhs.a(); + let bd = self.b() * rhs.b(); + // Computes (a + b) * (c + d). + let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b()); + // (ac - bd) + (ad + bc)i. + Self([ac - bd, ab_t_cd - ac - bd]) + } +} + +impl Zero for PackedCM31 { + fn zero() -> Self { + Self([PackedM31::zero(), PackedM31::zero()]) + } + + fn is_zero(&self) -> bool { + self.a().is_zero() && self.b().is_zero() + } +} + +impl One for PackedCM31 { + fn one() -> Self { + Self([PackedM31::one(), PackedM31::zero()]) + } +} + +impl MulAssign for PackedCM31 { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl FieldExpOps for PackedCM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // 1 / (a + bi) = (a - bi) / (a^2 + b^2). + Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse() + } +} + +impl Add for PackedCM31 { + type Output = Self; + + fn add(self, rhs: PackedM31) -> Self::Output { + Self([self.a() + rhs, self.b()]) + } +} + +impl Sub for PackedCM31 { + type Output = Self; + + fn sub(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a - rhs, b]) + } +} + +impl Mul for PackedCM31 { + type Output = Self; + + fn mul(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a * rhs, b * rhs]) + } +} + +impl Neg for PackedCM31 { + type Output = Self; + + fn neg(self) -> Self::Output { + let Self([a, b]) = self; + Self([-a, -b]) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::cm31::PackedCM31; + + #[test] + fn addition_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedCM31::from_array(lhs); + let packed_rhs = PackedCM31::from_array(rhs); + + let res = packed_lhs + packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); + } + + #[test] + fn subtraction_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedCM31::from_array(lhs); + let packed_rhs = PackedCM31::from_array(rhs); + + let res = packed_lhs - packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); + } + + #[test] + fn multiplication_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedCM31::from_array(lhs); + let packed_rhs = PackedCM31::from_array(rhs); + + let res = packed_lhs * packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); + } + + #[test] + fn negation_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedCM31::from_array(values); + + let res = -packed_values; + + assert_eq!(res.to_array(), values.map(|v| -v)); + } +} diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs new file mode 100644 index 000000000..0a31bab13 --- /dev/null +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -0,0 +1,590 @@ +use std::mem::transmute; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ptr; +use std::simd::cmp::SimdOrd; +use std::simd::{u32x16, Simd, Swizzle}; + +use bytemuck::{Pod, Zeroable}; +use num_traits::{One, Zero}; +use rand::distributions::{Distribution, Standard}; + +use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; +use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; +use crate::core::fields::FieldExpOps; + +pub const LOG_N_LANES: u32 = 4; + +pub const N_LANES: usize = 1 << LOG_N_LANES; + +pub const MODULUS: Simd = Simd::from_array([P; N_LANES]); + +pub type PackedBaseField = PackedM31; + +/// Holds a vector of unreduced [`M31`] elements in the range `[0, P]`. +/// +/// Implemented with [`std::simd`] to support multiple targets (avx512, neon, wasm etc.). +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct PackedM31(Simd); + +impl PackedM31 { + /// Constructs a new instance with all vector elements set to `value`. + pub fn broadcast(M31(value): M31) -> Self { + Self(Simd::splat(value)) + } + + pub fn from_array(values: [M31; N_LANES]) -> PackedM31 { + Self(Simd::from_array(values.map(|M31(v)| v))) + } + + pub fn to_array(self) -> [M31; N_LANES] { + self.reduce().0.to_array().map(M31) + } + + /// Reduces each element of the vector to the range `[0, P)`. + fn reduce(self) -> PackedM31 { + Self(Simd::simd_min(self.0, self.0 - MODULUS)) + } + + /// Interleaves two vectors. + pub fn interleave(self, other: Self) -> (Self, Self) { + let (a, b) = self.0.interleave(other.0); + (Self(a), Self(b)) + } + + /// Deinterleaves two vectors. + pub fn deinterleave(self, other: Self) -> (Self, Self) { + let (a, b) = self.0.deinterleave(other.0); + (Self(a), Self(b)) + } + + /// Sums all the elements in the vector. + pub fn pointwise_sum(self) -> M31 { + self.to_array().into_iter().sum() + } + + /// Doubles each element in the vector. + pub fn double(self) -> Self { + // TODO: Make more optimal. + self + self + } + + pub fn into_simd(self) -> Simd { + self.0 + } + + /// # Safety + /// + /// Vector elements must be in the range `[0, P]`. + pub unsafe fn from_simd_unchecked(v: Simd) -> Self { + Self(v) + } + + /// # Safety + /// + /// Behavior is undefined if the pointer does not have the same alignment as + /// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`. + pub unsafe fn load(mem_addr: *const u32) -> Self { + Self(ptr::read(mem_addr as *const u32x16)) + } + + /// # Safety + /// + /// Behavior is undefined if the pointer does not have the same alignment as + /// [`PackedM31`]. + pub unsafe fn store(self, dst: *mut u32) { + ptr::write(dst as *mut u32x16, self.0) + } +} + +impl Add for PackedM31 { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: Self) -> Self::Output { + // Add word by word. Each word is in the range [0, 2P]. + let c = self.0 + rhs.0; + // Apply min(c, c-P) to each word. + // When c in [P,2P], then c-P in [0,P] which is always less than [P,2P]. + // When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1]. + Self(Simd::simd_min(c, c - MODULUS)) + } +} + +impl AddAssign for PackedM31 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl Mul for PackedM31 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + // TODO: Ensure all these branches get tested in the CI. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + _mul_neon(self, rhs) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + _mul_wasm(self, rhs) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + _mul_avx512(self, rhs) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] { + _mul_avx2(self, rhs) + } else { + _mul_simd(self, rhs) + } + } + } +} + +impl MulAssign for PackedM31 { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl Neg for PackedM31 { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self(MODULUS - self.0) + } +} + +impl Sub for PackedM31 { + type Output = Self; + + #[inline(always)] + fn sub(self, rhs: Self) -> Self::Output { + // Subtract word by word. Each word is in the range [-P, P]. + let c = self.0 - rhs.0; + // Apply min(c, c+P) to each word. + // When c in [0,P], then c+P in [P,2P] which is always greater than [0,P]. + // When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than + // [2^32-P,2^32-1]. + Self(Simd::simd_min(c + MODULUS, c)) + } +} + +impl SubAssign for PackedM31 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Zero for PackedM31 { + fn zero() -> Self { + Self(Simd::from_array([0; N_LANES])) + } + + fn is_zero(&self) -> bool { + self.to_array().iter().all(M31::is_zero) + } +} + +impl One for PackedM31 { + fn one() -> Self { + Self(Simd::::from_array([1; N_LANES])) + } +} + +impl FieldExpOps for PackedM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + pow2147483645(*self) + } +} + +unsafe impl Pod for PackedM31 {} + +unsafe impl Zeroable for PackedM31 { + fn zeroed() -> Self { + unsafe { core::mem::zeroed() } + } +} + +impl From<[BaseField; N_LANES]> for PackedM31 { + fn from(v: [BaseField; N_LANES]) -> Self { + Self::from_array(v) + } +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> PackedM31 { + PackedM31::from_array(rng.gen()) + } +} + +/// Returns `a * b`. +#[cfg(target_arch = "aarch64")] +fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { + use core::arch::aarch64::{int32x2_t, vqdmull_s32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "aarch64")] +fn _mul_doubled_neon(a: PackedM31, b_double: PackedM31) -> PackedM31 { + use core::arch::aarch64::{uint32x2_t, vmull_u32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + +/// Returns `a * b`. +#[cfg(target_arch = "wasm32")] +fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_wasm(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "wasm32")] +fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; + + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd +} + +/// Returns `a * b`. +#[cfg(target_arch = "x86_64")] +fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_avx512(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "x86_64")] +fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + + let a: __m512i = unsafe { transmute(a) }; + let b_double: __m512i = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl_e = b_double; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) +} + +/// Returns `a * b`. +#[cfg(target_arch = "x86_64")] +fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_avx2(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "x86_64")] +fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + + let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) +} + +/// Returns `a * b`. +/// +/// Should only be used in the absence of a platform specific implementation. +fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_simd(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// Should only be used in the absence of a platform specific implementation. +/// +/// `b_double` should be in the range `[0, 2P]`. +fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; + + let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; + let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = InterleaveEvens::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = InterleaveOdds::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + PackedM31(prod_lows) + PackedM31(prod_highs) +} + +#[cfg(test)] +mod tests { + use std::array; + + use aligned::{Aligned, A64}; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::PackedM31; + use crate::core::fields::m31::BaseField; + + #[test] + fn addition_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedM31::from_array(lhs); + let packed_rhs = PackedM31::from_array(rhs); + + let res = packed_lhs + packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); + } + + #[test] + fn subtraction_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedM31::from_array(lhs); + let packed_rhs = PackedM31::from_array(rhs); + + let res = packed_lhs - packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); + } + + #[test] + fn multiplication_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedM31::from_array(lhs); + let packed_rhs = PackedM31::from_array(rhs); + + let res = packed_lhs * packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); + } + + #[test] + fn negation_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedM31::from_array(values); + + let res = -packed_values; + + assert_eq!(res.to_array(), array::from_fn(|i| -values[i])); + } + + #[test] + fn load_works() { + let v: Aligned = Aligned(array::from_fn(|i| i as u32)); + + let res = unsafe { PackedM31::load(v.as_ptr()) }; + + assert_eq!(res.to_array().map(|v| v.0), *v); + } + + #[test] + fn store_works() { + let v = PackedM31::from_array(array::from_fn(BaseField::from)); + + let mut res: Aligned = Aligned([0; 16]); + unsafe { v.store(res.as_mut_ptr()) }; + + assert_eq!(*res, v.to_array().map(|v| v.0)); + } +} diff --git a/crates/prover/src/core/backend/simd/mod.rs b/crates/prover/src/core/backend/simd/mod.rs new file mode 100644 index 000000000..ff94a9f5b --- /dev/null +++ b/crates/prover/src/core/backend/simd/mod.rs @@ -0,0 +1,7 @@ +pub mod cm31; +pub mod m31; +pub mod qm31; +mod utils; + +#[derive(Copy, Clone, Debug)] +pub struct SimdBackend; diff --git a/crates/prover/src/core/backend/simd/qm31.rs b/crates/prover/src/core/backend/simd/qm31.rs new file mode 100644 index 000000000..0401dcbf6 --- /dev/null +++ b/crates/prover/src/core/backend/simd/qm31.rs @@ -0,0 +1,294 @@ +use std::array; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use bytemuck::{Pod, Zeroable}; +use num_traits::{One, Zero}; +use rand::distributions::{Distribution, Standard}; + +use super::cm31::PackedCM31; +use super::m31::{PackedM31, N_LANES}; +use crate::core::fields::qm31::QM31; +use crate::core::fields::FieldExpOps; + +pub type PackedSecureField = PackedQM31; + +/// SIMD implementation of [`QM31`]. +#[derive(Copy, Clone, Debug)] +pub struct PackedQM31(pub [PackedCM31; 2]); + +impl PackedQM31 { + /// Constructs a new instance with all vector elements set to `value`. + pub fn broadcast(value: QM31) -> Self { + Self([ + PackedCM31::broadcast(value.0), + PackedCM31::broadcast(value.1), + ]) + } + + /// Returns all `a` values such that each vector element is represented as `a + bu`. + pub fn a(&self) -> PackedCM31 { + self.0[0] + } + + /// Returns all `b` values such that each vector element is represented as `a + bu`. + pub fn b(&self) -> PackedCM31 { + self.0[1] + } + + pub fn to_array(&self) -> [QM31; N_LANES] { + let a = self.a().to_array(); + let b = self.b().to_array(); + array::from_fn(|i| QM31(a[i], b[i])) + } + + pub fn from_array(values: [QM31; N_LANES]) -> Self { + let a = values.map(|v| v.0); + let b = values.map(|v| v.1); + Self([PackedCM31::from_array(a), PackedCM31::from_array(b)]) + } + + /// Interleaves two vectors. + pub fn interleave(self, other: Self) -> (Self, Self) { + let Self([a_evens, b_evens]) = self; + let Self([a_odds, b_odds]) = other; + let (a_lhs, a_rhs) = a_evens.interleave(a_odds); + let (b_lhs, b_rhs) = b_evens.interleave(b_odds); + (Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) + } + + /// Deinterleaves two vectors. + pub fn deinterleave(self, other: Self) -> (Self, Self) { + let Self([a_lhs, b_lhs]) = self; + let Self([a_rhs, b_rhs]) = other; + let (a_evens, a_odds) = a_lhs.deinterleave(a_rhs); + let (b_evens, b_odds) = b_lhs.deinterleave(b_rhs); + (Self([a_evens, b_evens]), Self([a_odds, b_odds])) + } + + /// Sums all the elements in the vector. + pub fn pointwise_sum(self) -> QM31 { + self.to_array().into_iter().sum() + } + + /// Doubles each element in the vector. + pub fn double(self) -> Self { + let Self([a, b]) = self; + Self([a.double(), b.double()]) + } +} + +impl Add for PackedQM31 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self([self.a() + rhs.a(), self.b() + rhs.b()]) + } +} + +impl Sub for PackedQM31 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self([self.a() - rhs.a(), self.b() - rhs.b()]) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + // Compute using Karatsuba. + // (a + ub) * (c + ud) = + // (ac + (2+i)bd) + (ad + bc)u = + // ac + 2bd + ibd + (ad + bc)u. + let ac = self.a() * rhs.a(); + let bd = self.b() * rhs.b(); + let bd_times_1_plus_i = PackedCM31([bd.a() - bd.b(), bd.a() + bd.b()]); + // Computes ac + bd. + let ac_p_bd = ac + bd; + // Computes ad + bc. + let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd; + // ac + 2bd + ibd = + // ac + bd + bd + ibd + let l = PackedCM31([ + ac_p_bd.a() + bd_times_1_plus_i.a(), + ac_p_bd.b() + bd_times_1_plus_i.b(), + ]); + Self([l, ad_p_bc]) + } +} + +impl Zero for PackedQM31 { + fn zero() -> Self { + Self([PackedCM31::zero(), PackedCM31::zero()]) + } + + fn is_zero(&self) -> bool { + self.a().is_zero() && self.b().is_zero() + } +} + +impl One for PackedQM31 { + fn one() -> Self { + Self([PackedCM31::one(), PackedCM31::zero()]) + } +} + +impl AddAssign for PackedQM31 { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl MulAssign for PackedQM31 { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl FieldExpOps for PackedQM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). + let b2 = self.b().square(); + let ib2 = PackedCM31([-b2.b(), b2.a()]); + let denom = self.a().square() - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + Self([self.a() * denom_inverse, -self.b() * denom_inverse]) + } +} + +impl Add for PackedQM31 { + type Output = Self; + + fn add(self, rhs: PackedM31) -> Self::Output { + Self([self.a() + rhs, self.b()]) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a * rhs, b * rhs]) + } +} + +impl Sub for PackedQM31 { + type Output = Self; + + fn sub(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a - rhs, b]) + } +} + +impl SubAssign for PackedQM31 { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +unsafe impl Pod for PackedQM31 {} + +unsafe impl Zeroable for PackedQM31 { + fn zeroed() -> Self { + unsafe { core::mem::zeroed() } + } +} + +impl Sum for PackedQM31 { + fn sum(mut iter: I) -> Self + where + I: Iterator, + { + let first = iter.next().unwrap_or_else(Self::zero); + iter.fold(first, |a, b| a + b) + } +} + +impl<'a> Sum<&'a Self> for PackedQM31 { + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.copied().sum() + } +} + +impl Neg for PackedQM31 { + type Output = Self; + + fn neg(self) -> Self::Output { + let Self([a, b]) = self; + Self([-a, -b]) + } +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> PackedQM31 { + PackedQM31::from_array(rng.gen()) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::qm31::PackedQM31; + + #[test] + fn addition_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedQM31::from_array(lhs); + let packed_rhs = PackedQM31::from_array(rhs); + + let res = packed_lhs + packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); + } + + #[test] + fn subtraction_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedQM31::from_array(lhs); + let packed_rhs = PackedQM31::from_array(rhs); + + let res = packed_lhs - packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); + } + + #[test] + fn multiplication_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedQM31::from_array(lhs); + let packed_rhs = PackedQM31::from_array(rhs); + + let res = packed_lhs * packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); + } + + #[test] + fn negation_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedQM31::from_array(values); + + let res = -packed_values; + + assert_eq!(res.to_array(), values.map(|v| -v)); + } +} diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs new file mode 100644 index 000000000..87dfd2246 --- /dev/null +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -0,0 +1,52 @@ +use std::simd::Swizzle; + +/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. +pub struct InterleaveEvens; + +impl Swizzle for InterleaveEvens { + const INDEX: [usize; N] = parity_interleave(false); +} + +/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. +pub struct InterleaveOdds; + +impl Swizzle for InterleaveOdds { + const INDEX: [usize; N] = parity_interleave(true); +} + +const fn parity_interleave(odd: bool) -> [usize; N] { + let mut res = [0; N]; + let mut i = 0; + while i < N { + res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; + i += 1; + } + res +} + +#[cfg(test)] +mod tests { + use std::simd::{u32x4, Swizzle}; + + use super::{InterleaveEvens, InterleaveOdds}; + + #[test] + fn interleave_evens() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveEvens::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + } + + #[test] + fn interleave_odds() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveOdds::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + } +} diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index a1dcb3442..9d9a5f926 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -9,7 +9,8 @@ get_many_mut, int_roundings, slice_flatten, - assert_matches + assert_matches, + portable_simd )] pub mod core; pub mod examples;