Skip to content

Commit

Permalink
perf: optim sha512 🚀 🚀 🚀 (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
shramee authored May 29, 2024
1 parent 6d42b70 commit 7bd1d4e
Showing 1 changed file with 154 additions and 49 deletions.
203 changes: 154 additions & 49 deletions src/math/src/sha512.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,45 @@ pub const SHA512_LEN: usize = 64;

pub const U64_BIT_NUM: u64 = 64;

// Powers of two to avoid recomputing
pub const TWO_POW_56: u64 = 0x100000000000000;
pub const TWO_POW_48: u64 = 0x1000000000000;
pub const TWO_POW_40: u64 = 0x10000000000;
pub const TWO_POW_32: u64 = 0x100000000;
pub const TWO_POW_24: u64 = 0x1000000;
pub const TWO_POW_16: u64 = 0x10000;
pub const TWO_POW_8: u64 = 0x100;
pub const TWO_POW_4: u64 = 0x10;
pub const TWO_POW_2: u64 = 0x4;
pub const TWO_POW_1: u64 = 0x2;
pub const TWO_POW_0: u64 = 0x1;

const TWO_POW_7: u64 = 0x80;
const TWO_POW_14: u64 = 0x4000;
const TWO_POW_18: u64 = 0x40000;
const TWO_POW_19: u64 = 0x80000;
const TWO_POW_28: u64 = 0x10000000;
const TWO_POW_34: u64 = 0x400000000;
const TWO_POW_39: u64 = 0x8000000000;
const TWO_POW_41: u64 = 0x20000000000;
const TWO_POW_61: u64 = 0x2000000000000000;

const TWO_POW_64_MINUS_1: u64 = 0x8000000000000000;
const TWO_POW_64_MINUS_6: u64 = 0x40;
const TWO_POW_64_MINUS_8: u64 = 0x100000000000000;
const TWO_POW_64_MINUS_14: u64 = 0x4000000000000;
const TWO_POW_64_MINUS_18: u64 = 0x400000000000;
const TWO_POW_64_MINUS_19: u64 = 0x200000000000;
const TWO_POW_64_MINUS_28: u64 = 0x1000000000;
const TWO_POW_64_MINUS_34: u64 = 0x40000000;
const TWO_POW_64_MINUS_39: u64 = 0x2000000;
const TWO_POW_64_MINUS_41: u64 = 0x800000;
const TWO_POW_64_MINUS_61: u64 = 0x8;

// Max u8 and u64 for bitwise operations
pub const MAX_U8: u64 = 0xff;
pub const MAX_U64: u128 = 0xffffffffffffffff;

#[derive(Drop, Copy)]
pub struct Word64 {
pub data: u64,
Expand Down Expand Up @@ -46,10 +85,23 @@ impl WordAdd of Add<Word64> {
}
}

impl U128IntoWord of Into<u128, Word64> {
fn into(self: u128) -> Word64 {
Word64 { data: self.try_into().unwrap() }
}
}

impl U64IntoWord of Into<u64, Word64> {
fn into(self: u64) -> Word64 {
Word64 { data: self }
}
}

pub trait WordOperations<T> {
fn shr(self: T, n: u64) -> T;
fn shl(self: T, n: u64) -> T;
fn rotr(self: T, n: u64) -> T;
fn rotr_precomputed(self: T, two_pow_n: u64, two_pow_64_n: u64) -> T;
fn rotl(self: T, n: u64) -> T;
}

Expand All @@ -66,6 +118,21 @@ pub impl Word64WordOperations of WordOperations<Word64> {
);
Word64 { data }
}
// does the work of rotr but with precomputed values 2**n and 2**(64-n)
fn rotr_precomputed(self: Word64, two_pow_n: u64, two_pow_64_n: u64) -> Word64 {
let data = self.data.into();
let data: u128 = BitOr::bitor(
math_shr_precomputed::<u128>(data, two_pow_n.into()),
math_shl_precomputed::<u128>(data, two_pow_64_n.into())
);

let data: u64 = match data.try_into() {
Option::Some(data) => data,
Option::None => (data & MAX_U64).try_into().unwrap()
};

Word64 { data }
}
fn rotl(self: Word64, n: u64) -> Word64 {
let data = BitOr::bitor(
math_shl_u64(self.data, n), math_shr_u64(self.data, (U64_BIT_NUM - n))
Expand All @@ -83,22 +150,43 @@ fn maj(x: Word64, y: Word64, z: Word64) -> Word64 {
(x & y) ^ (x & z) ^ (y & z)
}

/// Performs x.rotr(28) ^ x.rotr(34) ^ x.rotr(39),
/// Using precomputed values to avoid recomputation
fn bsig0(x: Word64) -> Word64 {
x.rotr(28) ^ x.rotr(34) ^ x.rotr(39)
// x.rotr(28) ^ x.rotr(34) ^ x.rotr(39)
x.rotr_precomputed(TWO_POW_28, TWO_POW_64_MINUS_28)
^ x.rotr_precomputed(TWO_POW_34, TWO_POW_64_MINUS_34)
^ x.rotr_precomputed(TWO_POW_39, TWO_POW_64_MINUS_39)
}

/// Performs x.rotr(14) ^ x.rotr(18) ^ x.rotr(41),
/// Using precomputed values to avoid recomputation
fn bsig1(x: Word64) -> Word64 {
x.rotr(14) ^ x.rotr(18) ^ x.rotr(41)
// x.rotr(14) ^ x.rotr(18) ^ x.rotr(41)
x.rotr_precomputed(TWO_POW_14, TWO_POW_64_MINUS_14)
^ x.rotr_precomputed(TWO_POW_18, TWO_POW_64_MINUS_18)
^ x.rotr_precomputed(TWO_POW_41, TWO_POW_64_MINUS_41)
}

/// Performs x.rotr(1) ^ x.rotr(8) ^ x.shr(7),
/// Using precomputed values to avoid recomputation
fn ssig0(x: Word64) -> Word64 {
x.rotr(1) ^ x.rotr(8) ^ x.shr(7)
// x.rotr(1) ^ x.rotr(8) ^ x.shr(7)
x.rotr_precomputed(TWO_POW_1, TWO_POW_64_MINUS_1)
^ x.rotr_precomputed(TWO_POW_8, TWO_POW_64_MINUS_8)
^ math_shr_precomputed::<u64>(x.data.into(), TWO_POW_7).into() // 2 ** 7
}

/// Performs x.rotr(19) ^ x.rotr(61) ^ x.shr(6),
/// Using precomputed values to avoid recomputation
fn ssig1(x: Word64) -> Word64 {
x.rotr(19) ^ x.rotr(61) ^ x.shr(6)
// x.rotr(19) ^ x.rotr(61) ^ x.shr(6)
x.rotr_precomputed(TWO_POW_19, TWO_POW_64_MINUS_19)
^ x.rotr_precomputed(TWO_POW_61, TWO_POW_64_MINUS_61)
^ math_shr_precomputed::<u64>(x.data, TWO_POW_64_MINUS_6).into() // 2 ** 6
}

/// Calculates base ** power
pub fn fpow(mut base: u128, mut power: u128) -> u128 {
// Return invalid input error
assert!(base != 0, "fpow: invalid input");
Expand All @@ -116,18 +204,17 @@ pub fn fpow(mut base: u128, mut power: u128) -> u128 {
result
}

// uses cache for faster powers of 2 in a u128
pub fn two_pow(mut power: u128) -> u128 {
let two_squarings: Array<u128> = array![
0x2, 0x4, 0x10, 0x100, 0x10000, 0x100000000, 0x10000000000000000
];

// Uses cache for faster powers of 2 in a u128
// Uses TWO_POW_* constants
// Generic T to use with both u128 and u64
pub fn two_pow<T, +DivRem<T>, +Mul<T>, +Into<u64, T>, +Drop<T>>(mut power: u64) -> T {
let two_squarings = array![TWO_POW_1, TWO_POW_2, TWO_POW_4, TWO_POW_8, TWO_POW_16, TWO_POW_32,];
let mut i = 0;
let mut result = 1;
let mut result: T = 1_u64.into();
while (power != 0) {
let (q, r) = DivRem::div_rem(power, 2);
if r == 1 {
result = result * *two_squarings[i];
result = result * (*two_squarings[i]).into();
}
i = i + 1;
power = q;
Expand All @@ -136,27 +223,40 @@ pub fn two_pow(mut power: u128) -> u128 {
result
}

fn math_shl(x: u128, n: u128) -> u128 {
x * two_pow(n) % BoundedInt::max()
// Shift left with math_shl_precomputed function
fn math_shl(x: u128, n: u64) -> u128 {
math_shl_precomputed(x, two_pow(n))
}

// Shift right with math_shr_precomputed function
fn math_shr(x: u128, n: u64) -> u128 {
math_shr_precomputed(x, two_pow(n))
}

fn math_shr(x: u128, n: u128) -> u128 {
x / two_pow(n) % BoundedInt::max()
// Shift left with precomputed powers of 2
fn math_shl_precomputed<T, +Mul<T>, +Rem<T>, +Drop<T>, +Copy<T>, +Into<T, u128>>(
x: T, two_power_n: T
) -> T {
x * two_power_n
}

fn math_shr_precomputed(x: u128, two_power_n: u128) -> u128 {
x / two_power_n % BoundedInt::max()
// Shift right with precomputed powers of 2
fn math_shr_precomputed<T, +Div<T>, +Rem<T>, +Drop<T>, +Copy<T>, +Into<T, u128>>(
x: T, two_power_n: T
) -> T {
x / two_power_n
}

// Shift left wrapper for u64
fn math_shl_u64(x: u64, n: u64) -> u64 {
(math_shl(x.into(), n.into()) % BoundedInt::<u64>::max().into()).try_into().unwrap()
(math_shl(x.into(), n) % BoundedInt::<u64>::max().into()).try_into().unwrap()
}

// Shift right wrapper for u64
fn math_shr_u64(x: u64, n: u64) -> u64 {
(math_shr(x.into(), n.into()) % BoundedInt::<u64>::max().into()).try_into().unwrap()
(math_shr(x.into(), n) % BoundedInt::<u64>::max().into()).try_into().unwrap()
}


fn add_trailing_zeroes(ref data: Array<u8>, msg_len: usize) {
let mdi = msg_len % 128;
let padding_len = if (mdi < 112) {
Expand All @@ -176,16 +276,18 @@ fn from_u8Array_to_WordArray(data: Array<u8>) -> Array<Word64> {
let mut new_arr: Array<Word64> = array![];
let mut i = 0;

// Use precomputed powers of 2 for shift left to avoid recomputation
// Safe to use u64 coz we shift u8 to the left by max 56 bits in u64
while (i < data.len()) {
let new_word: u128 = (BitShift::shl((*data[i + 0]).into(), 56)
+ BitShift::shl((*data[i + 1]).into(), 48)
+ BitShift::shl((*data[i + 2]).into(), 40)
+ BitShift::shl((*data[i + 3]).into(), 32)
+ BitShift::shl((*data[i + 4]).into(), 24)
+ BitShift::shl((*data[i + 5]).into(), 16)
+ BitShift::shl((*data[i + 6]).into(), 8)
+ BitShift::shl((*data[i + 7]).into(), 0));
new_arr.append(Word64 { data: new_word.try_into().unwrap() });
let new_word: u64 = math_shl_precomputed::<u64>((*data[i + 0]).into(), TWO_POW_56)
+ math_shl_precomputed((*data[i + 1]).into(), TWO_POW_48)
+ math_shl_precomputed((*data[i + 2]).into(), TWO_POW_40)
+ math_shl_precomputed((*data[i + 3]).into(), TWO_POW_32)
+ math_shl_precomputed((*data[i + 4]).into(), TWO_POW_24)
+ math_shl_precomputed((*data[i + 5]).into(), TWO_POW_16)
+ math_shl_precomputed((*data[i + 6]).into(), TWO_POW_8)
+ math_shl_precomputed((*data[i + 7]).into(), TWO_POW_0);
new_arr.append(Word64 { data: new_word });
i += 8;
};
new_arr
Expand All @@ -195,23 +297,23 @@ fn from_WordArray_to_u8array(data: Span<Word64>) -> Array<u8> {
let mut arr: Array<u8> = array![];

let mut i = 0;
// Use precomputed powers of 2 for shift right to avoid recomputation
while (i != data.len()) {
let mut res: u128 = BitShift::shr((*data.at(i).data).into(), 56)
& BoundedInt::<u8>::max().into();
let mut res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_56) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 48) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_48) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 40) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_40) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 32) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_32) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 24) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_24) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 16) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_16) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 8) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_8) & MAX_U8;
arr.append(res.try_into().unwrap());
res = BitShift::shr((*data.at(i).data).into(), 0) & BoundedInt::<u8>::max().into();
res = math_shr_precomputed((*data.at(i).data).into(), TWO_POW_0) & MAX_U8;
arr.append(res.try_into().unwrap());
i += 1;
};
Expand Down Expand Up @@ -291,9 +393,11 @@ fn digest_hash(data: Span<Word64>, msg_len: usize) -> Array<Word64> {
}

