Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: optimize DoublePairingCheck i.e. e(a,b)e(c,d) == 1 #1230

Closed
wants to merge 12 commits into from
Closed
142 changes: 142 additions & 0 deletions std/algebra/emulated/sw_bls12381/hints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package sw_bls12381

import (
"math/big"

"github.com/consensys/gnark-crypto/ecc/bls12-381"
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/std/math/emulated"
)

func init() {
solver.RegisterHint(GetHints()...)
}

// GetHints returns all hint functions used in the package.
func GetHints() []solver.Hint {
return []solver.Hint{
doublePairingCheckHint,
}
}

func doublePairingCheckHint(nativeMod *big.Int, nativeInputs, nativeOutputs []*big.Int) error {
// This is inspired from https://eprint.iacr.org/2024/640.pdf
// and based on a personal communication with the author Andrija Novakovic.
return emulated.UnwrapHint(nativeInputs, nativeOutputs,
func(mod *big.Int, inputs, outputs []*big.Int) error {
var P0, P1 bls12381.G1Affine
var Q0, Q1 bls12381.G2Affine

P0.X.SetBigInt(inputs[0])
P0.Y.SetBigInt(inputs[1])
P1.X.SetBigInt(inputs[2])
P1.Y.SetBigInt(inputs[3])
Q0.X.A0.SetBigInt(inputs[4])
Q0.X.A1.SetBigInt(inputs[5])
Q0.Y.A0.SetBigInt(inputs[6])
Q0.Y.A1.SetBigInt(inputs[7])
Q1.X.A0.SetBigInt(inputs[8])
Q1.X.A1.SetBigInt(inputs[9])
Q1.Y.A0.SetBigInt(inputs[10])
Q1.Y.A1.SetBigInt(inputs[11])

lines0 := bls12381.PrecomputeLines(Q0)
lines1 := bls12381.PrecomputeLines(Q1)
millerLoop, err := bls12381.MillerLoopFixedQ(
[]bls12381.G1Affine{P0, P1},
[][2][len(bls12381.LoopCounter) - 1]bls12381.LineEvaluationAff{lines0, lines1},
)
millerLoop.Conjugate(&millerLoop)
if err != nil {
return err
}

var root, rootPthInverse, root27thInverse, residueWitness, scalingFactor bls12381.E12
var order3rd, order3rdPower, exponent, exponentInv, finalExpFactor, polyFactor big.Int
// polyFactor = (1-x)/3
polyFactor.SetString("5044125407647214251", 10)
// finalExpFactor = ((q^12 - 1) / r) / (27 * polyFactor)
finalExpFactor.SetString("2366356426548243601069753987687709088104621721678962410379583120840019275952471579477684846670499039076873213559162845121989217658133790336552276567078487633052653005423051750848782286407340332979263075575489766963251914185767058009683318020965829271737924625612375201545022326908440428522712877494557944965298566001441468676802477524234094954960009227631543471415676620753242466901942121887152806837594306028649150255258504417829961387165043999299071444887652375514277477719817175923289019181393803729926249507024121957184340179467502106891835144220611408665090353102353194448552304429530104218473070114105759487413726485729058069746063140422361472585604626055492939586602274983146215294625774144156395553405525711143696689756441298365274341189385646499074862712688473936093315628166094221735056483459332831845007196600723053356837526749543765815988577005929923802636375670820616189737737304893769679803809426304143627363860243558537831172903494450556755190448279875942974830469855835666815454271389438587399739607656399812689280234103023464545891697941661992848552456326290792224091557256350095392859243101357349751064730561345062266850238821755009430903520645523345000326783803935359711318798844368754833295302563158150573540616830138810935344206231367357992991289265295323280", 10)

// 1. get pth-root inverse
exponent.Mul(&finalExpFactor, big.NewInt(27))
root.Exp(millerLoop, &exponent)
if root.IsOne() {
rootPthInverse.SetOne()
} else {
exponentInv.ModInverse(&exponent, &polyFactor)
exponent.Neg(&exponentInv).Mod(&exponent, &polyFactor)
rootPthInverse.Exp(root, &exponent)
}

// 2.1. get order of 3rd primitive root
var three big.Int
three.SetUint64(3)
exponent.Mul(&polyFactor, &finalExpFactor)
root.Exp(millerLoop, &exponent)
if root.IsOne() {
order3rdPower.SetUint64(0)
}
root.Exp(root, &three)
if root.IsOne() {
order3rdPower.SetUint64(1)
}
root.Exp(root, &three)
if root.IsOne() {
order3rdPower.SetUint64(2)
}
root.Exp(root, &three)
if root.IsOne() {
order3rdPower.SetUint64(3)
}

// 2.2. get 27th root inverse
if order3rdPower.Uint64() == 0 {
root27thInverse.SetOne()
} else {
order3rd.Exp(&three, &order3rdPower, nil)
exponent.Mul(&polyFactor, &finalExpFactor)
root.Exp(millerLoop, &exponent)
exponentInv.ModInverse(&exponent, &order3rd)
exponent.Neg(&exponentInv).Mod(&exponent, &order3rd)
root27thInverse.Exp(root, &exponent)
}

// 2.3. shift the Miller loop result so that millerLoop * scalingFactor
// is of order finalExpFactor
scalingFactor.Mul(&rootPthInverse, &root27thInverse)
millerLoop.Mul(&millerLoop, &scalingFactor)

// 3. get the witness residue
//
// lambda = q - u, the optimal exponent
var lambda big.Int
lambda.SetString("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129030796414117214202539", 10)
exponent.ModInverse(&lambda, &finalExpFactor)
residueWitness.Exp(millerLoop, &exponent)

// return the witness residue
residueWitness.C0.B0.A0.BigInt(outputs[0])
residueWitness.C0.B0.A1.BigInt(outputs[1])
residueWitness.C0.B1.A0.BigInt(outputs[2])
residueWitness.C0.B1.A1.BigInt(outputs[3])
residueWitness.C0.B2.A0.BigInt(outputs[4])
residueWitness.C0.B2.A1.BigInt(outputs[5])
residueWitness.C1.B0.A0.BigInt(outputs[6])
residueWitness.C1.B0.A1.BigInt(outputs[7])
residueWitness.C1.B1.A0.BigInt(outputs[8])
residueWitness.C1.B1.A1.BigInt(outputs[9])
residueWitness.C1.B2.A0.BigInt(outputs[10])
residueWitness.C1.B2.A1.BigInt(outputs[11])

// return the scaling factor
scalingFactor.C0.B0.A0.BigInt(outputs[12])
scalingFactor.C0.B0.A1.BigInt(outputs[13])
scalingFactor.C0.B1.A0.BigInt(outputs[14])
scalingFactor.C0.B1.A1.BigInt(outputs[15])
scalingFactor.C0.B2.A0.BigInt(outputs[16])
scalingFactor.C0.B2.A1.BigInt(outputs[17])

return nil
})
}
117 changes: 117 additions & 0 deletions std/algebra/emulated/sw_bls12381/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,123 @@ func (pr Pairing) PairingCheck(P []*G1Affine, Q []*G2Affine) error {
return nil
}

// DoublePairingCheck calculates the reduced pairing for a 2 pairs of points and asserts if the result is One
//
// e(P0, Q0) * e(P1, Q1) =? 1
//
// This function doesn't check that the inputs are in the correct subgroups. See
// [Pairing.AssertIsOnG1] and [Pairing.AssertIsOnG2].
func (pr Pairing) DoublePairingCheck(P [2]*G1Affine, Q [2]*G2Affine) error {
// hint the non-residue witness
hint, err := pr.curveF.NewHint(doublePairingCheckHint, 18, &P[0].X, &P[0].Y, &P[1].X, &P[1].Y, &Q[0].P.X.A0, &Q[0].P.X.A1, &Q[0].P.Y.A0, &Q[0].P.Y.A1, &Q[1].P.X.A0, &Q[1].P.X.A1, &Q[1].P.Y.A0, &Q[1].P.Y.A1)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}

residueWitness := fields_bls12381.E12{
C0: fields_bls12381.E6{
B0: fields_bls12381.E2{A0: *hint[0], A1: *hint[1]},
B1: fields_bls12381.E2{A0: *hint[2], A1: *hint[3]},
B2: fields_bls12381.E2{A0: *hint[4], A1: *hint[5]},
},
C1: fields_bls12381.E6{
B0: fields_bls12381.E2{A0: *hint[6], A1: *hint[7]},
B1: fields_bls12381.E2{A0: *hint[8], A1: *hint[9]},
B2: fields_bls12381.E2{A0: *hint[10], A1: *hint[11]},
},
}
// constrain scaling factor to be in Fp6
scalingFactor := fields_bls12381.E12{
C0: fields_bls12381.E6{
B0: fields_bls12381.E2{A0: *hint[12], A1: *hint[13]},
B1: fields_bls12381.E2{A0: *hint[14], A1: *hint[15]},
B2: fields_bls12381.E2{A0: *hint[16], A1: *hint[17]},
},
C1: (*pr.Ext6.Zero()),
}

// residueWitnessInv = 1 / residueWitness
residueWitnessInv := pr.Inverse(&residueWitness)

if Q[0].Lines == nil {
Q0lines := pr.computeLines(&Q[0].P)
Q[0].Lines = &Q0lines
}
lines0 := *Q[0].Lines
if Q[1].Lines == nil {
Q1lines := pr.computeLines(&Q[1].P)
Q[1].Lines = &Q1lines
}
lines1 := *Q[1].Lines

// precomputations
y0Inv := pr.curveF.Inverse(&P[0].Y)
x0NegOverY0 := pr.curveF.Mul(&P[0].X, y0Inv)
x0NegOverY0 = pr.curveF.Neg(x0NegOverY0)
y1Inv := pr.curveF.Inverse(&P[1].Y)
x1NegOverY1 := pr.curveF.Mul(&P[1].X, y1Inv)
x1NegOverY1 = pr.curveF.Neg(x1NegOverY1)

