Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(blockifier): add Secp256r1 cairo native syscalls #1675

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 109 additions & 24 deletions crates/blockifier/src/execution/native/syscall_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::hash::RandomState;
use std::sync::Arc;

use ark_ec::short_weierstrass::{Affine, Projective, SWCurveConfig};
use ark_ff::PrimeField;
use ark_ff::{BigInt, PrimeField};
use cairo_native::starknet::{
BlockInfo,
ExecutionInfo,
Expand All @@ -18,7 +18,7 @@ use cairo_native::starknet::{
TxV2Info,
U256,
};
use cairo_native::starknet_stub::u256_to_biguint;
use num_bigint::BigUint;
use starknet_api::contract_class::EntryPointType;
use starknet_api::core::{
calculate_contract_address,
Expand Down Expand Up @@ -266,7 +266,7 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> {

match syscall_base::get_block_hash_base(self.context, block_number, self.state) {
Ok(value) => Ok(value),
Err(e) => Err(self.handle_error(remaining_gas, e.into())),
Err(e) => Err(self.handle_error(remaining_gas, e)),
}
}

Expand Down Expand Up @@ -652,46 +652,65 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> {

fn secp256r1_new(
&mut self,
_x: U256,
_y: U256,
_remaining_gas: &mut u64,
x: U256,
y: U256,
remaining_gas: &mut u64,
) -> SyscallResult<Option<Secp256r1Point>> {
todo!("Implement secp256r1_new syscall.");
self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256r1_new_gas_cost)?;

Secp256Point::new(x, y)
.map(|option| option.map(|p| p.into()))
.map_err(|err| self.handle_error(remaining_gas, err))
}

fn secp256r1_add(
&mut self,
_p0: Secp256r1Point,
_p1: Secp256r1Point,
_remaining_gas: &mut u64,
p0: Secp256r1Point,
p1: Secp256r1Point,
remaining_gas: &mut u64,
) -> SyscallResult<Secp256r1Point> {
todo!("Implement secp256r1_add syscall.");
self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256r1_add_gas_cost)?;
Ok(Secp256Point::add(p0.into(), p1.into()).into())
}

fn secp256r1_mul(
&mut self,
_p: Secp256r1Point,
_m: U256,
_remaining_gas: &mut u64,
p: Secp256r1Point,
m: U256,
remaining_gas: &mut u64,
) -> SyscallResult<Secp256r1Point> {
todo!("Implement secp256r1_mul syscall.");
self.pre_execute_syscall(remaining_gas, self.context.gas_costs().secp256r1_mul_gas_cost)?;

Ok(Secp256Point::mul(p.into(), m).into())
}

fn secp256r1_get_point_from_x(
&mut self,
_x: U256,
_y_parity: bool,
_remaining_gas: &mut u64,
x: U256,
y_parity: bool,
remaining_gas: &mut u64,
) -> SyscallResult<Option<Secp256r1Point>> {
todo!("Implement secp256r1_get_point_from_x syscall.");
self.pre_execute_syscall(
remaining_gas,
self.context.gas_costs().secp256r1_get_point_from_x_gas_cost,
)?;

Secp256Point::get_point_from_x(x, y_parity)
.map(|option| option.map(|p| p.into()))
.map_err(|err| self.handle_error(remaining_gas, err))
}

fn secp256r1_get_xy(
&mut self,
_p: Secp256r1Point,
_remaining_gas: &mut u64,
p: Secp256r1Point,
remaining_gas: &mut u64,
) -> SyscallResult<(U256, U256)> {
todo!("Implement secp256r1_get_xy syscall.");
self.pre_execute_syscall(
remaining_gas,
self.context.gas_costs().secp256r1_get_xy_gas_cost,
)?;

Ok((p.x, p.y))
}

fn sha256_process_block(
Expand Down Expand Up @@ -725,9 +744,46 @@ impl<'state> StarknetSyscallHandler for &mut NativeSyscallHandler<'state> {
/// secp256k1 and secp256r1 curves through the generic `Curve` parameter.
#[derive(PartialEq, Clone, Copy)]
struct Secp256Point<Curve: SWCurveConfig>(Affine<Curve>);
impl From<Secp256Point<ark_secp256k1::Config>> for Secp256k1Point {
fn from(Secp256Point(Affine { x, y, infinity }): Secp256Point<ark_secp256k1::Config>) -> Self {
Secp256k1Point {
x: big4int_to_u256(x.into()),
y: big4int_to_u256(y.into()),
is_infinity: infinity,
}
}
}

impl From<Secp256Point<ark_secp256r1::Config>> for Secp256r1Point {
fn from(Secp256Point(Affine { x, y, infinity }): Secp256Point<ark_secp256r1::Config>) -> Self {
Secp256r1Point {
x: big4int_to_u256(x.into()),
y: big4int_to_u256(y.into()),
is_infinity: infinity,
}
}
}

impl From<Secp256k1Point> for Secp256Point<ark_secp256k1::Config> {
fn from(p: Secp256k1Point) -> Self {
Secp256Point(Affine {
x: u256_to_big4int(p.x).into(),
y: u256_to_big4int(p.y).into(),
infinity: p.is_infinity,
})
}
}

impl From<Secp256r1Point> for Secp256Point<ark_secp256r1::Config> {
fn from(p: Secp256r1Point) -> Self {
Secp256Point(Affine {
x: u256_to_big4int(p.x).into(),
y: u256_to_big4int(p.y).into(),
infinity: p.is_infinity,
})
}
}

// todo(xrvdg) remove dead_code annotation after adding syscalls
#[allow(dead_code)]
impl<Curve: SWCurveConfig> Secp256Point<Curve>
where
Curve::BaseField: PrimeField, // constraint for get_point_by_id
Expand Down Expand Up @@ -780,6 +836,35 @@ impl<Curve: SWCurveConfig> fmt::Debug for Secp256Point<Curve> {
}
}

fn u256_to_biguint(u256: U256) -> BigUint {
let lo = BigUint::from(u256.lo);
let hi = BigUint::from(u256.hi);

(hi << 128) + lo
}

fn big4int_to_u256(b_int: BigInt<4>) -> U256 {
let [a, b, c, d] = b_int.0;

let lo = u128::from(a) | (u128::from(b) << 64);
let hi = u128::from(c) | (u128::from(d) << 64);

U256 { lo, hi }
}

fn u256_to_big4int(u256: U256) -> BigInt<4> {
fn to_u64s(bytes: [u8; 16]) -> (u64, u64) {
let lo_bytes: [u8; 8] = bytes[0..8].try_into().expect("Take high bytes");
let lo: u64 = u64::from_le_bytes(lo_bytes);
let hi_bytes: [u8; 8] = bytes[8..16].try_into().expect("Take low bytes");
let hi: u64 = u64::from_le_bytes(hi_bytes);
(lo, hi)
}
let (hi_lo, hi_hi) = to_u64s(u256.hi.to_le_bytes());
let (lo_lo, lo_hi) = to_u64s(u256.lo.to_le_bytes());
BigInt::new([lo_lo, lo_hi, hi_lo, hi_hi])
}

#[cfg(test)]
mod test {
use cairo_native::starknet::U256;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ fn test_secp256k1(test_contract: FeatureContract) {
);
}

#[cfg_attr(
feature = "cairo_native",
test_case(FeatureContract::TestContract(CairoVersion::Native); "Native")
)]
#[test_case(FeatureContract::TestContract(CairoVersion::Cairo1); "VM")]
fn test_secp256r1(test_contract: FeatureContract) {
let chain_info = &ChainInfo::create_for_testing();
Expand Down
Loading