Skip to content

Commit

Permalink
feat(katana-rpc): rpc server builder (#2788)
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored Dec 10, 2024
1 parent 2738f67 commit 71db0b4
Show file tree
Hide file tree
Showing 12 changed files with 420 additions and 73 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/dojo/test-utils/src/sequencer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub fn get_default_test_config(sequencing: SequencingConfig) -> Config {
chain.genesis.sequencer_address = *DEFAULT_SEQUENCER_ADDRESS;

let rpc = RpcConfig {
cors_origins: None,
cors_origins: Vec::new(),
port: 0,
addr: DEFAULT_RPC_ADDR,
max_connections: DEFAULT_RPC_MAX_CONNECTIONS,
Expand Down
1 change: 1 addition & 0 deletions crates/katana/cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dojo-utils.workspace = true
katana-core.workspace = true
katana-node.workspace = true
katana-primitives.workspace = true
katana-rpc.workspace = true
katana-slot-controller = { workspace = true, optional = true }

alloy-primitives.workspace = true
Expand Down
20 changes: 20 additions & 0 deletions crates/katana/cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ mod test {
};
use katana_primitives::chain::ChainId;
use katana_primitives::{address, felt, ContractAddress, Felt};
use katana_rpc::cors::HeaderValue;

use super::*;

Expand Down Expand Up @@ -615,4 +616,23 @@ chain_id.Named = "Mainnet"
assert_eq!(config.chain.genesis.gas_prices.strk, 8888);
assert_eq!(config.chain.id, ChainId::Id(Felt::from_str("0x123").unwrap()));
}

#[test]
#[cfg(feature = "server")]
fn parse_cors_origins() {
let config = NodeArgs::parse_from([
"katana",
"--http.cors_origins",
"*,http://localhost:3000,https://example.com",
])
.config()
.unwrap();

let cors_origins = config.rpc.cors_origins;

assert_eq!(cors_origins.len(), 3);
assert!(cors_origins.contains(&HeaderValue::from_static("*")));
assert!(cors_origins.contains(&HeaderValue::from_static("http://localhost:3000")));
assert!(cors_origins.contains(&HeaderValue::from_static("https://example.com")));
}
}
16 changes: 12 additions & 4 deletions crates/katana/cli/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ use katana_node::config::rpc::{
use katana_primitives::block::BlockHashOrNumber;
use katana_primitives::chain::ChainId;
use katana_primitives::genesis::Genesis;
use katana_rpc::cors::HeaderValue;
use serde::{Deserialize, Serialize};
use url::Url;

use crate::utils::{parse_block_hash_or_number, parse_genesis, LogFormat};
use crate::utils::{
deserialize_cors_origins, parse_block_hash_or_number, parse_genesis, serialize_cors_origins,
LogFormat,
};

const DEFAULT_DEV_SEED: &str = "0";
const DEFAULT_DEV_ACCOUNTS: u16 = 10;
Expand Down Expand Up @@ -85,8 +89,12 @@ pub struct ServerOptions {
/// Comma separated list of domains from which to accept cross origin requests.
#[arg(long = "http.cors_origins")]
#[arg(value_delimiter = ',')]
#[serde(default)]
pub http_cors_origins: Option<Vec<String>>,
#[serde(
default,
serialize_with = "serialize_cors_origins",
deserialize_with = "deserialize_cors_origins"
)]
pub http_cors_origins: Vec<HeaderValue>,

/// Maximum number of concurrent connections allowed.
#[arg(long = "rpc.max-connections", value_name = "COUNT")]
Expand All @@ -108,7 +116,7 @@ impl Default for ServerOptions {
http_addr: DEFAULT_RPC_ADDR,
http_port: DEFAULT_RPC_PORT,
max_connections: DEFAULT_RPC_MAX_CONNECTIONS,
http_cors_origins: None,
http_cors_origins: Vec::new(),
max_event_page_size: DEFAULT_RPC_MAX_EVENT_PAGE_SIZE,
}
}
Expand Down
30 changes: 29 additions & 1 deletion crates/katana/cli/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use katana_primitives::genesis::constant::{
};
use katana_primitives::genesis::json::GenesisJson;
use katana_primitives::genesis::Genesis;
use serde::{Deserialize, Serialize};
use katana_rpc::cors::HeaderValue;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tracing::info;

use crate::args::LOG_TARGET;
Expand Down Expand Up @@ -191,6 +192,33 @@ PREFUNDED ACCOUNTS
}
}

