Skip to content

Commit

Permalink
Improve FieldElement struct construction
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoratger committed Feb 17, 2024
1 parent 6ebaa9b commit 7247e7e
Showing 1 changed file with 51 additions and 69 deletions.
120 changes: 51 additions & 69 deletions starknet-ff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,36 @@
#[cfg_attr(test, macro_use)]
extern crate alloc;

use core::{
fmt, ops,
str::{self, FromStr},
};

use crate::fr::Fr;

use ark_ff::{
fields::{Field, Fp256, PrimeField},
BigInteger, BigInteger256,
};
use core::{
fmt, ops,
str::{self, FromStr},
};
use crypto_bigint::{CheckedAdd, CheckedMul, NonZero, Zero, U256};

mod fr;

const U256_BYTE_COUNT: usize = 32;

#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct FieldElement {
inner: Fr,
pub struct FieldElement(Fr);

impl std::ops::Deref for FieldElement {

Check failure on line 26 in starknet-ff/src/lib.rs

View workflow job for this annotation

GitHub Actions / WASM tests

failed to resolve: use of undeclared crate or module `std`
type Target = Fr;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl std::ops::DerefMut for FieldElement {

Check failure on line 34 in starknet-ff/src/lib.rs

View workflow job for this annotation

GitHub Actions / WASM tests

failed to resolve: use of undeclared crate or module `std`
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

mod from_str_error {
Expand Down Expand Up @@ -136,9 +146,7 @@ impl FieldElement {

/// Create a new [FieldElement] from its Montgomery representation
pub const fn from_mont(data: [u64; 4]) -> Self {
Self {
inner: Fp256::new_unchecked(BigInteger256::new(data)),
}
Self(Fp256::new_unchecked(BigInteger256::new(data)))
}

pub fn from_dec_str(value: &str) -> Result<Self, FromStrError> {
Expand Down Expand Up @@ -170,7 +178,7 @@ impl FieldElement {
}

Fr::from_bigint(u256_to_biginteger256(&res))
.map(|inner| Self { inner })
.map(|inner| Self(inner))
.ok_or(FromStrError::OutOfRange)
}

Expand Down Expand Up @@ -238,7 +246,7 @@ impl FieldElement {
/// Transforms [FieldElement] into little endian bit representation.
pub fn to_bits_le(self) -> [bool; 256] {
let mut bits = [false; 256];
for (ind_element, element) in self.inner.into_bigint().0.iter().enumerate() {
for (ind_element, element) in self.into_bigint().0.iter().enumerate() {
for ind_bit in 0..64 {
bits[ind_element * 64 + ind_bit] = (element >> ind_bit) & 1 == 1;
}
Expand All @@ -250,22 +258,22 @@ impl FieldElement {
/// Convert the field element into a big-endian byte representation
pub fn to_bytes_be(&self) -> [u8; 32] {
let mut buffer = [0u8; 32];
buffer.copy_from_slice(&self.inner.into_bigint().to_bytes_be());
buffer.copy_from_slice(&self.into_bigint().to_bytes_be());

buffer
}

/// Transforms [FieldElement] into its Montgomery representation
pub const fn into_mont(self) -> [u64; 4] {
self.inner.0 .0
self.0 .0 .0
}

pub fn invert(&self) -> Option<FieldElement> {
self.inner.inverse().map(|inner| Self { inner })
self.inverse().map(|inner| Self(inner))
}

pub fn sqrt(&self) -> Option<FieldElement> {
self.inner.sqrt().map(|inner| Self { inner })
Some(FieldElement(self.0.sqrt()?))
}

pub fn double(&self) -> FieldElement {
Expand All @@ -286,9 +294,7 @@ impl FieldElement {
let (quotient, _) = div_result;

// It's safe to unwrap here since `rem` is never out of range
FieldElement {
inner: Fr::from_bigint(u256_to_biginteger256(&quotient)).unwrap(),
}
FieldElement(Fr::from_bigint(u256_to_biginteger256(&quotient)).unwrap())
} else {
// TODO: add `checked_floor_div` for panic-less use
panic!("division by zero");
Expand All @@ -306,7 +312,7 @@ impl FieldElement {

// No need to check range as `from_bigint` already does that
let big_int = BigInteger256::from_bits_be(&bits);
Fr::from_bigint(big_int).map(|inner| Self { inner })
Fr::from_bigint(big_int).map(|inner| Self(inner))
}
}

Expand All @@ -326,63 +332,57 @@ impl ops::Add<FieldElement> for FieldElement {
type Output = FieldElement;

fn add(self, rhs: FieldElement) -> Self::Output {
FieldElement {
inner: self.inner + rhs.inner,
}
FieldElement(*self + *rhs)
}
}

impl ops::AddAssign<FieldElement> for FieldElement {
fn add_assign(&mut self, rhs: FieldElement) {
self.inner = self.inner + rhs.inner;
*self = *self + rhs;
}
}

impl ops::Sub<FieldElement> for FieldElement {
type Output = FieldElement;

fn sub(self, rhs: FieldElement) -> Self::Output {
FieldElement {
inner: self.inner - rhs.inner,
}
FieldElement(*self - *rhs)
}
}

impl ops::SubAssign<FieldElement> for FieldElement {
fn sub_assign(&mut self, rhs: FieldElement) {
self.inner = self.inner - rhs.inner;
*self = *self - rhs;
}
}

impl ops::Mul<FieldElement> for FieldElement {
type Output = FieldElement;

fn mul(self, rhs: FieldElement) -> Self::Output {
FieldElement {
inner: self.inner * rhs.inner,
}
FieldElement(*self * *rhs)
}
}

impl ops::MulAssign<FieldElement> for FieldElement {
fn mul_assign(&mut self, rhs: FieldElement) {
self.inner = self.inner * rhs.inner;
*self = *self * rhs;
}
}

impl ops::Neg for FieldElement {
type Output = FieldElement;

fn neg(self) -> Self::Output {
FieldElement { inner: -self.inner }
FieldElement { 0: -self.0 }
}
}

impl ops::Rem<FieldElement> for FieldElement {
type Output = FieldElement;

fn rem(self, rhs: FieldElement) -> Self::Output {
if self.inner < rhs.inner {
if self < rhs {
return self;
}

Expand All @@ -396,9 +396,7 @@ impl ops::Rem<FieldElement> for FieldElement {
let (_, rem) = lhs.div_rem(&rhs);

// It's safe to unwrap here since `rem` is never out of range
FieldElement {
inner: Fr::from_bigint(u256_to_biginteger256(&rem)).unwrap(),
}
FieldElement(Fr::from_bigint(u256_to_biginteger256(&rem)).unwrap())
} else {
// TODO: add `checked_rem` for panic-less use
panic!("division by zero");
Expand All @@ -414,9 +412,7 @@ impl ops::BitAnd<FieldElement> for FieldElement {
let rhs: U256 = (&rhs).into();

// It's safe to unwrap here since the result is never out of range
FieldElement {
inner: Fr::from_bigint(u256_to_biginteger256(&(lhs & rhs))).unwrap(),
}
FieldElement(Fr::from_bigint(u256_to_biginteger256(&(lhs & rhs))).unwrap())
}
}

Expand All @@ -428,9 +424,7 @@ impl ops::BitOr<FieldElement> for FieldElement {
let rhs: U256 = (&rhs).into();

// It's safe to unwrap here since the result is never out of range
FieldElement {
inner: Fr::from_bigint(u256_to_biginteger256(&(lhs | rhs))).unwrap(),
}
FieldElement(Fr::from_bigint(u256_to_biginteger256(&(lhs | rhs))).unwrap())
}
}

Expand Down Expand Up @@ -605,33 +599,25 @@ mod serde_field_element {

impl From<u8> for FieldElement {
fn from(value: u8) -> Self {
Self {
inner: Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap(),
}
Self(Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap())
}
}

impl From<u16> for FieldElement {
fn from(value: u16) -> Self {
Self {
inner: Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap(),
}
Self(Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap())
}
}

impl From<u32> for FieldElement {
fn from(value: u32) -> Self {
Self {
inner: Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap(),
}
Self(Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap())
}
}

impl From<u64> for FieldElement {
fn from(value: u64) -> Self {
Self {
inner: Fr::from_bigint(BigInteger256::new([value, 0, 0, 0])).unwrap(),
}
Self(Fr::from_bigint(BigInteger256::new([value, 0, 0, 0])).unwrap())
}
}

Expand All @@ -640,17 +626,13 @@ impl From<u128> for FieldElement {
let low = value % (u64::MAX as u128 + 1);
let high = value / (u64::MAX as u128 + 1);

Self {
inner: Fr::from_bigint(BigInteger256::new([low as u64, high as u64, 0, 0])).unwrap(),
}
Self(Fr::from_bigint(BigInteger256::new([low as u64, high as u64, 0, 0])).unwrap())
}
}

impl From<usize> for FieldElement {
fn from(value: usize) -> Self {
Self {
inner: Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap(),
}
Self(Fr::from_bigint(BigInteger256::new([value as u64, 0, 0, 0])).unwrap())
}
}

Expand All @@ -670,7 +652,7 @@ impl TryFrom<FieldElement> for u8 {
type Error = ValueOutOfRangeError;

fn try_from(value: FieldElement) -> Result<Self, Self::Error> {
let repr = value.inner.into_bigint().0;
let repr = value.into_bigint().0;
if repr[0] > u8::MAX as u64 || repr[1] > 0 || repr[2] > 0 || repr[3] > 0 {
Err(ValueOutOfRangeError)
} else {
Expand All @@ -683,7 +665,7 @@ impl TryFrom<FieldElement> for u16 {
type Error = ValueOutOfRangeError;

fn try_from(value: FieldElement) -> Result<Self, Self::Error> {
let repr = value.inner.into_bigint().0;
let repr = value.into_bigint().0;
if repr[0] > u16::MAX as u64 || repr[1] > 0 || repr[2] > 0 || repr[3] > 0 {
Err(ValueOutOfRangeError)
} else {
Expand All @@ -696,7 +678,7 @@ impl TryFrom<FieldElement> for u32 {
type Error = ValueOutOfRangeError;

fn try_from(value: FieldElement) -> Result<Self, Self::Error> {
let repr = value.inner.into_bigint().0;
let repr = value.into_bigint().0;
if repr[0] > u32::MAX as u64 || repr[1] > 0 || repr[2] > 0 || repr[3] > 0 {
Err(ValueOutOfRangeError)
} else {
Expand All @@ -709,7 +691,7 @@ impl TryFrom<FieldElement> for u64 {
type Error = ValueOutOfRangeError;

fn try_from(value: FieldElement) -> Result<Self, Self::Error> {
let repr = value.inner.into_bigint().0;
let repr = value.into_bigint().0;
if repr[1] > 0 || repr[2] > 0 || repr[3] > 0 {
Err(ValueOutOfRangeError)
} else {
Expand All @@ -722,7 +704,7 @@ impl TryFrom<FieldElement> for u128 {
type Error = ValueOutOfRangeError;

fn try_from(value: FieldElement) -> Result<Self, Self::Error> {
let repr = value.inner.into_bigint().0;
let repr = value.into_bigint().0;
if repr[2] > 0 || repr[3] > 0 {
Err(ValueOutOfRangeError)
} else {
Expand All @@ -734,13 +716,13 @@ impl TryFrom<FieldElement> for u128 {
impl From<&FieldElement> for U256 {
#[cfg(target_pointer_width = "64")]
fn from(value: &FieldElement) -> Self {
U256::from_words(value.inner.into_bigint().0)
U256::from_words(value.into_bigint().0)
}

#[cfg(target_pointer_width = "32")]
fn from(value: &FieldElement) -> Self {
U256::from_words(unsafe {
core::mem::transmute::<[u64; 4], [u32; 8]>(value.inner.into_bigint().0)
core::mem::transmute::<[u64; 4], [u32; 8]>(value.into_bigint().0)
})
}
}
Expand Down

0 comments on commit 7247e7e

Please sign in to comment.