Skip to content

Commit

Permalink
Implement FieldOps for SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 1, 2024
1 parent d2ca30c commit 9bd0fc4
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 1 deletion.
159 changes: 159 additions & 0 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use bytemuck::{cast_slice, cast_slice_mut, Zeroable};
use itertools::Itertools;
use num_traits::Zero;

use super::m31::{PackedBaseField, N_LANES};
use super::qm31::PackedSecureField;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;

#[derive(Clone, Debug)]
pub struct BaseFieldVec {
pub data: Vec<PackedBaseField>,
pub length: usize,
}

impl AsRef<[BaseField]> for BaseFieldVec {
fn as_ref(&self) -> &[BaseField] {
&cast_slice(&self.data)[..self.length]
}
}

impl AsMut<[BaseField]> for BaseFieldVec {
fn as_mut(&mut self) -> &mut [BaseField] {
&mut cast_slice_mut(&mut self.data)[..self.length]
}
}

impl Column<BaseField> for BaseFieldVec {
fn zeros(length: usize) -> Self {
let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)];
Self { data, length }
}

fn to_cpu(&self) -> Vec<BaseField> {
self.as_ref().to_vec()
}

fn len(&self) -> usize {
self.length
}

fn at(&self, index: usize) -> BaseField {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
}

impl FromIterator<BaseField> for BaseFieldVec {
fn from_iter<I: IntoIterator<Item = BaseField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks).map(PackedBaseField::from_array).collect_vec();
let mut length = data.len() * N_LANES;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [BaseField::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedBaseField::from_array(last));
}
}

Self { data, length }
}
}

#[derive(Clone, Debug)]
pub struct SecureFieldVec {
pub data: Vec<PackedSecureField>,
pub length: usize,
}

impl Column<SecureField> for SecureFieldVec {
fn zeros(length: usize) -> Self {
Self {
data: vec![PackedSecureField::zeroed(); length.div_ceil(N_LANES)],
length,
}
}

fn to_cpu(&self) -> Vec<SecureField> {
self.data
.iter()
.flat_map(|x| x.to_array())
.take(self.length)
.collect()
}

fn len(&self) -> usize {
self.length
}

fn at(&self, index: usize) -> SecureField {
self.data[index / N_LANES].to_array()[index % N_LANES]
}
}

impl FromIterator<SecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = SecureField>>(iter: I) -> Self {
let mut chunks = iter.into_iter().array_chunks();
let mut data = (&mut chunks)
.map(PackedSecureField::from_array)
.collect_vec();
let mut length = data.len() * N_LANES;

if let Some(remainder) = chunks.into_remainder() {
if !remainder.is_empty() {
length += remainder.len();
let mut last = [SecureField::zero(); N_LANES];
last[..remainder.len()].copy_from_slice(remainder.as_slice());
data.push(PackedSecureField::from_array(last));
}
}

Self { data, length }
}
}

impl FromIterator<PackedSecureField> for SecureFieldVec {
fn from_iter<I: IntoIterator<Item = PackedSecureField>>(iter: I) -> Self {
let data = (&mut iter.into_iter()).collect_vec();
let length = data.len() * N_LANES;

Self { data, length }
}
}

#[cfg(test)]
mod tests {
use std::array;

use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::BaseFieldVec;
use crate::core::backend::simd::column::SecureFieldVec;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;

#[test]
fn base_field_vec_from_iter_works() {
let values: [BaseField; 30] = array::from_fn(BaseField::from);

let res = values.into_iter().collect::<BaseFieldVec>();

assert_eq!(res.to_cpu(), values);
}

#[test]
fn secure_field_vec_from_iter_works() {
let mut rng = SmallRng::seed_from_u64(0);
let values: [SecureField; 30] = rng.gen();

let res = values.into_iter().collect::<SecureFieldVec>();

assert_eq!(res.to_cpu(), values);
}
}
9 changes: 8 additions & 1 deletion crates/prover/src/core/backend/simd/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::simd::{u32x16, Simd, Swizzle};

use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use rand::distributions::{Distribution, Standard};

use crate::core::backend::simd::utils::{LoEvensInterleaveHiEvens, LoOddsInterleaveHiOdds};
use crate::core::fields::m31::{BaseField, M31, P};
Expand Down Expand Up @@ -60,7 +61,7 @@ impl PackedBaseField {
self.to_array().into_iter().sum()
}

/// Doubles each element.
/// Doubles each element in the vector.
pub fn double(self) -> Self {
// TODO: Make more optimal.
self + self
Expand Down Expand Up @@ -212,6 +213,12 @@ impl From<[BaseField; N_LANES]> for PackedBaseField {
}
}

impl Distribution<PackedBaseField> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedBaseField {
PackedBaseField::from_array(rng.gen())
}
}

#[cfg(target_arch = "aarch64")]
fn _mul_neon(a: PackedBaseField, b: PackedBaseField) -> PackedBaseField {
use core::arch::aarch64::{int32x2_t, vqdmull_s32};
Expand Down
37 changes: 37 additions & 0 deletions crates/prover/src/core/backend/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
use self::column::{BaseFieldVec, SecureFieldVec};
use self::m31::PackedBaseField;
use self::qm31::PackedSecureField;
use super::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{FieldExpOps, FieldOps};

pub mod cm31;
pub mod column;
pub mod m31;
pub mod qm31;
mod utils;

#[derive(Copy, Clone, Debug)]
pub struct SimdBackend;

impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseFieldVec;

fn bit_reverse_column(_column: &mut Self::Column) {
todo!()
}
}

impl FieldOps<BaseField> for SimdBackend {
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
PackedBaseField::batch_inverse(&column.data, &mut dst.data);
}
}

impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureFieldVec;

fn bit_reverse_column(_column: &mut Self::Column) {
todo!()
}
}

impl FieldOps<SecureField> for SimdBackend {
fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) {
PackedSecureField::batch_inverse(&column.data, &mut dst.data);
}
}
7 changes: 7 additions & 0 deletions crates/prover/src/core/backend/simd/qm31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use bytemuck::{Pod, Zeroable};
use num_traits::{One, Zero};
use rand::distributions::{Distribution, Standard};

use super::cm31::PackedCM31;
use super::m31::{PackedBaseField, N_LANES};
Expand Down Expand Up @@ -224,6 +225,12 @@ impl Neg for PackedSecureField {
}
}

impl Distribution<PackedSecureField> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> PackedSecureField {
PackedSecureField::from_array(rng.gen())
}
}

#[cfg(test)]
mod tests {
use std::array;
Expand Down

0 comments on commit 9bd0fc4

Please sign in to comment.