From 0e2122ddfd043f60360068f6d3aec721397670a2 Mon Sep 17 00:00:00 2001 From: n3wbie Date: Thu, 2 May 2024 19:54:14 +0900 Subject: [PATCH] GSW-1045 feat: update uint256 overflow calcualtion logic --- .../uint256/gs_overflow_calculation.gno | 184 ++++++++++-------- .../uint256/gs_overflow_calculation_test.gno | 25 +++ 2 files changed, 130 insertions(+), 79 deletions(-) create mode 100644 _deploy/p/demo/gnoswap/uint256/gs_overflow_calculation_test.gno diff --git a/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation.gno b/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation.gno index 5a3ac866..88f7dae4 100644 --- a/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation.gno +++ b/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation.gno @@ -1,102 +1,128 @@ -// REF: https://github.com/Uniswap/solidity-lib/blob/master/contracts/libraries/FullMath.sol +// REF: https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol package uint256 const ( MAX_UINT256 = "115792089237316195423570985008687907853269984665640564039457584007913129639935" ) -func fullMul( - x *Uint, - y *Uint, -) (*Uint, *Uint) { // l, h - mm := new(Uint).MulMod(x, y, MustFromDecimal(MAX_UINT256)) - - l := new(Uint).Mul(x, y) - h := new(Uint).Sub(mm, l) - - if mm.Lt(l) { - h = new(Uint).Sub(h, One()) - } - - return l, h -} - -func fullDiv( - l *Uint, - h *Uint, - d *Uint, +func MulDiv( + a, b, denominator *Uint, ) *Uint { - // uint256 pow2 = d & -d; - // d - _negD := new(Uint).Neg(d) - pow2 := new(Uint).And(d, _negD) - d = new(Uint).Div(d, pow2) - l = new(Uint).Div(l, pow2) - - _negPow2 := new(Uint).Neg(pow2) - - value1 := new(Uint).Div(_negPow2, pow2) // (-pow2) / pow2 - value2 := new(Uint).Add(value1, One()) // (-pow2) / pow2 + 1) - value3 := new(Uint).Mul(h, value2) // h * ((-pow2) / pow2 + 1); - l = new(Uint).Add(l, value3) - - r := One() - for i := 0; i < 8; i++ { - value1 := new(Uint).Mul(d, r) // d * r - value2 := new(Uint).Sub(NewUint(2), value1) // 2 - ( d * r ) - r = new(Uint).Mul(r, value2) // r *= 2 - d * r; + prod0 := Zero() + prod1 := Zero() + + { + mm := new(Uint).MulMod(a, b, new(Uint).Not(Zero())) + prod0 = new(Uint).Mul(a, b) + + ltBool := mm.Lt(prod0) + ltUint := Zero() + if ltBool { + ltUint = One() + } + prod1 = new(Uint).Sub(new(Uint).Sub(mm, prod0), ltUint) } - res := new(Uint).Mul(l, r) - return res -} -func MulDiv( - x *Uint, - y *Uint, - d *Uint, -) *Uint { - l, h := fullMul(x, y) - mm := new(Uint).MulMod(x, y, d) + // Handle non-overflow cases, 256 by 256 division + if prod1.IsZero() { + if !(denominator.Gt(Zero())) { // require(denominator > 0); + panic("denominator > 0") + } - if mm.Gt(l) { - h = new(Uint).Sub(h, One()) + result := new(Uint).Div(prod0, denominator) + return result } - l = new(Uint).Sub(l, mm) - if h.IsZero() { - return new(Uint).Div(l, d) + // Make sure the result is less than 2**256. + // Also prevents denominator == 0 + if !(denominator.Gt(prod1)) { // require(denominator > prod1) + panic("denominator > prod1") } - if !(h.Lt(d)) { - panic("FULLDIV_OVERFLOW") - } + /////////////////////////////////////////////// + // 512 by 256 division. + /////////////////////////////////////////////// + + // Make division exact by subtracting the remainder from [prod1 prod0] + // Compute remainder using mulmod + remainder := Zero() + remainder = new(Uint).MulMod(a, b, denominator) - return fullDiv(l, h, d) + // Subtract 256 bit number from 512 bit number + gtBool := remainder.Gt(prod0) + gtUint := Zero() + if gtBool { + gtUint = One() + } + prod1 = new(Uint).Sub(prod1, gtUint) + prod0 = new(Uint).Sub(prod0, remainder) + + // Factor powers of two out of denominator + // Compute largest power of two divisor of denominator. + // Always >= 1. + twos := Zero() + twos = new(Uint).And(new(Uint).Neg(denominator), denominator) + + // Divide denominator by power of two + denominator = new(Uint).Div(denominator, twos) + + // Divide [prod1 prod0] by the factors of two + prod0 = new(Uint).Div(prod0, twos) + + // Shift in bits from prod1 into prod0. For this we need + // to flip `twos` such that it is 2**256 / twos. + // If twos is zero, then it becomes one + twos = new(Uint).Add( + new(Uint).Div( + new(Uint).Sub(Zero(), twos), + twos, + ), + One(), + ) + prod0 = new(Uint).Or(prod0, new(Uint).Mul(prod1, twos)) + + // Invert denominator mod 2**256 + // Now that denominator is an odd number, it has an inverse + // modulo 2**256 such that denominator * inv = 1 mod 2**256. + // Compute the inverse by starting with a seed that is correct + // correct for four bits. That is, denominator * inv = 1 mod 2**4 + inv := Zero() + inv = new(Uint).Mul(NewUint(3), denominator) + inv = new(Uint).Xor(inv, NewUint(2)) + + // Now use Newton-Raphson iteration to improve the precision. + // Thanks to Hensel's lifting lemma, this also works in modular + // arithmetic, doubling the correct bits in each step. + + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**8 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**16 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**32 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**64 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**128 + inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**256 + + // Because the division is now exact we can divide by multiplying + // with the modular inverse of denominator. This will give us the + // correct result modulo 2**256. Since the precoditions guarantee + // that the outcome is less than 2**256, this is the final result. + // We don't need to compute the high bits of the result and prod1 + // is no longer required. + result := new(Uint).Mul(prod0, inv) + return result } -func DivRoundingUp( - x *Uint, - y *Uint, +func MulDivRoundingUp( + a, b, denominator *Uint, ) *Uint { - div := new(Uint).Div(x, y) + result := MulDiv(a, b, denominator) - mod := new(Uint).Mod(x, y) - return new(Uint).Add(div, gt(mod, Zero())) -} + if new(Uint).MulMod(a, b, denominator).Gt(Zero()) { + if !(result.Lt(MustFromDecimal(MAX_UINT256))) { // require(result < MAX_UINT256) + panic("result < MAX_UINT256") + } -// HELPERs -func lt(x, y *Uint) *Uint { - if x.Lt(y) { - return One() - } else { - return Zero() + result = new(Uint).Add(result, One()) } -} -func gt(x, y *Uint) *Uint { - if x.Gt(y) { - return One() - } else { - return Zero() - } + return result } diff --git a/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation_test.gno b/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation_test.gno new file mode 100644 index 00000000..6c3dfd0e --- /dev/null +++ b/_deploy/p/demo/gnoswap/uint256/gs_overflow_calculation_test.gno @@ -0,0 +1,25 @@ +package uint256 + +import "testing" + +func TestMulDiv(t *testing.T) { + a := MustFromDecimal("3961170441225674086664416884948992") + b := MustFromDecimal("1461300573427867316490840528175048480732148624513") + c := MustFromDecimal("1461300573427867316570072651998408279850435624081") + + z := MulDiv(a, b, c) + if z.ToString() != "3961170441225674086449641121090634" { + t.Errorf("expected 3961170441225674086449641121090634, got %s", z.ToString()) + } +} + +func TestMulDivRoundingUp(t *testing.T) { + a := MustFromDecimal("3961170441225674086664416884948992") + b := MustFromDecimal("1461300573427867316490840528175048480732148624513") + c := MustFromDecimal("1461300573427867316570072651998408279850435624081") + + z := MulDivRoundingUp(a, b, c) + if z.ToString() != "3961170441225674086449641121090635" { + t.Errorf("expected 3961170441225674086449641121090635, got %s", z.ToString()) + } +}