pub fn sha512(mut data: Array<u8>) -> Array<u8> {
let bit_numbers: u128 = (data.len() * 8).into();
let bit_numbers = bit_numbers & BoundedInt::<u64>::max().into();
let bit_numbers: u128 = data.len().into() * 8;
// any u32 * 8 fits in u64
// let bit_numbers = bit_numbers & BoundedInt::<u64>::max().into();

let max_u8: u128 = MAX_U8.into();
let mut msg_len = data.len();

// Appends 1
Expand All @@ -302,21 +406,22 @@ pub fn sha512(mut data: Array<u8>) -> Array<u8> {
add_trailing_zeroes(ref data, msg_len);

// add length to the end
let mut res: u128 = math_shr(bit_numbers, 56).into() & BoundedInt::<u8>::max().into();
// Use precomputed powers of 2 for shift right to avoid recomputation
let mut res: u128 = math_shr_precomputed(bit_numbers, TWO_POW_56.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 48).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_48.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 40).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_40.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 32).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_32.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 24).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_24.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 16).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_16.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 8).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_8.into()) & max_u8;
data.append(res.try_into().unwrap());
res = math_shr(bit_numbers, 0).into() & BoundedInt::<u8>::max().into();
res = math_shr_precomputed(bit_numbers, TWO_POW_0.into()) & max_u8;
data.append(res.try_into().unwrap());

msg_len = data.len();
Expand Down

0 comments on commit 7bd1d4e

Please sign in to comment.