diff --git a/std/math/emulated/field_mul.go b/std/math/emulated/field_mul.go index a1947cb7c..1d44483d5 100644 --- a/std/math/emulated/field_mul.go +++ b/std/math/emulated/field_mul.go @@ -347,10 +347,18 @@ func (f *Field[T]) callMulHint(a, b *Element[T], isMulMod bool, customMod *Eleme // the quotient can be the total length of the multiplication result. modbits = 0 } - nbQuoLimbs := (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // - modbits + // - nbBits - 1) / - nbBits + var nbQuoLimbs uint + if uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits+nextOverflow+nbBits > modbits { + // when the product of a*b is wider than the modulus, then we need + // non-zero limbs for the quotient. Otherwise the quotient is zero, + // represented on zero limbs. But we already handle cases when the + // quotient is zero in the calling functions, this is only for + // additional safety. + nbQuoLimbs = (uint(nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)))*nbBits + nextOverflow + 1 - // + modbits + // + nbBits - 1) / + nbBits + } // the remainder is always less than modulus so can represent on the same // number of limbs as the modulus. nbRemLimbs := nbLimbs @@ -435,15 +443,26 @@ func mulHint(field *big.Int, inputs, outputs []*big.Int) error { if err := limbs.Decompose(rem, uint(nbBits), remLimbs); err != nil { return fmt.Errorf("decompose rem: %w", err) } - xp := limbMul(alimbs, blimbs) - yp := limbMul(quoLimbs, plimbs) + // to compute the carries, we need to perform multiplication on limbs + lhs := limbMul(alimbs, blimbs) + rhs := limbMul(quoLimbs, plimbs) + // add the remainder to the rhs, it now only has k*p. This is only for very + // edge cases where by adding the remainder we get additional bits in the + // carry. + for i := range remLimbs { + if i < len(rhs) { + rhs[i].Add(rhs[i], remLimbs[i]) + } else { + rhs = append(rhs, new(big.Int).Set(remLimbs[i])) + } + } carry := new(big.Int) for i := range carryLimbs { - if i < len(xp) { - carry.Add(carry, xp[i]) + if i < len(lhs) { + carry.Add(carry, lhs[i]) } - if i < len(yp) { - carry.Sub(carry, yp[i]) + if i < len(rhs) { + carry.Sub(carry, rhs[i]) } carry.Rsh(carry, uint(nbBits)) carryLimbs[i] = new(big.Int).Set(carry) @@ -714,7 +733,10 @@ func (f *Field[T]) callPolyMvHint(mv *multivariate[T], at []*Element[T]) (quo, r nbLimbs, nbBits := f.fParams.NbLimbs(), f.fParams.BitsPerLimb() modBits := uint(f.fParams.Modulus().BitLen()) quoSize := f.polyMvEvalQuoSize(mv, at) - nbQuoLimbs := (uint(quoSize) - modBits + nbBits) / nbBits + var nbQuoLimbs uint + if quoSize+nbBits > modBits { + nbQuoLimbs = (quoSize - modBits + nbBits) / nbBits + } nbRemLimbs := nbLimbs nbCarryLimbs := nbMultiplicationResLimbs(int(nbQuoLimbs), int(nbLimbs)) - 1 @@ -723,7 +745,7 @@ func (f *Field[T]) callPolyMvHint(mv *multivariate[T], at []*Element[T]) (quo, r nbHintInputs += len(at[i].Limbs) + 1 } hintInputs := make([]frontend.Variable, 0, nbHintInputs) - hintInputs = append(hintInputs, nbBits, nbLimbs, len(mv.Terms), len(at), nbQuoLimbs, nbRemLimbs, nbCarryLimbs) + hintInputs = append(hintInputs, nbBits, nbLimbs, len(mv.Terms), len(at), nbQuoLimbs, nbCarryLimbs) // store the terms in the hint input. First the exponents for i := range mv.Terms { for j := range mv.Terms[i] { @@ -837,24 +859,28 @@ func (mc *mvCheck[T]) cleanEvaluations() { // // As it only depends on the bit-length of the inputs, then we can precompute it // regardless of the actual values. -func (f *Field[T]) polyMvEvalQuoSize(mv *multivariate[T], at []*Element[T]) (quoSize int) { +func (f *Field[T]) polyMvEvalQuoSize(mv *multivariate[T], at []*Element[T]) (quoSize uint) { var fp T - quoSizes := make([]int, len(mv.Terms)) + quoSizes := make([]uint, len(mv.Terms)) for i, term := range mv.Terms { // for every term, the result length is the sum of the lengths of the // variables and the coefficient. - var lengths []int + var lengths []uint for j, pow := range term { for k := 0; k < pow; k++ { - lengths = append(lengths, len(at[j].Limbs)*int(fp.BitsPerLimb())+int(at[j].overflow)) + lengths = append(lengths, uint(len(at[j].Limbs))*fp.BitsPerLimb()+at[j].overflow) } } - lengths = append(lengths, bits.Len(uint(mv.Coefficients[i]))) - quoSizes[i] = sum(lengths...) - 1 + lengths = append(lengths, uint(bits.Len(uint(mv.Coefficients[i])))) + if lengthSum := sum(lengths...); lengthSum > 0 { + // in edge case when inputs are zeros and coefficient is zero, we + // would have a underflow otherwise. + quoSizes[i] = lengthSum - 1 + } } // and for the full result, it is maximum of the inputs. We also add a bit // for every term for overflow. - quoSize = max(quoSizes...) + len(quoSizes) + quoSize = max(quoSizes...) + uint(len(quoSizes)) return quoSize } @@ -871,8 +897,8 @@ func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error { nbTerms = int(inputs[2].Int64()) nbVars = int(inputs[3].Int64()) nbQuoLimbs = int(inputs[4].Int64()) - nbRemLimbs = int(inputs[5].Int64()) - nbCarryLimbs = int(inputs[6].Int64()) + nbRemLimbs = nbLimbs + nbCarryLimbs = int(inputs[5].Int64()) ) if len(outputs) != nbQuoLimbs+nbRemLimbs+nbCarryLimbs { return fmt.Errorf("output length mismatch") @@ -884,7 +910,7 @@ func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error { outPtr += nbRemLimbs carryLimbs := outputs[outPtr : outPtr+nbCarryLimbs] terms := make([][]int, nbTerms) - ptr := 7 + ptr := 6 // read the terms for i := range terms { terms[i] = make([]int, nbVars) @@ -986,7 +1012,7 @@ func polyMvHint(mod *big.Int, inputs, outputs []*big.Int) error { } // compute the result as r + k*p on limbs - rhs := make([]*big.Int, nbMultiplicationResLimbs(nbQuoLimbs, nbLimbs)) + rhs := make([]*big.Int, max(nbLimbs, nbMultiplicationResLimbs(nbQuoLimbs, nbLimbs))) for i := range rhs { rhs[i] = new(big.Int) }