pub fn serialize_cors_origins<S>(values: &[HeaderValue], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let string = values
.iter()
.map(|v| v.to_str())
.collect::<Result<Vec<_>, _>>()
.map_err(serde::ser::Error::custom)?
.join(",");

serializer.serialize_str(&string)
}

pub fn deserialize_cors_origins<'de, D>(deserializer: D) -> Result<Vec<HeaderValue>, D::Error>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)?
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(HeaderValue::from_str)
.collect::<Result<Vec<HeaderValue>, _>>()
.map_err(serde::de::Error::custom)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
6 changes: 4 additions & 2 deletions crates/katana/node/src/config/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

use katana_rpc::cors::HeaderValue;

/// The default maximum number of concurrent RPC connections.
pub const DEFAULT_RPC_MAX_CONNECTIONS: u32 = 100;
pub const DEFAULT_RPC_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST);
Expand Down Expand Up @@ -28,7 +30,7 @@ pub struct RpcConfig {
pub max_connections: u32,
pub apis: HashSet<ApiKind>,
pub max_event_page_size: Option<u64>,
pub cors_origins: Option<Vec<String>>,
pub cors_origins: Vec<HeaderValue>,
}

impl RpcConfig {
Expand All @@ -41,7 +43,7 @@ impl RpcConfig {
impl Default for RpcConfig {
fn default() -> Self {
Self {
cors_origins: None,
cors_origins: Vec::new(),
addr: DEFAULT_RPC_ADDR,
port: DEFAULT_RPC_PORT,
max_connections: DEFAULT_RPC_MAX_CONNECTIONS,
Expand Down
89 changes: 25 additions & 64 deletions crates/katana/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@ pub mod exit;
pub mod version;

use std::future::IntoFuture;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use config::rpc::{ApiKind, RpcConfig};
use config::Config;
use dojo_metrics::exporters::prometheus::PrometheusRecorder;
use dojo_metrics::{Report, Server as MetricsServer};
use hyper::{Method, Uri};
use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer;
use jsonrpsee::server::{AllowHosts, ServerBuilder, ServerHandle};
use hyper::Method;
use jsonrpsee::RpcModule;
use katana_core::backend::gas_oracle::L1GasOracle;
use katana_core::backend::storage::Blockchain;
Expand All @@ -37,19 +33,19 @@ use katana_pool::ordering::FiFo;
use katana_pool::TxPool;
use katana_primitives::block::GasPrices;
use katana_primitives::env::{CfgEnv, FeeTokenAddressses};
use katana_rpc::cors::Cors;
use katana_rpc::dev::DevApi;
use katana_rpc::metrics::RpcServerMetrics;
use katana_rpc::saya::SayaApi;
use katana_rpc::starknet::forking::ForkedClient;
use katana_rpc::starknet::{StarknetApi, StarknetApiConfig};
use katana_rpc::torii::ToriiApi;
use katana_rpc::{RpcServer, RpcServerHandle};
use katana_rpc_api::dev::DevApiServer;
use katana_rpc_api::saya::SayaApiServer;
use katana_rpc_api::starknet::{StarknetApiServer, StarknetTraceApiServer, StarknetWriteApiServer};
use katana_rpc_api::torii::ToriiApiServer;
use katana_stage::Sequencing;
use katana_tasks::TaskManager;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::info;

use crate::exit::NodeStoppedFuture;
Expand All @@ -59,7 +55,7 @@ use crate::exit::NodeStoppedFuture;
pub struct LaunchedNode {
pub node: Node,
/// Handle to the rpc server.
pub rpc: RpcServer,
pub rpc: RpcServerHandle,
}

impl LaunchedNode {
Expand Down Expand Up @@ -261,16 +257,21 @@ pub async fn build(mut config: Config) -> Result<Node> {
pub async fn spawn<EF: ExecutorFactory>(
node_components: (TxPool, Arc<Backend<EF>>, BlockProducer<EF>, Option<ForkedClient>),
config: RpcConfig,
) -> Result<RpcServer> {
) -> Result<RpcServerHandle> {
let (pool, backend, block_producer, forked_client) = node_components;

let mut methods = RpcModule::new(());
methods.register_method("health", |_, _| Ok(serde_json::json!({ "health": true })))?;
let mut modules = RpcModule::new(());

let cors = Cors::new()
.allow_origins(config.cors_origins.clone())
// Allow `POST` when accessing the resource
.allow_methods([Method::POST, Method::GET])
.allow_headers([hyper::header::CONTENT_TYPE, "argent-client".parse().unwrap(), "argent-version".parse().unwrap()]);

if config.apis.contains(&ApiKind::Starknet) {
let cfg = StarknetApiConfig { max_event_page_size: config.max_event_page_size };

let server = if let Some(client) = forked_client {
let api = if let Some(client) = forked_client {
StarknetApi::new_forked(
backend.clone(),
pool.clone(),
Expand All @@ -282,68 +283,28 @@ pub async fn spawn<EF: ExecutorFactory>(
StarknetApi::new(backend.clone(), pool.clone(), Some(block_producer.clone()), cfg)
};

methods.merge(StarknetApiServer::into_rpc(server.clone()))?;
methods.merge(StarknetWriteApiServer::into_rpc(server.clone()))?;
methods.merge(StarknetTraceApiServer::into_rpc(server))?;
modules.merge(StarknetApiServer::into_rpc(api.clone()))?;
modules.merge(StarknetWriteApiServer::into_rpc(api.clone()))?;
modules.merge(StarknetTraceApiServer::into_rpc(api))?;
}

if config.apis.contains(&ApiKind::Dev) {
methods.merge(DevApi::new(backend.clone(), block_producer.clone()).into_rpc())?;
let api = DevApi::new(backend.clone(), block_producer.clone());
modules.merge(api.into_rpc())?;
}

if config.apis.contains(&ApiKind::Torii) {
methods.merge(
ToriiApi::new(backend.clone(), pool.clone(), block_producer.clone()).into_rpc(),
)?;
let api = ToriiApi::new(backend.clone(), pool.clone(), block_producer.clone());
modules.merge(api.into_rpc())?;
}

if config.apis.contains(&ApiKind::Saya) {
methods.merge(SayaApi::new(backend.clone(), block_producer.clone()).into_rpc())?;
let api = SayaApi::new(backend.clone(), block_producer.clone());
modules.merge(api.into_rpc())?;
}

let cors = CorsLayer::new()
// Allow `POST` when accessing the resource
.allow_methods([Method::POST, Method::GET])
.allow_headers([hyper::header::CONTENT_TYPE, "argent-client".parse().unwrap(), "argent-version".parse().unwrap()]);

let cors =
config.cors_origins.clone().map(|allowed_origins| match allowed_origins.as_slice() {
[origin] if origin == "*" => cors.allow_origin(AllowOrigin::mirror_request()),
origins => cors.allow_origin(
origins
.iter()
.map(|o| {
let _ = o.parse::<Uri>().expect("Invalid URI");

o.parse().expect("Invalid origin")
})
.collect::<Vec<_>>(),
),
});

let middleware = tower::ServiceBuilder::new()
.option_layer(cors)
.layer(ProxyGetRequestLayer::new("/", "health")?)
.timeout(Duration::from_secs(20));

let server = ServerBuilder::new()
.set_logger(RpcServerMetrics::new(&methods))
.set_host_filtering(AllowHosts::Any)
.set_middleware(middleware)
.max_connections(config.max_connections)
.build(config.socket_addr())
.await?;

let addr = server.local_addr()?;
let handle = server.start(methods)?;

info!(target: "rpc", %addr, "RPC server started.");

Ok(RpcServer { handle, addr })
}
let server = RpcServer::new().metrics().health_check().cors(cors).module(modules);
let handle = server.start(config.socket_addr()).await?;

#[derive(Debug)]
pub struct RpcServer {
pub addr: SocketAddr,
pub handle: ServerHandle,
Ok(handle)
}
4 changes: 4 additions & 0 deletions crates/katana/rpc/rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ version.workspace = true
anyhow.workspace = true
dojo-metrics.workspace = true
futures.workspace = true
http.workspace = true
jsonrpsee = { workspace = true, features = [ "server" ] }
katana-core.workspace = true
katana-executor.workspace = true
Expand All @@ -21,9 +22,12 @@ katana-rpc-types.workspace = true
katana-rpc-types-builder.workspace = true
katana-tasks.workspace = true
metrics.workspace = true
serde_json.workspace = true
starknet.workspace = true
thiserror.workspace = true
tokio.workspace = true
tower.workspace = true
tower-http.workspace = true
tracing.workspace = true
url.workspace = true

Expand Down
Loading

0 comments on commit 71db0b4

Please sign in to comment.