// init Miller loop accumulator to residueWitnessInv to share the squarings
// of residueWitnessInv^{-x₀}
res := residueWitnessInv

// Compute f_{x₀,Q}(P)
for i := 62; i >= 0; i-- {
res = pr.Square(res)

if loopCounter[i] == 0 {
// ℓ × res
res = pr.MulBy014(res,
pr.MulByElement(&lines0[0][i].R1, y0Inv),
pr.MulByElement(&lines0[0][i].R0, x0NegOverY0),
)

// ℓ × res
res = pr.MulBy014(res,
pr.MulByElement(&lines1[0][i].R1, y1Inv),
pr.MulByElement(&lines1[0][i].R0, x1NegOverY1),
)
} else {
// multiply by residueWitnessInv when bit=1
res = pr.Mul(res, residueWitnessInv)

res = pr.MulBy014(res,
pr.MulByElement(&lines0[0][i].R1, y0Inv),
pr.MulByElement(&lines0[0][i].R0, x0NegOverY0),
)
res = pr.MulBy014(res,
pr.MulByElement(&lines0[1][i].R1, y0Inv),
pr.MulByElement(&lines0[1][i].R0, x0NegOverY0),
)

res = pr.MulBy014(res,
pr.MulByElement(&lines1[0][i].R1, y1Inv),
pr.MulByElement(&lines1[0][i].R0, x1NegOverY1),
)
res = pr.MulBy014(res,
pr.MulByElement(&lines1[1][i].R1, y1Inv),
pr.MulByElement(&lines1[1][i].R0, x1NegOverY1),
)
}
}

// Check that res * scalingFactor * residueWitnessInv^λ' == 1
// where λ' = q, with u the BLS12-381 seed
// and residueWitnessInv, scalingFactor from the hint.
// Note that res is already MillerLoop(P,Q) * residueWitnessInv^{x₀} since
// we initialized the Miller loop accumulator with residueWitnessInv.
t0 := pr.Frobenius(residueWitnessInv)
t0 = pr.Mul(t0, res)
t0 = pr.Mul(t0, &scalingFactor)

pr.AssertIsEqual(t0, pr.One())

return nil

}

func (pr Pairing) AssertIsEqual(x, y *GTEl) {
pr.Ext12.AssertIsEqual(x, y)
}
Expand Down
43 changes: 39 additions & 4 deletions std/algebra/emulated/sw_bls12381/pairing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,26 @@ func TestMultiPairTestSolve(t *testing.T) {
}
}

type PairingCheckCircuit struct {
type DoublePairingCheckCircuit struct {
In1G1 G1Affine
In2G1 G1Affine
In1G2 G2Affine
In2G2 G2Affine
}

func (c *PairingCheckCircuit) Define(api frontend.API) error {
func (c *DoublePairingCheckCircuit) Define(api frontend.API) error {
pairing, err := NewPairing(api)
if err != nil {
return fmt.Errorf("new pairing: %w", err)
}
err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2})
err = pairing.DoublePairingCheck([2]*G1Affine{&c.In1G1, &c.In2G1}, [2]*G2Affine{&c.In1G2, &c.In2G2})
if err != nil {
return fmt.Errorf("pair: %w", err)
}
return nil
}

func TestPairingCheckTestSolve(t *testing.T) {
func TestDoublePairingCheckTestSolve(t *testing.T) {
assert := test.NewAssert(t)
// e(a,2b) * e(-2a,b) == 1
p1, q1 := randomG1G2Affines()
Expand All @@ -193,6 +193,41 @@ func TestPairingCheckTestSolve(t *testing.T) {
var q2 bls12381.G2Affine
q2.Set(&q1)
q1.Double(&q1)
witness := DoublePairingCheckCircuit{
In1G1: NewG1Affine(p1),
In1G2: NewG2Affine(q1),
In2G1: NewG1Affine(p2),
In2G2: NewG2Affine(q2),
}
err := test.IsSolved(&DoublePairingCheckCircuit{}, &witness, ecc.BN254.ScalarField())
assert.NoError(err)
}

type PairingCheckCircuit struct {
In1G1 G1Affine
In2G1 G1Affine
In1G2 G2Affine
In2G2 G2Affine
}

func (c *PairingCheckCircuit) Define(api frontend.API) error {
pairing, err := NewPairing(api)
if err != nil {
return fmt.Errorf("new pairing: %w", err)
}
err = pairing.PairingCheck([]*G1Affine{&c.In1G1, &c.In1G1, &c.In2G1, &c.In2G1}, []*G2Affine{&c.In1G2, &c.In2G2, &c.In1G2, &c.In2G2})
if err != nil {
return fmt.Errorf("pair: %w", err)
}
return nil
}

func TestPairingCheckTestSolve(t *testing.T) {
assert := test.NewAssert(t)
p1, q1 := randomG1G2Affines()
_, q2 := randomG1G2Affines()
var p2 bls12381.G1Affine
p2.Neg(&p1)
witness := PairingCheckCircuit{
In1G1: NewG1Affine(p1),
In1G2: NewG2Affine(q1),
Expand Down
Loading
Loading