diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index 328194685c..cae64710d8 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -129,8 +129,6 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { // here bound is NOT a constant, // but a can be either constant or a wire. - _, aConst := builder.constantValue(a) - nbBits := builder.cs.FieldBitLen() aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs(), bits.OmitModulusCheck()) @@ -152,28 +150,18 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { // else // p[i] = p[i+1] * a[i] // t = 0 + v := builder.Mul(p[i+1], aBits[i]) - p[i] = builder.Select(boundBits[i], v, p[i+1]) + p[i] = builder.Select(boundBits[i], v, p[i+1]) t := builder.Select(boundBits[i], zero, p[i+1]) - - // (1 - t - ai) * ai == 0 - var l frontend.Variable - l = builder.cstOne() - l = builder.Sub(l, t, aBits[i]) - // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - if aConst { - // aBits[i] is a constant; - l = builder.Mul(l, aBits[i]) - // TODO @gbotrel this constraint seems useless. - added = append(added, builder.cs.AddR1C(builder.newR1C(l, zero, zero), builder.genericGate)) - } else { - added = append(added, builder.cs.AddR1C(builder.newR1C(l, aBits[i], zero), builder.genericGate)) - } + // (1 - t - ai) * ai == 0 + l := builder.Sub(builder.cstOne(), t, aBits[i]) + added = append(added, builder.cs.AddR1C(builder.newR1C(l, builder.Mul(aBits[i], builder.cstOne()), zero), builder.genericGate)) } if debug.Debug { diff --git a/frontend/cs/r1cs/builder.go b/frontend/cs/r1cs/builder.go index 3102addbb5..bb0ac3c50c 100644 --- a/frontend/cs/r1cs/builder.go +++ b/frontend/cs/r1cs/builder.go @@ -211,9 +211,7 @@ func (builder *builder) getLinearExpression(_l interface{}) constraint.LinearExp case constraint.LinearExpression: L = tl default: - if debug.Debug { - panic("invalid input for getLinearExpression") // sanity check - } + panic("invalid input for getLinearExpression") // sanity check } return L diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index bce25ec16b..dc183f3876 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -87,8 +87,10 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { // AssertIsDifferent fails if i1 == i2 func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { s := builder.Sub(i1, i2) - if c, ok := builder.constantValue(s); ok && c.IsZero() { - panic("AssertIsDifferent(x,x) will never be satisfied") + if c, ok := builder.constantValue(s); ok { + if c.IsZero() { + panic("AssertIsDifferent(x,x) will never be satisfied") + } } else if t := s.(expr.Term); t.Coeff.IsZero() { panic("AssertIsDifferent(x,x) will never be satisfied") } diff --git a/internal/regression_tests/issue1227/issue_1227_test.go b/internal/regression_tests/issue1227/issue_1227_test.go new file mode 100644 index 0000000000..a934299527 --- /dev/null +++ b/internal/regression_tests/issue1227/issue_1227_test.go @@ -0,0 +1,29 @@ +package issue1226 + +import ( + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type Circuit struct { + constVal int + X frontend.Variable +} + +func (circuit *Circuit) Define(api frontend.API) error { + api.AssertIsLessOrEqual(circuit.constVal, circuit.X) + return nil +} + +func TestConstantPath(t *testing.T) { + assert := test.NewAssert(t) + assert.CheckCircuit(&Circuit{constVal: 1}, + test.WithValidAssignment(&Circuit{X: 1}), // 1 <= 1 --> true + test.WithInvalidAssignment(&Circuit{X: 0})) // 1 <= 0 --> false + // test edge case where constant is 0 + assert.CheckCircuit(&Circuit{constVal: 0}, + test.WithValidAssignment(&Circuit{X: 1}), // 0 <= 1 --> true + test.WithValidAssignment(&Circuit{X: 0})) // 0 <= 0 --> true +} diff --git a/std/algebra/algopts/algopts.go b/std/algebra/algopts/algopts.go index 9792f87c62..2434fd57a8 100644 --- a/std/algebra/algopts/algopts.go +++ b/std/algebra/algopts/algopts.go @@ -11,6 +11,7 @@ type algebraCfg struct { NbScalarBits int FoldMulti bool CompleteArithmetic bool + ToBitsCanonical bool } // AlgebraOption allows modifying algebraic operation behaviour. @@ -57,6 +58,25 @@ func WithCompleteArithmetic() AlgebraOption { } } +// WithCanonicalBitRepresentation enforces the marshalling methods to assert +// that the bit representation is in canonical form. For field elements this +// means that the bits represent a number less than the modulus. +// +// This option is useful when performing direct comparison between the bit form +// of two elements. It can be avoided when the bit representation is used in +// other cases, such as computing a challenge using a hash function, where +// non-canonical bit representation leads to incorrect challenge (which in turn +// makes the verification fail). +func WithCanonicalBitRepresentation() AlgebraOption { + return func(ac *algebraCfg) error { + if ac.ToBitsCanonical { + return fmt.Errorf("WithCanonicalBitRepresentation already set") + } + ac.ToBitsCanonical = true + return nil + } +} + // NewConfig applies all given options and returns a configuration to be used. func NewConfig(opts ...AlgebraOption) (*algebraCfg, error) { ret := new(algebraCfg) diff --git a/std/algebra/emulated/fields_bls12381/e12_pairing.go b/std/algebra/emulated/fields_bls12381/e12_pairing.go index 7bbb60b7e5..a2a3f25ebc 100644 --- a/std/algebra/emulated/fields_bls12381/e12_pairing.go +++ b/std/algebra/emulated/fields_bls12381/e12_pairing.go @@ -384,14 +384,15 @@ func (e Ext12) FrobeniusSquareTorus(y *E6) *E6 { return &E6{B0: *t0, B1: *t1, B2: *t2} } -// FinalExponentiationCheck checks that a Miller function output x lies in the +// AssertFinalExponentiationIsOne checks that a Miller function output x lies in the // same equivalence class as the reduced pairing. This replaces the final // exponentiation step in-circuit. -// The method follows Section 4 of [On Proving Pairings] paper by A. Novakovic and L. Eagen. +// The method is inspired from [On Proving Pairings] paper by A. Novakovic and +// L. Eagen, and is based on a personal communication with A. Novakovic. // // [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf -func (e Ext12) FinalExponentiationCheck(x *E12) *E12 { - res, err := e.fp.NewHint(finalExpHint, 12, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) +func (e Ext12) AssertFinalExponentiationIsOne(x *E12) { + res, err := e.fp.NewHint(finalExpHint, 18, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -409,21 +410,27 @@ func (e Ext12) FinalExponentiationCheck(x *E12) *E12 { B2: E2{A0: *res[10], A1: *res[11]}, }, } + // constrain cubicNonResiduePower to be in Fp6 + scalingFactor := E12{ + C0: E6{ + B0: E2{A0: *res[12], A1: *res[13]}, + B1: E2{A0: *res[14], A1: *res[15]}, + B2: E2{A0: *res[16], A1: *res[17]}, + }, + C1: (*e.Ext6.Zero()), + } - // Check that x == residueWitness^r by checking that: - // x^k == residueWitness^(q-u) - // where k = (u-1)^2/3, u=-0xd201000000010000 the BLS12-381 seed - // and residueWitness from the hint. + // Check that x * scalingFactor == residueWitness^(q-u) + // where u=-0xd201000000010000 is the BLS12-381 seed, + // and residueWitness, scalingFactor from the hint. t0 := e.Frobenius(&residueWitness) // exponentiation by -u t1 := e.Expt(&residueWitness) t0 = e.Mul(t0, t1) - // exponentiation by U=(u-1)^2/3 - t1 = e.ExpByU(x) - e.AssertIsEqual(t0, t1) + t1 = e.Mul(x, &scalingFactor) - return nil + e.AssertIsEqual(t0, t1) } func (e Ext12) Frobenius(x *E12) *E12 { diff --git a/std/algebra/emulated/fields_bls12381/hints.go b/std/algebra/emulated/fields_bls12381/hints.go index 320aaacb9d..77863d3ffa 100644 --- a/std/algebra/emulated/fields_bls12381/hints.go +++ b/std/algebra/emulated/fields_bls12381/hints.go @@ -271,11 +271,11 @@ func divE12Hint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) erro } func finalExpHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error { - // This follows section 4.1 of https://eprint.iacr.org/2024/640.pdf (Th. 1) + // This is inspired from https://eprint.iacr.org/2024/640.pdf + // and based on a personal communication with the author Andrija Novakovic. return emulated.UnwrapHint(nativeInputs, nativeOutputs, func(mod *big.Int, inputs, outputs []*big.Int) error { - var millerLoop, residueWitness bls12381.E12 - var rInv big.Int + var millerLoop bls12381.E12 millerLoop.C0.B0.A0.SetBigInt(inputs[0]) millerLoop.C0.B0.A1.SetBigInt(inputs[1]) @@ -290,12 +290,71 @@ func finalExpHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) er millerLoop.C1.B2.A0.SetBigInt(inputs[10]) millerLoop.C1.B2.A1.SetBigInt(inputs[11]) - // compute r-th root: - // Exponentiate to rInv where - // rInv = 1/r mod (p^12-1)/r - rInv.SetString("169662389312441398885310937191698694666993326870281216192803558492181163400934408837135364582394949149589560242411491538960982200559697133935443307582773537814554128992403254243871087441488619811839498788505657962013599019994544063402394719913759780901881538869078447034832302535303591303383830742161317593225991746471557492001710830538428792119562309446698444646787667517629943447802199824630112988907247336627481159245442124709621313522294197747687500252452962523217400829932174349352696726049683687654879009114460723993703760367089269403767790334911644010940272722630305066645230222732316445557889124653426141642271480304669447694344127599708992364443461893123938202386892312748211835322692697497854107961493711137028209148238339237355911496376520814450515612396561384525661635220451168152178239892009375229296874955612623691164738926395993739297557487207643426168321070539996994036837992284584225139752716615623194417718962478029165908544042568334172107008712033983002554672734519081879196926275059798317879322062358113986901925780890205936071364647548199159506709147492864081514759663116291487638998943660232689862634717010538047493292265992334130695994203833154950619462266484292385471162124464248375625748097868775829652908052615424796255913420292818674303286242639225711610323988077268116737", 10) - residueWitness.Exp(millerLoop, &rInv) - + var root, rootPthInverse, root27thInverse, residueWitness, scalingFactor bls12381.E12 + var order3rd, order3rdPower, exponent, exponentInv, finalExpFactor, polyFactor big.Int + // polyFactor = (1-x)/3 + polyFactor.SetString("5044125407647214251", 10) + // finalExpFactor = ((q^12 - 1) / r) / (27 * polyFactor) + finalExpFactor.SetString("2366356426548243601069753987687709088104621721678962410379583120840019275952471579477684846670499039076873213559162845121989217658133790336552276567078487633052653005423051750848782286407340332979263075575489766963251914185767058009683318020965829271737924625612375201545022326908440428522712877494557944965298566001441468676802477524234094954960009227631543471415676620753242466901942121887152806837594306028649150255258504417829961387165043999299071444887652375514277477719817175923289019181393803729926249507024121957184340179467502106891835144220611408665090353102353194448552304429530104218473070114105759487413726485729058069746063140422361472585604626055492939586602274983146215294625774144156395553405525711143696689756441298365274341189385646499074862712688473936093315628166094221735056483459332831845007196600723053356837526749543765815988577005929923802636375670820616189737737304893769679803809426304143627363860243558537831172903494450556755190448279875942974830469855835666815454271389438587399739607656399812689280234103023464545891697941661992848552456326290792224091557256350095392859243101357349751064730561345062266850238821755009430903520645523345000326783803935359711318798844368754833295302563158150573540616830138810935344206231367357992991289265295323280", 10) + + // 1. get pth-root inverse + exponent.Mul(&finalExpFactor, big.NewInt(27)) + root.Exp(millerLoop, &exponent) + if root.IsOne() { + rootPthInverse.SetOne() + } else { + exponentInv.ModInverse(&exponent, &polyFactor) + exponent.Neg(&exponentInv).Mod(&exponent, &polyFactor) + rootPthInverse.Exp(root, &exponent) + } + + // 2.1. get order of 3rd primitive root + var three big.Int + three.SetUint64(3) + exponent.Mul(&polyFactor, &finalExpFactor) + root.Exp(millerLoop, &exponent) + if root.IsOne() { + order3rdPower.SetUint64(0) + } + root.Exp(root, &three) + if root.IsOne() { + order3rdPower.SetUint64(1) + } + root.Exp(root, &three) + if root.IsOne() { + order3rdPower.SetUint64(2) + } + root.Exp(root, &three) + if root.IsOne() { + order3rdPower.SetUint64(3) + } + + // 2.2. get 27th root inverse + if order3rdPower.Uint64() == 0 { + root27thInverse.SetOne() + } else { + order3rd.Exp(&three, &order3rdPower, nil) + exponent.Mul(&polyFactor, &finalExpFactor) + root.Exp(millerLoop, &exponent) + exponentInv.ModInverse(&exponent, &order3rd) + exponent.Neg(&exponentInv).Mod(&exponent, &order3rd) + root27thInverse.Exp(root, &exponent) + } + + // 2.3. shift the Miller loop result so that millerLoop * scalingFactor + // is of order finalExpFactor + scalingFactor.Mul(&rootPthInverse, &root27thInverse) + millerLoop.Mul(&millerLoop, &scalingFactor) + + // 3. get the witness residue + // + // lambda = q - u, the optimal exponent + var lambda big.Int + lambda.SetString("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129030796414117214202539", 10) + exponent.ModInverse(&lambda, &finalExpFactor) + residueWitness.Exp(millerLoop, &exponent) + + // return the witness residue residueWitness.C0.B0.A0.BigInt(outputs[0]) residueWitness.C0.B0.A1.BigInt(outputs[1]) residueWitness.C0.B1.A0.BigInt(outputs[2]) @@ -309,6 +368,14 @@ func finalExpHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) er residueWitness.C1.B2.A0.BigInt(outputs[10]) residueWitness.C1.B2.A1.BigInt(outputs[11]) + // return the scaling factor + scalingFactor.C0.B0.A0.BigInt(outputs[12]) + scalingFactor.C0.B0.A1.BigInt(outputs[13]) + scalingFactor.C0.B1.A0.BigInt(outputs[14]) + scalingFactor.C0.B1.A1.BigInt(outputs[15]) + scalingFactor.C0.B2.A0.BigInt(outputs[16]) + scalingFactor.C0.B2.A1.BigInt(outputs[17]) + return nil }) } diff --git a/std/algebra/emulated/fields_bn254/e12.go b/std/algebra/emulated/fields_bn254/e12.go index 22218351f7..1cc99fe365 100644 --- a/std/algebra/emulated/fields_bn254/e12.go +++ b/std/algebra/emulated/fields_bn254/e12.go @@ -173,6 +173,12 @@ func (e Ext12) CyclotomicSquare(x *E12) *E12 { } } +func (e Ext12) IsEqual(x, y *E12) frontend.Variable { + isC0Equal := e.Ext6.IsEqual(&x.C0, &y.C0) + isC1Equal := e.Ext6.IsEqual(&x.C1, &y.C1) + return e.api.And(isC0Equal, isC1Equal) +} + func (e Ext12) AssertIsEqual(x, y *E12) { e.Ext6.AssertIsEqual(&x.C0, &y.C0) e.Ext6.AssertIsEqual(&x.C1, &y.C1) diff --git a/std/algebra/emulated/fields_bn254/e12_pairing.go b/std/algebra/emulated/fields_bn254/e12_pairing.go index 64d0d11fb8..f7c93c75e7 100644 --- a/std/algebra/emulated/fields_bn254/e12_pairing.go +++ b/std/algebra/emulated/fields_bn254/e12_pairing.go @@ -422,13 +422,13 @@ func (e Ext12) FrobeniusCubeTorus(y *E6) *E6 { return res } -// FinalExponentiationCheck checks that a Miller function output x lies in the +// AssertFinalExponentiationIsOne checks that a Miller function output x lies in the // same equivalence class as the reduced pairing. This replaces the final // exponentiation step in-circuit. // The method follows Section 4 of [On Proving Pairings] paper by A. Novakovic and L. Eagen. // // [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf -func (e Ext12) FinalExponentiationCheck(x *E12) *E12 { +func (e Ext12) AssertFinalExponentiationIsOne(x *E12) { res, err := e.fp.NewHint(finalExpHint, 24, &x.C0.B0.A0, &x.C0.B0.A1, &x.C0.B1.A0, &x.C0.B1.A1, &x.C0.B2.A0, &x.C0.B2.A1, &x.C1.B0.A0, &x.C1.B0.A1, &x.C1.B1.A0, &x.C1.B1.A1, &x.C1.B2.A0, &x.C1.B2.A1) if err != nil { // err is non-nil only for invalid number of inputs @@ -474,8 +474,6 @@ func (e Ext12) FinalExponentiationCheck(x *E12) *E12 { t0 = e.Mul(t0, t1) e.AssertIsEqual(t0, t2) - - return nil } func (e Ext12) Frobenius(x *E12) *E12 { diff --git a/std/algebra/emulated/fields_bn254/e6.go b/std/algebra/emulated/fields_bn254/e6.go index 584043114c..feca59a5f7 100644 --- a/std/algebra/emulated/fields_bn254/e6.go +++ b/std/algebra/emulated/fields_bn254/e6.go @@ -382,6 +382,15 @@ func (e Ext6) FrobeniusSquare(x *E6) *E6 { return &E6{B0: x.B0, B1: *z01, B2: *z02} } +func (e Ext6) IsEqual(x, y *E6) frontend.Variable { + isB0Equal := e.Ext2.IsEqual(&x.B0, &y.B0) + isB1Equal := e.Ext2.IsEqual(&x.B1, &y.B1) + isB2Equal := e.Ext2.IsEqual(&x.B2, &y.B2) + res := e.api.And(isB0Equal, isB1Equal) + res = e.api.And(res, isB2Equal) + return res +} + func (e Ext6) AssertIsEqual(x, y *E6) { e.Ext2.AssertIsEqual(&x.B0, &y.B0) e.Ext2.AssertIsEqual(&x.B1, &y.B1) diff --git a/std/algebra/emulated/fields_bw6761/e6_pairing.go b/std/algebra/emulated/fields_bw6761/e6_pairing.go index 8361ed8146..8907c0535b 100644 --- a/std/algebra/emulated/fields_bw6761/e6_pairing.go +++ b/std/algebra/emulated/fields_bw6761/e6_pairing.go @@ -322,13 +322,13 @@ func (e *Ext6) MulBy02345(z *E6, x [5]*baseEl) *E6 { } } -// FinalExponentiationCheck checks that a Miller function output x lies in the +// AssertFinalExponentiationIsOne checks that a Miller function output x lies in the // same equivalence class as the reduced pairing. This replaces the final // exponentiation step in-circuit. // The method is adapted from Section 4 of [On Proving Pairings] paper by A. Novakovic and L. Eagen. // // [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf -func (e Ext6) FinalExponentiationCheck(x *E6) *E6 { +func (e Ext6) AssertFinalExponentiationIsOne(x *E6) { res, err := e.fp.NewHint(finalExpHint, 6, &x.A0, &x.A1, &x.A2, &x.A3, &x.A4, &x.A5) if err != nil { // err is non-nil only for invalid number of inputs @@ -357,8 +357,6 @@ func (e Ext6) FinalExponentiationCheck(x *E6) *E6 { t0 = e.DivUnchecked(t0, t1) e.AssertIsEqual(t0, x) - - return nil } // ExpByU2 set z to z^(x₀+1) in E12 and return z diff --git a/std/algebra/emulated/sw_bls12381/pairing.go b/std/algebra/emulated/sw_bls12381/pairing.go index c5f96e62d8..4bc99671d6 100644 --- a/std/algebra/emulated/sw_bls12381/pairing.go +++ b/std/algebra/emulated/sw_bls12381/pairing.go @@ -251,7 +251,7 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { } // We perform the easy part of the final exp to push f to the cyclotomic - // subgroup so that FinalExponentiationCheck is carried with optimized + // subgroup so that AssertFinalExponentiationIsOne is carried with optimized // cyclotomic squaring (e.g. Karabina12345). // // f = f^(p⁶-1)(p²+1) @@ -260,7 +260,7 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { f = pr.FrobeniusSquare(buf) f = pr.Mul(f, buf) - pr.FinalExponentiationCheck(f) + pr.AssertFinalExponentiationIsOne(f) return nil } diff --git a/std/algebra/emulated/sw_bls12381/pairing_test.go b/std/algebra/emulated/sw_bls12381/pairing_test.go index dc723eb012..9ffdd18d56 100644 --- a/std/algebra/emulated/sw_bls12381/pairing_test.go +++ b/std/algebra/emulated/sw_bls12381/pairing_test.go @@ -177,7 +177,7 @@ func (c *PairingCheckCircuit) Define(api frontend.API) error { if err != nil { return fmt.Errorf("new pairing: %w", err) } - err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2}) if err != nil { return fmt.Errorf("pair: %w", err) } @@ -186,10 +186,13 @@ func (c *PairingCheckCircuit) Define(api frontend.API) error { func TestPairingCheckTestSolve(t *testing.T) { assert := test.NewAssert(t) + // e(a,2b) * e(-2a,b) == 1 p1, q1 := randomG1G2Affines() - _, q2 := randomG1G2Affines() var p2 bls12381.G1Affine - p2.Neg(&p1) + p2.Double(&p1).Neg(&p2) + var q2 bls12381.G2Affine + q2.Set(&q1) + q1.Double(&q1) witness := PairingCheckCircuit{ In1G1: NewG1Affine(p1), In1G2: NewG2Affine(q1), @@ -228,11 +231,13 @@ func TestGroupMembershipSolve(t *testing.T) { // bench func BenchmarkPairing(b *testing.B) { - + // e(a,2b) * e(-2a,b) == 1 p1, q1 := randomG1G2Affines() - _, q2 := randomG1G2Affines() var p2 bls12381.G1Affine - p2.Neg(&p1) + p2.Double(&p1).Neg(&p2) + var q2 bls12381.G2Affine + q2.Set(&q1) + q1.Double(&q1) witness := PairingCheckCircuit{ In1G1: NewG1Affine(p1), In1G2: NewG2Affine(q1), diff --git a/std/algebra/emulated/sw_bn254/hints.go b/std/algebra/emulated/sw_bn254/hints.go index d7700dd24a..3c4d5642fe 100644 --- a/std/algebra/emulated/sw_bn254/hints.go +++ b/std/algebra/emulated/sw_bn254/hints.go @@ -157,12 +157,6 @@ func millerLoopAndCheckFinalExpHint(nativeMod *big.Int, nativeInputs, nativeOutp cubicNonResiduePower.C0.B1.A1.BigInt(outputs[15]) cubicNonResiduePower.C0.B2.A0.BigInt(outputs[16]) cubicNonResiduePower.C0.B2.A1.BigInt(outputs[17]) - cubicNonResiduePower.C1.B0.A0.BigInt(outputs[18]) - cubicNonResiduePower.C1.B0.A1.BigInt(outputs[19]) - cubicNonResiduePower.C1.B1.A0.BigInt(outputs[20]) - cubicNonResiduePower.C1.B1.A1.BigInt(outputs[21]) - cubicNonResiduePower.C1.B2.A0.BigInt(outputs[22]) - cubicNonResiduePower.C1.B2.A1.BigInt(outputs[23]) return nil }) diff --git a/std/algebra/emulated/sw_bn254/pairing.go b/std/algebra/emulated/sw_bn254/pairing.go index 3cc9ca67e8..61a7e051e6 100644 --- a/std/algebra/emulated/sw_bn254/pairing.go +++ b/std/algebra/emulated/sw_bn254/pairing.go @@ -251,7 +251,7 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { } // We perform the easy part of the final exp to push f to the cyclotomic - // subgroup so that FinalExponentiationCheck is carried with optimized + // subgroup so that AssertFinalExponentiationIsOne is carried with optimized // cyclotomic squaring (e.g. Karabina12345). // // f = f^(p⁶-1)(p²+1) @@ -260,11 +260,15 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { f = pr.FrobeniusSquare(buf) f = pr.Mul(f, buf) - pr.FinalExponentiationCheck(f) + pr.AssertFinalExponentiationIsOne(f) return nil } +func (pr Pairing) IsEqual(x, y *GTEl) frontend.Variable { + return pr.Ext12.IsEqual(x, y) +} + func (pr Pairing) AssertIsEqual(x, y *GTEl) { pr.Ext12.AssertIsEqual(x, y) } @@ -679,19 +683,12 @@ func (pr Pairing) MillerLoopAndMul(P *G1Affine, Q *G2Affine, previous *GTEl) (*G return res, err } -// MillerLoopAndFinalExpCheck computes the Miller loop between P and Q, -// multiplies it in 𝔽p¹² by previous and checks that the result lies in the -// same equivalence class as the reduced pairing purported to be 1. This check -// replaces the final exponentiation step in-circuit and follows Section 4 of -// [On Proving Pairings] paper by A. Novakovic and L. Eagen. -// -// This method is needed for evmprecompiles/ecpair. -// -// [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf -func (pr Pairing) MillerLoopAndFinalExpCheck(P *G1Affine, Q *G2Affine, previous *GTEl) error { +// millerLoopAndFinalExpResult computes the Miller loop between P and Q, +// multiplies it in 𝔽p¹² by previous and returns the result. +func (pr Pairing) millerLoopAndFinalExpResult(P *G1Affine, Q *G2Affine, previous *GTEl) *GTEl { // hint the non-residue witness - hint, err := pr.curveF.NewHint(millerLoopAndCheckFinalExpHint, 24, &P.X, &P.Y, &Q.P.X.A0, &Q.P.X.A1, &Q.P.Y.A0, &Q.P.Y.A1, &previous.C0.B0.A0, &previous.C0.B0.A1, &previous.C0.B1.A0, &previous.C0.B1.A1, &previous.C0.B2.A0, &previous.C0.B2.A1, &previous.C1.B0.A0, &previous.C1.B0.A1, &previous.C1.B1.A0, &previous.C1.B1.A1, &previous.C1.B2.A0, &previous.C1.B2.A1) + hint, err := pr.curveF.NewHint(millerLoopAndCheckFinalExpHint, 18, &P.X, &P.Y, &Q.P.X.A0, &Q.P.X.A1, &Q.P.Y.A0, &Q.P.Y.A1, &previous.C0.B0.A0, &previous.C0.B0.A1, &previous.C0.B1.A0, &previous.C0.B1.A1, &previous.C0.B2.A0, &previous.C0.B2.A1, &previous.C1.B0.A0, &previous.C1.B0.A1, &previous.C1.B1.A0, &previous.C1.B1.A1, &previous.C1.B2.A0, &previous.C1.B2.A1) if err != nil { // err is non-nil only for invalid number of inputs panic(err) @@ -776,7 +773,7 @@ func (pr Pairing) MillerLoopAndFinalExpCheck(P *G1Affine, Q *G2Affine, previous // (ℓ × ℓ) × res res = pr.MulBy01234(res, prodLines) default: - return nil + panic(fmt.Sprintf("invalid loop counter value %d", loopCounter[i])) } } @@ -810,7 +807,36 @@ func (pr Pairing) MillerLoopAndFinalExpCheck(P *G1Affine, Q *G2Affine, previous t2 = pr.Mul(t2, t1) - pr.AssertIsEqual(t2, pr.One()) + return t2 +} - return nil +// IsMillerLoopAndFinalExpOne computes the Miller loop between P and Q, +// multiplies it in 𝔽p¹² by previous and and returns a boolean indicating if +// the result lies in the same equivalence class as the reduced pairing +// purported to be 1. This check replaces the final exponentiation step +// in-circuit and follows Section 4 of [On Proving Pairings] paper by A. +// Novakovic and L. Eagen. +// +// This method is needed for evmprecompiles/ecpair. +// +// [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf +func (pr Pairing) IsMillerLoopAndFinalExpOne(P *G1Affine, Q *G2Affine, previous *GTEl) frontend.Variable { + t2 := pr.millerLoopAndFinalExpResult(P, Q, previous) + + res := pr.IsEqual(t2, pr.One()) + return res +} + +// AssertMillerLoopAndFinalExpIsOne computes the Miller loop between P and Q, +// multiplies it in 𝔽p¹² by previous and checks that the result lies in the +// same equivalence class as the reduced pairing purported to be 1. This check +// replaces the final exponentiation step in-circuit and follows Section 4 of +// [On Proving Pairings] paper by A. Novakovic and L. Eagen. +// +// This method is needed for evmprecompiles/ecpair. +// +// [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf +func (pr Pairing) AssertMillerLoopAndFinalExpIsOne(P *G1Affine, Q *G2Affine, previous *GTEl) { + t2 := pr.millerLoopAndFinalExpResult(P, Q, previous) + pr.AssertIsEqual(t2, pr.One()) } diff --git a/std/algebra/emulated/sw_bn254/pairing_test.go b/std/algebra/emulated/sw_bn254/pairing_test.go index b6a9a751e2..8b49607441 100644 --- a/std/algebra/emulated/sw_bn254/pairing_test.go +++ b/std/algebra/emulated/sw_bn254/pairing_test.go @@ -100,6 +100,51 @@ func TestMillerLoopTestSolve(t *testing.T) { assert.NoError(err) } +type MillerLoopAndMulCircuit struct { + Prev GTEl + P G1Affine + Q G2Affine + Current GTEl +} + +func (c *MillerLoopAndMulCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res, err := pairing.MillerLoopAndMul(&c.P, &c.Q, &c.Prev) + if err != nil { + return fmt.Errorf("pair: %w", err) + } + pairing.AssertIsEqual(res, &c.Current) + return nil + +} + +func TestMillerLoopAndMulTestSolve(t *testing.T) { + assert := test.NewAssert(t) + var prev, curr bn254.GT + prev.SetRandom() + p, q := randomG1G2Affines() + lines := bn254.PrecomputeLines(q) + // need to use ML with precomputed lines. Otherwise, the result will be different + mlres, err := bn254.MillerLoopFixedQ( + []bn254.G1Affine{p}, + [][2][len(bn254.LoopCounter)]bn254.LineEvaluationAff{lines}, + ) + assert.NoError(err) + curr.Mul(&prev, &mlres) + + witness := MillerLoopAndMulCircuit{ + Prev: NewGTEl(prev), + P: NewG1Affine(p), + Q: NewG2Affine(q), + Current: NewGTEl(curr), + } + err = test.IsSolved(&MillerLoopAndMulCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + type PairCircuit struct { InG1 G1Affine InG2 G2Affine @@ -214,7 +259,7 @@ func (c *PairingCheckCircuit) Define(api frontend.API) error { if err != nil { return fmt.Errorf("new pairing: %w", err) } - err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2}) if err != nil { return fmt.Errorf("pair: %w", err) } @@ -223,10 +268,13 @@ func (c *PairingCheckCircuit) Define(api frontend.API) error { func TestPairingCheckTestSolve(t *testing.T) { assert := test.NewAssert(t) + // e(a,2b) * e(-2a,b) == 1 p1, q1 := randomG1G2Affines() - _, q2 := randomG1G2Affines() var p2 bn254.G1Affine - p2.Neg(&p1) + p2.Double(&p1).Neg(&p2) + var q2 bn254.G2Affine + q2.Set(&q1) + q1.Double(&q1) witness := PairingCheckCircuit{ In1G1: NewG1Affine(p1), In1G2: NewG2Affine(q1), @@ -371,13 +419,74 @@ func TestIsOnG2Solve(t *testing.T) { assert.NoError(err) } +type IsMillerLoopAndFinalExpCircuit struct { + Prev GTEl + P G1Affine + Q G2Affine + Expected frontend.Variable +} + +func (c *IsMillerLoopAndFinalExpCircuit) Define(api frontend.API) error { + pairing, err := NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + res := pairing.IsMillerLoopAndFinalExpOne(&c.P, &c.Q, &c.Prev) + api.AssertIsEqual(res, c.Expected) + return nil + +} + +func TestIsMillerLoopAndFinalExpCircuitTestSolve(t *testing.T) { + assert := test.NewAssert(t) + p, q := randomG1G2Affines() + + var np bn254.G1Affine + np.Neg(&p) + + ok, err := bn254.PairingCheck([]bn254.G1Affine{p, np}, []bn254.G2Affine{q, q}) + assert.NoError(err) + assert.True(ok) + + lines := bn254.PrecomputeLines(q) + // need to use ML with precomputed lines. Otherwise, the result will be different + mlres, err := bn254.MillerLoopFixedQ( + []bn254.G1Affine{p}, + [][2][len(bn254.LoopCounter)]bn254.LineEvaluationAff{lines}, + ) + assert.NoError(err) + + witness := IsMillerLoopAndFinalExpCircuit{ + Prev: NewGTEl(mlres), + P: NewG1Affine(np), + Q: NewG2Affine(q), + Expected: 1, + } + err = test.IsSolved(&IsMillerLoopAndFinalExpCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) + + var randPrev bn254.GT + randPrev.SetRandom() + + witness = IsMillerLoopAndFinalExpCircuit{ + Prev: NewGTEl(randPrev), + P: NewG1Affine(np), + Q: NewG2Affine(q), + Expected: 0, + } + err = test.IsSolved(&IsMillerLoopAndFinalExpCircuit{}, &witness, ecc.BN254.ScalarField()) + assert.NoError(err) +} + // bench func BenchmarkPairing(b *testing.B) { - + // e(a,2b) * e(-2a,b) == 1 p1, q1 := randomG1G2Affines() - _, q2 := randomG1G2Affines() var p2 bn254.G1Affine - p2.Neg(&p1) + p2.Double(&p1).Neg(&p2) + var q2 bn254.G2Affine + q2.Set(&q1) + q1.Double(&q1) witness := PairingCheckCircuit{ In1G1: NewG1Affine(p1), In1G2: NewG2Affine(q1), diff --git a/std/algebra/emulated/sw_bw6761/pairing.go b/std/algebra/emulated/sw_bw6761/pairing.go index 88485288cb..47ad915567 100644 --- a/std/algebra/emulated/sw_bw6761/pairing.go +++ b/std/algebra/emulated/sw_bw6761/pairing.go @@ -147,7 +147,7 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { } // We perform the easy part of the final exp to push f to the cyclotomic - // subgroup so that FinalExponentiationCheck is carried with optimized + // subgroup so that AssertFinalExponentiationIsOne is carried with optimized // cyclotomic squaring (e.g. Karabina12345). // // f = f^(p³-1)(p+1) @@ -156,7 +156,7 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { f = pr.Frobenius(buf) f = pr.Mul(f, buf) - pr.FinalExponentiationCheck(f) + pr.AssertFinalExponentiationIsOne(f) return nil } diff --git a/std/algebra/emulated/sw_bw6761/pairing_test.go b/std/algebra/emulated/sw_bw6761/pairing_test.go index 06b3276afa..65f087a555 100644 --- a/std/algebra/emulated/sw_bw6761/pairing_test.go +++ b/std/algebra/emulated/sw_bw6761/pairing_test.go @@ -175,7 +175,7 @@ func (c *PairingCheckCircuit) Define(api frontend.API) error { if err != nil { return fmt.Errorf("new pairing: %w", err) } - err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2}) + err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2}) if err != nil { return fmt.Errorf("pair: %w", err) } @@ -184,10 +184,13 @@ func (c *PairingCheckCircuit) Define(api frontend.API) error { func TestPairingCheckTestSolve(t *testing.T) { assert := test.NewAssert(t) + // e(a,2b) * e(-2a,b) == 1 p1, q1 := randomG1G2Affines() - _, q2 := randomG1G2Affines() var p2 bw6761.G1Affine - p2.Neg(&p1) + p2.Double(&p1).Neg(&p2) + var q2 bw6761.G2Affine + q2.Set(&q1) + q1.Double(&q1) witness := PairingCheckCircuit{ In1G1: NewG1Affine(p1), In1G2: NewG2Affine(q1), @@ -226,16 +229,18 @@ func TestGroupMembershipSolve(t *testing.T) { // bench func BenchmarkPairing(b *testing.B) { - - p, q := randomG1G2Affines() - res, err := bw6761.Pair([]bw6761.G1Affine{p}, []bw6761.G2Affine{q}) - if err != nil { - b.Fatal(err) - } - witness := PairCircuit{ - InG1: NewG1Affine(p), - InG2: NewG2Affine(q), - Res: NewGTEl(res), + // e(a,2b) * e(-2a,b) == 1 + p1, q1 := randomG1G2Affines() + var p2 bw6761.G1Affine + p2.Double(&p1).Neg(&p2) + var q2 bw6761.G2Affine + q2.Set(&q1) + q1.Double(&q1) + witness := PairingCheckCircuit{ + In1G1: NewG1Affine(p1), + In1G2: NewG2Affine(q1), + In2G1: NewG1Affine(p2), + In2G2: NewG2Affine(q2), } w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) if err != nil { @@ -245,7 +250,7 @@ func BenchmarkPairing(b *testing.B) { b.Run("compile scs", func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &PairCircuit{}); err != nil { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &PairingCheckCircuit{}); err != nil { b.Fatal(err) } } @@ -267,7 +272,7 @@ func BenchmarkPairing(b *testing.B) { b.Run("compile r1cs", func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &PairCircuit{}); err != nil { + if ccs, err = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &PairingCheckCircuit{}); err != nil { b.Fatal(err) } } diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index 7e308f649d..e930a58d3a 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -3,12 +3,12 @@ package sw_emulated import ( "fmt" "math/big" + "slices" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/algopts" "github.com/consensys/gnark/std/math/emulated" "github.com/consensys/gnark/std/math/emulated/emparams" - "golang.org/x/exp/slices" ) // New returns a new [Curve] instance over the base field Base and scalar field @@ -101,26 +101,41 @@ type AffinePoint[Base emulated.FieldParams] struct { // MarshalScalar marshals the scalar into bits. Compatible with scalar // marshalling in gnark-crypto. -func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S]) []frontend.Variable { +func (c *Curve[B, S]) MarshalScalar(s emulated.Element[S], opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } var fr S nbBits := 8 * ((fr.Modulus().BitLen() + 7) / 8) - sReduced := c.scalarApi.Reduce(&s) - res := c.scalarApi.ToBits(sReduced)[:nbBits] - for i, j := 0, nbBits-1; i < j; { - res[i], res[j] = res[j], res[i] - i++ - j-- + var sReduced *emulated.Element[S] + if cfg.ToBitsCanonical { + sReduced = c.scalarApi.ReduceStrict(&s) + } else { + sReduced = c.scalarApi.Reduce(&s) } + res := c.scalarApi.ToBits(sReduced)[:nbBits] + slices.Reverse(res) return res } // MarshalG1 marshals the affine point into bits. The output is compatible with // the point marshalling in gnark-crypto. -func (c *Curve[B, S]) MarshalG1(p AffinePoint[B]) []frontend.Variable { +func (c *Curve[B, S]) MarshalG1(p AffinePoint[B], opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } var fp B nbBits := 8 * ((fp.Modulus().BitLen() + 7) / 8) - x := c.baseApi.Reduce(&p.X) - y := c.baseApi.Reduce(&p.Y) + var x, y *emulated.Element[B] + if cfg.ToBitsCanonical { + x = c.baseApi.ReduceStrict(&p.X) + y = c.baseApi.ReduceStrict(&p.Y) + } else { + x = c.baseApi.Reduce(&p.X) + y = c.baseApi.Reduce(&p.Y) + } bx := c.baseApi.ToBits(x)[:nbBits] by := c.baseApi.ToBits(y)[:nbBits] slices.Reverse(bx) diff --git a/std/algebra/interfaces.go b/std/algebra/interfaces.go index c660fc441e..3775d0f922 100644 --- a/std/algebra/interfaces.go +++ b/std/algebra/interfaces.go @@ -55,10 +55,10 @@ type Curve[FR emulated.FieldParams, G1El G1ElementT] interface { // MarshalG1 returns the binary decomposition G1.X || G1.Y. It matches the // output of gnark-crypto's Marshal method on G1 points. - MarshalG1(G1El) []frontend.Variable + MarshalG1(G1El, ...algopts.AlgebraOption) []frontend.Variable // MarshalScalar returns the binary decomposition of the argument. - MarshalScalar(emulated.Element[FR]) []frontend.Variable + MarshalScalar(emulated.Element[FR], ...algopts.AlgebraOption) []frontend.Variable // Select sets p1 if b=1, p2 if b=0, and returns it. b must be boolean constrained Select(b frontend.Variable, p1 *G1El, p2 *G1El) *G1El diff --git a/std/algebra/native/fields_bls12377/e12_pairing.go b/std/algebra/native/fields_bls12377/e12_pairing.go index a4895b425d..de72c44645 100644 --- a/std/algebra/native/fields_bls12377/e12_pairing.go +++ b/std/algebra/native/fields_bls12377/e12_pairing.go @@ -1,6 +1,8 @@ package fields_bls12377 -import "github.com/consensys/gnark/frontend" +import ( + "github.com/consensys/gnark/frontend" +) // nSquareKarabina2345 repeated compressed cyclotmic square func (e *E12) nSquareKarabina2345(api frontend.API, n int) { @@ -125,7 +127,7 @@ func (e *E12) MulBy01234(api frontend.API, x [5]E2) *E12 { return e } -// ExpX0 compute e1^X0, where X0=9586122913090633729 +// ExpX0 compute e1^X0, where X0=0x8508c00000000001 func (e *E12) ExpX0(api frontend.API, e1 E12) *E12 { res := e1 @@ -148,7 +150,7 @@ func (e *E12) ExpX0(api frontend.API, e1 E12) *E12 { } -// ExpX0Minus1Square computes e1^((X0-1)^2), where X0=9586122913090633729 +// ExpX0Minus1Square computes e1^((X0-1)^2), where X0=0x8508c00000000001 func (e *E12) ExpX0Minus1Square(api frontend.API, e1 E12) *E12 { var t0, t1, t2, t3, res E12 @@ -176,3 +178,67 @@ func (e *E12) ExpX0Minus1Square(api frontend.API, e1 E12) *E12 { return e } + +// ExpU compute e1^U, where U=(X0-1)^2/3 and X0=0x8508c00000000001 +func (e *E12) ExpU(api frontend.API, e1 E12) *E12 { + + var t0, t1, t2, t3 E12 + t0.CyclotomicSquare(api, e1) + e.Mul(api, e1, t0) + t0.Mul(api, t0, *e) + t1.CyclotomicSquare(api, t0) + t2.Mul(api, e1, t1) + t1.CyclotomicSquare(api, t2) + t1.Mul(api, e1, t1) + t3.CyclotomicSquare(api, t1) + t3.nSquareKarabina2345(api, 7) + t2.Mul(api, t2, t3) + t2.nSquareKarabina2345(api, 6) + t1.Mul(api, t1, t2) + t1.nSquareKarabina2345(api, 4) + t0.Mul(api, t0, t1) + t0.nSquareKarabina2345(api, 4) + t0.Mul(api, e1, t0) + t0.nSquareKarabina2345(api, 6) + e.Mul(api, *e, t0) + e.nSquareKarabina2345(api, 92) + + return e +} + +// AssertFinalExponentiationIsOne checks that a Miller function output x lies in the +// same equivalence class as the reduced pairing. This replaces the final +// exponentiation step in-circuit. +// The method follows Section 4 of [On Proving Pairings] paper by A. Novakovic and L. Eagen. +// +// [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf +func (x *E12) AssertFinalExponentiationIsOne(api frontend.API) { + res, err := api.NewHint(finalExpHint, 18, x.C0.B0.A0, x.C0.B0.A1, x.C0.B1.A0, x.C0.B1.A1, x.C0.B2.A0, x.C0.B2.A1, x.C1.B0.A0, x.C1.B0.A1, x.C1.B1.A0, x.C1.B1.A1, x.C1.B2.A0, x.C1.B2.A1) + if err != nil { + // err is non-nil only for invalid number of inputs + panic(err) + } + + var residueWitness, scalingFactor, t0, t1 E12 + residueWitness.assign(res[:12]) + // constrain cubicNonResiduePower to be in Fp6 + scalingFactor.C0.B0.A0 = res[12] + scalingFactor.C0.B0.A1 = res[13] + scalingFactor.C0.B1.A0 = res[14] + scalingFactor.C0.B1.A1 = res[15] + scalingFactor.C0.B2.A0 = res[16] + scalingFactor.C0.B2.A1 = res[17] + scalingFactor.C1.SetZero() + + // Check that x * scalingFactor == residueWitness^(q-u) + // where u=0x8508c00000000001 is the BLS12-377 seed, + // and residueWitness, scalingFactor from the hint. + t0.Frobenius(api, residueWitness) + // exponentiation by u + t1.ExpX0(api, residueWitness) + t0.DivUnchecked(api, t0, t1) + + t1.Mul(api, *x, scalingFactor) + + t0.AssertIsEqual(api, t1) +} diff --git a/std/algebra/native/fields_bls12377/hints.go b/std/algebra/native/fields_bls12377/hints.go index 9138ca997c..0325cf4aea 100644 --- a/std/algebra/native/fields_bls12377/hints.go +++ b/std/algebra/native/fields_bls12377/hints.go @@ -15,6 +15,7 @@ func GetHints() []solver.Hint { inverseE2Hint, inverseE6Hint, inverseE12Hint, + finalExpHint, } } @@ -183,3 +184,75 @@ func inverseE12Hint(_ *big.Int, inputs []*big.Int, res []*big.Int) error { return nil } + +func finalExpHint(_ *big.Int, inputs, outputs []*big.Int) error { + var millerLoop bls12377.E12 + + millerLoop.C0.B0.A0.SetBigInt(inputs[0]) + millerLoop.C0.B0.A1.SetBigInt(inputs[1]) + millerLoop.C0.B1.A0.SetBigInt(inputs[2]) + millerLoop.C0.B1.A1.SetBigInt(inputs[3]) + millerLoop.C0.B2.A0.SetBigInt(inputs[4]) + millerLoop.C0.B2.A1.SetBigInt(inputs[5]) + millerLoop.C1.B0.A0.SetBigInt(inputs[6]) + millerLoop.C1.B0.A1.SetBigInt(inputs[7]) + millerLoop.C1.B1.A0.SetBigInt(inputs[8]) + millerLoop.C1.B1.A1.SetBigInt(inputs[9]) + millerLoop.C1.B2.A0.SetBigInt(inputs[10]) + millerLoop.C1.B2.A1.SetBigInt(inputs[11]) + + var root, rootPthInverse, residueWitness, scalingFactor bls12377.E12 + var exponent, exponentInv, finalExpFactor, polyFactor big.Int + // polyFactor = 12(x-1) + polyFactor.SetString("115033474957087604736", 10) + // finalExpFactor = ((q^12 - 1) / r) / polyFactor + finalExpFactor.SetString("92351561334497520756349650336409370070948672672207914824247073415859727964231807559847070685040742345026775319680739143654748316009031763764029886042408725311062057776702838555815712331129279611544378217895455619058809454575474763035923260395518532422855090028311239234310116353269618927871828693919559964406939845784130633021661399269804065961999062695977580539176029238189119059338698461832966347603096853909366901376879505972606045770762516580639801134008192256366142553202619529638202068488750102055204336502584141399828818871664747496033599618827160583206926869573005874449182200210044444351826855938563862937638034918413235278166699461287943529570559518592586872860190313088429391521694808994276205429071153237122495989095857292965461625387657577981811772819764071512345106346232882471034669258055302790607847924560040527682025558360106509628206144255667203317787586698694011876342903106644003067103035176245790275561392007119121995936066014208972135762663107247939004517852248103325700169848524693333524025685325993207375736519358185783520948988673594976115901587076295116293065682366935313875411927779217584729138600463438806153265891176654957439524358472291492028580820575807385461119025678550977847392818655362610734928283105671242634809807533919011078145", 10) + + // 1. get pth-root inverse + exponent.Set(&finalExpFactor) + root.Exp(millerLoop, &finalExpFactor) + if root.IsOne() { + rootPthInverse.SetOne() + } else { + exponentInv.ModInverse(&exponent, &polyFactor) + exponent.Neg(&exponentInv).Mod(&exponent, &polyFactor) + rootPthInverse.Exp(root, &exponent) + } + + // 3. shift the Miller loop result so that millerLoop * scalingFactor + // is of order finalExpFactor + scalingFactor.Set(&rootPthInverse) + millerLoop.Mul(&millerLoop, &scalingFactor) + + // 4. get the witness residue + // + // lambda = q - u, the optimal exponent + var lambda big.Int + lambda.SetString("258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139563774001527230824448", 10) + exponent.ModInverse(&lambda, &finalExpFactor) + residueWitness.Exp(millerLoop, &exponent) + + // return the witness residue + residueWitness.C0.B0.A0.BigInt(outputs[0]) + residueWitness.C0.B0.A1.BigInt(outputs[1]) + residueWitness.C0.B1.A0.BigInt(outputs[2]) + residueWitness.C0.B1.A1.BigInt(outputs[3]) + residueWitness.C0.B2.A0.BigInt(outputs[4]) + residueWitness.C0.B2.A1.BigInt(outputs[5]) + residueWitness.C1.B0.A0.BigInt(outputs[6]) + residueWitness.C1.B0.A1.BigInt(outputs[7]) + residueWitness.C1.B1.A0.BigInt(outputs[8]) + residueWitness.C1.B1.A1.BigInt(outputs[9]) + residueWitness.C1.B2.A0.BigInt(outputs[10]) + residueWitness.C1.B2.A1.BigInt(outputs[11]) + + // return the scaling factor + scalingFactor.C0.B0.A0.BigInt(outputs[12]) + scalingFactor.C0.B0.A1.BigInt(outputs[13]) + scalingFactor.C0.B1.A0.BigInt(outputs[14]) + scalingFactor.C0.B1.A1.BigInt(outputs[15]) + scalingFactor.C0.B2.A0.BigInt(outputs[16]) + scalingFactor.C0.B2.A1.BigInt(outputs[17]) + + return nil +} diff --git a/std/algebra/native/sw_bls12377/pairing.go b/std/algebra/native/sw_bls12377/pairing.go index fa9febf7c3..9cd1ce851b 100644 --- a/std/algebra/native/sw_bls12377/pairing.go +++ b/std/algebra/native/sw_bls12377/pairing.go @@ -253,13 +253,22 @@ func Pair(api frontend.API, P []G1Affine, Q []G2Affine) (GT, error) { // // This function doesn't check that the inputs are in the correct subgroups func PairingCheck(api frontend.API, P []G1Affine, Q []G2Affine) error { - f, err := Pair(api, P, Q) + f, err := MillerLoop(api, P, Q) if err != nil { return err } - var one GT - one.SetOne() - f.AssertIsEqual(api, one) + // We perform the easy part of the final exp to push f to the cyclotomic + // subgroup so that AssertFinalExponentiationIsOne is carried with optimized + // cyclotomic squaring (e.g. Karabina12345). + // + // f = f^(p⁶-1)(p²+1) + var buf GT + buf.Conjugate(api, f) + buf.DivUnchecked(api, buf, f) + f.FrobeniusSquare(api, buf). + Mul(api, f, buf) + + f.AssertFinalExponentiationIsOne(api) return nil } diff --git a/std/algebra/native/sw_bls12377/pairing2.go b/std/algebra/native/sw_bls12377/pairing2.go index f977ab916d..d057ea13d1 100644 --- a/std/algebra/native/sw_bls12377/pairing2.go +++ b/std/algebra/native/sw_bls12377/pairing2.go @@ -3,6 +3,7 @@ package sw_bls12377 import ( "fmt" "math/big" + "slices" "github.com/consensys/gnark-crypto/ecc" bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) { } // MarshalScalar returns -func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable { +func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8) - ss := c.fr.Reduce(&s) - x := c.fr.ToBits(ss) - for i, j := 0, nbBits-1; i < j; { - x[i], x[j] = x[j], x[i] - i++ - j-- + var ss *emulated.Element[ScalarField] + if cfg.ToBitsCanonical { + ss = c.fr.ReduceStrict(&s) + } else { + ss = c.fr.Reduce(&s) } + x := c.fr.ToBits(ss)[:nbBits] + slices.Reverse(x) return x } // MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are // in little endian. -func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable { +func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ecc.BLS12_377.BaseField().BitLen() + 7) / 8) + bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)} + if !cfg.ToBitsCanonical { + bOpts = append(bOpts, bits.OmitModulusCheck()) + } res := make([]frontend.Variable, 2*nbBits) - x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits)) - y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits)) + x := bits.ToBinary(c.api, P.X, bOpts...) + y := bits.ToBinary(c.api, P.Y, bOpts...) for i := 0; i < nbBits; i++ { res[i] = x[nbBits-1-i] res[i+nbBits] = y[nbBits-1-i] @@ -298,13 +312,21 @@ func (p *Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error { for i := range Q { inQ[i] = *Q[i] } - res, err := Pair(p.api, inP, inQ) + res, err := MillerLoop(p.api, inP, inQ) if err != nil { return err } - var one fields_bls12377.E12 - one.SetOne() - res.AssertIsEqual(p.api, one) + // We perform the easy part of the final exp to push res to the cyclotomic + // subgroup so that AssertFinalExponentiationIsOne is carried with optimized + // cyclotomic squaring (e.g. Karabina12345). + // + // res = res^(p⁶-1)(p²+1) + var buf GT + buf.Conjugate(p.api, res) + buf.DivUnchecked(p.api, buf, res) + res.FrobeniusSquare(p.api, buf).Mul(p.api, res, buf) + + res.AssertFinalExponentiationIsOne(p.api) return nil } diff --git a/std/algebra/native/sw_bls12377/pairing_test.go b/std/algebra/native/sw_bls12377/pairing_test.go index 7524263c6e..1956764d2a 100644 --- a/std/algebra/native/sw_bls12377/pairing_test.go +++ b/std/algebra/native/sw_bls12377/pairing_test.go @@ -251,9 +251,11 @@ func pairingData() (P bls12377.G1Affine, Q bls12377.G2Affine, milRes, pairingRes } func pairingCheckData() (P [2]bls12377.G1Affine, Q [2]bls12377.G2Affine) { + // e(a,2b) * e(-2a,b) == 1 _, _, P[0], Q[0] = bls12377.Generators() - P[1].Neg(&P[0]) + P[1].Double(&P[0]).Neg(&P[1]) Q[1].Set(&Q[0]) + Q[0].Double(&Q[0]) return } diff --git a/std/algebra/native/sw_bls24315/pairing2.go b/std/algebra/native/sw_bls24315/pairing2.go index 643314ef4a..fc5fac0ba8 100644 --- a/std/algebra/native/sw_bls24315/pairing2.go +++ b/std/algebra/native/sw_bls24315/pairing2.go @@ -3,6 +3,7 @@ package sw_bls24315 import ( "fmt" "math/big" + "slices" "github.com/consensys/gnark-crypto/ecc" bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315" @@ -36,25 +37,38 @@ func NewCurve(api frontend.API) (*Curve, error) { } // MarshalScalar returns -func (c *Curve) MarshalScalar(s Scalar) []frontend.Variable { +func (c *Curve) MarshalScalar(s Scalar, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ScalarField{}.Modulus().BitLen() + 7) / 8) - ss := c.fr.Reduce(&s) - x := c.fr.ToBits(ss) - for i, j := 0, nbBits-1; i < j; { - x[i], x[j] = x[j], x[i] - i++ - j-- + var ss *emulated.Element[ScalarField] + if cfg.ToBitsCanonical { + ss = c.fr.ReduceStrict(&s) + } else { + ss = c.fr.Reduce(&s) } + x := c.fr.ToBits(ss)[:nbBits] + slices.Reverse(x) return x } // MarshalG1 returns [P.X || P.Y] in binary. Both P.X and P.Y are // in little endian. -func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable { +func (c *Curve) MarshalG1(P G1Affine, opts ...algopts.AlgebraOption) []frontend.Variable { + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } nbBits := 8 * ((ecc.BLS24_315.BaseField().BitLen() + 7) / 8) + bOpts := []bits.BaseConversionOption{bits.WithNbDigits(nbBits)} + if !cfg.ToBitsCanonical { + bOpts = append(bOpts, bits.OmitModulusCheck()) + } res := make([]frontend.Variable, 2*nbBits) - x := bits.ToBinary(c.api, P.X, bits.WithNbDigits(nbBits)) - y := bits.ToBinary(c.api, P.Y, bits.WithNbDigits(nbBits)) + x := bits.ToBinary(c.api, P.X, bOpts...) + y := bits.ToBinary(c.api, P.Y, bOpts...) for i := 0; i < nbBits; i++ { res[i] = x[nbBits-1-i] res[i+nbBits] = y[nbBits-1-i] diff --git a/std/evmprecompiles/08-bnpairing.go b/std/evmprecompiles/08-bnpairing.go index b796ee458e..4a31e1f058 100644 --- a/std/evmprecompiles/08-bnpairing.go +++ b/std/evmprecompiles/08-bnpairing.go @@ -1,30 +1,33 @@ package evmprecompiles import ( + "fmt" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/algebra/emulated/sw_bn254" ) // ECPair implements [ALT_BN128_PAIRING_CHECK] precompile contract at address 0x08. // -// [ALT_BN128_PAIRING_CHECK]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-pairing-check -// // To have a fixed-circuit regardless of the number of inputs, we need 2 fixed circuits: -// - MillerLoopAndMul: -// A Miller loop of fixed size 1 followed by a multiplication in 𝔽p¹². -// - MillerLoopAndFinalExpCheck: -// A Miller loop of fixed size 1 followed by a multiplication in 𝔽p¹², and -// a check that the result lies in the same equivalence class as the -// reduced pairing purported to be 1. This check replaces the final -// exponentiation step in-circuit and follows Section 4 of [On Proving -// Pairings] paper by A. Novakovic and L. Eagen. -// -// [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf +// - MillerLoopAndMul: +// A Miller loop of fixed size 1 followed by a multiplication in 𝔽p¹². +// - MillerLoopAndFinalExpCheck: +// A Miller loop of fixed size 1 followed by a multiplication in 𝔽p¹², and +// a check that the result lies in the same equivalence class as the +// reduced pairing purported to be 1. This check replaces the final +// exponentiation step in-circuit and follows Section 4 of [On Proving +// Pairings] paper by A. Novakovic and L. Eagen. // // N.B.: This is a sub-optimal routine but defines a fixed circuit regardless // of the number of inputs. We can extend this routine to handle a 2-by-2 // logic but we prefer a minimal number of circuits (2). - +// +// See the methods [ECPairMillerLoopAndMul] and [ECPairMillerLoopAndFinalExpCheck] for the fixed circuits. +// See the method [ECPairIsOnG2] for the check that Qᵢ are on G2. +// +// [ALT_BN128_PAIRING_CHECK]: https://ethereum.github.io/execution-specs/autoapi/ethereum/paris/vm/precompiled_contracts/alt_bn128/index.html#alt-bn128-pairing-check +// [On Proving Pairings]: https://eprint.iacr.org/2024/640.pdf func ECPair(api frontend.API, P []*sw_bn254.G1Affine, Q []*sw_bn254.G2Affine) { if len(P) != len(Q) { panic("P and Q length mismatch") @@ -54,5 +57,49 @@ func ECPair(api frontend.API, P []*sw_bn254.G1Affine, Q []*sw_bn254.G2Affine) { } // fixed circuit 2 - pair.MillerLoopAndFinalExpCheck(P[n-1], Q[n-1], ml) + pair.AssertMillerLoopAndFinalExpIsOne(P[n-1], Q[n-1], ml) +} + +// ECPairIsOnG2 implements the fixed circuit for checking G2 membership and non-membership. +func ECPairIsOnG2(api frontend.API, Q *sw_bn254.G2Affine, expectedIsOnG2 frontend.Variable) error { + pairing, err := sw_bn254.NewPairing(api) + if err != nil { + return err + } + isOnG2 := pairing.IsOnG2(Q) + api.AssertIsEqual(expectedIsOnG2, isOnG2) + return nil +} + +// ECPairMillerLoopAndMul implements the fixed circuit for a Miller loop of +// fixed size 1 followed by a multiplication with an accumulator in 𝔽p¹². It +// asserts that the result corresponds to the expected result. +func ECPairMillerLoopAndMul(api frontend.API, accumulator *sw_bn254.GTEl, P *sw_bn254.G1Affine, Q *sw_bn254.G2Affine, expected *sw_bn254.GTEl) error { + pairing, err := sw_bn254.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG2(Q) + ml, err := pairing.MillerLoopAndMul(P, Q, accumulator) + if err != nil { + return fmt.Errorf("miller loop and mul: %w", err) + } + pairing.AssertIsEqual(expected, ml) + return nil +} + +// ECPairMillerLoopAndFinalExpCheck implements the fixed circuit for a Miller +// loop of fixed size 1 followed by a multiplication with an accumulator in +// 𝔽p¹², and a check that the result corresponds to the expected result. +func ECPairMillerLoopAndFinalExpCheck(api frontend.API, accumulator *sw_bn254.GTEl, P *sw_bn254.G1Affine, Q *sw_bn254.G2Affine, expectedIsSuccess frontend.Variable) error { + api.AssertIsBoolean(expectedIsSuccess) + pairing, err := sw_bn254.NewPairing(api) + if err != nil { + return fmt.Errorf("new pairing: %w", err) + } + pairing.AssertIsOnG2(Q) + + isSuccess := pairing.IsMillerLoopAndFinalExpOne(P, Q, accumulator) + api.AssertIsEqual(expectedIsSuccess, isSuccess) + return nil } diff --git a/std/math/emulated/custommod.go b/std/math/emulated/custommod.go index 2f5cbaca1b..acc25c8eda 100644 --- a/std/math/emulated/custommod.go +++ b/std/math/emulated/custommod.go @@ -48,6 +48,13 @@ func (f *Field[T]) modSub(a, b *Element[T], modulus *Element[T]) *Element[T] { // instead of assuming T as a constant. And when doing as a hint, then need // to assert that the padding is a multiple of the modulus (done inside callSubPaddingHint) nextOverflow := max(b.overflow+1, a.overflow) + 1 + if nextOverflow > f.maxOverflow() { + // TODO: in general we should handle it more gracefully, but this method + // is only used in ModAssertIsEqual which in turn is only used in tests, + // then for now we avoid automatic overflow handling (like we have for fixed modulus case). + // We only panic here so that the user would know to manually handle the overflow. + panic("next overflow would overflow the native field") + } nbLimbs := max(len(a.Limbs), len(b.Limbs)) limbs := make([]frontend.Variable, nbLimbs) padding := f.computeSubPaddingHint(b.overflow, uint(nbLimbs), modulus) diff --git a/std/math/emulated/element.go b/std/math/emulated/element.go index f3da9d3c7c..171418d747 100644 --- a/std/math/emulated/element.go +++ b/std/math/emulated/element.go @@ -32,6 +32,12 @@ type Element[T FieldParams] struct { // enforcement info in the Element to prevent modifying the witness. internal bool + // modReduced indicates that the element has been reduced modulo the modulus + // and we have asserted that the integer value of the element is strictly + // less than the modulus. This is required for some operations which depend + // on the bit-representation of the element (ToBits, exponentiation etc.). + modReduced bool + isEvaluated bool evaluation frontend.Variable `gnark:"-"` } @@ -95,6 +101,11 @@ func (e *Element[T]) GnarkInitHook() { *e = ValueOf[T](0) e.internal = false // we need to constrain in later. } + // set modReduced to false - in case the circuit is compiled we may change + // the value for an existing element. If we don't reset it here, then during + // second compilation we may take a shortPath where we assume that modReduce + // flag is set. + e.modReduced = false } // copy makes a deep copy of the element. @@ -104,5 +115,6 @@ func (e *Element[T]) copy() *Element[T] { copy(r.Limbs, e.Limbs) r.overflow = e.overflow r.internal = e.internal + r.modReduced = e.modReduced return &r } diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go index 675f296596..96530c1126 100644 --- a/std/math/emulated/element_test.go +++ b/std/math/emulated/element_test.go @@ -1098,3 +1098,181 @@ func TestExp(t *testing.T) { testExp[BN254Fr](t) testExp[emparams.Mod1e512](t) } + +type ReduceStrictCircuit[T FieldParams] struct { + Limbs []frontend.Variable + Expected []frontend.Variable + strictReduce bool +} + +func (c *ReduceStrictCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + el := f.newInternalElement(c.Limbs, 0) + var elR *Element[T] + if c.strictReduce { + elR = f.ReduceStrict(el) + } else { + elR = f.Reduce(el) + } + for i := range elR.Limbs { + api.AssertIsEqual(elR.Limbs[i], c.Expected[i]) + } + + // TODO: dummy constraint to have at least two constraints in the circuit. + // Otherwise PLONK setup phase fails. + api.AssertIsEqual(c.Expected[0], elR.Limbs[0]) + return nil +} + +func testReduceStrict[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + p := fp.Modulus() + plimbs := make([]*big.Int, int(fp.NbLimbs())) + for i := range plimbs { + plimbs[i] = new(big.Int) + } + err := decompose(p, fp.BitsPerLimb(), plimbs) + assert.NoError(err) + plimbs[0].Add(plimbs[0], big.NewInt(1)) + exp := make([]*big.Int, int(fp.NbLimbs())) + exp[0] = big.NewInt(1) + for i := 1; i < int(fp.NbLimbs()); i++ { + exp[i] = big.NewInt(0) + } + circuitStrict := &ReduceStrictCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, int(fp.NbLimbs())), strictReduce: true} + circuitLax := &ReduceStrictCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, int(fp.NbLimbs()))} + witness := &ReduceStrictCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, int(fp.NbLimbs()))} + for i := range plimbs { + witness.Limbs[i] = plimbs[i] + witness.Expected[i] = exp[i] + } + assert.CheckCircuit(circuitStrict, test.WithValidAssignment(witness)) + assert.CheckCircuit(circuitLax, test.WithInvalidAssignment(witness)) + + }, testName[T]()) +} + +func TestReduceStrict(t *testing.T) { + testReduceStrict[Goldilocks](t) + testReduceStrict[BN254Fr](t) + testReduceStrict[emparams.Mod1e512](t) +} + +type ToBitsCanonicalCircuit[T FieldParams] struct { + Limbs []frontend.Variable + Expected []frontend.Variable +} + +func (c *ToBitsCanonicalCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return fmt.Errorf("new variable modulus: %w", err) + } + el := f.newInternalElement(c.Limbs, 0) + bts := f.ToBitsCanonical(el) + for i := range bts { + api.AssertIsEqual(bts[i], c.Expected[i]) + } + return nil +} + +func testToBitsCanonical[T FieldParams](t *testing.T) { + var fp T + nbBits := fp.Modulus().BitLen() + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + p := fp.Modulus() + plimbs := make([]*big.Int, int(fp.NbLimbs())) + for i := range plimbs { + plimbs[i] = new(big.Int) + } + err := decompose(p, fp.BitsPerLimb(), plimbs) + assert.NoError(err) + plimbs[0].Add(plimbs[0], big.NewInt(1)) + exp := make([]*big.Int, int(nbBits)) + exp[0] = big.NewInt(1) + for i := 1; i < len(exp); i++ { + exp[i] = big.NewInt(0) + } + circuit := &ToBitsCanonicalCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, nbBits)} + witness := &ToBitsCanonicalCircuit[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: make([]frontend.Variable, nbBits)} + for i := range plimbs { + witness.Limbs[i] = plimbs[i] + } + for i := range exp { + witness.Expected[i] = exp[i] + } + assert.CheckCircuit(circuit, test.WithValidAssignment(witness)) + }, testName[T]()) +} + +func TestToBitsCanonical(t *testing.T) { + testToBitsCanonical[Goldilocks](t) + testToBitsCanonical[BN254Fr](t) + testToBitsCanonical[emparams.Mod1e512](t) +} + +type IsZeroEdgeCase[T FieldParams] struct { + Limbs []frontend.Variable + Expected frontend.Variable +} + +func (c *IsZeroEdgeCase[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + el := f.newInternalElement(c.Limbs, 0) + res := f.IsZero(el) + api.AssertIsEqual(res, c.Expected) + return nil +} + +func testIsZeroEdgeCases[T FieldParams](t *testing.T) { + var fp T + p := fp.Modulus() + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + plimbs := make([]*big.Int, int(fp.NbLimbs())) + for i := range plimbs { + plimbs[i] = new(big.Int) + } + err := decompose(p, fp.BitsPerLimb(), plimbs) + assert.NoError(err) + // limbs are for zero + witness1 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 1} + for i := range plimbs { + witness1.Limbs[i] = big.NewInt(0) + } + // limbs are for p + witness2 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 1} + for i := range plimbs { + witness2.Limbs[i] = plimbs[i] + } + // limbs are for not zero + witness3 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 0} + witness3.Limbs[0] = big.NewInt(1) + for i := 1; i < len(witness3.Limbs); i++ { + witness3.Limbs[i] = big.NewInt(0) + } + // limbs are for not zero bigger than p + witness4 := &IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs())), Expected: 0} + witness4.Limbs[0] = new(big.Int).Add(plimbs[0], big.NewInt(1)) + for i := 1; i < len(witness4.Limbs); i++ { + witness4.Limbs[i] = plimbs[i] + } + assert.CheckCircuit(&IsZeroEdgeCase[T]{Limbs: make([]frontend.Variable, int(fp.NbLimbs()))}, test.WithValidAssignment(witness1), test.WithValidAssignment(witness2), test.WithValidAssignment(witness3), test.WithValidAssignment(witness4)) + + }, testName[T]()) +} + +func TestIsZeroEdgeCases(t *testing.T) { + testIsZeroEdgeCases[Goldilocks](t) + testIsZeroEdgeCases[BN254Fr](t) + testIsZeroEdgeCases[emparams.Mod1e512](t) +} diff --git a/std/math/emulated/field_assert.go b/std/math/emulated/field_assert.go index 86ff353424..3cac83ea2c 100644 --- a/std/math/emulated/field_assert.go +++ b/std/math/emulated/field_assert.go @@ -50,8 +50,7 @@ func (f *Field[T]) AssertIsEqual(a, b *Element[T]) { } // AssertIsLessOrEqual ensures that e is less or equal than a. For proper -// bitwise comparison first reduce the element using [Reduce] and then assert -// that its value is less than the modulus using [AssertIsInRange]. +// bitwise comparison first reduce the element using [Field.ReduceStrict]. func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { // we omit conditional width assertion as is done in ToBits below if e.overflow+a.overflow > 0 { @@ -91,16 +90,28 @@ func (f *Field[T]) AssertIsLessOrEqual(e, a *Element[T]) { // it is not. For binary comparison the values have both to be below the // modulus. func (f *Field[T]) AssertIsInRange(a *Element[T]) { + // short path - this element is already enforced to be less than the modulus + if a.modReduced { + return + } // we omit conditional width assertion as is done in ToBits down the calling stack f.AssertIsLessOrEqual(a, f.modulusPrev()) + a.modReduced = true } // IsZero returns a boolean indicating if the element is strictly zero. The // method internally reduces the element and asserts that the value is less than // the modulus. func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { + // to avoid using strict reduction (which is expensive as requires binary + // assertion that value is less than modulus), we use ordinary reduction but + // in this case the result can be either 0 or p (if it is zero). + // + // so we check that the reduced value limbs are either all zeros or + // corrspond to the modulus limbs. ca := f.Reduce(a) - f.AssertIsInRange(ca) + p := f.Modulus() + // we use two approaches for checking if the element is exactly zero. The // first approach is to check that every limb individually is zero. The // second approach is to check if the sum of all limbs is zero. Usually, we @@ -109,23 +120,32 @@ func (f *Field[T]) IsZero(a *Element[T]) frontend.Variable { // then we can ensure in most cases that no overflows happen. // as ca is already reduced, then every limb overflow is already 0. Only - // every addition adds a bit to the overflow + // every addition adds a bit to the overflow. + var res0 frontend.Variable totalOverflow := len(ca.Limbs) - 1 if totalOverflow > int(f.maxOverflow()) { // the sums of limbs would overflow the native field. Use the first // approach instead. - res := f.api.IsZero(ca.Limbs[0]) + res0 = f.api.IsZero(ca.Limbs[0]) + for i := 1; i < len(ca.Limbs); i++ { + res0 = f.api.Mul(res0, f.api.IsZero(ca.Limbs[i])) + } + } else { + // default case, limbs sum does not overflow the native field + limbSum := ca.Limbs[0] for i := 1; i < len(ca.Limbs); i++ { - res = f.api.Mul(res, f.api.IsZero(ca.Limbs[i])) + limbSum = f.api.Add(limbSum, ca.Limbs[i]) } - return res + res0 = f.api.IsZero(limbSum) } - // default case, limbs sum does not overflow the native field - limbSum := ca.Limbs[0] + // however, for checking if the element is p, we can not use the + // optimization as we may have underflows. So we have to check every limb + // individually. + resP := f.api.IsZero(f.api.Sub(p.Limbs[0], ca.Limbs[0])) for i := 1; i < len(ca.Limbs); i++ { - limbSum = f.api.Add(limbSum, ca.Limbs[i]) + resP = f.api.Mul(resP, f.api.IsZero(f.api.Sub(p.Limbs[i], ca.Limbs[i]))) } - return f.api.IsZero(limbSum) + return f.api.Or(res0, resP) } // // Cmp returns: diff --git a/std/math/emulated/field_binary.go b/std/math/emulated/field_binary.go index 8c949af2e4..d2dd5f3d6b 100644 --- a/std/math/emulated/field_binary.go +++ b/std/math/emulated/field_binary.go @@ -8,8 +8,7 @@ import ( // ToBits returns the bit representation of the Element in little-endian (LSB // first) order. The returned bits are constrained to be 0-1. The number of // returned bits is nbLimbs*nbBits+overflow. To obtain the bits of the canonical -// representation of Element, reduce Element first and take less significant -// bits corresponding to the bitwidth of the emulated modulus. +// representation of Element, use method [Field.ToBitsCanonical]. func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable { f.enforceWidthConditional(a) ba, aConst := f.constantValue(a) @@ -34,6 +33,29 @@ func (f *Field[T]) ToBits(a *Element[T]) []frontend.Variable { return fullBits } +// ToBitsCanonical represents the unique bit representation in the canonical +// format (less that the modulus). +func (f *Field[T]) ToBitsCanonical(a *Element[T]) []frontend.Variable { + // TODO: implement a inline version of this function. We perform binary + // decomposition both in the `ReduceStrict` and `ToBits` methods, but we can + // essentially do them at the same time. + // + // If we do this, then also check in places where we use `Reduce` and + // `ToBits` after that manually (e.g. in point and scalar marshaling) and + // replace them with this method. + + var fp T + nbBits := fp.Modulus().BitLen() + // when the modulus is a power of 2, then we can remove the most significant + // bit as it is always zero. + if fp.Modulus().TrailingZeroBits() == uint(nbBits-1) { + nbBits-- + } + ca := f.ReduceStrict(a) + bts := f.ToBits(ca) + return bts[:nbBits] +} + // FromBits returns a new Element given the bits is little-endian order. func (f *Field[T]) FromBits(bs ...frontend.Variable) *Element[T] { nbLimbs := (uint(len(bs)) + f.fParams.BitsPerLimb() - 1) / f.fParams.BitsPerLimb() diff --git a/std/math/emulated/field_ops.go b/std/math/emulated/field_ops.go index a9f0d9cda3..9ee181b7fb 100644 --- a/std/math/emulated/field_ops.go +++ b/std/math/emulated/field_ops.go @@ -164,21 +164,6 @@ func (f *Field[T]) Sum(inputs ...*Element[T]) *Element[T] { return f.newInternalElement(limbs, overflow+uint(addOverflow)) } -// Reduce reduces a modulo the field order and returns it. -func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { - f.enforceWidthConditional(a) - if a.overflow == 0 { - // fast path - already reduced, omit reduction. - return a - } - // sanity check - if _, aConst := f.constantValue(a); aConst { - panic("trying to reduce a constant, which happen to have an overflow flag set") - } - // slow path - use hint to reduce value - return f.mulMod(a, f.One(), 0, nil) -} - // Sub subtracts b from a and returns it. Reduces locally if wouldn't fit into // Element. Doesn't mutate inputs. func (f *Field[T]) Sub(a, b *Element[T]) *Element[T] { diff --git a/std/math/emulated/field_reduce.go b/std/math/emulated/field_reduce.go new file mode 100644 index 0000000000..7fe55377c2 --- /dev/null +++ b/std/math/emulated/field_reduce.go @@ -0,0 +1,53 @@ +package emulated + +// ReduceWidth returns an element reduced by the modulus and constrained to have +// same length as the modulus. The output element has the same width as the +// modulus but may up to twice larger than the modulus). +// +// Does not mutate the input. +// +// In cases where the canonical representation of the element is required, use +// [Field.ReduceStrict]. +func (f *Field[T]) Reduce(a *Element[T]) *Element[T] { + ret := f.reduce(a, false) + return ret +} + +func (f *Field[T]) reduce(a *Element[T], strict bool) *Element[T] { + f.enforceWidthConditional(a) + if a.modReduced { + // fast path - we are in the strict case and the element was just strictly reduced + return a + } + if !strict && a.overflow == 0 { + // fast path - we are in non-strict case and the element has no + // overflow. We don't need to reduce now. + return a + } + // rest of the cases: + // - in strict case and element was not recently reduced (even if it has no overflow) + // - in non-strict case and the element has overflow + + // sanity check + if _, aConst := f.constantValue(a); aConst { + panic("trying to reduce a constant, which happen to have an overflow flag set") + } + // slow path - use hint to reduce value + return f.mulMod(a, f.One(), 0, nil) +} + +// ReduceStrict returns an element reduced by the modulus. The output element +// has the same width as the modulus and is guaranteed to be less than the +// modulus. +// +// Does not mutate the input. +// +// This method is useful when the canonical representation of the element is +// required. For example, when the element is used in bitwise operations. This +// means that the reduction is enforced even when the overflow of the element is +// 0, but it has not been strictly reduced before. +func (f *Field[T]) ReduceStrict(a *Element[T]) *Element[T] { + ret := f.reduce(a, true) + f.AssertIsInRange(ret) + return ret +} diff --git a/std/math/emulated/hints.go b/std/math/emulated/hints.go index eab14b47e9..08c56a5f23 100644 --- a/std/math/emulated/hints.go +++ b/std/math/emulated/hints.go @@ -188,17 +188,32 @@ func subPaddingHint(mod *big.Int, inputs, outputs []*big.Int) error { } func (f *Field[T]) computeSubPaddingHint(overflow uint, nbLimbs uint, modulus *Element[T]) *Element[T] { + // we compute the subtraction padding hint in-circuit. The padding has satisfy: + // 1. padding % modulus = 0 + // 2. padding[i] >= (1 << (bits+overflow)) + // 3. padding[i] + a[i] < native_field for all valid a[i] (defined by overflow) var fp T inputs := []frontend.Variable{fp.NbLimbs(), fp.BitsPerLimb(), overflow, nbLimbs} inputs = append(inputs, modulus.Limbs...) + // compute the actual padding value res, err := f.api.NewHint(subPaddingHint, int(nbLimbs), inputs...) if err != nil { panic(fmt.Sprintf("sub padding hint: %v", err)) } + maxLimb := new(big.Int).Lsh(big.NewInt(1), fp.BitsPerLimb()+overflow) + maxLimb.Sub(maxLimb, big.NewInt(1)) for i := range res { - f.checker.Check(res[i], int(fp.BitsPerLimb()+overflow+1)) + // we can check conditions 2 and 3 together by subtracting the maximum + // value which can be subtracted from the padding. The result should not + // underflow (in which case the width of the subtraction result could be + // at least native_width-overflow) and should be nbBits+overflow+1 bits + // wide (as expected padding is one bit wider than the maximum allowed + // subtraction limb). + f.checker.Check(f.api.Sub(res[i], maxLimb), int(fp.BitsPerLimb()+overflow+1)) } - padding := f.newInternalElement(res, fp.BitsPerLimb()+overflow+1) + + // ensure that condition 1 holds + padding := f.newInternalElement(res, overflow+1) f.checkZero(padding, modulus) return padding }