From 0b33db0bb90591121a395a9943d6a23bdd683a11 Mon Sep 17 00:00:00 2001 From: Marcus Liotta Date: Wed, 25 Oct 2023 15:45:56 +0200 Subject: [PATCH] adds bit rotation to math --- src/math/src/lib.cairo | 97 ++++++++++++++++++++++++++++++ src/math/src/tests/math_test.cairo | 48 ++++++++++++++- 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/math/src/lib.cairo b/src/math/src/lib.cairo index ff778d81..b08edd8f 100644 --- a/src/math/src/lib.cairo +++ b/src/math/src/lib.cairo @@ -111,6 +111,103 @@ impl U256BitShift of BitShift { } } +trait BitRotate { + fn rotl(x: T, n: T) -> T; + fn rotr(x: T, n: T) -> T; +} + +impl U8BitRotate of BitRotate { + fn rotl(x: u8, n: u8) -> u8 { + let word = u8_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem(word, 0x100_u16.try_into().unwrap()); + (quotient + remainder).try_into().unwrap() + } + + fn rotr(x: u8, n: u8) -> u8 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 8 - n) + quotient + } +} + +impl U16BitRotate of BitRotate { + fn rotl(x: u16, n: u16) -> u16 { + let word = u16_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem(word, 0x10000_u32.try_into().unwrap()); + (quotient + remainder).try_into().unwrap() + } + + fn rotr(x: u16, n: u16) -> u16 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 16 - n) + quotient + } +} + +impl U32BitRotate of BitRotate { + fn rotl(x: u32, n: u32) -> u32 { + let word = u32_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem(word, 0x100000000_u64.try_into().unwrap()); + (quotient + remainder).try_into().unwrap() + } + + fn rotr(x: u32, n: u32) -> u32 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 32 - n) + quotient + } +} + +impl U64BitRotate of BitRotate { + fn rotl(x: u64, n: u64) -> u64 { + let word = u64_wide_mul(x, pow(2, n)); + let (quotient, remainder) = DivRem::div_rem( + word, 0x10000000000000000_u128.try_into().unwrap() + ); + (quotient + remainder).try_into().unwrap() + } + + fn rotr(x: u64, n: u64) -> u64 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 64 - n) + quotient + } +} + +impl U128BitRotate of BitRotate { + fn rotl(x: u128, n: u128) -> u128 { + let (high, low) = u128_wide_mul(x, pow(2, n)); + let word = u256 { low, high }; + let (quotient, remainder) = DivRem::div_rem( + word, u256 { low: 0, high: 1 }.try_into().unwrap() + ); + (quotient + remainder).try_into().unwrap() + } + + fn rotr(x: u128, n: u128) -> u128 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 128 - n) + quotient + } +} + +impl U256BitRotate of BitRotate { + fn rotl(x: u256, n: u256) -> u256 { + // TODO(sveamarcus): missing non-zero implementation for u512 + // let word = u256_wide_mul(x, pow(2, n)); + // let (quotient, remainder) = DivRem::div_rem(word, + // u512_as_non_zero(u512{limb0: 0, limb1: 0, limb2: 1, limb3: 0 })); + // (quotient + remainder).try_into().unwrap() + panic_with_felt252('missing impl') + } + + fn rotr(x: u256, n: u256) -> u256 { + let step = pow(2, n); + let (quotient, remainder) = DivRem::div_rem(x, step.try_into().unwrap()); + remainder * pow(2, 256 - n) + quotient + } +} + mod aliquot_sum; mod armstrong_number; mod collatz_sequence; diff --git a/src/math/src/tests/math_test.cairo b/src/math/src/tests/math_test.cairo index 5262151d..46f136db 100644 --- a/src/math/src/tests/math_test.cairo +++ b/src/math/src/tests/math_test.cairo @@ -1,4 +1,4 @@ -use alexandria_math::{pow, BitShift, count_digits_of_base}; +use alexandria_math::{pow, BitShift, BitRotate, count_digits_of_base}; use integer::{BoundedInt}; // Test power function @@ -170,3 +170,49 @@ fn shl_should_not_overflow() { assert(BitShift::shl(pow::(2, 127), 1) == 0, 'invalid result'); assert(BitShift::shl(pow::(2, 255), 1) == 0, 'invalid result'); } + +#[test] +#[available_gas(2000000)] +fn test_rotl_min() { + assert(BitRotate::rotl(pow::(2, 7) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotl(pow::(2, 15) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotl(pow::(2, 31) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotl(pow::(2, 63) + 1, 1) == 3, 'invalid result'); + assert(BitRotate::rotl(pow::(2, 127) + 1, 1) == 3, 'invalid result'); +// TODO(sveamarcus): missing implementation +// assert(BitRotate::rotl(pow::(2, 255) + 1, 1) == 3, 'invalid result'); +} + +#[test] +#[available_gas(2000000)] +fn test_rotl_max() { + assert(BitRotate::rotl(0b101, 7) == pow::(2, 7) + 0b10, 'invalid result'); + assert(BitRotate::rotl(0b101, 15) == pow::(2, 15) + 0b10, 'invalid result'); + assert(BitRotate::rotl(0b101, 31) == pow::(2, 31) + 0b10, 'invalid result'); + assert(BitRotate::rotl(0b101, 63) == pow::(2, 63) + 0b10, 'invalid result'); + assert(BitRotate::rotl(0b101, 127) == pow::(2, 127) + 0b10, 'invalid result'); +// TODO(sveamarcus): missing implementation +// assert(BitRotate::rotl(0b101, 255) == pow::(2, 255) + 0b10, 'invalid result'); +} + +#[test] +#[available_gas(4000000)] +fn test_rotr_min() { + assert(BitRotate::rotr(pow::(2, 7) + 1, 1) == 0b11 * pow(2, 6), 'invalid result'); + assert(BitRotate::rotr(pow::(2, 15) + 1, 1) == 0b11 * pow(2, 14), 'invalid result'); + assert(BitRotate::rotr(pow::(2, 31) + 1, 1) == 0b11 * pow(2, 30), 'invalid result'); + assert(BitRotate::rotr(pow::(2, 63) + 1, 1) == 0b11 * pow(2, 62), 'invalid result'); + assert(BitRotate::rotr(pow::(2, 127) + 1, 1) == 0b11 * pow(2, 126), 'invalid result'); + assert(BitRotate::rotr(pow::(2, 255) + 1, 1) == 0b11 * pow(2, 254), 'invalid result'); +} + +#[test] +#[available_gas(2000000)] +fn test_rotr_max() { + assert(BitRotate::rotr(0b101_u8, 7) == 0b1010, 'invalid result'); + assert(BitRotate::rotr(0b101_u16, 15) == 0b1010, 'invalid result'); + assert(BitRotate::rotr(0b101_u32, 31) == 0b1010, 'invalid result'); + assert(BitRotate::rotr(0b101_u64, 63) == 0b1010, 'invalid result'); + assert(BitRotate::rotr(0b101_u128, 127) == 0b1010, 'invalid result'); + assert(BitRotate::rotr(0b101_u256, 255) == 0b1010, 'invalid result'); +}