Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: NNA quotient length computation edge cases #1340

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions std/math/emulated/field_mul.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,18 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool, customMod *Eleme
// the quotient can be the total length of the multiplication result.
modbits = 0
}
nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - //
modbits + //
nbBits - 1) /
nbBits
var nbQuoLimbs uint
if uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits+nextOverflow+nbBits > modbits {
// when the product of a*b is wider than the modulus, then we need
// non-zero limbs for the quotient. Otherwise the quotient is zero,
// represented on zero limbs. But we already handle cases when the
// quotient is zero in the calling functions, this is only for
// additional safety.
nbQuoLimbs = (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - //
modbits + //
nbBits - 1) /
nbBits
}
// the remainder is always less than modulus so can represent on the same
// number of limbs as the modulus.
nbRemLimbs := nbLimbs
Expand Down Expand Up @@ -435,15 +443,26 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error {
if err := limbs.Decompose(rem, uint(nbBits), remLimbs); err != nil {
return fmt.Errorf("decompose rem: %w", err)
}
xp := limbMul(alimbs, blimbs)
yp := limbMul(quoLimbs, plimbs)
// to compute the carries, we need to perform multiplication on limbs
lhs := limbMul(alimbs, blimbs)
rhs := limbMul(quoLimbs, plimbs)
// add the remainder to the rhs, it now only has k*p. This is only for very
// edge cases where by adding the remainder we get additional bits in the
// carry.
for i := range remLimbs {
if i < len(rhs) {
rhs[i].Add(rhs[i], remLimbs[i])
} else {
rhs = append(rhs, new(big.Int).Set(remLimbs[i]))
}
}
carry := new(big.Int)
for i := range carryLimbs {
if i < len(xp) {
carry.Add(carry, xp[i])
if i < len(lhs) {
carry.Add(carry, lhs[i])
}
if i < len(yp) {
carry.Sub(carry, yp[i])
if i < len(rhs) {
carry.Sub(carry, rhs[i])
}
carry.Rsh(carry, uint(nbBits))
carryLimbs[i] = new(big.Int).Set(carry)
Expand Down Expand Up @@ -714,7 +733,10 @@ func (f *Field[T]) callPolyMvHint(mv *multivariate[T], at []*Element[T]) (quo, r
nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb()
modBits := uint(f.fParams.Modulus().BitLen())
quoSize := f.polyMvEvalQuoSize(mv, at)
nbQuoLimbs := (uint(quoSize) - modBits + nbBits) / nbBits
var nbQuoLimbs uint
if quoSize+nbBits > modBits {
nbQuoLimbs = (quoSize - modBits + nbBits) / nbBits
}
nbRemLimbs := nbLimbs
nbCarryLimbs := nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs)) - 1

Expand All @@ -723,7 +745,7 @@ func (f *Field[T]) callPolyMvHint(mv *multivariate[T], at []*Element[T]) (quo, r
nbHintInputs += len(at[i].Limbs) + 1
}
hintInputs := make([]frontend.Variable, 0, nbHintInputs)
hintInputs = append(hintInputs, nbBits, nbLimbs, len(mv.Terms), len(at), nbQuoLimbs, nbRemLimbs, nbCarryLimbs)
hintInputs = append(hintInputs, nbBits, nbLimbs, len(mv.Terms), len(at), nbQuoLimbs, nbCarryLimbs)
// store the terms in the hint input. First the exponents
for i := range mv.Terms {
for j := range mv.Terms[i] {
Expand Down Expand Up @@ -837,24 +859,28 @@ func (mc *mvCheck[T]) cleanEvaluations() {
//
// As it only depends on the bit-length of the inputs, then we can precompute it
// regardless of the actual values.
func (f *Field[T]) polyMvEvalQuoSize(mv *multivariate[T], at []*Element[T]) (quoSize int) {
func (f *Field[T]) polyMvEvalQuoSize(mv *multivariate[T], at []*Element[T]) (quoSize uint) {
var fp T
quoSizes := make([]int, len(mv.Terms))
quoSizes := make([]uint, len(mv.Terms))
for i, term := range mv.Terms {
// for every term, the result length is the sum of the lengths of the
// variables and the coefficient.
var lengths []int
var lengths []uint
for j, pow := range term {
for k := 0; k < pow; k++ {
lengths = append(lengths, len(at[j].Limbs)*int(fp.BitsPerLimb())+int(at[j].overflow))
lengths = append(lengths, uint(len(at[j].Limbs))*fp.BitsPerLimb()+at[j].overflow)
}
}
lengths = append(lengths, bits.Len(uint(mv.Coefficients[i])))
quoSizes[i] = sum(lengths...) - 1
lengths = append(lengths, uint(bits.Len(uint(mv.Coefficients[i]))))
if lengthSum := sum(lengths...); lengthSum > 0 {
// in edge case when inputs are zeros and coefficient is zero, we
// would have a underflow otherwise.
quoSizes[i] = lengthSum - 1
}
}
// and for the full result, it is maximum of the inputs. We also add a bit
// for every term for overflow.
quoSize = max(quoSizes...) + len(quoSizes)
quoSize = max(quoSizes...) + uint(len(quoSizes))
return quoSize
}

Expand All @@ -871,8 +897,8 @@ func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error {
nbTerms = int(inputs[2].Int64())
nbVars = int(inputs[3].Int64())
nbQuoLimbs = int(inputs[4].Int64())
nbRemLimbs = int(inputs[5].Int64())
nbCarryLimbs = int(inputs[6].Int64())
nbRemLimbs = nbLimbs
nbCarryLimbs = int(inputs[5].Int64())
)
if len(outputs) != nbQuoLimbs+nbRemLimbs+nbCarryLimbs {
return fmt.Errorf("output length mismatch")
Expand All @@ -884,7 +910,7 @@ func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error {
outPtr += nbRemLimbs
carryLimbs := outputs[outPtr : outPtr+nbCarryLimbs]
terms := make([][]int, nbTerms)
ptr := 7
ptr := 6
// read the terms
for i := range terms {
terms[i] = make([]int, nbVars)
Expand Down Expand Up @@ -986,7 +1012,7 @@ func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error {
}

// compute the result as r + k*p on limbs
rhs := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLimbs, nbLimbs))
rhs := make([]*big.Int, max(nbLimbs, nbMultiplicationResLimbs(nbQuoLimbs, nbLimbs)))
for i := range rhs {
rhs[i] = new(big.Int)
}
Expand Down
Loading