-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
89b4fcd
commit 7ddb0a1
Showing
9 changed files
with
1,217 additions
and
3 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
use std::array; | ||
use std::ops::{Add, Mul, MulAssign, Neg, Sub}; | ||
|
||
use num_traits::{One, Zero}; | ||
|
||
use super::m31::{PackedBaseField, N_LANES}; | ||
use crate::core::fields::cm31::CM31; | ||
use crate::core::fields::FieldExpOps; | ||
|
||
/// SIMD implementation of [`CM31`]. | ||
#[derive(Copy, Clone, Debug)] | ||
pub struct PackedCM31(pub [PackedBaseField; 2]); | ||
|
||
impl PackedCM31 { | ||
/// Constructs a new instance with all vector elements set to `value`. | ||
pub fn broadcast(value: CM31) -> Self { | ||
Self([ | ||
PackedBaseField::broadcast(value.0), | ||
PackedBaseField::broadcast(value.1), | ||
]) | ||
} | ||
|
||
/// Returns all `a` values such that each vector element is represented as `a + bi`. | ||
pub fn a(&self) -> PackedBaseField { | ||
self.0[0] | ||
} | ||
|
||
/// Returns all `b` values such that each vector element is represented as `a + bi`. | ||
pub fn b(&self) -> PackedBaseField { | ||
self.0[1] | ||
} | ||
|
||
pub fn to_array(&self) -> [CM31; N_LANES] { | ||
let a = self.a().to_array(); | ||
let b = self.b().to_array(); | ||
array::from_fn(|i| CM31(a[i], b[i])) | ||
} | ||
|
||
pub fn from_array(values: [CM31; N_LANES]) -> Self { | ||
Self([ | ||
PackedBaseField::from_array(values.map(|v| v.0)), | ||
PackedBaseField::from_array(values.map(|v| v.1)), | ||
]) | ||
} | ||
|
||
/// Interleaves two vectors. | ||
pub fn interleave(self, other: Self) -> (Self, Self) { | ||
let Self([a_evens, b_evens]) = self; | ||
let Self([a_odds, b_odds]) = other; | ||
let (a_lhs, a_rhs) = a_evens.interleave(a_odds); | ||
let (b_lhs, b_rhs) = b_evens.interleave(b_odds); | ||
(Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) | ||
} | ||
|
||
/// Deinterleaves two vectors. | ||
pub fn deinterleave(self, other: Self) -> (Self, Self) { | ||
let Self([a_self, b_self]) = self; | ||
let Self([a_other, b_other]) = other; | ||
let (a_evens, a_odds) = a_self.deinterleave(a_other); | ||
let (b_evens, b_odds) = b_self.deinterleave(b_other); | ||
(Self([a_evens, b_evens]), Self([a_odds, b_odds])) | ||
} | ||
|
||
/// Doubles each element in the vector. | ||
pub fn double(self) -> Self { | ||
let Self([a, b]) = self; | ||
Self([a.double(), b.double()]) | ||
} | ||
} | ||
|
||
impl Add for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn add(self, rhs: Self) -> Self::Output { | ||
Self([self.a() + rhs.a(), self.b() + rhs.b()]) | ||
} | ||
} | ||
|
||
impl Sub for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn sub(self, rhs: Self) -> Self::Output { | ||
Self([self.a() - rhs.a(), self.b() - rhs.b()]) | ||
} | ||
} | ||
|
||
impl Mul for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn mul(self, rhs: Self) -> Self::Output { | ||
// Compute using Karatsuba. | ||
let ac = self.a() * rhs.a(); | ||
let bd = self.b() * rhs.b(); | ||
// Computes (a + b) * (c + d). | ||
let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b()); | ||
// (ac - bd) + (ad + bc)i. | ||
Self([ac - bd, ab_t_cd - ac - bd]) | ||
} | ||
} | ||
|
||
impl Zero for PackedCM31 { | ||
fn zero() -> Self { | ||
Self([PackedBaseField::zero(), PackedBaseField::zero()]) | ||
} | ||
|
||
fn is_zero(&self) -> bool { | ||
self.a().is_zero() && self.b().is_zero() | ||
} | ||
} | ||
|
||
impl One for PackedCM31 { | ||
fn one() -> Self { | ||
Self([PackedBaseField::one(), PackedBaseField::zero()]) | ||
} | ||
} | ||
|
||
impl MulAssign for PackedCM31 { | ||
fn mul_assign(&mut self, rhs: Self) { | ||
*self = *self * rhs; | ||
} | ||
} | ||
|
||
impl FieldExpOps for PackedCM31 { | ||
fn inverse(&self) -> Self { | ||
assert!(!self.is_zero(), "0 has no inverse"); | ||
// 1 / (a + bi) = (a - bi) / (a^2 + b^2). | ||
Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse() | ||
} | ||
} | ||
|
||
impl Add<PackedBaseField> for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn add(self, rhs: PackedBaseField) -> Self::Output { | ||
Self([self.a() + rhs, self.b()]) | ||
} | ||
} | ||
|
||
impl Sub<PackedBaseField> for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn sub(self, rhs: PackedBaseField) -> Self::Output { | ||
let Self([a, b]) = self; | ||
Self([a - rhs, b]) | ||
} | ||
} | ||
|
||
impl Mul<PackedBaseField> for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn mul(self, rhs: PackedBaseField) -> Self::Output { | ||
let Self([a, b]) = self; | ||
Self([a * rhs, b * rhs]) | ||
} | ||
} | ||
|
||
impl Neg for PackedCM31 { | ||
type Output = Self; | ||
|
||
fn neg(self) -> Self::Output { | ||
let Self([a, b]) = self; | ||
Self([-a, -b]) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::array; | ||
|
||
use rand::rngs::SmallRng; | ||
use rand::{Rng, SeedableRng}; | ||
|
||
use crate::core::backend::simd::cm31::PackedCM31; | ||
|
||
#[test] | ||
fn addition_works() { | ||
let mut rng = SmallRng::seed_from_u64(0); | ||
let lhs = rng.gen(); | ||
let rhs = rng.gen(); | ||
let packed_lhs = PackedCM31::from_array(lhs); | ||
let packed_rhs = PackedCM31::from_array(rhs); | ||
|
||
let res = packed_lhs + packed_rhs; | ||
|
||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); | ||
} | ||
|
||
#[test] | ||
fn subtraction_works() { | ||
let mut rng = SmallRng::seed_from_u64(0); | ||
let lhs = rng.gen(); | ||
let rhs = rng.gen(); | ||
let packed_lhs = PackedCM31::from_array(lhs); | ||
let packed_rhs = PackedCM31::from_array(rhs); | ||
|
||
let res = packed_lhs - packed_rhs; | ||
|
||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); | ||
} | ||
|
||
#[test] | ||
fn multiplication_works() { | ||
let mut rng = SmallRng::seed_from_u64(0); | ||
let lhs = rng.gen(); | ||
let rhs = rng.gen(); | ||
let packed_lhs = PackedCM31::from_array(lhs); | ||
let packed_rhs = PackedCM31::from_array(rhs); | ||
|
||
let res = packed_lhs * packed_rhs; | ||
|
||
assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); | ||
} | ||
|
||
#[test] | ||
fn negation_works() { | ||
let mut rng = SmallRng::seed_from_u64(0); | ||
let values = rng.gen(); | ||
let packed_values = PackedCM31::from_array(values); | ||
|
||
let res = -packed_values; | ||
|
||
assert_eq!(res.to_array(), values.map(|v| -v)); | ||
} | ||
} |
Oops, something went wrong.