Skip to content

Commit

Permalink
precomp: use normalized extended points (#59)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Ignacio Hagopian <[email protected]>
  • Loading branch information
jsign authored Oct 21, 2023
1 parent ceac265 commit ff2c8f7
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 16 deletions.
63 changes: 63 additions & 0 deletions bandersnatch/bandersnatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@ import (
"io"

gnarkbandersnatch "github.com/consensys/gnark-crypto/ecc/bls12-381/bandersnatch"
gnarkfr "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/crate-crypto/go-ipa/bandersnatch/fp"
)

var CurveParams = gnarkbandersnatch.GetEdwardsCurve()

type PointAffine = gnarkbandersnatch.PointAffine
type PointProj = gnarkbandersnatch.PointProj
type PointExtended = gnarkbandersnatch.PointExtended

var Identity = PointProj{
X: fp.Zero(),
Y: fp.One(),
Z: fp.One(),
}

var IdentityExt = PointExtendedFromProj(&Identity)

// Reads an uncompressed affine point
// Point is not guaranteed to be in the prime subgroup
func ReadUncompressedPoint(r io.Reader) (PointAffine, error) {
Expand Down Expand Up @@ -92,3 +96,62 @@ func computeY(x *fp.Element, choose_largest bool) *fp.Element {
return sqrtY.Neg(sqrtY)
}
}

// PointExtendedFromProj converts a point in projective coordinates to extended coordinates.
func PointExtendedFromProj(p *PointProj) PointExtended {
var pzinv fp.Element
pzinv.Inverse(&p.Z)
var z fp.Element
z.Mul(&p.X, &p.Y).Mul(&z, &pzinv)
return PointExtended{
X: p.X,
Y: p.Y,
Z: p.Z,
T: z,
}
}

// PointExtendedNormalized is an extended point which is normalized.
// i.e: Z=1. We store it this way to save 32 bytes per point in memory.
type PointExtendedNormalized struct {
X, Y, T gnarkfr.Element
}

// Neg computes p = -p1
func (p *PointExtendedNormalized) Neg(p1 *PointExtendedNormalized) *PointExtendedNormalized {
p.X.Neg(&p1.X)
p.Y = p1.Y
p.T.Neg(&p1.T)
return p
}

// ExtendedAddNormalized computes p = p1 + p2.
// https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html#addition-madd-2008-hwcd
func ExtendedAddNormalized(p, p1 *PointExtended, p2 *PointExtendedNormalized) *gnarkbandersnatch.PointExtended {
var A, B, C, D, E, F, G, H, tmp gnarkfr.Element
A.Mul(&p1.X, &p2.X)
B.Mul(&p1.Y, &p2.Y)
C.Mul(&p1.T, &p2.T).Mul(&C, &CurveParams.D)
D.Set(&p1.Z)
tmp.Add(&p1.X, &p1.Y)
E.Add(&p2.X, &p2.Y).
Mul(&E, &tmp).
Sub(&E, &A).
Sub(&E, &B)
F.Sub(&D, &C)
G.Add(&D, &C)
H.Set(&A)

// mulBy5(&H)
H.Neg(&H)
gnarkfr.MulBy5(&H)

H.Sub(&B, &H)

p.X.Mul(&E, &F)
p.Y.Mul(&G, &H)
p.T.Mul(&E, &H)
p.Z.Mul(&F, &G)

return p
}
78 changes: 62 additions & 16 deletions banderwagon/precomp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,24 @@ func NewPrecompMSM(points []Element) (MSMPrecomp, error) {
// MSM calculates the 256-MSM of the given scalars on the fixed basis.
// It automatically detects how many non-zero scalars there are and parallelizes the computation.
func (msm *MSMPrecomp) MSM(scalars []fr.Element) Element {
result := Identity.inner
result := bandersnatch.IdentityExt

for i := range scalars {
if !scalars[i].IsZero() {
msm.precompPoints[i].ScalarMul(scalars[i], &result)
}
}
return Element{inner: result}
return Element{inner: bandersnatch.PointProj{
X: result.X,
Y: result.Y,
Z: result.Z,
}}
}

// PrecompPoint is a precomputed table for a single point.
type PrecompPoint struct {
windowSize int
windows [][]bandersnatch.PointAffine
windows [][]bandersnatch.PointExtendedNormalized
}

// NewPrecompPoint creates a new PrecompPoint for the given point and window size.
Expand All @@ -88,27 +92,23 @@ func NewPrecompPoint(point Element, windowSize int) (PrecompPoint, error) {

res := PrecompPoint{
windowSize: windowSize,
windows: make([][]bandersnatch.PointAffine, 256/windowSize),
windows: make([][]bandersnatch.PointExtendedNormalized, 256/windowSize),
}

windows := make([][]bandersnatch.PointProj, 256/windowSize)
windows := make([][]bandersnatch.PointExtended, 256/windowSize)
group, _ := errgroup.WithContext(context.Background())
group.SetLimit(runtime.NumCPU())
for i := 0; i < len(res.windows); i++ {
i := i
base := point.inner
base := bandersnatch.PointExtendedFromProj(&point.inner)
group.Go(func() error {
windows[i] = make([]bandersnatch.PointProj, 1<<(windowSize-1))
windows[i] = make([]bandersnatch.PointExtended, 1<<(windowSize-1))
curr := base
for j := 0; j < len(windows[i]); j++ {
windows[i][j] = curr
curr.Add(&curr, &base)
}
batchProjToAffine(windows[i])
res.windows[i] = make([]bandersnatch.PointAffine, 1<<(windowSize-1))
for j := range windows[i] {
res.windows[i][j].FromProj(&windows[i][j])
}
res.windows[i] = batchToExtendedPointNormalized(windows[i])
return nil
})
point.ScalarMul(&point, &specialWindow)
Expand All @@ -121,12 +121,12 @@ func NewPrecompPoint(point Element, windowSize int) (PrecompPoint, error) {
// ScalarMul multiplies the point by the given scalar using the precomputed points.
// It applies a trick to push a carry between windows since our precomputed tables
// avoid storing point inverses.
func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointProj) {
func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointExtended) {
numWindowsInLimb := 64 / pp.windowSize

scalar.FromMont()
var carry uint64
var pNeg bandersnatch.PointAffine
var pNeg bandersnatch.PointExtendedNormalized
for l := 0; l < fr.Limbs; l++ {
for w := 0; w < numWindowsInLimb; w++ {
windowValue := (scalar[l]>>(pp.windowSize*w))&((1<<pp.windowSize)-1) + carry
Expand All @@ -139,11 +139,11 @@ func (pp *PrecompPoint) ScalarMul(scalar fr.Element, res *bandersnatch.PointProj
windowValue = (1 << pp.windowSize) - windowValue
if windowValue != 0 {
pNeg.Neg(&pp.windows[l*numWindowsInLimb+w][windowValue-1])
res.MixedAdd(res, &pNeg)
bandersnatch.ExtendedAddNormalized(res, res, &pNeg)
}
carry = 1
} else {
res.MixedAdd(res, &pp.windows[l*numWindowsInLimb+w][windowValue-1])
bandersnatch.ExtendedAddNormalized(res, res, &pp.windows[l*numWindowsInLimb+w][windowValue-1])
}
}
}
Expand Down Expand Up @@ -195,3 +195,49 @@ func batchProjToAffine(points []bandersnatch.PointProj) []bandersnatch.PointAffi

return result
}

func batchToExtendedPointNormalized(points []bandersnatch.PointExtended) []bandersnatch.PointExtendedNormalized {
result := make([]bandersnatch.PointExtendedNormalized, len(points))
zeroes := make([]bool, len(points))
accumulator := fp.One()

// batch invert all points[].Z coordinates with Montgomery batch inversion trick
// (stores points[].Z^-1 in result[i].X to avoid allocating a slice of fr.Elements)
for i := 0; i < len(points); i++ {
if points[i].Z.IsZero() {
zeroes[i] = true
continue
}
result[i].X = accumulator
accumulator.Mul(&accumulator, &points[i].Z)
}

var accInverse fp.Element
accInverse.Inverse(&accumulator)

for i := len(points) - 1; i >= 0; i-- {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}
result[i].X.Mul(&result[i].X, &accInverse)
accInverse.Mul(&accInverse, &points[i].Z)
}

// batch convert to affine.
parallel.Execute(len(points), func(start, end int) {
for i := start; i < end; i++ {
if zeroes[i] {
// do nothing, (X=0, Y=0) is infinity point in affine
continue
}

a := result[i].X
result[i].X.Mul(&points[i].X, &a)
result[i].Y.Mul(&points[i].Y, &a)
result[i].T.Mul(&result[i].X, &result[i].Y)
}
})

return result
}

0 comments on commit ff2c8f7

Please sign in to comment.