Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: establish proof system version handling for onchain prover #91

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 37 additions & 11 deletions onchain/bonsol/src/actions/status.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use crate::{
assertions::*,
error::ChannelError,
proof_handling::{output_digest, prepare_inputs, verify_risc0},
proof_handling::{output_digest_v1_0_1, prepare_inputs_v1_0_1, verify_risc0_v1_0_1},
utilities::*,
};

use bonsol_interface::{
bonsol_schema::{root_as_execution_request_v1, ChannelInstruction, ExitCode, StatusV1},
bonsol_schema::{
root_as_execution_request_v1, ChannelInstruction, ExecutionRequestV1, ExitCode, StatusV1,
},
prover_version::{ProverVersion, VERSION_V1_0_1},
util::execution_address_seeds,
};

use solana_program::{
account_info::AccountInfo,
clock::Clock,
Expand Down Expand Up @@ -94,15 +99,7 @@ pub fn process_status_v1<'a>(
er.input_digest()
.map(|x| check_bytes_match(x.bytes(), input_digest, ChannelError::InputsDontMatch));
}
let output_digest = output_digest(input_digest, co, asud);
let proof_inputs = prepare_inputs(
er.image_id().unwrap(),
exed,
output_digest.as_ref(),
st.exit_code_system(),
st.exit_code_user(),
)?;
let verified = verify_risc0(proof, &proof_inputs)?;
let verified = verify_with_prover(input_digest, co, asud, er, exed, st, proof)?;
let tip = er.tip();
if verified {
let callback_program_set =
Expand Down Expand Up @@ -185,3 +182,32 @@ pub fn process_status_v1<'a>(
}
Ok(())
}

fn verify_with_prover(
input_digest: &[u8],
co: &[u8],
asud: &[u8],
er: ExecutionRequestV1,
exed: &[u8],
st: StatusV1,
proof: &[u8; 256],
) -> Result<bool, ProgramError> {
let prover_version =
ProverVersion::try_from(er.prover_version()).unwrap_or(ProverVersion::default());

let verified = match prover_version {
VERSION_V1_0_1 => {
let output_digest = output_digest_v1_0_1(input_digest, co, asud);
let proof_inputs = prepare_inputs_v1_0_1(
er.image_id().unwrap(),
exed,
output_digest.as_ref(),
st.exit_code_system(),
st.exit_code_user(),
)?;
verify_risc0_v1_0_1(proof, &proof_inputs)?
}
_ => false,
};
Ok(verified)
}
2 changes: 2 additions & 0 deletions onchain/bonsol/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub enum ChannelError {
InvalidExecutionId,
#[error("Invalid Execution Account Owner")]
InvalidExecutionAccountOwner,
#[error("Unexpected Proof System")]
UnexpectedProofSystem,
}

impl From<ChannelError> for ProgramError {
Expand Down
8 changes: 6 additions & 2 deletions onchain/bonsol/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
#![allow(clippy::arithmetic_side_effects)]
#![cfg_attr(not(test), forbid(unsafe_code))]
use solana_program::declare_id;

pub mod actions;
mod assertions;
pub mod error;
pub mod program;
pub mod proof_handling;
pub mod prover;
pub mod utilities;

mod assertions;
mod verifying_key;

use solana_program::declare_id;

declare_id!("BoNsHRcyLLNdtnoDf8hiCNZpyehMC4FDMxs6NTxFi3ew");

#[cfg(not(feature = "no-entrypoint"))]
Expand Down
210 changes: 166 additions & 44 deletions onchain/bonsol/src/proof_handling.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
use crate::{error::ChannelError, verifying_key::VERIFYINGKEY};
use std::ops::Neg;

use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, Validate};
use groth16_solana::groth16::Groth16Verifier;
use hex_literal::hex;
use solana_program::hash::hashv;
use std::ops::Neg;
type G1 = ark_bn254::g1::G1Affine;

fn sized_range<const N: usize>(slice: &[u8]) -> Result<[u8; N], ChannelError> {
slice
.try_into()
.map_err(|_| ChannelError::InvalidInstruction)
}
use crate::{
error::ChannelError,
prover::{Groth16Prover, PROVER_CONSTANTS_V1_0_1},
verifying_key::VERIFYINGKEY,
};

type G1 = ark_bn254::g1::G1Affine;

