Skip to content

Commit

Permalink
fix(Python): Add invariant max/min checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
johngrantuk committed Dec 19, 2024
1 parent f1fc5c3 commit 01c89cb
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/src/add_liquidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def add_liquidity(add_liquidity_input, pool_state, pool_class, hook_class, hook_
max_amounts_in_scaled18,
pool_state["totalSupply"],
pool_state["swapFee"],
pool_class.get_maximum_invariant_ratio(),
lambda balances_live_scaled18, rounding: pool_class.compute_invariant(
balances_live_scaled18, rounding
),
Expand All @@ -71,6 +72,7 @@ def add_liquidity(add_liquidity_input, pool_state, pool_class, hook_class, hook_
bpt_amount_out,
pool_state["totalSupply"],
pool_state["swapFee"],
pool_class.get_maximum_invariant_ratio(),
lambda balances_live_scaled18, token_index, invariant_ratio: pool_class.compute_balance(
balances_live_scaled18, token_index, invariant_ratio
),
Expand Down
37 changes: 33 additions & 4 deletions python/src/base_pool_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def compute_add_liquidity_unbalanced(
exact_amounts,
total_supply,
swap_fee_percentage,
max_invariant_ratio,
compute_invariant,
):
# /***********************************************************************
Expand Down Expand Up @@ -47,6 +48,12 @@ def compute_add_liquidity_unbalanced(
# Calculate the new invariant ratio by dividing the new invariant by the old invariant.
invariant_ratio = div_down_fixed(new_invariant, current_invariant)

# Add check for max invariant ratio
if invariant_ratio > max_invariant_ratio:
raise ValueError(
f"InvariantRatioAboveMax {invariant_ratio} {max_invariant_ratio}"
)

# Loop through each token to apply fees if necessary.
for index in range(len(current_balances)):
# // Check if the new balance is greater than the equivalent proportional balance.
Expand Down Expand Up @@ -104,17 +111,23 @@ def compute_add_liquidity_single_token_exact_out(
exact_bpt_amount_out,
total_supply,
swap_fee_percentage,
max_invariant_ratio,
compute_balance,
):
# Calculate new supply after minting exactBptamount_out
new_supply = exact_bpt_amount_out + total_supply

invariant_ratio = div_up_fixed(new_supply, total_supply)
# Add check for max invariant ratio
if invariant_ratio > max_invariant_ratio:
raise ValueError(
f"InvariantRatioAboveMax {invariant_ratio} {max_invariant_ratio}"
)

# Calculate the initial amount of the input token needed for the desired amount of BPT out
# "divUp" leads to a higher "new_balance," which in turn results in a larger "amountIn."
# This leads to receiving more tokens for the same amount of BTP minted.
new_balance = compute_balance(
current_balances, token_in_index, div_up_fixed(new_supply, total_supply)
)
new_balance = compute_balance(current_balances, token_in_index, invariant_ratio)
amount_in = new_balance - current_balances[token_in_index]

# Calculate the taxable amount, which is the difference
Expand Down Expand Up @@ -199,17 +212,26 @@ def compute_remove_liquidity_single_token_exact_in(
exact_bpt_amount_in,
total_supply,
swap_fee_percentage,
min_invariant_ratio,
compute_balance,
):
# // Calculate new supply accounting for burning exactBptAmountIn
new_supply = total_supply - exact_bpt_amount_in

invariant_ratio = div_up_fixed(new_supply, total_supply)
# Add check for min invariant ratio
if invariant_ratio < min_invariant_ratio:
raise ValueError(
f"InvariantRatioBelowMin {invariant_ratio} {min_invariant_ratio}"
)

# // Calculate the new balance of the output token after the BPT burn.
# // "divUp" leads to a higher "new_balance," which in turn results in a lower "amount_out."
# // This leads to giving less tokens for the same amount of BTP burned.
new_balance = compute_balance(
current_balances,
token_out_index,
div_up_fixed(new_supply, total_supply),
invariant_ratio,
)

# // Compute the amount to be withdrawn from the pool.
Expand Down Expand Up @@ -253,6 +275,7 @@ def compute_remove_liquidity_single_token_exact_out(
exact_amount_out,
total_supply,
swap_fee_percentage,
min_invariant_ratio,
compute_invariant,
):
# // Determine the number of tokens in the pool.
Expand All @@ -275,6 +298,12 @@ def compute_remove_liquidity_single_token_exact_out(
compute_invariant(new_balances, Rounding.ROUND_UP), current_invariant
)

# Add check for min invariant ratio
if invariant_ratio < min_invariant_ratio:
raise ValueError(
f"InvariantRatioBelowMin {invariant_ratio} {min_invariant_ratio}"
)

# Taxable amount is proportional to invariant ratio; a larger taxable amount rounds in the Vault's favor.
taxable_amount = (
mul_up_fixed(
Expand Down
8 changes: 8 additions & 0 deletions python/src/pools/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
compute_out_given_exact_in,
compute_in_given_exact_out,
compute_balance,
_MAX_INVARIANT_RATIO,
_MIN_INVARIANT_RATIO,
)
from src.swap import SwapKind

Expand All @@ -12,6 +14,12 @@ class Stable:
def __init__(self, pool_state):
self.amp = pool_state["amp"]

def get_maximum_invariant_ratio(self) -> int:
return _MAX_INVARIANT_RATIO

def get_minimum_invariant_ratio(self) -> int:
return _MIN_INVARIANT_RATIO

def on_swap(self, swap_params):
invariant = compute_invariant(self.amp, swap_params["balances_live_scaled18"])

Expand Down
5 changes: 5 additions & 0 deletions python/src/pools/stable_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
# we have chosen the rounding direction to favor the protocol in all cases.
AMP_PRECISION = int(1000)

# Invariant growth limit: non-proportional add cannot cause the invariant to increase by more than this ratio.
_MIN_INVARIANT_RATIO = int(60e16) # 60%
# Invariant shrink limit: non-proportional remove cannot cause the invariant to decrease by less than this ratio.
_MAX_INVARIANT_RATIO = int(500e16) # 500%


def compute_invariant(amplification_parameter: int, balances: list[int]) -> int:
"""
Expand Down
8 changes: 8 additions & 0 deletions python/src/pools/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
compute_invariant_up,
compute_invariant_down,
compute_balance_out_given_invariant,
_MAX_INVARIANT_RATIO,
_MIN_INVARIANT_RATIO,
)
from src.swap import SwapKind

Expand All @@ -13,6 +15,12 @@ class Weighted:
def __init__(self, pool_state):
self.normalized_weights = pool_state["weights"]

def get_maximum_invariant_ratio(self) -> int:
return _MAX_INVARIANT_RATIO

def get_minimum_invariant_ratio(self) -> int:
return _MIN_INVARIANT_RATIO

def on_swap(self, swap_params):
if swap_params["swap_kind"] == SwapKind.GIVENIN.value:
return compute_out_given_exact_in(
Expand Down
2 changes: 2 additions & 0 deletions python/src/remove_liquidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def remove_liquidity(
remove_liquidity_input["max_bpt_amount_in_raw"],
pool_state["totalSupply"],
pool_state["swapFee"],
pool_class.get_minimum_invariant_ratio(),
lambda balancesLiveScaled18, tokenIndex, invariantRatio: pool_class.compute_balance(
balancesLiveScaled18, tokenIndex, invariantRatio
),
Expand All @@ -92,6 +93,7 @@ def remove_liquidity(
amounts_out_scaled18[token_out_index],
pool_state["totalSupply"],
pool_state["swapFee"],
pool_class.get_minimum_invariant_ratio(),
lambda balances_live_scaled18, rounding: pool_class.compute_invariant(
balances_live_scaled18, rounding
),
Expand Down
6 changes: 6 additions & 0 deletions python/test/hooks/after_remove_liquidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ class CustomPool():
def __init__(self, pool_state):
self.pool_state = pool_state

def get_maximum_invariant_ratio(self) -> int:
return 1

def get_minimum_invariant_ratio(self) -> int:
return 1

def on_swap(self, swap_params):
return 1

Expand Down
6 changes: 6 additions & 0 deletions python/test/hooks/before_remove_liquidity.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class CustomPool():
def __init__(self, pool_state):
self.pool_state = pool_state

def get_maximum_invariant_ratio(self) -> int:
return 1

def get_minimum_invariant_ratio(self) -> int:
return 1

def on_swap(self, swap_params):
return 1

Expand Down

0 comments on commit 01c89cb

Please sign in to comment.