Skip to content

Commit

Permalink
refactor!: upgrade thiserror and refine error handling (#88)
Browse files Browse the repository at this point in the history
Upgraded `thiserror` version and made specific error messages clearer. Adjusted error variants and assertions to provide more meaningful feedback, enhancing clarity in debug and runtime scenarios.
  • Loading branch information
shuhuiluo authored Nov 6, 2024
1 parent f94c2a1 commit bd5c71a
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 56 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "uniswap-sdk-core"
version = "3.0.0"
version = "3.1.0"
edition = "2021"
authors = ["malik <[email protected]>", "Shuhui Luo <twitter.com/aureliano_law>"]
description = "The Uniswap SDK Core in Rust provides essential functionality for interacting with the Uniswap decentralized exchange"
Expand All @@ -17,10 +17,10 @@ num-integer = "0.1"
num-traits = "0.2"
regex = { version = "1.11", optional = true }
rustc-hash = "2.0"
thiserror = { version = "1.0", optional = true }
thiserror = { version = "2", default-features = false }

[features]
std = ["thiserror"]
std = ["thiserror/std"]
validate_parse_address = ["eth_checksum", "regex"]

[lib]
Expand Down
8 changes: 4 additions & 4 deletions src/entities/fractions/currency_amount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl<T: BaseCurrency> CurrencyAmount<T> {
let denominator = denominator.into();
// Ensure the amount does not exceed MAX_UINT256
if numerator.div_floor(&denominator) > *MAX_UINT256 {
return Err(Error::MaxUint);
return Err(Error::UintOverflow);
}
let exponent = currency.decimals();
Ok(FractionBase::new(
Expand Down Expand Up @@ -88,7 +88,7 @@ impl<T: BaseCurrency> CurrencyAmount<T> {
#[inline]
pub fn add(&self, other: &Self) -> Result<Self, Error> {
if !self.currency.equals(&other.currency) {
return Err(Error::NotEqual);
return Err(Error::CurrencyMismatch);
}
let added = self.as_fraction() + other.as_fraction();
Self::from_fractional_amount(self.currency.clone(), added.numerator, added.denominator)
Expand All @@ -98,7 +98,7 @@ impl<T: BaseCurrency> CurrencyAmount<T> {
#[inline]
pub fn subtract(&self, other: &Self) -> Result<Self, Error> {
if !self.currency.equals(&other.currency) {
return Err(Error::NotEqual);
return Err(Error::CurrencyMismatch);
}
let subtracted = self.as_fraction() - other.as_fraction();
Self::from_fractional_amount(
Expand All @@ -123,7 +123,7 @@ impl<T: BaseCurrency> CurrencyAmount<T> {
#[inline]
pub fn to_fixed(&self, decimal_places: u8, rounding: Rounding) -> Result<String, Error> {
if decimal_places > self.currency.decimals() {
return Err(Error::NotEqual);
return Err(Error::Invalid("DECIMALS"));
}

if decimal_places == 0 {
Expand Down
2 changes: 1 addition & 1 deletion src/entities/fractions/fraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub trait FractionBase<M: Clone>: Sized {
#[inline]
fn to_significant(&self, significant_digits: u8, rounding: Rounding) -> Result<String, Error> {
if significant_digits == 0 {
return Err(Error::Invalid);
return Err(Error::Invalid("SIGNIFICANT_DIGITS"));
}
let rounding_strategy = to_rounding_strategy(rounding);
let quotient = self.to_decimal().with_precision_round(
Expand Down
4 changes: 2 additions & 2 deletions src/entities/fractions/price.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ where
other: &Price<TQuote, TOtherQuote>,
) -> Result<Price<TBase, TOtherQuote>, Error> {
if !self.quote_currency.equals(&other.base_currency) {
return Err(Error::NotEqual);
return Err(Error::CurrencyMismatch);
}

let fraction = self.as_fraction() * other.as_fraction();
Expand All @@ -102,7 +102,7 @@ where
currency_amount: &CurrencyAmount<TBase>,
) -> Result<CurrencyAmount<TQuote>, Error> {
if !currency_amount.currency.equals(&self.base_currency) {
return Err(Error::NotEqual);
return Err(Error::CurrencyMismatch);
}
let fraction = self.as_fraction() * currency_amount.as_fraction();
CurrencyAmount::from_fractional_amount(
Expand Down
40 changes: 17 additions & 23 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,58 @@
/// Represents errors that can occur in the context of currency operations.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
#[cfg_attr(feature = "std", derive(thiserror::Error))]
/// Custom error types that are used throughout the SDK to handle various error conditions.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, thiserror::Error)]
pub enum Error {
/// Triggers when the compared chain IDs do not match.
#[cfg_attr(feature = "std", error("chain IDs do not match: {0} and {1}"))]
#[error("chain IDs do not match: {0} and {1}")]
ChainIdMismatch(u64, u64),

/// Triggers when compared addresses are the same.
#[cfg_attr(feature = "std", error("addresses are equal"))]
#[error("addresses are equal")]
EqualAddresses,

/// Triggers when it tries to exceed the max uint.
#[cfg_attr(feature = "std", error("amount has exceeded MAX_UINT256"))]
MaxUint,
/// Triggers when it tries to exceed [`alloy_primitives::U256::MAX`].
#[error("amount exceeds U256::MAX")]
UintOverflow,

/// Triggers when the compared values are not equal.
#[cfg_attr(feature = "std", error("not equal"))]
NotEqual,
/// Triggers when the currency values are not equal.
#[error("currency values are not equal")]
CurrencyMismatch,

/// Triggers when the value is invalid.
#[cfg_attr(feature = "std", error("invalid"))]
Invalid,
#[error("{0}")]
Invalid(&'static str),
}

#[cfg(all(feature = "std", test))]
mod tests {
use super::*;

/// Test that `Error::ChainIdMismatch` displays the correct error message.
#[test]
fn test_chain_id_mismatch_error() {
let error = Error::ChainIdMismatch(1, 2);
assert_eq!(error.to_string(), "chain IDs do not match: 1 and 2");
}

/// Test that `Error::EqualAddresses` displays the correct error message.
#[test]
fn test_equal_addresses_error() {
let error = Error::EqualAddresses;
assert_eq!(error.to_string(), "addresses are equal");
}

/// Test that `Error::MaxUint` displays the correct error message.
#[test]
fn test_max_uint_error() {
let error = Error::MaxUint;
assert_eq!(error.to_string(), "amount has exceeded MAX_UINT256");
let error = Error::UintOverflow;
assert_eq!(error.to_string(), "amount exceeds U256::MAX");
}

/// Test that `Error::NotEqual` displays the correct error message.
#[test]
fn test_not_equal_error() {
let error = Error::NotEqual;
assert_eq!(error.to_string(), "not equal");
let error = Error::CurrencyMismatch;
assert_eq!(error.to_string(), "currency values are not equal");
}

/// Test that `Error::Invalid` displays the correct error message.
#[test]
fn test_incorrect_error() {
let error = Error::Invalid;
let error = Error::Invalid("invalid");
assert_eq!(error.to_string(), "invalid");
}
}
41 changes: 19 additions & 22 deletions src/utils/sorted_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ pub fn sorted_insert<T: Clone>(
add: T,
max_size: usize,
comparator: fn(&T, &T) -> Ordering,
) -> Result<Option<T>, Error> {
if max_size == 0 {
return Err(Error::Invalid);
}

if items.len() > max_size {
return Err(Error::Invalid);
}
) -> Option<T> {
assert!(max_size > 0, "max_size must be greater than 0");
assert!(
items.len() <= max_size,
"array length cannot exceed max_size"
);

let removed_item = if items.len() == max_size {
match items.last() {
Some(last) if comparator(&add, last) != Ordering::Greater => items.pop(),
// short circuit if full and the additional item does not come before the last item
_ => return Ok(Some(add)),
_ => return Some(add),
}
} else {
None
Expand All @@ -32,7 +30,7 @@ pub fn sorted_insert<T: Clone>(
};

items.insert(pos, add);
Ok(removed_item)
removed_item
}

#[cfg(test)]
Expand All @@ -55,73 +53,72 @@ mod tests {
}

#[test]
#[should_panic]
#[should_panic(expected = "array length cannot exceed max_size")]
fn test_length_greater_than_max_size() {
let mut arr = vec![1, 2];
let _w = sorted_insert(&mut arr, 1, 1, cmp).unwrap();
assert!(_w.is_none(), "array length cannot exceed max_size");
sorted_insert(&mut arr, 1, 1, cmp);
}

#[test]
fn test_add_if_empty() {
let mut arr = Vec::new();
assert_eq!(sorted_insert(&mut arr, 3, 2, cmp).unwrap(), None);
assert_eq!(sorted_insert(&mut arr, 3, 2, cmp), None);
assert_eq!(arr, vec![3]);
}

#[test]
fn test_add_if_not_full() {
let mut arr = vec![1, 5];
assert_eq!(sorted_insert(&mut arr, 3, 3, cmp).unwrap(), None);
assert_eq!(sorted_insert(&mut arr, 3, 3, cmp), None);
assert_eq!(arr, vec![1, 3, 5]);
}

#[test]
fn test_add_if_will_not_be_full_after() {
let mut arr = vec![1];
assert_eq!(sorted_insert(&mut arr, 0, 3, cmp).unwrap(), None);
assert_eq!(sorted_insert(&mut arr, 0, 3, cmp), None);
assert_eq!(arr, vec![0, 1]);
}

#[test]
fn test_return_add_if_sorts_after_last() {
let mut arr = vec![1, 2, 3];
assert_eq!(sorted_insert(&mut arr, 4, 3, cmp).unwrap(), Some(4));
assert_eq!(sorted_insert(&mut arr, 4, 3, cmp), Some(4));
assert_eq!(arr, vec![1, 2, 3]);
}

#[test]
fn test_remove_from_end_if_full() {
let mut arr = vec![1, 3, 4];
assert_eq!(sorted_insert(&mut arr, 2, 3, cmp).unwrap(), Some(4));
assert_eq!(sorted_insert(&mut arr, 2, 3, cmp), Some(4));
assert_eq!(arr, vec![1, 2, 3]);
}

#[test]
fn test_uses_comparator() {
let mut arr = vec![4, 2, 1];
assert_eq!(sorted_insert(&mut arr, 3, 3, reverse_cmp).unwrap(), Some(1));
assert_eq!(sorted_insert(&mut arr, 3, 3, reverse_cmp), Some(1));
assert_eq!(arr, vec![4, 3, 2]);
}

#[test]
fn test_max_size_of_1_empty_add() {
let mut arr = Vec::new();
assert_eq!(sorted_insert(&mut arr, 3, 1, cmp).unwrap(), None);
assert_eq!(sorted_insert(&mut arr, 3, 1, cmp), None);
assert_eq!(arr, vec![3]);
}

#[test]
fn test_max_size_of_1_full_add_greater() {
let mut arr = vec![2];
assert_eq!(sorted_insert(&mut arr, 3, 1, cmp).unwrap(), Some(3));
assert_eq!(sorted_insert(&mut arr, 3, 1, cmp), Some(3));
assert_eq!(arr, vec![2]);
}

#[test]
fn test_max_size_of_1_full_add_lesser() {
let mut arr = vec![4];
assert_eq!(sorted_insert(&mut arr, 3, 1, cmp).unwrap(), Some(4));
assert_eq!(sorted_insert(&mut arr, 3, 1, cmp), Some(4));
assert_eq!(arr, vec![3]);
}
}
2 changes: 1 addition & 1 deletion src/utils/sqrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use num_traits::Signed;
#[inline]
pub fn sqrt(value: &BigInt) -> Result<BigInt, Error> {
if value.is_negative() {
Err(Error::Invalid)
Err(Error::Invalid("NEGATIVE"))
} else {
Ok(value.sqrt())
}
Expand Down

0 comments on commit bd5c71a

Please sign in to comment.