diff --git a/Cargo.lock b/Cargo.lock index a9989e963b..cd43811b99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1627,9 +1627,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.82" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27b8a3a6e1a44fa4c8baf1f653e4172e81486d4941f2237e20dc2d0cf4ddff1" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", @@ -15569,6 +15569,7 @@ name = "torii-server" version = "1.0.1" dependencies = [ "anyhow", + "async-trait", "base64 0.21.7", "camino", "data-url", diff --git a/crates/torii/server/Cargo.toml b/crates/torii/server/Cargo.toml index 6aade5f7f0..3f59e66f23 100644 --- a/crates/torii/server/Cargo.toml +++ b/crates/torii/server/Cargo.toml @@ -30,3 +30,4 @@ tower.workspace = true tracing.workspace = true warp.workspace = true form_urlencoded = "1.2.1" +async-trait = "0.1.83" diff --git a/crates/torii/server/src/handlers/graphql.rs b/crates/torii/server/src/handlers/graphql.rs new file mode 100644 index 0000000000..c51e7cccd7 --- /dev/null +++ b/crates/torii/server/src/handlers/graphql.rs @@ -0,0 +1,46 @@ +use std::net::{IpAddr, SocketAddr}; + +use http::StatusCode; +use hyper::{Body, Request, Response}; +use tracing::error; + +use super::Handler; + +pub(crate) const LOG_TARGET: &str = "torii::server::handlers::graphql"; + +pub struct GraphQLHandler { + client_ip: IpAddr, + graphql_addr: Option, +} + +impl GraphQLHandler { + pub fn new(client_ip: IpAddr, graphql_addr: Option) -> Self { + Self { client_ip, graphql_addr } + } +} + +#[async_trait::async_trait] +impl Handler for GraphQLHandler { + fn should_handle(&self, req: &Request) -> bool { + req.uri().path().starts_with("/graphql") + } + + async fn handle(&self, req: Request) -> Response { + if let Some(addr) = self.graphql_addr { + let graphql_addr = format!("http://{}", addr); + match crate::proxy::GRAPHQL_PROXY_CLIENT.call(self.client_ip, &graphql_addr, req).await + { + Ok(response) => response, + Err(_error) => { + error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap() + } + } + } else { + Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap() + } + } +} diff --git a/crates/torii/server/src/handlers/grpc.rs b/crates/torii/server/src/handlers/grpc.rs new file mode 100644 index 0000000000..befa3e56a4 --- /dev/null +++ b/crates/torii/server/src/handlers/grpc.rs @@ -0,0 +1,49 @@ +use std::net::{IpAddr, SocketAddr}; + +use http::header::CONTENT_TYPE; +use hyper::{Body, Request, Response, StatusCode}; +use tracing::error; + +use super::Handler; + +pub(crate) const LOG_TARGET: &str = "torii::server::handlers::grpc"; + +pub struct GrpcHandler { + client_ip: IpAddr, + grpc_addr: Option, +} + +impl GrpcHandler { + pub fn new(client_ip: IpAddr, grpc_addr: Option) -> Self { + Self { client_ip, grpc_addr } + } +} + +#[async_trait::async_trait] +impl Handler for GrpcHandler { + fn should_handle(&self, req: &Request) -> bool { + req.headers() + .get(CONTENT_TYPE) + .and_then(|ct| ct.to_str().ok()) + .map(|ct| ct.starts_with("application/grpc")) + .unwrap_or(false) + } + + async fn handle(&self, req: Request) -> Response { + if let Some(grpc_addr) = self.grpc_addr { + let grpc_addr = format!("http://{}", grpc_addr); + match crate::proxy::GRPC_PROXY_CLIENT.call(self.client_ip, &grpc_addr, req).await { + Ok(response) => response, + Err(_error) => { + error!(target: LOG_TARGET, "{:?}", _error); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap() + } + } + } else { + Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap() + } + } +} diff --git a/crates/torii/server/src/handlers/mod.rs b/crates/torii/server/src/handlers/mod.rs new file mode 100644 index 0000000000..73e97cb4ab --- /dev/null +++ b/crates/torii/server/src/handlers/mod.rs @@ -0,0 +1,15 @@ +pub mod graphql; +pub mod grpc; +pub mod sql; +pub mod static_files; + +use hyper::{Body, Request, Response}; + +#[async_trait::async_trait] +pub trait Handler: Send + Sync { + // Check if this handler should handle the given request + fn should_handle(&self, req: &Request) -> bool; + + // Handle the request + async fn handle(&self, req: Request) -> Response; +} diff --git a/crates/torii/server/src/handlers/sql.rs b/crates/torii/server/src/handlers/sql.rs new file mode 100644 index 0000000000..2d48b20d07 --- /dev/null +++ b/crates/torii/server/src/handlers/sql.rs @@ -0,0 +1,134 @@ +use std::sync::Arc; + +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use http::header::CONTENT_TYPE; +use hyper::{Body, Method, Request, Response, StatusCode}; +use sqlx::{Column, Row, SqlitePool, TypeInfo}; + +use super::Handler; + +pub struct SqlHandler { + pool: Arc, +} + +impl SqlHandler { + pub fn new(pool: Arc) -> Self { + Self { pool } + } + + pub async fn execute_query(&self, query: String) -> Response { + match sqlx::query(&query).fetch_all(&*self.pool).await { + Ok(rows) => { + let result: Vec<_> = rows + .iter() + .map(|row| { + let mut obj = serde_json::Map::new(); + for (i, column) in row.columns().iter().enumerate() { + let value: serde_json::Value = match column.type_info().name() { + "TEXT" => row + .get::, _>(i) + .map_or(serde_json::Value::Null, serde_json::Value::String), + "INTEGER" | "NULL" => row + .get::, _>(i) + .map_or(serde_json::Value::Null, |n| { + serde_json::Value::Number(n.into()) + }), + "REAL" => row.get::, _>(i).map_or( + serde_json::Value::Null, + |f| { + serde_json::Number::from_f64(f).map_or( + serde_json::Value::Null, + serde_json::Value::Number, + ) + }, + ), + "BLOB" => row + .get::>, _>(i) + .map_or(serde_json::Value::Null, |bytes| { + serde_json::Value::String(STANDARD.encode(bytes)) + }), + _ => row + .get::, _>(i) + .map_or(serde_json::Value::Null, serde_json::Value::String), + }; + obj.insert(column.name().to_string(), value); + } + serde_json::Value::Object(obj) + }) + .collect(); + + let json = match serde_json::to_string(&result) { + Ok(json) => json, + Err(e) => { + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(format!("Failed to serialize result: {:?}", e))) + .unwrap(); + } + }; + + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(json)) + .unwrap() + } + Err(e) => Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from(format!("Query error: {:?}", e))) + .unwrap(), + } + } + + async fn extract_query(&self, req: Request) -> Result> { + match *req.method() { + Method::GET => { + // Get the query from the query params + let params = req.uri().query().unwrap_or_default(); + form_urlencoded::parse(params.as_bytes()) + .find(|(key, _)| key == "q" || key == "query") + .map(|(_, value)| value.to_string()) + .ok_or( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Missing 'q' or 'query' parameter.")) + .unwrap(), + ) + } + Method::POST => { + // Get the query from request body + let body_bytes = hyper::body::to_bytes(req.into_body()).await.map_err(|_| { + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Failed to read query from request body")) + .unwrap() + })?; + String::from_utf8(body_bytes.to_vec()).map_err(|_| { + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Invalid query")) + .unwrap() + }) + } + _ => Err(Response::builder() + .status(StatusCode::METHOD_NOT_ALLOWED) + .body(Body::from("Only GET and POST methods are allowed")) + .unwrap()), + } + } +} + +#[async_trait::async_trait] +impl Handler for SqlHandler { + fn should_handle(&self, req: &Request) -> bool { + req.uri().path().starts_with("/sql") + } + + async fn handle(&self, req: Request) -> Response { + match self.extract_query(req).await { + Ok(query) => self.execute_query(query).await, + Err(response) => response, + } + } +} diff --git a/crates/torii/server/src/handlers/static_files.rs b/crates/torii/server/src/handlers/static_files.rs new file mode 100644 index 0000000000..631b032e11 --- /dev/null +++ b/crates/torii/server/src/handlers/static_files.rs @@ -0,0 +1,47 @@ +use std::net::{IpAddr, SocketAddr}; + +use hyper::{Body, Request, Response, StatusCode}; +use tracing::error; + +use super::Handler; + +pub(crate) const LOG_TARGET: &str = "torii::server::handlers::static"; + +pub struct StaticHandler { + client_ip: IpAddr, + artifacts_addr: Option, +} + +impl StaticHandler { + pub fn new(client_ip: IpAddr, artifacts_addr: Option) -> Self { + Self { client_ip, artifacts_addr } + } +} + +#[async_trait::async_trait] +impl Handler for StaticHandler { + fn should_handle(&self, req: &Request) -> bool { + req.uri().path().starts_with("/static") + } + + async fn handle(&self, req: Request) -> Response { + if let Some(artifacts_addr) = self.artifacts_addr { + let artifacts_addr = format!("http://{}", artifacts_addr); + match crate::proxy::GRAPHQL_PROXY_CLIENT + .call(self.client_ip, &artifacts_addr, req) + .await + { + Ok(response) => response, + Err(_error) => { + error!(target: LOG_TARGET, "{:?}", _error); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap() + } + } + } else { + Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap() + } + } +} diff --git a/crates/torii/server/src/lib.rs b/crates/torii/server/src/lib.rs index 621f66d155..9fe086e1fb 100644 --- a/crates/torii/server/src/lib.rs +++ b/crates/torii/server/src/lib.rs @@ -1,2 +1,3 @@ pub mod artifacts; +pub(crate) mod handlers; pub mod proxy; diff --git a/crates/torii/server/src/proxy.rs b/crates/torii/server/src/proxy.rs index 2808db2295..7f276aedaf 100644 --- a/crates/torii/server/src/proxy.rs +++ b/crates/torii/server/src/proxy.rs @@ -3,8 +3,6 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use base64::engine::general_purpose::STANDARD; -use base64::Engine; use http::header::CONTENT_TYPE; use http::{HeaderName, Method}; use hyper::client::connect::dns::GaiResolver; @@ -14,13 +12,16 @@ use hyper::service::make_service_fn; use hyper::{Body, Client, Request, Response, Server, StatusCode}; use hyper_reverse_proxy::ReverseProxy; use serde_json::json; -use sqlx::{Column, Row, SqlitePool, TypeInfo}; +use sqlx::SqlitePool; use tokio::sync::RwLock; use tower::ServiceBuilder; use tower_http::cors::{AllowOrigin, CorsLayer}; -use tracing::error; -pub(crate) const LOG_TARGET: &str = "torii::server::proxy"; +use crate::handlers::graphql::GraphQLHandler; +use crate::handlers::grpc::GrpcHandler; +use crate::handlers::sql::SqlHandler; +use crate::handlers::static_files::StaticHandler; +use crate::handlers::Handler; const DEFAULT_ALLOW_HEADERS: [&str; 13] = [ "accept", @@ -42,14 +43,14 @@ const DEFAULT_EXPOSED_HEADERS: [&str; 4] = const DEFAULT_MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60); lazy_static::lazy_static! { - static ref GRAPHQL_PROXY_CLIENT: ReverseProxy> = { + pub(crate) static ref GRAPHQL_PROXY_CLIENT: ReverseProxy> = { ReverseProxy::new( Client::builder() .build_http(), ) }; - static ref GRPC_PROXY_CLIENT: ReverseProxy> = { + pub(crate) static ref GRPC_PROXY_CLIENT: ReverseProxy> = { ReverseProxy::new( Client::builder() .http2_only(true) @@ -170,158 +171,28 @@ async fn handle( pool: Arc, req: Request, ) -> Result, Infallible> { - if req.uri().path().starts_with("/static") { - if let Some(artifacts_addr) = artifacts_addr { - let artifacts_addr = format!("http://{}", artifacts_addr); - - return match GRAPHQL_PROXY_CLIENT.call(client_ip, &artifacts_addr, req).await { - Ok(response) => Ok(response), - Err(_error) => { - error!(target: LOG_TARGET, "Artifacts proxy error: {:?}", _error); - Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap()) - } - }; - } else { - return Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap()); - } - } - - if req.uri().path().starts_with("/graphql") { - if let Some(graphql_addr) = graphql_addr { - let graphql_addr = format!("http://{}", graphql_addr); - return match GRAPHQL_PROXY_CLIENT.call(client_ip, &graphql_addr, req).await { - Ok(response) => Ok(response), - Err(_error) => { - error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error); - Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap()) - } - }; - } else { - return Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap()); + let handlers: Vec> = vec![ + Box::new(SqlHandler::new(pool)), + Box::new(GraphQLHandler::new(client_ip, graphql_addr)), + Box::new(GrpcHandler::new(client_ip, grpc_addr)), + Box::new(StaticHandler::new(client_ip, artifacts_addr)), + ]; + + for handler in handlers { + if handler.should_handle(&req) { + return Ok(handler.handle(req).await); } } - if let Some(content_type) = req.headers().get(CONTENT_TYPE) { - if content_type.to_str().unwrap().starts_with("application/grpc") { - if let Some(grpc_addr) = grpc_addr { - let grpc_addr = format!("http://{}", grpc_addr); - return match GRPC_PROXY_CLIENT.call(client_ip, &grpc_addr, req).await { - Ok(response) => Ok(response), - Err(_error) => { - error!(target: LOG_TARGET, "GRPC proxy error: {:?}", _error); - Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap()) - } - }; - } else { - return Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::empty()) - .unwrap()); - } - } - } - - if req.uri().path().starts_with("/sql") { - let query = if req.method() == Method::GET { - // Get the query from URL parameters - let params = req.uri().query().unwrap_or_default(); - form_urlencoded::parse(params.as_bytes()) - .find(|(key, _)| key == "q") - .map(|(_, value)| value.to_string()) - .unwrap_or_default() - } else if req.method() == Method::POST { - // Get the query from request body - let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap_or_default(); - String::from_utf8(body_bytes.to_vec()).unwrap_or_default() - } else { - return Ok(Response::builder() - .status(StatusCode::METHOD_NOT_ALLOWED) - .body(Body::from("Only GET and POST methods are allowed")) - .unwrap()); - }; - - // Execute the query - return match sqlx::query(&query).fetch_all(&*pool).await { - Ok(rows) => { - let result: Vec<_> = rows - .iter() - .map(|row| { - let mut obj = serde_json::Map::new(); - for (i, column) in row.columns().iter().enumerate() { - let value: serde_json::Value = match column.type_info().name() { - "TEXT" => row - .get::, _>(i) - .map_or(serde_json::Value::Null, serde_json::Value::String), - // for operators like count(*) the type info is NULL - // so we default to a number - "INTEGER" | "NULL" => row - .get::, _>(i) - .map_or(serde_json::Value::Null, |n| { - serde_json::Value::Number(n.into()) - }), - "REAL" => row.get::, _>(i).map_or( - serde_json::Value::Null, - |f| { - serde_json::Number::from_f64(f).map_or( - serde_json::Value::Null, - serde_json::Value::Number, - ) - }, - ), - "BLOB" => row - .get::>, _>(i) - .map_or(serde_json::Value::Null, |bytes| { - serde_json::Value::String(STANDARD.encode(bytes)) - }), - _ => row - .get::, _>(i) - .map_or(serde_json::Value::Null, serde_json::Value::String), - }; - obj.insert(column.name().to_string(), value); - } - serde_json::Value::Object(obj) - }) - .collect(); - - let json = serde_json::to_string(&result).unwrap(); - - Ok(Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, "application/json") - .body(Body::from(json)) - .unwrap()) - } - Err(e) => Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from(format!("Query error: {:?}", e))) - .unwrap()), - }; - } - + // Default response if no handler matches let json = json!({ "service": "torii", "success": true }); - let body = Body::from(json.to_string()); - let response = Response::builder() + + Ok(Response::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, "application/json") - .body(body) - .unwrap(); - Ok(response) + .body(Body::from(json.to_string())) + .unwrap()) }