diff --git a/std/algebra/algopts/algopts.go b/std/algebra/algopts/algopts.go index 9792f87c62..2434fd57a8 100644 --- a/std/algebra/algopts/algopts.go +++ b/std/algebra/algopts/algopts.go @@ -11,6 +11,7 @@ type algebraCfg struct { NbScalarBits int FoldMulti bool CompleteArithmetic bool + ToBitsCanonical bool } // AlgebraOption allows modifying algebraic operation behaviour. @@ -57,6 +58,25 @@ func WithCompleteArithmetic() AlgebraOption { } } +// WithCanonicalBitRepresentation enforces the marshalling methods to assert +// that the bit representation is in canonical form. For field elements this +// means that the bits represent a number less than the modulus. +// +// This option is useful when performing direct comparison between the bit form +// of two elements. It can be avoided when the bit representation is used in +// other cases, such as computing a challenge using a hash function, where +// non-canonical bit representation leads to incorrect challenge (which in turn +// makes the verification fail). +func WithCanonicalBitRepresentation() AlgebraOption { + return func(ac *algebraCfg) error { + if ac.ToBitsCanonical { + return fmt.Errorf("WithCanonicalBitRepresentation already set") + } + ac.ToBitsCanonical = true + return nil + } +} + // NewConfig applies all given options and returns a configuration to be used. func NewConfig(opts ...AlgebraOption) (*algebraCfg, error) { ret := new(algebraCfg) diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 7e308f649d..e930a58d3a 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -3,12 +3,12 @@ package sw_emulated import ( "fmt" "math/big" + "slices" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" - "golang.org/x/exp/slices" ) // New returns a new [Curve] instance over the base field Base and scalar field @@ -101,26 +101,41 @@ type AffinePoint[Base emulated.FieldParams] struct { // MarshalScalar marshals the scalar into bits. Compatible with scalar // marshalling in gnark-crypto. -func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S]) []frontend.Variable { +func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S], opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } var fr S nbBits := 8 * ((fr.Modulus().BitLen() + 7) / 8) - sReduced := c.scalarApi.Reduce(&s) - res := c.scalarApi.ToBits(sReduced)[:nbBits] - for i, j := 0, nbBits-1; i < j; { - res[i], res[j] = res[j], res[i] - i++ - j-- + var sReduced *emulated.Element[S] + if cfg.ToBitsCanonical { + sReduced = c.scalarApi.ReduceStrict(&s) + } else { + sReduced = c.scalarApi.Reduce(&s) } + res := c.scalarApi.ToBits(sReduced)[:nbBits] + slices.Reverse(res) return res } // MarshalG1 marshals the affine point into bits. The output is compatible with // the point marshalling in gnark-crypto. -func (c *Curve[B, S]) MarshalG1(p AffinePoint[B]) []frontend.Variable { +func (c *Curve[B, S]) MarshalG1(p AffinePoint[B], opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } var fp B nbBits := 8 * ((fp.Modulus().BitLen() + 7) / 8) - x := c.baseApi.Reduce(&p.X) - y := c.baseApi.Reduce(&p.Y) + var x, y *emulated.Element[B] + if cfg.ToBitsCanonical { + x = c.baseApi.ReduceStrict(&p.X) + y = c.baseApi.ReduceStrict(&p.Y) + } else { + x = c.baseApi.Reduce(&p.X) + y = c.baseApi.Reduce(&p.Y) + } bx := c.baseApi.ToBits(x)[:nbBits] by := c.baseApi.ToBits(y)[:nbBits] slices.Reverse(bx) diff --git a/std/algebra/interfaces.go b/std/algebra/interfaces.go index c660fc441e..3775d0f922 100644 --- a/std/algebra/interfaces.go +++ b/std/algebra/interfaces.go @@ -55,10 +55,10 @@ type Curve[FR emulated.FieldParams, G1El G1ElementT] interface { // MarshalG1 returns the binary decomposition G1.X || G1.Y. It matches the // output of gnark-crypto's Marshal method on G1 points. - MarshalG1(G1El) []frontend.Variable + MarshalG1(G1El, ...algopts.AlgebraOption) []frontend.Variable // MarshalScalar returns the binary decomposition of the argument. - MarshalScalar(emulated.Element[FR]) []frontend.Variable + MarshalScalar(emulated.Element[FR], ...algopts.AlgebraOption) []frontend.Variable // Select sets p1 if b=1, p2 if b=0, and returns it. b must be boolean constrained Select(b frontend.Variable, p1 *G1El, p2 *G1El) *G1El diff --git a/std/algebra/native/sw_bls12377/pairing2.go b/std/algebra/native/sw_bls12377/pairing2.go index f977ab916d..2faddb4039 100644 --- a/std/algebra/native/sw_bls12377/pairing2.go +++ b/std/algebra/native/sw_bls12377/pairing2.go @@ -3,6 +3,7 @@ package sw_bls12377 import ( "fmt" "math/big" + "slices" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) { } // MarshalScalar returns -func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable { +func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8) - ss := c.fr.Reduce(&s) - x := c.fr.ToBits(ss) - for i, j := 0, nbBits-1; i < j; { - x[i], x[j] = x[j], x[i] - i++ - j-- + var ss *emulated.Element[ScalarField] + if cfg.ToBitsCanonical { + ss = c.fr.ReduceStrict(&s) + } else { + ss = c.fr.Reduce(&s) } + x := c.fr.ToBits(ss)[:nbBits] + slices.Reverse(x) return x } // MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are // in little endian. -func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable { +func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ecc.BLS12_377.BaseField().BitLen() + 7) / 8) + bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)} + if !cfg.ToBitsCanonical { + bOpts = append(bOpts, bits.OmitModulusCheck()) + } res := make([]frontend.Variable, 2*nbBits) - x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits)) - y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits)) + x := bits.ToBinary(c.api, P.X, bOpts...) + y := bits.ToBinary(c.api, P.Y, bOpts...) for i := 0; i < nbBits; i++ { res[i] = x[nbBits-1-i] res[i+nbBits] = y[nbBits-1-i] diff --git a/std/algebra/native/sw_bls24315/pairing2.go b/std/algebra/native/sw_bls24315/pairing2.go index 643314ef4a..fc5fac0ba8 100644 --- a/std/algebra/native/sw_bls24315/pairing2.go +++ b/std/algebra/native/sw_bls24315/pairing2.go @@ -3,6 +3,7 @@ package sw_bls24315 import ( "fmt" "math/big" + "slices" "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) { } // MarshalScalar returns -func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable { +func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8) - ss := c.fr.Reduce(&s) - x := c.fr.ToBits(ss) - for i, j := 0, nbBits-1; i < j; { - x[i], x[j] = x[j], x[i] - i++ - j-- + var ss *emulated.Element[ScalarField] + if cfg.ToBitsCanonical { + ss = c.fr.ReduceStrict(&s) + } else { + ss = c.fr.Reduce(&s) } + x := c.fr.ToBits(ss)[:nbBits] + slices.Reverse(x) return x } // MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are // in little endian. -func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable { +func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ecc.BLS24_315.BaseField().BitLen() + 7) / 8) + bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)} + if !cfg.ToBitsCanonical { + bOpts = append(bOpts, bits.OmitModulusCheck()) + } res := make([]frontend.Variable, 2*nbBits) - x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits)) - y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits)) + x := bits.ToBinary(c.api, P.X, bOpts...) + y := bits.ToBinary(c.api, P.Y, bOpts...) for i := 0; i < nbBits; i++ { res[i] = x[nbBits-1-i] res[i+nbBits] = y[nbBits-1-i] diff --git a/std/math/emulated/element.go b/std/math/emulated/element.go index f3da9d3c7c..171418d747 100644 --- a/std/math/emulated/element.go +++ b/std/math/emulated/element.go @@ -32,6 +32,12 @@ type Element[T FieldParams] struct { // enforcement info in the Element to prevent modifying the witness. internal bool + // modReduced indicates that the element has been reduced modulo the modulus + // and we have asserted that the integer value of the element is strictly + // less than the modulus. This is required for some operations which depend + // on the bit-representation of the element (ToBits, exponentiation etc.). + modReduced bool + isEvaluated bool evaluation frontend.Variable `gnark:"-"` } @@ -95,6 +101,11 @@ func (e *Element[T]) GnarkInitHook() { *e = ValueOf[T](0) e.internal = false // we need to constrain in later. } + // set modReduced to false - in case the circuit is compiled we may change + // the value for an existing element. If we don't reset it here, then during + // second compilation we may take a shortPath where we assume that modReduce + // flag is set. + e.modReduced = false } // copy makes a deep copy of the element. @@ -104,5 +115,6 @@ func (e *Element[T]) copy() *Element[T] { copy(r.Limbs, e.Limbs) r.overflow = e.overflow r.internal = e.internal + r.modReduced = e.modReduced return &r } diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 675f296596..140df18f12 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -1098,3 +1098,177 @@ func TestExp(t *testing.T) { testExp[BN254Fr](t) testExp[emparams.Mod1e512](t) } + +type ReduceStrictCircuit[T FieldParams] struct { + Limbs []frontend.Variable + Expected []frontend.Variable + strictReduce bool +} + +func (c *ReduceStrictCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + el := f.newInternalElement(c.Limbs, 0) + var elR *Element[T] + if c.strictReduce { + elR = f.ReduceStrict(el) + } else { + elR = f.Reduce(el) + } + for i := range elR.Limbs { + api.AssertIsEqual(elR.Limbs[i], c.Expected[i]) + } + return nil +} + +func testReduceStrict[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + p := fp.Modulus() + plimbs := make([]*big.Int, int(fp.NbLimbs())) + for i := range plimbs { + plimbs[i] = new(big.Int) + } + err := decompose(p, fp.BitsPerLimb(), plimbs) + assert.NoError(err) + plimbs[0].Add(plimbs[0], big.NewInt(1)) + exp := make([]*big.Int, int(fp.NbLimbs())) + exp[0] = big.NewInt(1) + for i := 1; i < int(fp.NbLimbs()); i++ { + exp[i] = big.NewInt(0) + } + circuitStrict := &ReduceStrictCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, int(fp.NbLimbs())), strictReduce: true} + circuitLax := &ReduceStrictCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, int(fp.NbLimbs()))} + witness := &ReduceStrictCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, int(fp.NbLimbs()))} + for i := range plimbs { + witness.Limbs[i] = plimbs[i] + witness.Expected[i] = exp[i] + } + assert.CheckCircuit(circuitStrict, test.WithValidAssignment(witness)) + assert.CheckCircuit(circuitLax, test.WithInvalidAssignment(witness)) + + }, testName[T]()) +} + +func TestReduceStrict(t *testing.T) { + testReduceStrict[Goldilocks](t) + testReduceStrict[BN254Fr](t) + testReduceStrict[emparams.Mod1e512](t) +} + +type ToBitsCanonicalCircuit[T FieldParams] struct { + Limbs []frontend.Variable + Expected []frontend.Variable +} + +func (c *ToBitsCanonicalCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + el := f.newInternalElement(c.Limbs, 0) + bts := f.ToBitsCanonical(el) + for i := range bts { + api.AssertIsEqual(bts[i], c.Expected[i]) + } + return nil +} + +func testToBitsCanonical[T FieldParams](t *testing.T) { + var fp T + nbBits := fp.Modulus().BitLen() + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + p := fp.Modulus() + plimbs := make([]*big.Int, int(fp.NbLimbs())) + for i := range plimbs { + plimbs[i] = new(big.Int) + } + err := decompose(p, fp.BitsPerLimb(), plimbs) + assert.NoError(err) + plimbs[0].Add(plimbs[0], big.NewInt(1)) + exp := make([]*big.Int, int(nbBits)) + exp[0] = big.NewInt(1) + for i := 1; i < len(exp); i++ { + exp[i] = big.NewInt(0) + } + circuit := &ToBitsCanonicalCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, nbBits)} + witness := &ToBitsCanonicalCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, nbBits)} + for i := range plimbs { + witness.Limbs[i] = plimbs[i] + } + for i := range exp { + witness.Expected[i] = exp[i] + } + assert.CheckCircuit(circuit, test.WithValidAssignment(witness)) + }, testName[T]()) +} + +func TestToBitsCanonical(t *testing.T) { + testToBitsCanonical[Goldilocks](t) + testToBitsCanonical[BN254Fr](t) + testToBitsCanonical[emparams.Mod1e512](t) +} + +type IsZeroEdgeCase[T FieldParams] struct { + Limbs []frontend.Variable + Expected frontend.Variable +} + +func (c *IsZeroEdgeCase[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + el := f.newInternalElement(c.Limbs, 0) + res := f.IsZero(el) + api.AssertIsEqual(res, c.Expected) + return nil +} + +func testIsZeroEdgeCases[T FieldParams](t *testing.T) { + var fp T + p := fp.Modulus() + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + plimbs := make([]*big.Int, int(fp.NbLimbs())) + for i := range plimbs { + plimbs[i] = new(big.Int) + } + err := decompose(p, fp.BitsPerLimb(), plimbs) + assert.NoError(err) + // limbs are for zero + witness1 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 1} + for i := range plimbs { + witness1.Limbs[i] = big.NewInt(0) + } + // limbs are for p + witness2 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 1} + for i := range plimbs { + witness2.Limbs[i] = plimbs[i] + } + // limbs are for not zero + witness3 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 0} + witness3.Limbs[0] = big.NewInt(1) + for i := 1; i < len(witness3.Limbs); i++ { + witness3.Limbs[i] = big.NewInt(0) + } + // limbs are for not zero bigger than p + witness4 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 0} + witness4.Limbs[0] = new(big.Int).Add(plimbs[0], big.NewInt(1)) + for i := 1; i < len(witness4.Limbs); i++ { + witness4.Limbs[i] = plimbs[i] + } + assert.CheckCircuit(&IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs()))}, test.WithValidAssignment(witness1), test.WithValidAssignment(witness2), test.WithValidAssignment(witness3), test.WithValidAssignment(witness4)) + + }, testName[T]()) +} + +func TestIsZeroEdgeCases(t *testing.T) { + testIsZeroEdgeCases[Goldilocks](t) + testIsZeroEdgeCases[BN254Fr](t) + testIsZeroEdgeCases[emparams.Mod1e512](t) +} diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 86ff353424..3cac83ea2c 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -50,8 +50,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { } // AssertIsLessOrEqual ensures that e is less or equal than a. For proper -// bitwise comparison first reduce the element using [Reduce] and then assert -// that its value is less than the modulus using [AssertIsInRange]. +// bitwise comparison first reduce the element using [Field.ReduceStrict]. func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { // we omit conditional width assertion as is done in ToBits below if e.overflow+a.overflow > 0 { @@ -91,16 +90,28 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { // it is not. For binary comparison the values have both to be below the // modulus. func (f *Field[T]) AssertIsInRange(a *Element[T]) { + // short path - this element is already enforced to be less than the modulus + if a.modReduced { + return + } // we omit conditional width assertion as is done in ToBits down the calling stack f.AssertIsLessOrEqual(a, f.modulusPrev()) + a.modReduced = true } // IsZero returns a boolean indicating if the element is strictly zero. The // method internally reduces the element and asserts that the value is less than // the modulus. func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { + // to avoid using strict reduction (which is expensive as requires binary + // assertion that value is less than modulus), we use ordinary reduction but + // in this case the result can be either 0 or p (if it is zero). + // + // so we check that the reduced value limbs are either all zeros or + // corrspond to the modulus limbs. ca := f.Reduce(a) - f.AssertIsInRange(ca) + p := f.Modulus() + // we use two approaches for checking if the element is exactly zero. The // first approach is to check that every limb individually is zero. The // second approach is to check if the sum of all limbs is zero. Usually, we @@ -109,23 +120,32 @@ func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { // then we can ensure in most cases that no overflows happen. // as ca is already reduced, then every limb overflow is already 0. Only - // every addition adds a bit to the overflow + // every addition adds a bit to the overflow. + var res0 frontend.Variable totalOverflow := len(ca.Limbs) - 1 if totalOverflow > int(f.maxOverflow()) { // the sums of limbs would overflow the native field. Use the first // approach instead. - res := f.api.IsZero(ca.Limbs[0]) + res0 = f.api.IsZero(ca.Limbs[0]) + for i := 1; i < len(ca.Limbs); i++ { + res0 = f.api.Mul(res0, f.api.IsZero(ca.Limbs[i])) + } + } else { + // default case, limbs sum does not overflow the native field + limbSum := ca.Limbs[0] for i := 1; i < len(ca.Limbs); i++ { - res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + limbSum = f.api.Add(limbSum, ca.Limbs[i]) } - return res + res0 = f.api.IsZero(limbSum) } - // default case, limbs sum does not overflow the native field - limbSum := ca.Limbs[0] + // however, for checking if the element is p, we can not use the + // optimization as we may have underflows. So we have to check every limb + // individually. + resP := f.api.IsZero(f.api.Sub(p.Limbs[0], ca.Limbs[0])) for i := 1; i < len(ca.Limbs); i++ { - limbSum = f.api.Add(limbSum, ca.Limbs[i]) + resP = f.api.Mul(resP, f.api.IsZero(f.api.Sub(p.Limbs[i], ca.Limbs[i]))) } - return f.api.IsZero(limbSum) + return f.api.Or(res0, resP) } // // Cmp returns: diff --git a/std/math/emulated/field_binary.go b/std/math/emulated/field_binary.go index 8c949af2e4..d2dd5f3d6b 100644 --- a/std/math/emulated/field_binary.go +++ b/std/math/emulated/field_binary.go @@ -8,8 +8,7 @@ import ( // ToBits returns the bit representation of the Element in little-endian (LSB // first) order. The returned bits are constrained to be 0-1. The number of // returned bits is nbLimbs*nbBits+overflow. To obtain the bits of the canonical -// representation of Element, reduce Element first and take less significant -// bits corresponding to the bitwidth of the emulated modulus. +// representation of Element, use method [Field.ToBitsCanonical]. func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable { f.enforceWidthConditional(a) ba, aConst := f.constantValue(a) @@ -34,6 +33,29 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable { return fullBits } +// ToBitsCanonical represents the unique bit representation in the canonical +// format (less that the modulus). +func (f *Field[T]) ToBitsCanonical(a *Element[T]) []frontend.Variable { + // TODO: implement a inline version of this function. We perform binary + // decomposition both in the `ReduceStrict` and `ToBits` methods, but we can + // essentially do them at the same time. + // + // If we do this, then also check in places where we use `Reduce` and + // `ToBits` after that manually (e.g. in point and scalar marshaling) and + // replace them with this method. + + var fp T + nbBits := fp.Modulus().BitLen() + // when the modulus is a power of 2, then we can remove the most significant + // bit as it is always zero. + if fp.Modulus().TrailingZeroBits() == uint(nbBits-1) { + nbBits-- + } + ca := f.ReduceStrict(a) + bts := f.ToBits(ca) + return bts[:nbBits] +} + // FromBits returns a new Element given the bits is little-endian order. func (f *Field[T]) FromBits(bs ...frontend.Variable) *Element[T] { nbLimbs := (uint(len(bs)) + f.fParams.BitsPerLimb() - 1) / f.fParams.BitsPerLimb() diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index a9f0d9cda3..9ee181b7fb 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -164,21 +164,6 @@ func (f *Field[T]) Sum(inputs ...*Element[T]) *Element[T] { return f.newInternalElement(limbs, overflow+uint(addOverflow)) } -// Reduce reduces a modulo the field order and returns it. -func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { - f.enforceWidthConditional(a) - if a.overflow == 0 { - // fast path - already reduced, omit reduction. - return a - } - // sanity check - if _, aConst := f.constantValue(a); aConst { - panic("trying to reduce a constant, which happen to have an overflow flag set") - } - // slow path - use hint to reduce value - return f.mulMod(a, f.One(), 0, nil) -} - // Sub subtracts b from a and returns it. Reduces locally if wouldn't fit into // Element. Doesn't mutate inputs. func (f *Field[T]) Sub(a, b *Element[T]) *Element[T] { diff --git a/std/math/emulated/field_reduce.go b/std/math/emulated/field_reduce.go new file mode 100644 index 0000000000..7fe55377c2 --- /dev/null +++ b/std/math/emulated/field_reduce.go @@ -0,0 +1,53 @@ +package emulated + +// ReduceWidth returns an element reduced by the modulus and constrained to have +// same length as the modulus. The output element has the same width as the +// modulus but may up to twice larger than the modulus). +// +// Does not mutate the input. +// +// In cases where the canonical representation of the element is required, use +// [Field.ReduceStrict]. +func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { + ret := f.reduce(a, false) + return ret +} + +func (f *Field[T]) reduce(a *Element[T], strict bool) *Element[T] { + f.enforceWidthConditional(a) + if a.modReduced { + // fast path - we are in the strict case and the element was just strictly reduced + return a + } + if !strict && a.overflow == 0 { + // fast path - we are in non-strict case and the element has no + // overflow. We don't need to reduce now. + return a + } + // rest of the cases: + // - in strict case and element was not recently reduced (even if it has no overflow) + // - in non-strict case and the element has overflow + + // sanity check + if _, aConst := f.constantValue(a); aConst { + panic("trying to reduce a constant, which happen to have an overflow flag set") + } + // slow path - use hint to reduce value + return f.mulMod(a, f.One(), 0, nil) +} + +// ReduceStrict returns an element reduced by the modulus. The output element +// has the same width as the modulus and is guaranteed to be less than the +// modulus. +// +// Does not mutate the input. +// +// This method is useful when the canonical representation of the element is +// required. For example, when the element is used in bitwise operations. This +// means that the reduction is enforced even when the overflow of the element is +// 0, but it has not been strictly reduced before. +func (f *Field[T]) ReduceStrict(a *Element[T]) *Element[T] { + ret := f.reduce(a, true) + f.AssertIsInRange(ret) + return ret +}