Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(torii-server): server proxy handlers #2708

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/torii/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ tower.workspace = true
tracing.workspace = true
warp.workspace = true
form_urlencoded = "1.2.1"
async-trait = "0.1.83"
46 changes: 46 additions & 0 deletions crates/torii/server/src/handlers/graphql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::net::{IpAddr, SocketAddr};

use http::StatusCode;
use hyper::{Body, Request, Response};
use tracing::error;

use super::Handler;

pub(crate) const LOG_TARGET: &str = "torii::server::handlers::graphql";

pub struct GraphQLHandler {
client_ip: IpAddr,
graphql_addr: Option<SocketAddr>,
}

impl GraphQLHandler {
pub fn new(client_ip: IpAddr, graphql_addr: Option<SocketAddr>) -> Self {
Self { client_ip, graphql_addr }
}
}

#[async_trait::async_trait]
impl Handler for GraphQLHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/graphql")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if let Some(addr) = self.graphql_addr {
let graphql_addr = format!("http://{}", addr);
match crate::proxy::GRAPHQL_PROXY_CLIENT.call(self.client_ip, &graphql_addr, req).await
{
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "GraphQL proxy error: {:?}", _error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
} else {
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle Response builder error case, sensei!

The unwrap() call on the Response builder could panic. Consider using unwrap_or_else for safer error handling.

-            Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
+            Response::builder()
+                .status(StatusCode::NOT_FOUND)
+                .body(Body::empty())
+                .unwrap_or_else(|_| Response::new(Body::empty()))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap_or_else(|_| Response::new(Body::empty()))

}
}
}
49 changes: 49 additions & 0 deletions crates/torii/server/src/handlers/grpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::net::{IpAddr, SocketAddr};

use http::header::CONTENT_TYPE;
use hyper::{Body, Request, Response, StatusCode};
use tracing::error;

use super::Handler;

pub(crate) const LOG_TARGET: &str = "torii::server::handlers::grpc";

pub struct GrpcHandler {
client_ip: IpAddr,
grpc_addr: Option<SocketAddr>,
}

impl GrpcHandler {
pub fn new(client_ip: IpAddr, grpc_addr: Option<SocketAddr>) -> Self {
Self { client_ip, grpc_addr }
}
}

#[async_trait::async_trait]
impl Handler for GrpcHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.headers()
.get(CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
.map(|ct| ct.starts_with("application/grpc"))
.unwrap_or(false)
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if let Some(grpc_addr) = self.grpc_addr {
let grpc_addr = format!("http://{}", grpc_addr);
match crate::proxy::GRPC_PROXY_CLIENT.call(self.client_ip, &grpc_addr, req).await {
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "{:?}", _error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
} else {
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
}
}
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
}
15 changes: 15 additions & 0 deletions crates/torii/server/src/handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub mod graphql;
pub mod grpc;
pub mod sql;
pub mod static_files;

use hyper::{Body, Request, Response};

#[async_trait::async_trait]
pub trait Handler: Send + Sync {
// Check if this handler should handle the given request
fn should_handle(&self, req: &Request<Body>) -> bool;

// Handle the request
async fn handle(&self, req: Request<Body>) -> Response<Body>;
}
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
116 changes: 116 additions & 0 deletions crates/torii/server/src/handlers/sql.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::sync::Arc;

use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use http::header::CONTENT_TYPE;
use hyper::{Body, Method, Request, Response, StatusCode};
use sqlx::{Column, Row, SqlitePool, TypeInfo};

use super::Handler;

pub struct SqlHandler {
pool: Arc<SqlitePool>,
}

