Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: impl reqactor #448

Merged
merged 9 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl TryFrom<Vec<TransactionSigned>> for TaikoGuestInput {
}
}

#[derive(Clone, Debug, Serialize, Deserialize, Default)]
#[derive(Clone, Debug, Serialize, Deserialize, Default, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[serde(rename_all = "snake_case")]
pub enum BlobProofType {
/// Guest runs through the entire computation from blob to Kzg commitment
Expand Down
6 changes: 4 additions & 2 deletions lib/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ pub type ProverResult<T, E = ProverError> = core::result::Result<T, E>;
pub type ProverConfig = serde_json::Value;
pub type ProofKey = (ChainId, u64, B256, u8);

#[derive(Clone, Debug, Serialize, ToSchema, Deserialize, Default, PartialEq, Eq, Hash)]
#[derive(
Clone, Debug, Serialize, ToSchema, Deserialize, Default, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
/// The response body of a proof request.
pub struct Proof {
/// The proof either TEE or ZK.
Expand All @@ -50,7 +52,7 @@ pub trait IdWrite: Send {

#[async_trait::async_trait]
pub trait IdStore: IdWrite {
async fn read_id(&self, key: ProofKey) -> ProverResult<String>;
async fn read_id(&mut self, key: ProofKey) -> ProverResult<String>;
}

#[allow(async_fn_in_trait)]
Expand Down
26 changes: 26 additions & 0 deletions reqactor/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "raiko-reqactor"
version = "0.1.0"
edition = "2021"

[dependencies]
raiko-lib = { workspace = true }
raiko-core = { workspace = true }
raiko-reqpool = { workspace = true }

serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
chrono = { workspace = true, features = ["serde"] }

reth-primitives = { workspace = true }
alloy-primitives = { workspace = true }

[dev-dependencies]

[features]
default = []
sp1 = ["raiko-core/sp1"]
risc0 = ["raiko-core/risc0"]
sgx = ["raiko-core/sgx"]
26 changes: 26 additions & 0 deletions reqactor/src/action.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use crate::{RequestEntity, RequestKey};
use raiko_reqpool::impl_display_using_json_pretty;
use serde::{Deserialize, Serialize};

/// The action message sent from **external** to the actor.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Action {
Prove {
request_key: RequestKey,
request_entity: RequestEntity,
},
Cancel {
request_key: RequestKey,
},
}

impl Action {
pub fn request_key(&self) -> &RequestKey {
match self {
Action::Prove { request_key, .. } => request_key,
Action::Cancel { request_key, .. } => request_key,
}
}
}

impl_display_using_json_pretty!(Action);
212 changes: 212 additions & 0 deletions reqactor/src/actor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};

use raiko_core::interfaces::ProofRequestOpt;
use raiko_lib::consts::SupportedChainSpecs;
use raiko_reqpool::{Pool, RequestKey, StatusWithContext};
use tokio::sync::{mpsc::Sender, oneshot};

use crate::Action;

/// Actor is the main interface interacting with the backend and the pool.
#[derive(Debug, Clone)]
pub struct Actor {
default_request_config: ProofRequestOpt,
chain_specs: SupportedChainSpecs,
action_tx: Sender<(Action, oneshot::Sender<Result<StatusWithContext, String>>)>,
pause_tx: Sender<()>,
is_paused: Arc<AtomicBool>,

// TODO: Remove Mutex. currently, in order to pass `&mut Pool`, we need to use Arc<Mutex<Pool>>.
pool: Arc<Mutex<Pool>>,
}

impl Actor {
pub fn new(
pool: Pool,
default_request_config: ProofRequestOpt,
chain_specs: SupportedChainSpecs,
action_tx: Sender<(Action, oneshot::Sender<Result<StatusWithContext, String>>)>,
pause_tx: Sender<()>,
) -> Self {
Self {
default_request_config,
chain_specs,
action_tx,
pause_tx,
is_paused: Arc::new(AtomicBool::new(false)),
pool: Arc::new(Mutex::new(pool)),
}
}

/// Return the default request config.
pub fn default_request_config(&self) -> &ProofRequestOpt {
&self.default_request_config
}

/// Return the chain specs.
pub fn chain_specs(&self) -> &SupportedChainSpecs {
&self.chain_specs
}

/// Check if the system is paused.
pub fn is_paused(&self) -> bool {
self.is_paused.load(Ordering::SeqCst)
}

/// Get the status of the request from the pool.
pub fn pool_get_status(
&self,
request_key: &RequestKey,
) -> Result<Option<StatusWithContext>, String> {
self.pool.lock().unwrap().get_status(request_key)
}

pub fn pool_list_status(&self) -> Result<HashMap<RequestKey, StatusWithContext>, String> {
self.pool.lock().unwrap().list()
}

/// Send an action to the backend and wait for the response.
pub async fn act(&self, action: Action) -> Result<StatusWithContext, String> {
let (resp_tx, resp_rx) = oneshot::channel();

// Send the action to the backend
self.action_tx
.send((action, resp_tx))
.await
.map_err(|e| format!("failed to send action: {e}"))?;

// Wait for response of the action
resp_rx
.await
.map_err(|e| format!("failed to receive action response: {e}"))?
}

/// Set the pause flag and notify the task manager to pause, then wait for the task manager to
/// finish the pause process.
///
/// Note that this function is blocking until the task manager finishes the pause process.
pub async fn pause(&self) -> Result<(), String> {
self.is_paused.store(true, Ordering::SeqCst);
self.pause_tx
.send(())
.await
.map_err(|e| format!("failed to send pause signal: {e}"))?;
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::Address;
use raiko_lib::{
consts::SupportedChainSpecs,
input::BlobProofType,
primitives::{ChainId, B256},
proof_type::ProofType,
};
use raiko_reqpool::{
Pool, RedisPoolConfig, RequestEntity, RequestKey, SingleProofRequestEntity,
SingleProofRequestKey, StatusWithContext,
};
use std::collections::HashMap;
use tokio::sync::mpsc;

#[tokio::test]
async fn test_pause_sets_is_paused_flag() {
let (action_tx, _) = mpsc::channel(1);
let (pause_tx, _pause_rx) = mpsc::channel(1);

let config = RedisPoolConfig {
redis_url: "redis://localhost:6379/0".to_string(),
redis_ttl: 3600,
};

let actor = Actor::new(
Pool::open(config).expect("Failed to create pool"),
ProofRequestOpt::default(),
SupportedChainSpecs::default(),
action_tx,
pause_tx,
);

assert!(!actor.is_paused(), "Actor should not be paused initially");

actor.pause().await.expect("Pause should succeed");
assert!(
actor.is_paused(),
"Actor should be paused after calling pause()"
);
}

#[tokio::test]
async fn test_act_sends_action_and_returns_response() {
let (action_tx, mut action_rx) = mpsc::channel(1);
let (pause_tx, _) = mpsc::channel(1);

let config = RedisPoolConfig {
redis_url: "redis://localhost:6379/0".to_string(),
redis_ttl: 3600,
};

let actor = Actor::new(
Pool::open(config).expect("Failed to create pool"),
ProofRequestOpt::default(),
SupportedChainSpecs::default(),
action_tx,
pause_tx,
);

// Create a test action
let request_key = RequestKey::SingleProof(SingleProofRequestKey::new(
ChainId::default(),
1,
B256::default(),
ProofType::default(),
"test_prover".to_string(),
));
let request_entity = RequestEntity::SingleProof(SingleProofRequestEntity::new(
1,
1,
"test_network".to_string(),
"test_l1_network".to_string(),
B256::default(),
Address::default(),
ProofType::default(),
BlobProofType::default(),
HashMap::new(),
));
let test_action = Action::Prove {
request_key: request_key.clone(),
request_entity,
};

// Spawn a task to handle the action and send back a response
let status = StatusWithContext::new_registered();
let status_clone = status.clone();
let handle = tokio::spawn(async move {
let (action, resp_tx) = action_rx.recv().await.expect("Should receive action");
// Verify we received the expected action
assert_eq!(action.request_key(), &request_key);
// Send back a mock response with Registered status
resp_tx
.send(Ok(status_clone))
.expect("Should send response");
});

// Send the action and wait for response
let result = actor.act(test_action).await;

// Make sure we got back an Ok response
assert_eq!(result, Ok(status), "Should receive successful response");

// Wait for the handler to complete
handle.await.expect("Handler should complete");
}
}
Loading
Loading