diff --git a/src/math/src/lib.cairo b/src/math/src/lib.cairo index 59d06263..b2110570 100644 --- a/src/math/src/lib.cairo +++ b/src/math/src/lib.cairo @@ -1,5 +1,10 @@ use option::OptionTrait; use traits::Into; +use integer::{ + u8_wide_mul, u16_wide_mul, u32_wide_mul, u64_wide_mul, u128_wide_mul, u256_overflow_mul, + BoundedInt +}; +use debug::PrintTrait; /// Raise a number to a power. /// O(log n) time complexity. @@ -56,7 +61,7 @@ impl U8BitShift of BitShift { } } fn shl(x: u8, n: u8) -> u8 { - x * BitShift::fpow(2, n) + (u8_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::::max().into()).try_into().unwrap() } fn shr(x: u8, n: u8) -> u8 { @@ -77,7 +82,9 @@ impl U16BitShift of BitShift { } } fn shl(x: u16, n: u16) -> u16 { - x * BitShift::fpow(2, n) + (u16_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::::max().into()) + .try_into() + .unwrap() } fn shr(x: u16, n: u16) -> u16 { @@ -98,7 +105,9 @@ impl U32BitShift of BitShift { } } fn shl(x: u32, n: u32) -> u32 { - x * BitShift::fpow(2, n) + (u32_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::::max().into()) + .try_into() + .unwrap() } fn shr(x: u32, n: u32) -> u32 { @@ -119,7 +128,9 @@ impl U64BitShift of BitShift { } } fn shl(x: u64, n: u64) -> u64 { - x * BitShift::fpow(2, n) + (u64_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::::max().into()) + .try_into() + .unwrap() } fn shr(x: u64, n: u64) -> u64 { @@ -140,7 +151,8 @@ impl U128BitShift of BitShift { } } fn shl(x: u128, n: u128) -> u128 { - x * BitShift::fpow(2, n) + let (_, bottom_word) = u128_wide_mul(x, BitShift::fpow(2, n)); + bottom_word } fn shr(x: u128, n: u128) -> u128 { @@ -161,7 +173,8 @@ impl U256BitShift of BitShift { } } fn shl(x: u256, n: u256) -> u256 { - x * BitShift::fpow(2, n) + let (r, _) = u256_overflow_mul(x, BitShift::fpow(2, n)); + r } fn shr(x: u256, n: u256) -> u256 { diff --git a/src/math/src/tests/math_test.cairo b/src/math/src/tests/math_test.cairo index fc9d132e..c9f06fd4 100644 --- a/src/math/src/tests/math_test.cairo +++ b/src/math/src/tests/math_test.cairo @@ -1,4 +1,5 @@ use alexandria_math::{pow, BitShift, count_digits_of_base}; +use integer::{BoundedInt}; // Test power function #[test] @@ -168,3 +169,14 @@ fn fpow_test() { fn fpow_test_u256() { assert(BitShift::fpow(3_u256, 8) == 6561, 'invalid result'); } + +#[test] +#[available_gas(2000000)] +fn shl_should_not_overflow() { + assert(BitShift::shl(BitShift::::fpow(2, 7), 1) == 0, 'invalid result'); + assert(BitShift::shl(BitShift::::fpow(2, 15), 1) == 0, 'invalid result'); + assert(BitShift::shl(BitShift::::fpow(2, 31), 1) == 0, 'invalid result'); + assert(BitShift::shl(BitShift::::fpow(2, 63), 1) == 0, 'invalid result'); + assert(BitShift::shl(BitShift::::fpow(2, 127), 1) == 0, 'invalid result'); + assert(BitShift::shl(BitShift::::fpow(2, 255), 1) == 0, 'invalid result'); +}