impl SqlHandler {
pub fn new(pool: Arc<SqlitePool>) -> Self {
Self { pool }
}

pub async fn execute_query(&self, query: String) -> Response<Body> {
match sqlx::query(&query).fetch_all(&*self.pool).await {
Ok(rows) => {
let result: Vec<_> = rows
.iter()
.map(|row| {
let mut obj = serde_json::Map::new();
for (i, column) in row.columns().iter().enumerate() {
let value: serde_json::Value = match column.type_info().name() {
"TEXT" => row
.get::<Option<String>, _>(i)
.map_or(serde_json::Value::Null, serde_json::Value::String),
"INTEGER" | "NULL" => row
.get::<Option<i64>, _>(i)
.map_or(serde_json::Value::Null, |n| {
serde_json::Value::Number(n.into())
}),
"REAL" => row.get::<Option<f64>, _>(i).map_or(
serde_json::Value::Null,
|f| {
serde_json::Number::from_f64(f).map_or(
serde_json::Value::Null,
serde_json::Value::Number,
)
},
),
"BLOB" => row
.get::<Option<Vec<u8>>, _>(i)
.map_or(serde_json::Value::Null, |bytes| {
serde_json::Value::String(STANDARD.encode(bytes))
}),
_ => row
.get::<Option<String>, _>(i)
.map_or(serde_json::Value::Null, serde_json::Value::String),
};
obj.insert(column.name().to_string(), value);
}
serde_json::Value::Object(obj)
})
.collect();

let json = serde_json::to_string(&result).unwrap();

Larkooo marked this conversation as resolved.
Show resolved Hide resolved
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/json")
.body(Body::from(json))
.unwrap()
}
Err(e) => Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Query error: {:?}", e)))
.unwrap(),
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
}
}
Larkooo marked this conversation as resolved.
Show resolved Hide resolved

async fn extract_query(&self, req: Request<Body>) -> Result<String, Response<Body>> {
match *req.method() {
Method::GET => {
// Get the query from the query params
let params = req.uri().query().unwrap_or_default();
Ok(form_urlencoded::parse(params.as_bytes())
.find(|(key, _)| key == "q" || key == "query")
.map(|(_, value)| value.to_string())
.unwrap_or_default())
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
}
Method::POST => {
// Get the query from request body
let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap_or_default();
String::from_utf8(body_bytes.to_vec()).map_err(|_| {
Larkooo marked this conversation as resolved.
Show resolved Hide resolved
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid query"))
.unwrap()
})
}
_ => Err(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::from("Only GET and POST methods are allowed"))
.unwrap()),
}
}
}

#[async_trait::async_trait]
impl Handler for SqlHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/sql")
}
Comment on lines +124 to +126
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ohayo sensei! Enhance request validation

The should_handle method only checks the path prefix. Consider adding more validation:

  1. Content-Type validation for POST requests
  2. Query parameter validation for GET requests
     fn should_handle(&self, req: &Request<Body>) -> bool {
-        req.uri().path().starts_with("/sql")
+        if !req.uri().path().starts_with("/sql") {
+            return false;
+        }
+        match *req.method() {
+            Method::POST => req
+                .headers()
+                .get(CONTENT_TYPE)
+                .map_or(false, |ct| ct == "application/json"),
+            Method::GET => true,
+            _ => false,
+        }
     }

Committable suggestion skipped: line range outside the PR's diff.


async fn handle(&self, req: Request<Body>) -> Response<Body> {
match self.extract_query(req).await {
Ok(query) => self.execute_query(query).await,
Err(response) => response,
}
}
}
47 changes: 47 additions & 0 deletions crates/torii/server/src/handlers/static_files.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::net::{IpAddr, SocketAddr};

use hyper::{Body, Request, Response, StatusCode};
use tracing::error;

use super::Handler;

pub(crate) const LOG_TARGET: &str = "torii::server::handlers::static";

pub struct StaticHandler {
client_ip: IpAddr,
artifacts_addr: Option<SocketAddr>,
}

impl StaticHandler {
pub fn new(client_ip: IpAddr, artifacts_addr: Option<SocketAddr>) -> Self {
Self { client_ip, artifacts_addr }
}
}

#[async_trait::async_trait]
impl Handler for StaticHandler {
fn should_handle(&self, req: &Request<Body>) -> bool {
req.uri().path().starts_with("/static")
}

async fn handle(&self, req: Request<Body>) -> Response<Body> {
if let Some(artifacts_addr) = self.artifacts_addr {
let artifacts_addr = format!("http://{}", artifacts_addr);
match crate::proxy::GRAPHQL_PROXY_CLIENT
.call(self.client_ip, &artifacts_addr, req)
.await
{
Ok(response) => response,
Err(_error) => {
error!(target: LOG_TARGET, "{:?}", _error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
}
} else {
Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
}
}
}
1 change: 1 addition & 0 deletions crates/torii/server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod artifacts;
pub(crate) mod handlers;
pub mod proxy;
Loading
Loading