Skip to content

Commit

Permalink
GSW-1045 feat: update uint256 overflow calcualtion logic
Browse files Browse the repository at this point in the history
  • Loading branch information
r3v4s committed May 2, 2024
1 parent 0505710 commit 0e2122d
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 79 deletions.
184 changes: 105 additions & 79 deletions _deploy/p/demo/gnoswap/uint256/gs_overflow_calculation.gno
Original file line number Diff line number Diff line change
@@ -1,102 +1,128 @@
// REF: https://github.com/Uniswap/solidity-lib/blob/master/contracts/libraries/FullMath.sol
// REF: https://github.com/Uniswap/v3-core/blob/main/contracts/libraries/FullMath.sol
package uint256

const (
MAX_UINT256 = "115792089237316195423570985008687907853269984665640564039457584007913129639935"
)

func fullMul(
x *Uint,
y *Uint,
) (*Uint, *Uint) { // l, h
mm := new(Uint).MulMod(x, y, MustFromDecimal(MAX_UINT256))

l := new(Uint).Mul(x, y)
h := new(Uint).Sub(mm, l)

if mm.Lt(l) {
h = new(Uint).Sub(h, One())
}

return l, h
}

func fullDiv(
l *Uint,
h *Uint,
d *Uint,
func MulDiv(
a, b, denominator *Uint,
) *Uint {
// uint256 pow2 = d & -d;
// d
_negD := new(Uint).Neg(d)
pow2 := new(Uint).And(d, _negD)
d = new(Uint).Div(d, pow2)
l = new(Uint).Div(l, pow2)

_negPow2 := new(Uint).Neg(pow2)

value1 := new(Uint).Div(_negPow2, pow2) // (-pow2) / pow2
value2 := new(Uint).Add(value1, One()) // (-pow2) / pow2 + 1)
value3 := new(Uint).Mul(h, value2) // h * ((-pow2) / pow2 + 1);
l = new(Uint).Add(l, value3)

r := One()
for i := 0; i < 8; i++ {
value1 := new(Uint).Mul(d, r) // d * r
value2 := new(Uint).Sub(NewUint(2), value1) // 2 - ( d * r )
r = new(Uint).Mul(r, value2) // r *= 2 - d * r;
prod0 := Zero()
prod1 := Zero()

{
mm := new(Uint).MulMod(a, b, new(Uint).Not(Zero()))
prod0 = new(Uint).Mul(a, b)

ltBool := mm.Lt(prod0)
ltUint := Zero()
if ltBool {
ltUint = One()
}
prod1 = new(Uint).Sub(new(Uint).Sub(mm, prod0), ltUint)
}
res := new(Uint).Mul(l, r)
return res
}

func MulDiv(
x *Uint,
y *Uint,
d *Uint,
) *Uint {
l, h := fullMul(x, y)
mm := new(Uint).MulMod(x, y, d)
// Handle non-overflow cases, 256 by 256 division
if prod1.IsZero() {
if !(denominator.Gt(Zero())) { // require(denominator > 0);
panic("denominator > 0")
}

if mm.Gt(l) {
h = new(Uint).Sub(h, One())
result := new(Uint).Div(prod0, denominator)
return result
}
l = new(Uint).Sub(l, mm)

if h.IsZero() {
return new(Uint).Div(l, d)
// Make sure the result is less than 2**256.
// Also prevents denominator == 0
if !(denominator.Gt(prod1)) { // require(denominator > prod1)
panic("denominator > prod1")
}

if !(h.Lt(d)) {
panic("FULLDIV_OVERFLOW")
}
///////////////////////////////////////////////
// 512 by 256 division.
///////////////////////////////////////////////

// Make division exact by subtracting the remainder from [prod1 prod0]
// Compute remainder using mulmod
remainder := Zero()
remainder = new(Uint).MulMod(a, b, denominator)

return fullDiv(l, h, d)
// Subtract 256 bit number from 512 bit number
gtBool := remainder.Gt(prod0)
gtUint := Zero()
if gtBool {
gtUint = One()
}
prod1 = new(Uint).Sub(prod1, gtUint)
prod0 = new(Uint).Sub(prod0, remainder)

// Factor powers of two out of denominator
// Compute largest power of two divisor of denominator.
// Always >= 1.
twos := Zero()
twos = new(Uint).And(new(Uint).Neg(denominator), denominator)

// Divide denominator by power of two
denominator = new(Uint).Div(denominator, twos)

// Divide [prod1 prod0] by the factors of two
prod0 = new(Uint).Div(prod0, twos)

// Shift in bits from prod1 into prod0. For this we need
// to flip `twos` such that it is 2**256 / twos.
// If twos is zero, then it becomes one
twos = new(Uint).Add(
new(Uint).Div(
new(Uint).Sub(Zero(), twos),
twos,
),
One(),
)
prod0 = new(Uint).Or(prod0, new(Uint).Mul(prod1, twos))

// Invert denominator mod 2**256
// Now that denominator is an odd number, it has an inverse
// modulo 2**256 such that denominator * inv = 1 mod 2**256.
// Compute the inverse by starting with a seed that is correct
// correct for four bits. That is, denominator * inv = 1 mod 2**4
inv := Zero()
inv = new(Uint).Mul(NewUint(3), denominator)
inv = new(Uint).Xor(inv, NewUint(2))

// Now use Newton-Raphson iteration to improve the precision.
// Thanks to Hensel's lifting lemma, this also works in modular
// arithmetic, doubling the correct bits in each step.

inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**8
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**16
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**32
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**64
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**128
inv = new(Uint).Mul(inv, new(Uint).Sub(NewUint(2), new(Uint).Mul(denominator, inv))) // inverse mod 2**256

// Because the division is now exact we can divide by multiplying
// with the modular inverse of denominator. This will give us the
// correct result modulo 2**256. Since the precoditions guarantee
// that the outcome is less than 2**256, this is the final result.
// We don't need to compute the high bits of the result and prod1
// is no longer required.
result := new(Uint).Mul(prod0, inv)
return result
}

func DivRoundingUp(
x *Uint,
y *Uint,
func MulDivRoundingUp(
a, b, denominator *Uint,
) *Uint {
div := new(Uint).Div(x, y)
result := MulDiv(a, b, denominator)

mod := new(Uint).Mod(x, y)
return new(Uint).Add(div, gt(mod, Zero()))
}
if new(Uint).MulMod(a, b, denominator).Gt(Zero()) {
if !(result.Lt(MustFromDecimal(MAX_UINT256))) { // require(result < MAX_UINT256)
panic("result < MAX_UINT256")
}

// HELPERs
func lt(x, y *Uint) *Uint {
if x.Lt(y) {
return One()
} else {
return Zero()
result = new(Uint).Add(result, One())
}
}

func gt(x, y *Uint) *Uint {
if x.Gt(y) {
return One()
} else {
return Zero()
}
return result
}
25 changes: 25 additions & 0 deletions _deploy/p/demo/gnoswap/uint256/gs_overflow_calculation_test.gno
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package uint256

import "testing"

func TestMulDiv(t *testing.T) {
a := MustFromDecimal("3961170441225674086664416884948992")
b := MustFromDecimal("1461300573427867316490840528175048480732148624513")
c := MustFromDecimal("1461300573427867316570072651998408279850435624081")

z := MulDiv(a, b, c)
if z.ToString() != "3961170441225674086449641121090634" {
t.Errorf("expected 3961170441225674086449641121090634, got %s", z.ToString())
}
}

func TestMulDivRoundingUp(t *testing.T) {
a := MustFromDecimal("3961170441225674086664416884948992")
b := MustFromDecimal("1461300573427867316490840528175048480732148624513")
c := MustFromDecimal("1461300573427867316570072651998408279850435624081")

z := MulDivRoundingUp(a, b, c)
if z.ToString() != "3961170441225674086449641121090635" {
t.Errorf("expected 3961170441225674086449641121090635, got %s", z.ToString())
}
}

0 comments on commit 0e2122d

Please sign in to comment.