Skip to content

Commit

Permalink
fix BitShift::shl overflows (#190)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

## Pull Request type

<!-- Please try to limit your pull request to one type; submit multiple
pull requests if needed. -->

Please check the type of change your PR introduces:

- [x] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no API changes)
- [ ] Build-related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

Current implementation will return in premature overflows, for example:
`BitShift::shl(255_u8, 1)` results in `u8_mul Overflow` while it should
just return `254`.

Issue Number: N/A

## What is the new behavior?

`BitShift::shl(255_u8, 1) == 254` etc

<!-- Please describe the behavior or changes that are being added by
this PR. -->

-
-
-

## Does this introduce a breaking change?

- [ ] Yes
- [ ] No

<!-- If this does introduce a breaking change, please describe the
impact and migration path for existing applications below. -->

## Other information

<!-- Any other information that is important to this PR, such as
screenshots of how the component looks before and after the change. -->
  • Loading branch information
maciejka authored Oct 11, 2023
1 parent 23f0357 commit 3356bf0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/math/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use option::OptionTrait;
use traits::Into;
use integer::{
u8_wide_mul, u16_wide_mul, u32_wide_mul, u64_wide_mul, u128_wide_mul, u256_overflow_mul,
BoundedInt
};
use debug::PrintTrait;

/// Raise a number to a power.
/// O(log n) time complexity.
Expand Down Expand Up @@ -56,7 +61,7 @@ impl U8BitShift of BitShift<u8> {
}
}
fn shl(x: u8, n: u8) -> u8 {
x * BitShift::fpow(2, n)
(u8_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u8>::max().into()).try_into().unwrap()
}

fn shr(x: u8, n: u8) -> u8 {
Expand All @@ -77,7 +82,9 @@ impl U16BitShift of BitShift<u16> {
}
}
fn shl(x: u16, n: u16) -> u16 {
x * BitShift::fpow(2, n)
(u16_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u16>::max().into())
.try_into()
.unwrap()
}

fn shr(x: u16, n: u16) -> u16 {
Expand All @@ -98,7 +105,9 @@ impl U32BitShift of BitShift<u32> {
}
}
fn shl(x: u32, n: u32) -> u32 {
x * BitShift::fpow(2, n)
(u32_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u32>::max().into())
.try_into()
.unwrap()
}

fn shr(x: u32, n: u32) -> u32 {
Expand All @@ -119,7 +128,9 @@ impl U64BitShift of BitShift<u64> {
}
}
fn shl(x: u64, n: u64) -> u64 {
x * BitShift::fpow(2, n)
(u64_wide_mul(x, BitShift::fpow(2, n)) & BoundedInt::<u64>::max().into())
.try_into()
.unwrap()
}

fn shr(x: u64, n: u64) -> u64 {
Expand All @@ -140,7 +151,8 @@ impl U128BitShift of BitShift<u128> {
}
}
fn shl(x: u128, n: u128) -> u128 {
x * BitShift::fpow(2, n)
let (_, bottom_word) = u128_wide_mul(x, BitShift::fpow(2, n));
bottom_word
}

fn shr(x: u128, n: u128) -> u128 {
Expand All @@ -161,7 +173,8 @@ impl U256BitShift of BitShift<u256> {
}
}
fn shl(x: u256, n: u256) -> u256 {
x * BitShift::fpow(2, n)
let (r, _) = u256_overflow_mul(x, BitShift::fpow(2, n));
r
}

fn shr(x: u256, n: u256) -> u256 {
Expand Down
12 changes: 12 additions & 0 deletions src/math/src/tests/math_test.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alexandria_math::{pow, BitShift, count_digits_of_base};
use integer::{BoundedInt};

// Test power function
#[test]
Expand Down Expand Up @@ -168,3 +169,14 @@ fn fpow_test() {
fn fpow_test_u256() {
assert(BitShift::fpow(3_u256, 8) == 6561, 'invalid result');
}

#[test]
#[available_gas(2000000)]
fn shl_should_not_overflow() {
assert(BitShift::shl(BitShift::<u8>::fpow(2, 7), 1) == 0, 'invalid result');
assert(BitShift::shl(BitShift::<u16>::fpow(2, 15), 1) == 0, 'invalid result');
assert(BitShift::shl(BitShift::<u32>::fpow(2, 31), 1) == 0, 'invalid result');
assert(BitShift::shl(BitShift::<u64>::fpow(2, 63), 1) == 0, 'invalid result');
assert(BitShift::shl(BitShift::<u128>::fpow(2, 127), 1) == 0, 'invalid result');
assert(BitShift::shl(BitShift::<u256>::fpow(2, 255), 1) == 0, 'invalid result');
}

0 comments on commit 3356bf0

Please sign in to comment.