Skip to content

Commit

Permalink
Add extract_output methods to ZkvmHost (#1256)
Browse files Browse the repository at this point in the history
* Add extract_public_input

* Remove StateTransitionOutput

* Add extract_public_input

* Add ValidityCond

* ValidityCond

* Renaming

* fix

* remove old tests

* Fiz lint

* Cleanup

* Add comment
  • Loading branch information
bkolad authored Dec 22, 2023
1 parent 3c9821d commit 4497a7f
Show file tree
Hide file tree
Showing 20 changed files with 201 additions and 96 deletions.
71 changes: 63 additions & 8 deletions adapters/mock-zkvm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#![deny(missing_docs)]
#![doc = include_str!("../README.md")]

use std::collections::VecDeque;
use std::io::Write;
use std::sync::{Arc, Condvar, Mutex};

use anyhow::ensure;
use borsh::{BorshDeserialize, BorshSerialize};
use serde::{Deserialize, Serialize};
use sov_rollup_interface::zk::Matches;
use sov_rollup_interface::da::BlockHeaderTrait;
use sov_rollup_interface::zk::{Matches, StateTransitionData, ValidityCondition};

/// A mock commitment to a particular zkVM program.
#[derive(Debug, Clone, PartialEq, Eq, BorshDeserialize, BorshSerialize, Serialize, Deserialize)]
Expand Down Expand Up @@ -91,20 +93,31 @@ impl Notifier {
}

/// A mock implementing the zkVM trait.
#[derive(Clone, Default)]
pub struct MockZkvm {
#[derive(Clone)]
pub struct MockZkvm<ValidityCond> {
worker_thread_notifier: Notifier,
committed_data: VecDeque<Vec<u8>>,
validity_condition: ValidityCond,
}

impl MockZkvm {
impl<ValidityCond> MockZkvm<ValidityCond> {
/// Creates a new MockZkvm
pub fn new(validity_condition: ValidityCond) -> Self {
Self {
worker_thread_notifier: Default::default(),
committed_data: Default::default(),
validity_condition,
}
}

/// Simulates zk proof generation.
pub fn make_proof(&self) {
// We notify the worket thread.
self.worker_thread_notifier.notify();
}
}

impl sov_rollup_interface::zk::Zkvm for MockZkvm {
impl<ValidityCond: ValidityCondition> sov_rollup_interface::zk::Zkvm for MockZkvm<ValidityCond> {
type CodeCommitment = MockCodeCommitment;

type Error = anyhow::Error;
Expand Down Expand Up @@ -134,18 +147,54 @@ impl sov_rollup_interface::zk::Zkvm for MockZkvm {
}
}

impl sov_rollup_interface::zk::ZkvmHost for MockZkvm {
impl<ValidityCond: ValidityCondition> sov_rollup_interface::zk::ZkvmHost
for MockZkvm<ValidityCond>
{
type Guest = MockZkGuest;

fn add_hint<T: Serialize>(&mut self, _item: T) {}
fn add_hint<T: Serialize>(&mut self, item: T) {
let hint = bincode::serialize(&item).unwrap();
let proof_info = ProofInfo {
hint,
validity_condition: self.validity_condition,
};

let data = bincode::serialize(&proof_info).unwrap();
self.committed_data.push_back(data)
}

fn simulate_with_hints(&mut self) -> Self::Guest {
MockZkGuest {}
}

fn run(&mut self, _with_proof: bool) -> Result<sov_rollup_interface::zk::Proof, anyhow::Error> {
self.worker_thread_notifier.wait();
Ok(sov_rollup_interface::zk::Proof::Empty)
let data = self.committed_data.pop_front().unwrap_or_default();
Ok(sov_rollup_interface::zk::Proof::PublicInput(data))
}

fn extract_output<
Da: sov_rollup_interface::da::DaSpec,
Root: Serialize + serde::de::DeserializeOwned,
>(
proof: &sov_rollup_interface::zk::Proof,
) -> Result<sov_rollup_interface::zk::StateTransition<Da, Root>, Self::Error> {
match proof {
sov_rollup_interface::zk::Proof::PublicInput(pub_input) => {
let data: ProofInfo<Da::ValidityCondition> = bincode::deserialize(pub_input)?;
let st: StateTransitionData<Root, (), Da> = bincode::deserialize(&data.hint)?;

Ok(sov_rollup_interface::zk::StateTransition {
initial_state_root: st.initial_state_root,
final_state_root: st.final_state_root,
slot_hash: st.da_block_header.hash(),
validity_condition: data.validity_condition,
})
}
sov_rollup_interface::zk::Proof::Full(_) => {
panic!("Mock DA doesn't generate real proofs")
}
}
}
}

Expand Down Expand Up @@ -185,6 +234,12 @@ impl sov_rollup_interface::zk::ZkvmGuest for MockZkGuest {
}
}

#[derive(Debug, Serialize, Deserialize)]
struct ProofInfo<ValidityCond> {
hint: Vec<u8>,
validity_condition: ValidityCond,
}

#[test]
fn test_mock_proof_round_trip() {
let proof = MockProof {
Expand Down
24 changes: 20 additions & 4 deletions adapters/risc0/src/host.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module implements the [`ZkvmHost`] trait for the RISC0 VM.
use risc0_zkvm::{ExecutorEnvBuilder, ExecutorImpl, InnerReceipt, Receipt, Session};
use risc0_zkvm::{ExecutorEnvBuilder, ExecutorImpl, InnerReceipt, Journal, Receipt, Session};
use serde::de::DeserializeOwned;
use serde::Serialize;
use sov_rollup_interface::zk::{Proof, Zkvm, ZkvmHost};
Expand Down Expand Up @@ -89,10 +89,26 @@ impl<'a> ZkvmHost for Risc0Host<'a> {
if with_proof {
let receipt = self.run()?;
let data = bincode::serialize(&receipt)?;
Ok(Proof::Data(data))
Ok(Proof::Full(data))
} else {
self.run_without_proving()?;
Ok(Proof::Empty)
let session = self.run_without_proving()?;
let data = bincode::serialize(&session.journal)?;
Ok(Proof::PublicInput(data))
}
}

fn extract_output<Da: sov_rollup_interface::da::DaSpec, Root: Serialize + DeserializeOwned>(
proof: &Proof,
) -> Result<sov_rollup_interface::zk::StateTransition<Da, Root>, Self::Error> {
match proof {
Proof::PublicInput(journal) => {
let journal: Journal = bincode::deserialize(journal)?;
Ok(journal.decode()?)
}
Proof::Full(data) => {
let receipt: Receipt = bincode::deserialize(data)?;
Ok(receipt.journal.decode()?)
}
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions examples/demo-rollup/stf/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ pub(crate) type C = DefaultContext;
pub(crate) type Da = MockDaSpec;

pub(crate) type RuntimeTest = Runtime<DefaultContext, Da>;
pub(crate) type StfBlueprintTest =
StfBlueprint<DefaultContext, Da, sov_mock_zkvm::MockZkvm, RuntimeTest, BasicKernel<C>>;
pub(crate) type StfBlueprintTest = StfBlueprint<
DefaultContext,
Da,
sov_mock_zkvm::MockZkvm<<Da as DaSpec>::ValidityCondition>,
RuntimeTest,
BasicKernel<C>,
>;

pub(crate) fn create_storage_manager_for_tests(
path: impl AsRef<Path>,
Expand Down
4 changes: 2 additions & 2 deletions examples/demo-simple-stf/tests/stf_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn test_stf_success() {
let address = MockAddress::from([1; 32]);

let stf = &mut CheckHashPreimageStf::<MockValidityCond>::default();
StateTransitionFunction::<MockZkvm, MockDaSpec>::init_chain(stf, (), ());
StateTransitionFunction::<MockZkvm<MockValidityCond>, MockDaSpec>::init_chain(stf, (), ());

let mut blobs = {
let incorrect_preimage = vec![1; 32];
Expand All @@ -26,7 +26,7 @@ fn test_stf_success() {
blob.data.advance(blob.data.total_len());
}

let result = StateTransitionFunction::<MockZkvm, MockDaSpec>::apply_slot(
let result = StateTransitionFunction::<MockZkvm<MockValidityCond>, MockDaSpec>::apply_slot(
stf,
&[],
(),
Expand Down
28 changes: 3 additions & 25 deletions full-node/sov-stf-runner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use std::path::Path;

#[cfg(feature = "native")]
use anyhow::Context;
use borsh::{BorshDeserialize, BorshSerialize};
#[cfg(feature = "native")]
pub use config::RpcConfig;
#[cfg(feature = "native")]
Expand All @@ -26,36 +25,15 @@ mod runner;
pub use config::{from_toml_path, ProverServiceConfig, RollupConfig, RunnerConfig, StorageConfig};
#[cfg(feature = "native")]
pub use runner::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use sov_rollup_interface::da::DaSpec;

/// Implements the `StateTransitionVerifier` type for checking the validity of a state transition
pub mod verifier;

#[derive(Serialize, BorshDeserialize, BorshSerialize, Deserialize)]
// Prevent serde from generating spurious trait bounds. The correct serde bounds are already enforced by the
// StateTransitionFunction, DA, and Zkvm traits.
#[serde(bound = "StateRoot: Serialize + DeserializeOwned, Witness: Serialize + DeserializeOwned")]
/// Data required to verify a state transition.
pub struct StateTransitionData<StateRoot, Witness, Da: DaSpec> {
/// The state root before the state transition
pub initial_state_root: StateRoot,
/// The header of the da block that is being processed
pub da_block_header: Da::BlockHeader,
/// The proof of inclusion for all blobs
pub inclusion_proof: Da::InclusionMultiProof,
/// The proof that the provided set of blobs is complete
pub completeness_proof: Da::CompletenessProof,
/// The blobs that are being processed
pub blobs: Vec<<Da as DaSpec>::BlobTransaction>,
/// The witness for the state transition
pub state_transition_witness: Witness,
}

#[cfg(feature = "native")]
/// Reads json file.
pub fn read_json_file<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> anyhow::Result<T> {
pub fn read_json_file<T: serde::de::DeserializeOwned, P: AsRef<Path>>(
path: P,
) -> anyhow::Result<T> {
let path_str = path.as_ref().display();

let data = std::fs::read_to_string(&path)
Expand Down
3 changes: 1 addition & 2 deletions full-node/sov-stf-runner/src/prover_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ pub use parallel::ParallelProverService;
use serde::Serialize;
use sov_rollup_interface::da::DaSpec;
use sov_rollup_interface::services::da::DaService;
use sov_rollup_interface::zk::StateTransitionData;
use thiserror::Error;

use crate::StateTransitionData;

/// The possible configurations of the prover.
pub enum RollupProverConfig {
/// Skip proving.
Expand Down
4 changes: 2 additions & 2 deletions full-node/sov-stf-runner/src/prover_service/parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ use serde::Serialize;
use sov_rollup_interface::da::DaSpec;
use sov_rollup_interface::services::da::DaService;
use sov_rollup_interface::stf::StateTransitionFunction;
use sov_rollup_interface::zk::ZkvmHost;
use sov_rollup_interface::zk::{StateTransitionData, ZkvmHost};

use super::{ProverService, ProverServiceError};
use crate::config::ProverServiceConfig;
use crate::verifier::StateTransitionVerifier;
use crate::{
ProofGenConfig, ProofProcessingStatus, ProofSubmissionStatus, RollupProverConfig,
StateTransitionData, WitnessSubmissionStatus,
WitnessSubmissionStatus,
};

/// Prover service that generates proofs in parallel.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ use serde::Serialize;
use sov_rollup_interface::da::{BlockHeaderTrait, DaSpec};
use sov_rollup_interface::services::da::DaService;
use sov_rollup_interface::stf::StateTransitionFunction;
use sov_rollup_interface::zk::{Proof, ZkvmHost};
use sov_rollup_interface::zk::{Proof, StateTransitionData, ZkvmHost};

use super::ProverServiceError;
use crate::{
ProofGenConfig, ProofProcessingStatus, ProofSubmissionStatus, StateTransitionData,
WitnessSubmissionStatus,
ProofGenConfig, ProofProcessingStatus, ProofSubmissionStatus, WitnessSubmissionStatus,
};

enum ProverStatus<StateRoot, Witness, Da: DaSpec> {
Expand Down Expand Up @@ -223,10 +222,10 @@ where
V::PreState: Send + Sync + 'static,
{
match config.deref() {
ProofGenConfig::Skip => Ok(Proof::Empty),
ProofGenConfig::Skip => Ok(Proof::PublicInput(Vec::default())),
ProofGenConfig::Simulate(verifier) => verifier
.run_block(vm.simulate_with_hints(), zk_storage)
.map(|_| Proof::Empty)
.map(|_| Proof::PublicInput(Vec::default()))
.map_err(|e| anyhow::anyhow!("Guest execution must succeed but failed with {:?}", e)),
ProofGenConfig::Execute => vm.run(false),
ProofGenConfig::Prover => vm.run(true),
Expand Down
7 changes: 4 additions & 3 deletions full-node/sov-stf-runner/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ use sov_rollup_interface::da::{BlobReaderTrait, BlockHeaderTrait, DaSpec};
use sov_rollup_interface::services::da::{DaService, SlotData};
use sov_rollup_interface::stf::StateTransitionFunction;
use sov_rollup_interface::storage::HierarchicalStorageManager;
use sov_rollup_interface::zk::{Zkvm, ZkvmHost};
use sov_rollup_interface::zk::{StateTransitionData, Zkvm, ZkvmHost};
use tokio::sync::oneshot;
use tracing::{debug, info};

use crate::verifier::StateTransitionVerifier;
use crate::{ProofSubmissionStatus, ProverService, RunnerConfig, StateTransitionData};
use crate::{ProofSubmissionStatus, ProverService, RunnerConfig};

type StateRoot<ST, Vm, Da> = <ST as StateTransitionFunction<Vm, Da>>::StateRoot;
type GenesisParams<ST, Vm, Da> = <ST as StateTransitionFunction<Vm, Da>>::GenesisParams;

Expand Down Expand Up @@ -237,7 +238,7 @@ where
StateTransitionData {
// TODO(https://github.com/Sovereign-Labs/sovereign-sdk/issues/1247): incorrect pre-state root in case of re-org
initial_state_root: self.state_root.clone(),

final_state_root: slot_result.state_root.clone(),
da_block_header: filtered_block.header().clone(),
inclusion_proof,
completeness_proof,
Expand Down
5 changes: 1 addition & 4 deletions full-node/sov-stf-runner/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use std::marker::PhantomData;

use sov_rollup_interface::da::{BlockHeaderTrait, DaVerifier};
use sov_rollup_interface::stf::StateTransitionFunction;
use sov_rollup_interface::zk::{StateTransition, Zkvm, ZkvmGuest};

use crate::StateTransitionData;

use sov_rollup_interface::zk::{StateTransition, StateTransitionData, Zkvm, ZkvmGuest};
/// Verifies a state transition
pub struct StateTransitionVerifier<ST, Da, Zk>
where
Expand Down
30 changes: 16 additions & 14 deletions full-node/sov-stf-runner/tests/hash_stf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,10 @@ pub fn get_result_from_blocks(
let stf = HashStf::<MockValidityCond>::new();

let (genesis_state_root, mut storage) =
<HashStf<MockValidityCond> as StateTransitionFunction<MockZkvm, MockDaSpec>>::init_chain(
&stf,
storage,
genesis_params.to_vec(),
);
<HashStf<MockValidityCond> as StateTransitionFunction<
MockZkvm<MockValidityCond>,
MockDaSpec,
>>::init_chain(&stf, storage, genesis_params.to_vec());

let mut state_root = genesis_state_root;

Expand All @@ -209,15 +208,18 @@ pub fn get_result_from_blocks(
for block in blocks {
let mut blobs = block.blobs.clone();

let result =
<HashStf<MockValidityCond> as StateTransitionFunction<MockZkvm, MockDaSpec>>::apply_slot::<&mut Vec<MockBlob>>(
&stf,
&state_root,
storage,
ArrayWitness::default(),
&block.header,
&block.validity_cond,
&mut blobs);
let result = <HashStf<MockValidityCond> as StateTransitionFunction<
MockZkvm<MockValidityCond>,
MockDaSpec,
>>::apply_slot::<&mut Vec<MockBlob>>(
&stf,
&state_root,
storage,
ArrayWitness::default(),
&block.header,
&block.validity_cond,
&mut blobs,
);

state_root = result.state_root;
storage = result.change_set;
Expand Down
Loading

0 comments on commit 4497a7f

Please sign in to comment.