fn change_endianness(bytes: &[u8]) -> Vec<u8> {
let mut vec = Vec::new();
for b in bytes.chunks(32) {
for byte in b.iter().rev() {
vec.push(*byte);
}
pub fn verify_risc0(
proof: &[u8],
inputs: &[u8],
groth16_prover: Groth16Prover,
) -> Result<bool, ChannelError> {
match groth16_prover {
Groth16Prover::V1_0_1 => verify_risc0_v1_0_1(proof, inputs),
_ => Err(ChannelError::UnexpectedProofSystem),
}
vec
}

pub fn verify_risc0(proof: &[u8], inputs: &[u8]) -> Result<bool, ChannelError> {
let ace: Vec<u8> = change_endianness(&*[&proof[0..64], &[0u8][..]].concat());
pub fn verify_risc0_v1_0_1(proof: &[u8], inputs: &[u8]) -> Result<bool, ChannelError> {
let ace: Vec<u8> = toggle_endianness_256(&*[&proof[0..64], &[0u8][..]].concat());
let proof_a: G1 = G1::deserialize_with_mode(&*ace, Compress::No, Validate::No).unwrap();

let mut proof_a_neg = [0u8; 65];
G1::serialize_with_mode(&proof_a.neg(), &mut proof_a_neg[..], Compress::No)
.map_err(|_| ChannelError::InvalidInstruction)?;

let proof_a = change_endianness(&proof_a_neg[..64])
let proof_a = toggle_endianness_256(&proof_a_neg[..64])
.try_into()
.map_err(|_| ChannelError::InvalidInstruction)?;

Expand Down Expand Up @@ -58,32 +59,23 @@ pub fn verify_risc0(proof: &[u8], inputs: &[u8]) -> Result<bool, ChannelError> {
.map_err(|_| ChannelError::ProofVerificationFailed)
}

const CONTROL_ROOT: [u8; 32] =
hex!("a516a057c9fbf5629106300934d48e0e775d4230e41e503347cad96fcbde7e2e");
const BN254_CONTROL_ID_BYTES: [u8; 32] =
hex!("0eb6febcf06c5df079111be116f79bd8c7e85dc9448776ef9a59aaf2624ab551");
const OUTPUT_HASH: [u8; 32] =
hex!("77eafeb366a78b47747de0d7bb176284085ff5564887009a5be63da32d3559d4");
const RECIEPT_CLAIM_HASH: [u8; 32] =
hex!("cb1fefcd1f2d9a64975cbbbf6e161e2914434b0cbb9960b84df5d717e86b48af");

pub fn output_digest(
pub fn output_digest_v1_0_1(
input_digest: &[u8],
committed_outputs: &[u8],
assumption_digest: &[u8],
) -> [u8; 32] {
let jbytes = [input_digest, committed_outputs].concat(); // bad copy here
let journal = hashv(&[jbytes.as_slice()]);
hashv(&[
OUTPUT_HASH.as_ref(),
PROVER_CONSTANTS_V1_0_1.output_hash.as_ref(),
journal.as_ref(),
assumption_digest,
&2u16.to_le_bytes(),
])
.to_bytes()
}

pub fn prepare_inputs(
pub fn prepare_inputs_v1_0_1(
image_id: &str,
execution_digest: &[u8],
output_digest: &[u8],
Expand All @@ -92,7 +84,7 @@ pub fn prepare_inputs(
) -> Result<Vec<u8>, ChannelError> {
let imgbytes = hex::decode(image_id).map_err(|_| ChannelError::InvalidFieldElement)?;
let mut digest = hashv(&[
RECIEPT_CLAIM_HASH.as_ref(),
PROVER_CONSTANTS_V1_0_1.receipt_claim_hash.as_ref(),
&[0u8; 32],
&imgbytes,
execution_digest,
Expand All @@ -102,32 +94,162 @@ pub fn prepare_inputs(
&4u16.to_le_bytes(),
])
.to_bytes();
let (c0, c1) =
split_digest(&mut CONTROL_ROOT.clone()).map_err(|_| ChannelError::InvalidFieldElement)?;
let (c0, c1) = split_digest_reversed(&mut PROVER_CONSTANTS_V1_0_1.control_root.clone())
.map_err(|_| ChannelError::InvalidFieldElement)?;
let (half1_bytes, half2_bytes) =
split_digest(&mut digest).map_err(|_| ChannelError::InvalidFieldElement)?;
split_digest_reversed(&mut digest).map_err(|_| ChannelError::InvalidFieldElement)?;
let inputs = [
c0,
c1,
half1_bytes.try_into().unwrap(),
half2_bytes.try_into().unwrap(),
BN254_CONTROL_ID_BYTES,
PROVER_CONSTANTS_V1_0_1.bn254_control_id_bytes,
]
.concat();
Ok(inputs)
}

pub fn split_digest(d: &mut [u8]) -> Result<([u8; 32], [u8; 32]), ChannelError> {
/**
* Reverse and split a digest into two halves
* The first half is the left half of the digest
* The second half is the right half of the digest
*
* @param d: The digest to split
* @return A tuple containing the left and right halves of the digest
*/
pub fn split_digest_reversed_256(d: &mut [u8]) -> Result<([u8; 32], [u8; 32]), ChannelError> {
split_digest_reversed::<32>(d)
}

fn split_digest_reversed<const N: usize>(d: &mut [u8]) -> Result<([u8; N], [u8; N]), ChannelError> {
if d.len() != N {
return Err(ChannelError::UnexpectedProofSystem);
}
d.reverse();
let (a, b) = d.split_at(16);
let af = to_fixed_array(a.to_vec());
let bf = to_fixed_array(b.to_vec());
let split_index = (N + 1) / 2;
let (a, b) = d.split_at(split_index);
let af = to_fixed_array(a);
let bf = to_fixed_array(b);
Ok((bf, af))
}

fn to_fixed_array(input: Vec<u8>) -> [u8; 32] {
let mut fixed_array = [0u8; 32];
let start = core::cmp::max(32, input.len()) - core::cmp::min(32, input.len());
fixed_array[start..].copy_from_slice(&input[input.len().saturating_sub(32)..]);
fn to_fixed_array<const N: usize>(input: &[u8]) -> [u8; N] {
let mut fixed_array = [0u8; N];
if input.len() >= N {
// Copy the last N bytes of input into fixed_array
fixed_array.copy_from_slice(&input[input.len() - N..]);
} else {
// Copy input into the end of fixed_array
let start = N - input.len();
fixed_array[start..].copy_from_slice(input);
}
fixed_array
}

fn sized_range<const N: usize>(slice: &[u8]) -> Result<[u8; N], ChannelError> {
slice
.try_into()
.map_err(|_| ChannelError::InvalidInstruction)
}

// hello ethereum! Toggle endianness of a slice of bytes assuming 256 bit word size
fn toggle_endianness_256(bytes: &[u8]) -> Vec<u8> {
toggle_endianness::<32>(bytes)
}

fn toggle_endianness<const N: usize>(bytes: &[u8]) -> Vec<u8> {
let mut vec = Vec::with_capacity(bytes.len());
let chunk_size = N;

for chunk in bytes.chunks(chunk_size) {
// Reverse the chunk and extend the vector
vec.extend(chunk.iter().rev());
}

vec
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_toggle_endianness() {
let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8];
let expected = [8u8, 7, 6, 5, 4, 3, 2, 1];
assert_eq!(toggle_endianness::<8>(&bytes), expected);
}

#[test]
fn test_toggle_endianness_odd() {
let bytes = [1u8, 2, 3, 4, 5, 6, 7];
let expected = [7u8, 6, 5, 4, 3, 2, 1];
assert_eq!(toggle_endianness::<7>(&bytes), expected);
}

#[test]
fn test_toggle_endianness_quad_word() {
let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let expected = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
assert_eq!(toggle_endianness_256(&bytes), expected);
}

#[test]
fn test_split_digest() {
let mut digest = [1u8; 32];
digest[0] = 103;
let (a, b) = split_digest_reversed(&mut digest).unwrap();
let expect_digest_right = to_fixed_array::<32>(&[1u8; 16]);
let mut expect_digest_left = expect_digest_right.clone();
expect_digest_left[31] = 103;
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_split_digest_odd() {
let mut digest = [1u8; 31];
digest[0] = 103;
let (a, b) = split_digest_reversed(&mut digest).unwrap();
let expect_digest_right = to_fixed_array::<31>(&[1u8; 16]);
let mut expect_digest_left = to_fixed_array::<31>(&[1u8; 15]);
expect_digest_left[30] = 103;
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_split_digest_16() {
let digest = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let (a, b) = split_digest_reversed::<16>(&mut digest.to_vec()).unwrap();
let expect_digest_left = to_fixed_array::<16>(&[7, 6, 5, 4, 3, 2, 1, 0]);
let expect_digest_right = to_fixed_array::<16>(&[15, 14, 13, 12, 11, 10, 9, 8]);
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_split_digest_8() {
let digest = [0, 1, 2, 3, 4, 5, 6, 7];
let (a, b) = split_digest_reversed::<8>(&mut digest.to_vec()).unwrap();
let expect_digest_left = to_fixed_array::<8>(&[3, 2, 1, 0]);
let expect_digest_right = to_fixed_array::<8>(&[7, 6, 5, 4]);
assert_eq!(a, expect_digest_left);
assert_eq!(b, expect_digest_right);
}

#[test]
fn test_invalid_digest_wrong_size() {
let mut d1 = [1u8; 31];
assert!(split_digest_reversed_256(&mut d1).is_err());
let mut d2 = [1u8; 33];
assert!(split_digest_reversed_256(&mut d2).is_err());
}

#[test]
fn test_sized_range() {
let slice = [1u8; 32];
let expected = [1u8; 32];
assert_eq!(sized_range::<32>(&slice).unwrap(), expected);
}
}
Loading
Loading