From bd5c71a2801b28dcf58742bc5796b40984efc031 Mon Sep 17 00:00:00 2001 From: Shuhui Luo <107524008+shuhuiluo@users.noreply.github.com> Date: Wed, 6 Nov 2024 02:11:54 -0500 Subject: [PATCH] refactor!: upgrade `thiserror` and refine error handling (#88) Upgraded `thiserror` version and made specific error messages clearer. Adjusted error variants and assertions to provide more meaningful feedback, enhancing clarity in debug and runtime scenarios. --- Cargo.toml | 6 ++-- src/entities/fractions/currency_amount.rs | 8 ++--- src/entities/fractions/fraction.rs | 2 +- src/entities/fractions/price.rs | 4 +-- src/error.rs | 40 ++++++++++------------ src/utils/sorted_insert.rs | 41 +++++++++++------------ src/utils/sqrt.rs | 2 +- 7 files changed, 47 insertions(+), 56 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2770fbe..84bcafc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "uniswap-sdk-core" -version = "3.0.0" +version = "3.1.0" edition = "2021" authors = ["malik ", "Shuhui Luo "] description = "The Uniswap SDK Core in Rust provides essential functionality for interacting with the Uniswap decentralized exchange" @@ -17,10 +17,10 @@ num-integer = "0.1" num-traits = "0.2" regex = { version = "1.11", optional = true } rustc-hash = "2.0" -thiserror = { version = "1.0", optional = true } +thiserror = { version = "2", default-features = false } [features] -std = ["thiserror"] +std = ["thiserror/std"] validate_parse_address = ["eth_checksum", "regex"] [lib] diff --git a/src/entities/fractions/currency_amount.rs b/src/entities/fractions/currency_amount.rs index 3827b27..141f00c 100644 --- a/src/entities/fractions/currency_amount.rs +++ b/src/entities/fractions/currency_amount.rs @@ -25,7 +25,7 @@ impl CurrencyAmount { let denominator = denominator.into(); // Ensure the amount does not exceed MAX_UINT256 if numerator.div_floor(&denominator) > *MAX_UINT256 { - return Err(Error::MaxUint); + return Err(Error::UintOverflow); } let exponent = currency.decimals(); Ok(FractionBase::new( @@ -88,7 +88,7 @@ impl CurrencyAmount { #[inline] pub fn add(&self, other: &Self) -> Result { if !self.currency.equals(&other.currency) { - return Err(Error::NotEqual); + return Err(Error::CurrencyMismatch); } let added = self.as_fraction() + other.as_fraction(); Self::from_fractional_amount(self.currency.clone(), added.numerator, added.denominator) @@ -98,7 +98,7 @@ impl CurrencyAmount { #[inline] pub fn subtract(&self, other: &Self) -> Result { if !self.currency.equals(&other.currency) { - return Err(Error::NotEqual); + return Err(Error::CurrencyMismatch); } let subtracted = self.as_fraction() - other.as_fraction(); Self::from_fractional_amount( @@ -123,7 +123,7 @@ impl CurrencyAmount { #[inline] pub fn to_fixed(&self, decimal_places: u8, rounding: Rounding) -> Result { if decimal_places > self.currency.decimals() { - return Err(Error::NotEqual); + return Err(Error::Invalid("DECIMALS")); } if decimal_places == 0 { diff --git a/src/entities/fractions/fraction.rs b/src/entities/fractions/fraction.rs index ac73b03..24508ba 100644 --- a/src/entities/fractions/fraction.rs +++ b/src/entities/fractions/fraction.rs @@ -123,7 +123,7 @@ pub trait FractionBase: Sized { #[inline] fn to_significant(&self, significant_digits: u8, rounding: Rounding) -> Result { if significant_digits == 0 { - return Err(Error::Invalid); + return Err(Error::Invalid("SIGNIFICANT_DIGITS")); } let rounding_strategy = to_rounding_strategy(rounding); let quotient = self.to_decimal().with_precision_round( diff --git a/src/entities/fractions/price.rs b/src/entities/fractions/price.rs index ceea6ef..fce0a50 100644 --- a/src/entities/fractions/price.rs +++ b/src/entities/fractions/price.rs @@ -83,7 +83,7 @@ where other: &Price, ) -> Result, Error> { if !self.quote_currency.equals(&other.base_currency) { - return Err(Error::NotEqual); + return Err(Error::CurrencyMismatch); } let fraction = self.as_fraction() * other.as_fraction(); @@ -102,7 +102,7 @@ where currency_amount: &CurrencyAmount, ) -> Result, Error> { if !currency_amount.currency.equals(&self.base_currency) { - return Err(Error::NotEqual); + return Err(Error::CurrencyMismatch); } let fraction = self.as_fraction() * currency_amount.as_fraction(); CurrencyAmount::from_fractional_amount( diff --git a/src/error.rs b/src/error.rs index dfde6d8..3cce5a6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,64 +1,58 @@ -/// Represents errors that can occur in the context of currency operations. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -#[cfg_attr(feature = "std", derive(thiserror::Error))] +/// Custom error types that are used throughout the SDK to handle various error conditions. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, thiserror::Error)] pub enum Error { /// Triggers when the compared chain IDs do not match. - #[cfg_attr(feature = "std", error("chain IDs do not match: {0} and {1}"))] + #[error("chain IDs do not match: {0} and {1}")] ChainIdMismatch(u64, u64), /// Triggers when compared addresses are the same. - #[cfg_attr(feature = "std", error("addresses are equal"))] + #[error("addresses are equal")] EqualAddresses, - /// Triggers when it tries to exceed the max uint. - #[cfg_attr(feature = "std", error("amount has exceeded MAX_UINT256"))] - MaxUint, + /// Triggers when it tries to exceed [`alloy_primitives::U256::MAX`]. + #[error("amount exceeds U256::MAX")] + UintOverflow, - /// Triggers when the compared values are not equal. - #[cfg_attr(feature = "std", error("not equal"))] - NotEqual, + /// Triggers when the currency values are not equal. + #[error("currency values are not equal")] + CurrencyMismatch, /// Triggers when the value is invalid. - #[cfg_attr(feature = "std", error("invalid"))] - Invalid, + #[error("{0}")] + Invalid(&'static str), } #[cfg(all(feature = "std", test))] mod tests { use super::*; - /// Test that `Error::ChainIdMismatch` displays the correct error message. #[test] fn test_chain_id_mismatch_error() { let error = Error::ChainIdMismatch(1, 2); assert_eq!(error.to_string(), "chain IDs do not match: 1 and 2"); } - /// Test that `Error::EqualAddresses` displays the correct error message. #[test] fn test_equal_addresses_error() { let error = Error::EqualAddresses; assert_eq!(error.to_string(), "addresses are equal"); } - /// Test that `Error::MaxUint` displays the correct error message. #[test] fn test_max_uint_error() { - let error = Error::MaxUint; - assert_eq!(error.to_string(), "amount has exceeded MAX_UINT256"); + let error = Error::UintOverflow; + assert_eq!(error.to_string(), "amount exceeds U256::MAX"); } - /// Test that `Error::NotEqual` displays the correct error message. #[test] fn test_not_equal_error() { - let error = Error::NotEqual; - assert_eq!(error.to_string(), "not equal"); + let error = Error::CurrencyMismatch; + assert_eq!(error.to_string(), "currency values are not equal"); } - /// Test that `Error::Invalid` displays the correct error message. #[test] fn test_incorrect_error() { - let error = Error::Invalid; + let error = Error::Invalid("invalid"); assert_eq!(error.to_string(), "invalid"); } } diff --git a/src/utils/sorted_insert.rs b/src/utils/sorted_insert.rs index de386d1..88abe53 100644 --- a/src/utils/sorted_insert.rs +++ b/src/utils/sorted_insert.rs @@ -8,20 +8,18 @@ pub fn sorted_insert( add: T, max_size: usize, comparator: fn(&T, &T) -> Ordering, -) -> Result, Error> { - if max_size == 0 { - return Err(Error::Invalid); - } - - if items.len() > max_size { - return Err(Error::Invalid); - } +) -> Option { + assert!(max_size > 0, "max_size must be greater than 0"); + assert!( + items.len() <= max_size, + "array length cannot exceed max_size" + ); let removed_item = if items.len() == max_size { match items.last() { Some(last) if comparator(&add, last) != Ordering::Greater => items.pop(), // short circuit if full and the additional item does not come before the last item - _ => return Ok(Some(add)), + _ => return Some(add), } } else { None @@ -32,7 +30,7 @@ pub fn sorted_insert( }; items.insert(pos, add); - Ok(removed_item) + removed_item } #[cfg(test)] @@ -55,73 +53,72 @@ mod tests { } #[test] - #[should_panic] + #[should_panic(expected = "array length cannot exceed max_size")] fn test_length_greater_than_max_size() { let mut arr = vec![1, 2]; - let _w = sorted_insert(&mut arr, 1, 1, cmp).unwrap(); - assert!(_w.is_none(), "array length cannot exceed max_size"); + sorted_insert(&mut arr, 1, 1, cmp); } #[test] fn test_add_if_empty() { let mut arr = Vec::new(); - assert_eq!(sorted_insert(&mut arr, 3, 2, cmp).unwrap(), None); + assert_eq!(sorted_insert(&mut arr, 3, 2, cmp), None); assert_eq!(arr, vec![3]); } #[test] fn test_add_if_not_full() { let mut arr = vec![1, 5]; - assert_eq!(sorted_insert(&mut arr, 3, 3, cmp).unwrap(), None); + assert_eq!(sorted_insert(&mut arr, 3, 3, cmp), None); assert_eq!(arr, vec![1, 3, 5]); } #[test] fn test_add_if_will_not_be_full_after() { let mut arr = vec![1]; - assert_eq!(sorted_insert(&mut arr, 0, 3, cmp).unwrap(), None); + assert_eq!(sorted_insert(&mut arr, 0, 3, cmp), None); assert_eq!(arr, vec![0, 1]); } #[test] fn test_return_add_if_sorts_after_last() { let mut arr = vec![1, 2, 3]; - assert_eq!(sorted_insert(&mut arr, 4, 3, cmp).unwrap(), Some(4)); + assert_eq!(sorted_insert(&mut arr, 4, 3, cmp), Some(4)); assert_eq!(arr, vec![1, 2, 3]); } #[test] fn test_remove_from_end_if_full() { let mut arr = vec![1, 3, 4]; - assert_eq!(sorted_insert(&mut arr, 2, 3, cmp).unwrap(), Some(4)); + assert_eq!(sorted_insert(&mut arr, 2, 3, cmp), Some(4)); assert_eq!(arr, vec![1, 2, 3]); } #[test] fn test_uses_comparator() { let mut arr = vec![4, 2, 1]; - assert_eq!(sorted_insert(&mut arr, 3, 3, reverse_cmp).unwrap(), Some(1)); + assert_eq!(sorted_insert(&mut arr, 3, 3, reverse_cmp), Some(1)); assert_eq!(arr, vec![4, 3, 2]); } #[test] fn test_max_size_of_1_empty_add() { let mut arr = Vec::new(); - assert_eq!(sorted_insert(&mut arr, 3, 1, cmp).unwrap(), None); + assert_eq!(sorted_insert(&mut arr, 3, 1, cmp), None); assert_eq!(arr, vec![3]); } #[test] fn test_max_size_of_1_full_add_greater() { let mut arr = vec![2]; - assert_eq!(sorted_insert(&mut arr, 3, 1, cmp).unwrap(), Some(3)); + assert_eq!(sorted_insert(&mut arr, 3, 1, cmp), Some(3)); assert_eq!(arr, vec![2]); } #[test] fn test_max_size_of_1_full_add_lesser() { let mut arr = vec![4]; - assert_eq!(sorted_insert(&mut arr, 3, 1, cmp).unwrap(), Some(4)); + assert_eq!(sorted_insert(&mut arr, 3, 1, cmp), Some(4)); assert_eq!(arr, vec![3]); } } diff --git a/src/utils/sqrt.rs b/src/utils/sqrt.rs index c54a58e..c028ac6 100644 --- a/src/utils/sqrt.rs +++ b/src/utils/sqrt.rs @@ -11,7 +11,7 @@ use num_traits::Signed; #[inline] pub fn sqrt(value: &BigInt) -> Result { if value.is_negative() { - Err(Error::Invalid) + Err(Error::Invalid("NEGATIVE")) } else { Ok(value.sqrt()) }