From 3da51dcc61d3c5c07aba5a548fe8b1721b0ebb3e Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Wed, 8 May 2024 18:20:12 -0400 Subject: [PATCH] Add GKR implementation of Logup lookups --- .../src/core/backend/cpu/lookups/gkr.rs | 331 +++++++++++++++++- crates/prover/src/core/lookups/gkr_prover.rs | 113 +++++- .../prover/src/core/lookups/gkr_verifier.rs | 18 +- crates/prover/src/core/lookups/utils.rs | 55 ++- 4 files changed, 487 insertions(+), 30 deletions(-) diff --git a/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/crates/prover/src/core/backend/cpu/lookups/gkr.rs index 86f73ab52..1c3afe207 100644 --- a/crates/prover/src/core/backend/cpu/lookups/gkr.rs +++ b/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -1,14 +1,17 @@ -use num_traits::Zero; +use std::ops::{Add, Index}; + +use num_traits::{One, Zero}; use crate::core::backend::CpuBackend; +use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::Field; +use crate::core::fields::{ExtensionOf, Field}; use crate::core::lookups::gkr_prover::{ correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, }; -use crate::core::lookups::mle::Mle; +use crate::core::lookups::mle::{Mle, MleOps}; use crate::core::lookups::sumcheck::MultivariatePolyOracle; -use crate::core::lookups::utils::UnivariatePoly; +use crate::core::lookups::utils::{Fraction, UnivariatePoly}; impl GkrOps for CpuBackend { fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { @@ -18,7 +21,17 @@ impl GkrOps for CpuBackend { fn next_layer(layer: &Layer) -> Layer { match layer { Layer::GrandProduct(layer) => next_grand_product_layer(layer), - Layer::_LogUp(_) => todo!(), + Layer::LogUpGeneric { + numerators, + denominators, + } => next_logup_layer(MleExpr::Mle(numerators), denominators), + Layer::LogUpMultiplicities { + numerators, + denominators, + } => next_logup_layer(MleExpr::Mle(numerators), denominators), + Layer::LogUpSingles { denominators } => { + next_logup_layer(MleExpr::Constant(BaseField::one()), denominators) + } } } @@ -32,11 +45,22 @@ impl GkrOps for CpuBackend { let eq_evals = h.eq_evals; // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. let y = eq_evals.y(); + let lambda = h.lambda; let input_layer = &h.input_layer; let (mut eval_at_0, mut eval_at_2) = match input_layer { Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms), - Layer::_LogUp(_) => todo!(), + Layer::LogUpGeneric { + numerators, + denominators, + } => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda), + Layer::LogUpMultiplicities { + numerators, + denominators, + } => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda), + Layer::LogUpSingles { denominators } => { + eval_logup_singles_sum(eq_evals, denominators, n_terms, lambda) + } }; eval_at_0 *= h.eq_fixed_var_correction; @@ -79,6 +103,134 @@ fn eval_grand_product_sum( (eval_at_0, eval_at_2) } +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_numer(r, t, x, 0) * inp_denom(r, t, x, 1) + +/// inp_numer(r, t, x, 1) * inp_denom(r, t, x, 0) + lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, +/// x, 1))` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_logup_sum( + eq_evals: &EqEvals, + input_numerators: &Mle, + input_denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) +where + SecureField: ExtensionOf + Field, +{ + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + for i in 0..n_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), {0, 1})`. + let inp_numer_at_r0i0 = input_numerators[i * 2]; + let inp_denom_at_r0i0 = input_denominators[i * 2]; + let inp_numer_at_r0i1 = input_numerators[i * 2 + 1]; + let inp_denom_at_r0i1 = input_denominators[i * 2 + 1]; + let inp_numer_at_r1i0 = input_numerators[(n_terms + i) * 2]; + let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2]; + let inp_numer_at_r1i1 = input_numerators[(n_terms + i) * 2 + 1]; + let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1]; + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2i0 = inp_numer_at_r1i0.double() - inp_numer_at_r0i0; + let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0; + let inp_numer_at_r2i1 = inp_numer_at_r1i1.double() - inp_numer_at_r0i1; + let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)` + // at points `(r, {0, 2}, bits(i))`. + let Fraction { + numerator: numer_at_r0i, + denominator: denom_at_r0i, + } = Fraction::new(inp_numer_at_r0i0, inp_denom_at_r0i0) + + Fraction::new(inp_numer_at_r0i1, inp_denom_at_r0i1); + let Fraction { + numerator: numer_at_r2i, + denominator: denom_at_r2i, + } = Fraction::new(inp_numer_at_r2i0, inp_denom_at_r2i0) + + Fraction::new(inp_numer_at_r2i1, inp_denom_at_r2i1); + + let eq_eval_at_0i = eq_evals[i]; + eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); + eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i); + } + + (eval_at_0, eval_at_2) +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) + +/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_logup_singles_sum( + eq_evals: &EqEvals, + input_denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) { + /// Represents the fraction `1 / x` + struct Reciprocal { + x: SecureField, + } + + impl Add for Reciprocal { + type Output = Fraction; + + fn add(self, rhs: Self) -> Fraction { + // `1/a + 1/b = (a + b)/(a * b)` + Fraction { + numerator: self.x + rhs.x, + denominator: self.x * rhs.x, + } + } + } + + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + for i in 0..n_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. + let inp_denom_at_r0i0 = input_denominators[i * 2]; + let inp_denom_at_r0i1 = input_denominators[i * 2 + 1]; + let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2]; + let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1]; + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0; + let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1; + + // Fraction addition polynomials at points: + // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)` + // at points `(r, {0, 2}, bits(i))`. + let Fraction { + numerator: numer_at_r0i, + denominator: denom_at_r0i, + } = Reciprocal { + x: inp_denom_at_r0i0, + } + Reciprocal { + x: inp_denom_at_r0i1, + }; + let Fraction { + numerator: numer_at_r2i, + denominator: denom_at_r2i, + } = Reciprocal { + x: inp_denom_at_r2i0, + } + Reciprocal { + x: inp_denom_at_r2i1, + }; + + let eq_eval_at_0i = eq_evals[i]; + eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); + eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i); + } + + (eval_at_0, eval_at_2) +} + /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. /// /// Evaluations are returned in bit-reversed order. @@ -104,18 +256,66 @@ fn next_grand_product_layer(layer: &Mle) -> Layer( + numerators: MleExpr<'_, F>, + denominators: &Mle, +) -> Layer +where + F: Field, + SecureField: ExtensionOf, + CpuBackend: MleOps, +{ + let half_n = 1 << (denominators.n_variables() - 1); + let mut next_numerators = Vec::with_capacity(half_n); + let mut next_denominators = Vec::with_capacity(half_n); + + for i in 0..half_n { + let a = Fraction::new(numerators[i * 2], denominators[i * 2]); + let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]); + let res = a + b; + next_numerators.push(res.numerator); + next_denominators.push(res.denominator); + } + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +enum MleExpr<'a, F: Field> { + Constant(F), + Mle(&'a Mle), +} + +impl<'a, F: Field> Index for MleExpr<'a, F> { + type Output = F; + + fn index(&self, index: usize) -> &F { + match self { + Self::Constant(v) => v, + Self::Mle(mle) => &mle[index], + } + } +} + #[cfg(test)] mod tests { + use std::iter::zip; + use num_traits::{One, Zero}; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; use crate::core::backend::CpuBackend; use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; use crate::core::lookups::mle::Mle; - use crate::core::lookups::utils::eq; + use crate::core::lookups::utils::{eq, Fraction}; use crate::core::test_utils::test_channel; #[test] @@ -150,11 +350,126 @@ mod tests { let GkrArtifact { ood_point: r, claims_to_verify_by_instance, - .. + n_variables_by_instance: _, } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; assert_eq!(proof.output_claims_by_instance, [vec![product]]); assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]); Ok(()) } + + #[test] + fn logup_with_generic_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let mut rng = SmallRng::seed_from_u64(0); + let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerator_values, &denominator_values) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerator_values); + let denominators = Mle::::new(denominator_values); + let top_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_singles_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + println!("{}", BaseField::from(2).inverse()); + println!("{}", BaseField::from(1) - BaseField::from(2).inverse()); + + let mut rng = SmallRng::seed_from_u64(0); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominator_values + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominator_values); + let top_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [SecureField::one(), denominators.eval_at_point(&ood_point)] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let mut rng = SmallRng::seed_from_u64(0); + let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerator_values, &denominator_values) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerator_values); + let denominators = Mle::::new(denominator_values); + let top_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } } diff --git a/crates/prover/src/core/lookups/gkr_prover.rs b/crates/prover/src/core/lookups/gkr_prover.rs index cae983e75..c0486f3b5 100644 --- a/crates/prover/src/core/lookups/gkr_prover.rs +++ b/crates/prover/src/core/lookups/gkr_prover.rs @@ -86,24 +86,45 @@ impl> Deref for EqEvals { /// /// [LogUp]: https://eprint.iacr.org/2023/1284.pdf pub enum Layer { - _LogUp(B), GrandProduct(Mle), + LogUpGeneric { + numerators: Mle, + denominators: Mle, + }, + LogUpMultiplicities { + numerators: Mle, + denominators: Mle, + }, + /// All numerators implicitly equal "1". + LogUpSingles { + denominators: Mle, + }, } impl Layer { /// Returns the number of variables used to interpolate the layer's gate values. fn n_variables(&self) -> usize { match self { - Self::_LogUp(_) => todo!(), - Self::GrandProduct(mle) => mle.n_variables(), + Self::GrandProduct(mle) + | Self::LogUpSingles { denominators: mle } + | Self::LogUpMultiplicities { + denominators: mle, .. + } + | Self::LogUpGeneric { + denominators: mle, .. + } => mle.n_variables(), } } + fn is_output_layer(&self) -> bool { + self.n_variables() == 0 + } + /// Produces the next layer from the current layer. /// /// The next layer is strictly half the size of the current layer. /// Returns [`None`] if called on an output layer. - fn next_layer(&self) -> Option> { + pub fn next_layer(&self) -> Option { if self.is_output_layer() { return None; } @@ -111,10 +132,6 @@ impl Layer { Some(B::next_layer(self)) } - fn is_output_layer(&self) -> bool { - self.n_variables() == 0 - } - /// Returns each column output if the layer is an output layer, otherwise returns an `Err`. fn try_into_output_layer_values(self) -> Result, NotOutputLayerError> { if !self.is_output_layer() { @@ -122,10 +139,30 @@ impl Layer { } Ok(match self { - Self::GrandProduct(col) => { + Layer::LogUpSingles { denominators } => { + let numerator = SecureField::one(); + let denominator = denominators.at(0); + vec![numerator, denominator] + } + Layer::LogUpMultiplicities { + numerators, + denominators, + } => { + let numerator = numerators.at(0).into(); + let denominator = denominators.at(0); + vec![numerator, denominator] + } + Layer::LogUpGeneric { + numerators, + denominators, + } => { + let numerator = numerators.at(0); + let denominator = denominators.at(0); + vec![numerator, denominator] + } + Layer::GrandProduct(col) => { vec![col.at(0)] } - Self::_LogUp(_) => todo!(), }) } @@ -136,8 +173,24 @@ impl Layer { } match self { - Self::_LogUp(_) => todo!(), Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)), + Self::LogUpGeneric { + numerators, + denominators, + } => Self::LogUpGeneric { + numerators: numerators.fix_first_variable(x0), + denominators: denominators.fix_first_variable(x0), + }, + Self::LogUpMultiplicities { + numerators, + denominators, + } => Self::LogUpGeneric { + numerators: numerators.fix_first_variable(x0), + denominators: denominators.fix_first_variable(x0), + }, + Self::LogUpSingles { denominators } => Self::LogUpSingles { + denominators: denominators.fix_first_variable(x0), + }, } } @@ -170,13 +223,14 @@ impl Layer { /// hypercube that interpolates `c_i`. fn into_multivariate_poly( self, - _lambda: SecureField, + lambda: SecureField, eq_evals: &EqEvals, ) -> GkrMultivariatePolyOracle<'_, B> { GkrMultivariatePolyOracle { eq_evals, input_layer: self, eq_fixed_var_correction: SecureField::one(), + lambda, } } } @@ -207,6 +261,8 @@ pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { pub eq_evals: &'a EqEvals, pub input_layer: Layer, pub eq_fixed_var_correction: SecureField, + /// Used by LogUp to perform a random linear combination of the numerators and denominators. + pub lambda: SecureField, } impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { @@ -219,7 +275,7 @@ impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> } fn fix_first_variable(self, challenge: SecureField) -> Self { - if self.n_variables() == 0 { + if self.is_constant() { return self; } @@ -230,11 +286,16 @@ impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> eq_evals: self.eq_evals, eq_fixed_var_correction, input_layer: self.input_layer.fix_first_variable(challenge), + lambda: self.lambda, } } } impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { + fn is_constant(&self) -> bool { + self.n_variables() == 0 + } + /// Returns all input layer columns restricted to a line. /// /// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants @@ -246,14 +307,30 @@ impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { /// /// For more context see page 64. fn try_into_mask(self) -> Result { - if self.n_variables() != 0 { + if !self.is_constant() { return Err(NotConstantPolyError); } - match self.input_layer { - Layer::_LogUp(_) => todo!(), - Layer::GrandProduct(mle) => Ok(GkrMask::new(vec![mle.to_cpu().try_into().unwrap()])), - } + let columns = match self.input_layer { + Layer::GrandProduct(mle) => vec![mle.to_cpu().try_into().unwrap()], + Layer::LogUpGeneric { + numerators, + denominators, + } => { + let numerators = numerators.to_cpu().try_into().unwrap(); + let denominators = denominators.to_cpu().try_into().unwrap(); + vec![numerators, denominators] + } + // Should never get called. + Layer::LogUpMultiplicities { .. } => unimplemented!(), + Layer::LogUpSingles { denominators } => { + let numerators = [SecureField::one(); 2]; + let denominators = denominators.to_cpu().try_into().unwrap(); + vec![numerators, denominators] + } + }; + + Ok(GkrMask::new(columns)) } } diff --git a/crates/prover/src/core/lookups/gkr_verifier.rs b/crates/prover/src/core/lookups/gkr_verifier.rs index 7c9598b79..b65ceb162 100644 --- a/crates/prover/src/core/lookups/gkr_verifier.rs +++ b/crates/prover/src/core/lookups/gkr_verifier.rs @@ -7,6 +7,7 @@ use crate::core::channel::Channel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::lookups::sumcheck; +use crate::core::lookups::utils::Fraction; /// Partially verifies a batch GKR proof. /// @@ -177,7 +178,7 @@ pub struct GkrArtifact { /// [Thaler13]: https://eprint.iacr.org/2013/351.pdf #[derive(Debug, Clone, Copy)] pub enum Gate { - _LogUp, + LogUp, GrandProduct, } @@ -185,7 +186,20 @@ impl Gate { /// Returns the output after applying the gate to the mask. fn eval(&self, mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { Ok(match self { - Self::_LogUp => todo!(), + Self::LogUp => { + if mask.columns().len() != 2 { + return Err(InvalidNumMaskColumnsError); + } + + let [numerator_a, numerator_b] = mask.columns()[0]; + let [denominator_a, denominator_b] = mask.columns()[1]; + + let a = Fraction::new(numerator_a, denominator_a); + let b = Fraction::new(numerator_b, denominator_b); + let res = a + b; + + vec![res.numerator, res.denominator] + } Self::GrandProduct => { if mask.columns().len() != 1 { return Err(InvalidNumMaskColumnsError); diff --git a/crates/prover/src/core/lookups/utils.rs b/crates/prover/src/core/lookups/utils.rs index e2aba5107..70adb4c64 100644 --- a/crates/prover/src/core/lookups/utils.rs +++ b/crates/prover/src/core/lookups/utils.rs @@ -1,7 +1,7 @@ -use std::iter::zip; +use std::iter::{zip, Sum}; use std::ops::{Add, Deref, Mul, Neg, Sub}; -use num_traits::Zero; +use num_traits::{One, Zero}; use crate::core::fields::qm31::SecureField; use crate::core::fields::{ExtensionOf, Field}; @@ -194,6 +194,57 @@ where assignment * (eval1 - eval0) + eval0 } +/// Projective fraction. +#[derive(Debug, Clone, Copy)] +pub struct Fraction { + pub numerator: F, + pub denominator: SecureField, +} + +impl Fraction { + pub fn new(numerator: F, denominator: SecureField) -> Self { + Self { + numerator, + denominator, + } + } +} + +impl Add for Fraction +where + F: Field, + SecureField: ExtensionOf + Field, +{ + type Output = Fraction; + + fn add(self, rhs: Self) -> Fraction { + Fraction { + numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator, + denominator: self.denominator * rhs.denominator, + } + } +} + +impl Zero for Fraction { + fn zero() -> Self { + Self { + numerator: SecureField::zero(), + denominator: SecureField::one(), + } + } + + fn is_zero(&self) -> bool { + self.numerator.is_zero() && !self.denominator.is_zero() + } +} + +impl Sum for Fraction { + fn sum>(mut iter: I) -> Self { + let first = iter.next().unwrap_or_else(Self::zero); + iter.fold(first, |a, b| a + b) + } +} + #[cfg(test)] mod tests { use std::iter::zip;