From af230b5402ca6cd4fe585df2dd3c9f27af2bb4be Mon Sep 17 00:00:00 2001 From: 0xTopaz Date: Mon, 16 Dec 2024 23:22:49 +0900 Subject: [PATCH] refactor: pool mint --- _deploy/r/gnoswap/consts/consts.gno | 12 +- pool/errors.gno | 4 +- pool/pool.gno | 17 +-- pool/pool_manager.gno | 93 +++++++++++---- pool/pool_manager_test.gno | 56 ++++++++++ pool/pool_transfer.gno | 44 +++++--- pool/pool_transfer_test.gno | 4 +- pool/protocol_fee_pool_creation.gno | 4 +- pool/protocol_fee_withdrawal.gno | 6 +- pool/swap.gno | 2 +- pool/utils.gno | 139 +++++++++++++++++++++-- pool/utils_test.gno | 168 +++++++++++++++++++++++++++- 12 files changed, 474 insertions(+), 75 deletions(-) diff --git a/_deploy/r/gnoswap/consts/consts.gno b/_deploy/r/gnoswap/consts/consts.gno index 9f9228e9f..3864049a7 100644 --- a/_deploy/r/gnoswap/consts/consts.gno +++ b/_deploy/r/gnoswap/consts/consts.gno @@ -6,9 +6,8 @@ import ( // GNOSWAP SERVICE const ( - ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" // Admin - DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" // DevOps - + ADMIN std.Address = "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d" + DEV_OPS std.Address = "g1mjvd83nnjee3z2g7683er55me9f09688pd4mj9" TOKEN_REGISTER std.Address = "g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" TOKEN_REGISTER_NAMESPACE string = "gno.land/r/g1er355fkjksqpdtwmhf5penwa82p0rhqxkkyhk5" @@ -21,7 +20,8 @@ const ( GNOT string = "gnot" WRAPPED_WUGNOT string = "gno.land/r/demo/wugnot" - UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + // defined in https://github.com/gnolang/gno/blob/81a88a2976ba9f2f9127ebbe7fb7d1e1f7fa4bd4/examples/gno.land/r/demo/wugnot/wugnot.gno#L19 + UGNOT_MIN_DEPOSIT_TO_WRAP uint64 = 1000 ) // CONTRACT PATH & ADDRESS @@ -91,9 +91,11 @@ const ( MAX_UINT128 string = "340282366920938463463374607431768211455" MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" - MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + MAX_INT128 string = "170141183460469231731687303715884105727" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" + // Tick Related MIN_TICK int32 = -887272 MAX_TICK int32 = 887272 diff --git a/pool/errors.gno b/pool/errors.gno index 11c4efa57..d06da6afc 100644 --- a/pool/errors.gno +++ b/pool/errors.gno @@ -32,7 +32,9 @@ var ( errInvalidTickAndTickSpacing = errors.New("[GNOSWAP-POOL-022] invalid tick and tick spacing requested") errInvalidAddress = errors.New("[GNOSWAP-POOL-023] invalid address") errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than tickUpper") - errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") // TODO: make as common error code + errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") + errOverFlow = errors.New("[GNOSWAP-POOL-026] overflow") + errBalanceUpdateFailed = errors.New("[GNOSWAP-POOL-027] balance update failed") ) // addDetailToError adds detail to an error message diff --git a/pool/pool.gno b/pool/pool.gno index db4df9b8e..786d0f0a7 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -22,7 +22,7 @@ func Mint( recipient std.Address, tickLower int32, tickUpper int32, - _liquidityAmount string, + liquidityAmount string, positionCaller std.Address, ) (string, string) { common.IsHalted() @@ -36,13 +36,14 @@ func Mint( } } - liquidityAmount := u256.MustFromDecimal(_liquidityAmount) - if liquidityAmount.IsZero() { + liquidity := u256.MustFromDecimal(liquidityAmount) + if liquidity.IsZero() { panic(errZeroLiquidity) } pool := GetPool(token0Path, token1Path, fee) - position := newModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) + liquidityDelta := safeConvertToInt128(liquidity) + position := newModifyPositionParams(recipient, tickLower, tickUpper, liquidityDelta) _, amount0, amount1 := pool.modifyPosition(position) if amount0.Gt(u256.Zero()) { @@ -198,7 +199,7 @@ func SetFeeProtocolByAdmin( newFee := setFeeProtocol(feeProtocol0, feeProtocol1) - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() std.Emit( "SetFeeProtocolByAdmin", "prevAddr", prevAddr, @@ -221,7 +222,7 @@ func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { newFee := setFeeProtocol(feeProtocol0, feeProtocol1) - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() std.Emit( "SetFeeProtocol", "prevAddr", prevAddr, @@ -323,7 +324,7 @@ func CollectProtocolByAdmin( amount1Requested, ) - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() std.Emit( "CollectProtocolByAdmin", "prevAddr", prevAddr, @@ -368,7 +369,7 @@ func CollectProtocol( amount1Requested, ) - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() std.Emit( "CollectProtocol", "prevAddr", prevAddr, diff --git a/pool/pool_manager.gno b/pool/pool_manager.gno index c8f5c25be..7ac80e2ed 100644 --- a/pool/pool_manager.gno +++ b/pool/pool_manager.gno @@ -130,12 +130,12 @@ func CreatePool( token0Path string, token1Path string, fee uint32, - _sqrtPriceX96 string, + sqrtPriceX96 string, ) { common.IsHalted() en.MintAndDistributeGns() - poolInfo := newPoolParams(token0Path, token1Path, fee, _sqrtPriceX96) + poolInfo := newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) if poolInfo.isSameTokenPath() { panic(addDetailToError( @@ -153,7 +153,7 @@ func CreatePool( poolPath := GetPoolPath(token0Path, token1Path, fee) // reinitialize poolInfo with wrapped tokens - poolInfo = newPoolParams(token0Path, token1Path, fee, _sqrtPriceX96) + poolInfo = newPoolParams(token0Path, token1Path, fee, sqrtPriceX96) // then check if token0Path == token1Path if poolInfo.isSameTokenPath() { @@ -174,7 +174,7 @@ func CreatePool( } // TODO: make this as a parameter - prevAddr, prevRealm := getPrev() + prevAddr, prevRealm := getPrevAsString() // check whether the pool already exist pool, exist := pools.Get(poolPath) @@ -208,7 +208,7 @@ func CreatePool( "token0Path", token0Path, "token1Path", token1Path, "fee", ufmt.Sprintf("%d", fee), - "sqrtPriceX96", _sqrtPriceX96, + "sqrtPriceX96", sqrtPriceX96, "internal_poolPath", poolPath, ) } @@ -220,46 +220,91 @@ func DoesPoolPathExist(poolPath string) bool { return exist } -// GetPool retrieves the pool for the given token paths and fee. -// It constructs the poolPath from the given parameters and returns the corresponding pool. -// Returns pool struct +// GetPool retrieves a pool instance based on the provided token paths and fee tier. +// +// This function determines the pool path by combining the paths of token0 and token1 along with the fee tier, +// and then retrieves the corresponding pool instance using that path. +// +// Parameters: +// - token0Path (string): The unique identifier or path for token0. +// - token1Path (string): The unique identifier or path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the provided tokens and fee tier. +// +// Notes: +// - The order of token paths (token0Path and token1Path) matters and should match the pool's configuration. +// - Ensure that the tokens and fee tier provided are valid and registered in the system. +// +// Example: +// pool := GetPool("path/to/token0", "path/to/token1", 3000) func GetPool(token0Path, token1Path string, fee uint32) *Pool { poolPath := GetPoolPath(token0Path, token1Path, fee) - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPool() || expected poolPath(%s) to exist", poolPath), - )) - } - - return pool + return GetPoolFromPoolPath(poolPath) } -// GetPoolFromPoolPath retrieves the pool for the given poolPath. +// GetPoolFromPoolPath retrieves a pool instance based on the provided pool path. +// +// This function checks if a pool exists for the given poolPath in the `pools` mapping. +// If the pool exists, it returns the pool instance. Otherwise, it panics with a descriptive error. +// +// Parameters: +// - poolPath (string): The unique identifier or path for the pool. +// +// Returns: +// - *Pool: A pointer to the Pool instance corresponding to the given poolPath. +// +// Panics: +// - If the `poolPath` does not exist in the `pools` mapping, it panics with an error message +// indicating that the expected poolPath was not found. +// +// Notes: +// - Ensure that the `poolPath` provided is valid and corresponds to an existing pool in the `pools` mapping. +// +// Example: +// pool := GetPoolFromPoolPath("path/to/pool") func GetPoolFromPoolPath(poolPath string) *Pool { pool, exist := pools[poolPath] if !exist { panic(addDetailToError( errDataNotFound, - ufmt.Sprintf("pool_manager.gno__GetPoolFromPoolPath() || expected poolPath(%s) to exist", poolPath), + ufmt.Sprintf("expected poolPath(%s) to exist", poolPath), )) } - return pool } -// GetPoolPath generates a poolPath from the given token paths and fee. -// The poolPath is constructed by joining the token paths and fee with colons. +// GetPoolPath generates a unique pool path string based on the token paths and fee tier. +// +// This function ensures that the token paths are registered and sorted in alphabetical order +// before combining them with the fee tier to create a unique identifier for the pool. +// +// Parameters: +// - token0Path (string): The unique identifier or path for token0. +// - token1Path (string): The unique identifier or path for token1. +// - fee (uint32): The fee tier for the pool, expressed in basis points (e.g., 3000 for 0.3%). +// +// Returns: +// - string: A unique pool path string in the format "token0Path:token1Path:fee". +// +// Notes: +// - The function validates that both `token0Path` and `token1Path` are registered in the system +// using `common.MustRegistered`. +// - The token paths are sorted alphabetically to ensure consistent pool path generation, regardless +// of the input order. +// - This sorting guarantees that the pool path remains deterministic for the same pair of tokens and fee. +// +// Example: +// poolPath := GetPoolPath("path/to/token0", "path/to/token1", 3000) +// // Output: "path/to/token0:path/to/token1:3000" func GetPoolPath(token0Path, token1Path string, fee uint32) string { common.MustRegistered(token0Path) common.MustRegistered(token1Path) - // TODO: this check is not unnecessary, if we are sure that // all the token paths in the pool are sorted in alphabetical order. if strings.Compare(token1Path, token0Path) < 0 { token0Path, token1Path = token1Path, token0Path } - return ufmt.Sprintf("%s:%s:%d", token0Path, token1Path, fee) } diff --git a/pool/pool_manager_test.gno b/pool/pool_manager_test.gno index 328315b56..36637a281 100644 --- a/pool/pool_manager_test.gno +++ b/pool/pool_manager_test.gno @@ -202,3 +202,59 @@ func TestCreatePool(t *testing.T) { resetObject(t) } + +func TestGetPool(t *testing.T) { + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "Panic - unregisterd poolPath ", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + action: func(t *testing.T) { + GetPool(barPath, fooPath, fee500) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-008] requested data not found || expected poolPath(gno.land/r/onbloc/bar:gno.land/r/onbloc/foo:500) to exist", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + } + }) + } +} diff --git a/pool/pool_transfer.gno b/pool/pool_transfer.gno index 201ae2cfb..2078f2f61 100644 --- a/pool/pool_transfer.gno +++ b/pool/pool_transfer.gno @@ -53,10 +53,7 @@ func (p *Pool) transferAndVerify( if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { panic(err) } - amountUint64, err := safeConvertToUint64(absAmount) - if err != nil { - panic(err) - } + amountUint64 := safeConvertToUint64(absAmount) token := common.GetTokenTeller(tokenPath) checkTransferError(token.Transfer(to, amountUint64)) @@ -98,20 +95,41 @@ func (p *Pool) transferFromAndVerify( amount *u256.Uint, isToken0 bool, ) { - absAmount := amount - amountUint64, err := safeConvertToUint64(absAmount) - if err != nil { - panic(err) - } + amountUint64 := safeConvertToUint64(amount) - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.TransferFrom(from, to, amountUint64)) + token := common.GetToken(tokenPath) + beforeBalance := token.BalanceOf(to) + + teller := common.GetTokenTeller(tokenPath) + checkTransferError(teller.TransferFrom(from, to, amountUint64)) + + afterBalance := token.BalanceOf(to) + if (beforeBalance + amountUint64) != afterBalance { + panic(ufmt.Sprintf( + "%v. beforeBalance(%d) + amount(%d) != afterBalance(%d)", + errTransferFailed, beforeBalance, amountUint64, afterBalance, + )) + } // update pool balances if isToken0 { - p.balances.token0 = new(u256.Uint).Add(p.balances.token0, absAmount) + beforeToken0 := p.balances.token0.Clone() + p.balances.token0 = new(u256.Uint).Add(p.balances.token0, amount) + if p.balances.token0.Lt(beforeToken0) { + panic(ufmt.Sprintf( + "%v. token0(%s) < beforeToken0(%s)", + errBalanceUpdateFailed, p.balances.token0.ToString(), beforeToken0.ToString(), + )) + } } else { - p.balances.token1 = new(u256.Uint).Add(p.balances.token1, absAmount) + beforeToken1 := p.balances.token1.Clone() + p.balances.token1 = new(u256.Uint).Add(p.balances.token1, amount) + if p.balances.token1.Lt(beforeToken1) { + panic(ufmt.Sprintf( + "%v. token1(%s) < beforeToken1(%s)", + errBalanceUpdateFailed, p.balances.token1.ToString(), beforeToken1.ToString(), + )) + } } } diff --git a/pool/pool_transfer_test.gno b/pool/pool_transfer_test.gno index 562825b9a..d114003ab 100644 --- a/pool/pool_transfer_test.gno +++ b/pool/pool_transfer_test.gno @@ -134,8 +134,8 @@ func TestTransferFromAndVerify(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - TokenFaucet(t, fooPath, pusers.AddressOrName(tt.from)) - TokenApprove(t, fooPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) + TokenFaucet(t, tt.tokenPath, pusers.AddressOrName(tt.from)) + TokenApprove(t, tt.tokenPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) diff --git a/pool/protocol_fee_pool_creation.gno b/pool/protocol_fee_pool_creation.gno index a56284298..c6e4e06ae 100644 --- a/pool/protocol_fee_pool_creation.gno +++ b/pool/protocol_fee_pool_creation.gno @@ -32,7 +32,7 @@ func SetPoolCreationFee(fee uint64) { } setPoolCreationFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetPoolCreationFee", "prevAddr", prevAddr, @@ -54,7 +54,7 @@ func SetPoolCreationFeeByAdmin(fee uint64) { } setPoolCreationFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetPoolCreationFeeByAdmin", "prevAddr", prevAddr, diff --git a/pool/protocol_fee_withdrawal.gno b/pool/protocol_fee_withdrawal.gno index cacc602c7..d7d208b9b 100644 --- a/pool/protocol_fee_withdrawal.gno +++ b/pool/protocol_fee_withdrawal.gno @@ -72,7 +72,7 @@ func HandleWithdrawalFee( token1Teller := common.GetTokenTeller(token1Path) checkTransferError(token1Teller.TransferFrom(positionCaller, consts.PROTOCOL_FEE_ADDR, feeAmount1.Uint64())) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "WithdrawalFee", "prevAddr", prevAddr, @@ -106,7 +106,7 @@ func SetWithdrawalFee(fee uint64) { setWithdrawalFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetWithdrawalFee", "prevAddr", prevAddr, @@ -126,7 +126,7 @@ func SetWithdrawalFeeByAdmin(fee uint64) { setWithdrawalFee(fee) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "SetWithdrawalFeeByAdmin", "prevAddr", prevAddr, diff --git a/pool/swap.gno b/pool/swap.gno index 47c5e2d67..65512e123 100644 --- a/pool/swap.gno +++ b/pool/swap.gno @@ -107,7 +107,7 @@ func Swap( // actual swap pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) - prevAddr, prevPkgPath := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Swap", diff --git a/pool/utils.gno b/pool/utils.gno index e45116217..719e04275 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -4,24 +4,102 @@ import ( "std" "gno.land/p/demo/ufmt" - pusers "gno.land/p/demo/users" - + i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/consts" ) -func safeConvertToUint64(value *u256.Uint) (uint64, error) { +// safeConvertToUint64 safely converts a *u256.Uint value to a uint64, ensuring no overflow. +// +// This function attempts to convert the given *u256.Uint value to a uint64. If the value exceeds +// the maximum allowable range for uint64 (`2^64 - 1`), it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - uint64: The converted value if it falls within the uint64 range. +// +// Panics: +// - If the `value` exceeds the range of uint64, the function will panic with an error indicating +// the overflow and the original value. +// +// Notes: +// - This function uses the `Uint64WithOverflow` method to detect overflow during the conversion. +// - It is essential to validate large values before calling this function to avoid unexpected panics. +// +// Example: +// safeValue := safeConvertToUint64(u256.MustFromDecimal("18446744073709551615")) // Valid conversion +// safeConvertToUint64(u256.MustFromDecimal("18446744073709551616")) // Panics due to overflow +func safeConvertToUint64(value *u256.Uint) uint64 { res, overflow := value.Uint64WithOverflow() if overflow { - return 0, ufmt.Errorf( + panic(ufmt.Sprintf( "%v: amount(%s) overflows uint64 range", - errOutOfRange, value.ToString(), - ) + errOutOfRange, value.ToString())) } + return res +} - return res, nil +// safeConvertToInt128 safely converts a *u256.Uint value to an *i256.Int, ensuring it does not exceed the int128 range. +// +// This function converts an unsigned 256-bit integer (*u256.Uint) into a signed 256-bit integer (*i256.Int). +// It checks whether the resulting value falls within the valid range of int128 (`-2^127` to `2^127 - 1`). +// If the value exceeds the maximum allowable int128 range, it triggers a panic with a descriptive error message. +// +// Parameters: +// - value (*u256.Uint): The unsigned 256-bit integer to be converted. +// +// Returns: +// - *i256.Int: The converted value if it falls within the int128 range. +// +// Panics: +// - If the converted value exceeds the maximum int128 value (`2^127 - 1`), the function will panic with an +// error message indicating the overflow and the original value. +// +// Notes: +// - The function uses `i256.FromUint256` to perform the conversion. +// - The constant `MAX_INT128` is used to define the upper bound of the int128 range (`170141183460469231731687303715884105727`). +// +// Example: +// validInt128 := safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105727")) // Valid conversion +// safeConvertToInt128(u256.MustFromDecimal("170141183460469231731687303715884105728")) // Panics due to overflow +func safeConvertToInt128(value *u256.Uint) *i256.Int { + liquidityDelta := i256.FromUint256(value) + if liquidityDelta.Gt(i256.MustFromDecimal(consts.MAX_INT128)) { + panic(ufmt.Sprintf( + "%v: amount(%s) overflows int128 range", + errOverFlow, value.ToString())) + } + return liquidityDelta } +// a2u converts a std.Address to a pusers.AddressOrName, ensuring the input address is valid. +// +// This function takes a `std.Address` and verifies its validity. If the address is invalid, +// the function triggers a panic with an appropriate error message. For valid addresses, +// it performs the conversion to `pusers.AddressOrName`. +// +// Parameters: +// - addr (std.Address): The input address to be converted. +// +// Returns: +// - pusers.AddressOrName: The converted address, wrapped as a `pusers.AddressOrName` type. +// +// Panics: +// - If the provided `addr` is invalid, the function will panic with an error indicating +// the invalid address. +// +// Notes: +// - The function relies on the `addr.IsValid()` method to determine the validity of the input address. +// - It uses `addDetailToError` to provide additional context for the error message when an invalid +// address is encountered. +// +// Example: +// converted := a2u(std.Address("validAddress")) // Successful conversion +// a2u(std.Address("")) // Panics due to invalid address func a2u(addr std.Address) pusers.AddressOrName { if !addr.IsValid() { panic(addDetailToError( @@ -32,23 +110,60 @@ func a2u(addr std.Address) pusers.AddressOrName { return pusers.AddressOrName(addr) } +// u256Min returns the smaller of two *u256.Uint values. +// +// This function compares two unsigned 256-bit integers and returns the smaller of the two. +// If `num1` is less than `num2`, it returns `num1`; otherwise, it returns `num2`. +// +// Parameters: +// - num1 (*u256.Uint): The first unsigned 256-bit integer. +// - num2 (*u256.Uint): The second unsigned 256-bit integer. +// +// Returns: +// - *u256.Uint: The smaller of `num1` and `num2`. +// +// Notes: +// - This function uses the `Lt` (less than) method of `*u256.Uint` to perform the comparison. +// - The function assumes both input values are non-nil. If nil inputs are possible in the usage context, +// additional validation may be needed. +// +// Example: +// smaller := u256Min(u256.MustFromDecimal("10"), u256.MustFromDecimal("20")) // Returns 10 +// smaller := u256Min(u256.MustFromDecimal("30"), u256.MustFromDecimal("20")) // Returns 20 func u256Min(num1, num2 *u256.Uint) *u256.Uint { if num1.Lt(num2) { return num1 } - return num2 } -func isUserCall() bool { - return std.PrevRealm().IsUser() +// derivePkgAddr derives the Realm address from it's pkgPath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) } -func getPrev() (string, string) { - prev := std.PrevRealm() +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrevAsString returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { + prev := getPrevRealm() return prev.Addr().String(), prev.PkgPath() } +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkTransferError checks transfer error. func checkTransferError(err error) { if err != nil { panic(addDetailToError( diff --git a/pool/utils_test.gno b/pool/utils_test.gno index 623cc4aef..a4b4a6388 100644 --- a/pool/utils_test.gno +++ b/pool/utils_test.gno @@ -6,10 +6,11 @@ import ( "gno.land/p/demo/testutils" "gno.land/p/demo/uassert" - + pusers "gno.land/p/demo/users" u256 "gno.land/p/gnoswap/uint256" "gno.land/r/demo/users" + "gno.land/r/gnoswap/v1/consts" ) func TestA2U(t *testing.T) { @@ -125,7 +126,7 @@ func TestIsUserCall(t *testing.T) { } } -func TestGetPrev(t *testing.T) { +func TestGetPrevAsString(t *testing.T) { tests := []struct { name string action func() (string, string) @@ -137,7 +138,7 @@ func TestGetPrev(t *testing.T) { action: func() (string, string) { userRealm := std.NewUserRealm(std.Address("user")) std.TestSetRealm(userRealm) - return getPrev() + return getPrevAsString() }, expectedAddr: "user", expectedPkgPath: "", @@ -147,7 +148,7 @@ func TestGetPrev(t *testing.T) { action: func() (string, string) { codeRealm := std.NewCodeRealm("gno.land/r/demo/realm") std.TestSetRealm(codeRealm) - return getPrev() + return getPrevAsString() }, expectedAddr: std.DerivePkgAddr("gno.land/r/demo/realm").String(), expectedPkgPath: "gno.land/r/demo/realm", @@ -162,3 +163,162 @@ func TestGetPrev(t *testing.T) { }) } } + +func TestSafeConvertToUint64(t *testing.T) { + tests := []struct { + name string + value *u256.Uint + wantRes uint64 + wantPanic bool + }{ + {"normal conversion", u256.NewUint(123), 123, false}, + {"overflow", u256.MustFromDecimal(consts.MAX_UINT128), 0, true}, + {"max uint64", u256.NewUint(1<<64 - 1), 1<<64 - 1, false}, + {"zero", u256.NewUint(0), 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToUint64(tt.value) + if res != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestSafeConvertToInt128(t *testing.T) { + tests := []struct { + name string + value string + wantRes string + wantPanic bool + }{ + {"normal conversion", "170141183460469231731687303715884105727", "170141183460469231731687303715884105727", false}, + {"overflow", "170141183460469231731687303715884105728", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if !tt.wantPanic { + t.Errorf("unexpected panic: %v", r) + } + return + } + if tt.wantPanic { + t.Errorf("expected panic, but none occurred") + } + }() + + res := safeConvertToInt128(u256.MustFromDecimal(tt.value)) + if res.ToString() != tt.wantRes { + t.Errorf("safeConvertToUint64() = %v, want %v", res, tt.wantRes) + } + }) + } +} + +func TestA2u(t *testing.T) { + var ( + addr = std.Address("g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c") + ) + + tests := []struct { + name string + input std.Address + expected pusers.AddressOrName + }{ + { + name: "Success - a2u", + input: addr, + expected: pusers.AddressOrName(addr), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := a2u(tc.input) + uassert.Equal(t, users.Resolve(got).String(), users.Resolve(tc.expected).String()) + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + var ( + pkgPath = "gno.land/r/gnoswap/v1/position" + ) + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +}