From 49a22a8fca30ee58f196f3babe0c14e9d0defe12 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:26:07 -0600 Subject: [PATCH] fix beta update --- backend/groth16/bn254/mpcsetup/marshal.go | 5 +- backend/groth16/bn254/mpcsetup/phase1.go | 4 +- backend/groth16/bn254/mpcsetup/phase2.go | 2 + backend/groth16/bn254/mpcsetup/setup.go | 2 +- backend/groth16/bn254/mpcsetup/unit_test.go | 110 ++++++++++++++++++++ 5 files changed, 118 insertions(+), 5 deletions(-) diff --git a/backend/groth16/bn254/mpcsetup/marshal.go b/backend/groth16/bn254/mpcsetup/marshal.go index 3511a1a87..4c1ab916c 100644 --- a/backend/groth16/bn254/mpcsetup/marshal.go +++ b/backend/groth16/bn254/mpcsetup/marshal.go @@ -150,11 +150,12 @@ func (p *Phase2) ReadFrom(reader io.Reader) (int64, error) { func (c *Phase2Evaluations) refsSlice() []any { N := uint64(len(c.G1.A)) - expectedLen := 3*N + 3 - refs := make([]any, 3, expectedLen) + expectedLen := 3*N + 4 + refs := make([]any, 4, expectedLen) refs[0] = &c.G1.CKK refs[1] = &c.G1.VKK refs[2] = &c.PublicAndCommitmentCommitted + refs[3] = &c.NbConstraints refs = appendRefs(refs, c.G1.A) refs = appendRefs(refs, c.G1.B) refs = appendRefs(refs, c.G2.B) diff --git a/backend/groth16/bn254/mpcsetup/phase1.go b/backend/groth16/bn254/mpcsetup/phase1.go index 2c9251880..c4f9f0c00 100644 --- a/backend/groth16/bn254/mpcsetup/phase1.go +++ b/backend/groth16/bn254/mpcsetup/phase1.go @@ -114,12 +114,12 @@ func (c *SrsCommons) update(tauUpdate, alphaUpdate, betaUpdate *fr.Element) { betaUpdates := make([]fr.Element, len(c.G1.BetaTau)) betaUpdates[0].Set(betaUpdate) for i := range betaUpdates { - alphaUpdates[i].Mul(&tauUpdates[i], betaUpdate) + betaUpdates[i].Mul(&tauUpdates[i], betaUpdate) } scaleG1InPlace(c.G1.BetaTau, betaUpdates) var betaUpdateI big.Int - betaUpdate.SetBigInt(&betaUpdateI) + betaUpdate.BigInt(&betaUpdateI) c.G2.Beta.ScalarMultiplication(&c.G2.Beta, &betaUpdateI) } diff --git a/backend/groth16/bn254/mpcsetup/phase2.go b/backend/groth16/bn254/mpcsetup/phase2.go index 12f638d0b..664bf502c 100644 --- a/backend/groth16/bn254/mpcsetup/phase2.go +++ b/backend/groth16/bn254/mpcsetup/phase2.go @@ -34,6 +34,7 @@ type Phase2Evaluations struct { // TODO @Tabaie rename B []curve.G2Affine // B are the right coefficient polynomials for each witness element, evaluated at τ } PublicAndCommitmentCommitted [][]int + NbConstraints uint64 } type Phase2 struct { @@ -225,6 +226,7 @@ func (p *Phase2) Initialize(r1cs *cs.R1CS, commons *SrsCommons) Phase2Evaluation var evals Phase2Evaluations commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) evals.PublicAndCommitmentCommitted = commitmentInfo.GetPublicAndCommitmentCommitted(commitmentInfo.CommitmentIndexes(), nbPublic) + evals.NbConstraints = uint64(r1cs.GetNbConstraints()) evals.G1.A = make([]curve.G1Affine, nWires) // recall: A are the left coefficients in DIZK parlance evals.G1.B = make([]curve.G1Affine, nWires) // recall: B are the right coefficients in DIZK parlance evals.G2.B = make([]curve.G2Affine, nWires) // recall: A only appears in 𝔾₁ elements in the proof, but B needs to appear in a 𝔾₂ element so the verifier can compute something resembling (A.x).(B.x) via pairings diff --git a/backend/groth16/bn254/mpcsetup/setup.go b/backend/groth16/bn254/mpcsetup/setup.go index bb0f3449a..f4e2662df 100644 --- a/backend/groth16/bn254/mpcsetup/setup.go +++ b/backend/groth16/bn254/mpcsetup/setup.go @@ -34,7 +34,7 @@ func (p *Phase2) Seal(commons *SrsCommons, evals *Phase2Evaluations, beaconChall ) // Initialize PK - pk.Domain = *fft.NewDomain(uint64(len(evals.G1.A))) + pk.Domain = *fft.NewDomain(evals.NbConstraints) pk.G1.Alpha.Set(&commons.G1.AlphaTau[0]) pk.G1.Beta.Set(&commons.G1.BetaTau[0]) pk.G1.Delta.Set(&p.Parameters.G1.Delta) diff --git a/backend/groth16/bn254/mpcsetup/unit_test.go b/backend/groth16/bn254/mpcsetup/unit_test.go index 8c72766fa..fbf34e8bd 100644 --- a/backend/groth16/bn254/mpcsetup/unit_test.go +++ b/backend/groth16/bn254/mpcsetup/unit_test.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark/backend/groth16" groth16Impl "github.com/consensys/gnark/backend/groth16/bn254" "github.com/stretchr/testify/require" @@ -70,6 +71,8 @@ func TestSetupBeaconOnly(t *testing.T) { p1.Initialize(domainSize) commons := p1.Seal([]byte("beacon 1")) + commons = commonsSmallValues(domainSize, 2, 3, 4) + evals := p2.Initialize(ccs, &commons) pk, vk := p2.Seal(&commons, &evals, []byte("beacon 2")) @@ -96,4 +99,111 @@ func TestSetupBeaconOnly(t *testing.T) { proveVerifyCircuit(t, rpk, rvk) fmt.Println("regular proof verified") proveVerifyCircuit(t, pk, vk) + fmt.Println("mpc proof verified") +} + +func TestPhase1Contribute(t *testing.T) { + +} + +func TestPhase1Seal(t *testing.T) { + +} + +func commonsSmallValues(N, tau, alpha, beta uint64) SrsCommons { + var ( + res SrsCommons + I big.Int + coeff fr.Element + ) + _, _, g1, g2 := curve.Generators() + tauPowers := powersI(tau, int(2*N-1)) + res.G1.Tau = make([]curve.G1Affine, 2*N-1) + for i := range res.G1.Tau { + tauPowers[i].BigInt(&I) + res.G1.Tau[i].ScalarMultiplication(&g1, &I) + } + + res.G2.Tau = make([]curve.G2Affine, N) + for i := range res.G2.Tau { + tauPowers[i].BigInt(&I) + res.G2.Tau[i].ScalarMultiplication(&g2, &I) + } + + res.G1.AlphaTau = make([]curve.G1Affine, N) + coeff.SetUint64(alpha) + for i := range res.G1.AlphaTau { + var x fr.Element + x.Mul(&tauPowers[i], &coeff) + x.BigInt(&I) + res.G1.AlphaTau[i].ScalarMultiplication(&g1, &I) + } + + res.G1.BetaTau = make([]curve.G1Affine, N) + coeff.SetUint64(beta) + for i := range res.G1.BetaTau { + var x fr.Element + x.Mul(&tauPowers[i], &coeff) + x.BigInt(&I) + res.G1.BetaTau[i].ScalarMultiplication(&g1, &I) + } + + I.SetUint64(beta) + res.G2.Beta.ScalarMultiplication(&g2, &I) + + return res +} + +func powersI(x uint64, n int) []fr.Element { + var y fr.Element + y.SetUint64(x) + return powers(&y, n) +} + +func TestPowers(t *testing.T) { + var x fr.Element + x.SetUint64(2) + x2 := powers(&x, 10) + for i := range x2 { + require.True(t, x2[i].IsUint64()) + require.Equal(t, x2[i].Uint64(), uint64(1<