Skip to content

Commit

Permalink
Use More Semantic HTTP (#346)
Browse files Browse the repository at this point in the history
Close #345
  • Loading branch information
DanGould authored Aug 15, 2024
2 parents cc97788 + 11353f9 commit ca3f400
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 82 deletions.
66 changes: 9 additions & 57 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion payjoin-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ clap = { version = "~4.0.32", features = ["derive"] }
config = "0.13.3"
env_logger = "0.9.0"
http-body-util = { version = "0.1", optional = true }
hyper = { version = "1", features = ["full"], optional = true }
hyper = { version = "1", features = ["http1", "server"], optional = true }
hyper-rustls = { version = "0.26", optional = true }
hyper-util = { version = "0.1", optional = true }
log = "0.4.7"
Expand Down
9 changes: 8 additions & 1 deletion payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,14 @@ mod e2e {
let db = docker.run(Redis::default());
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls(
format!("http://localhost:{}", port),
port,
db_host,
timeout,
local_cert_key,
)
.await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down
2 changes: 1 addition & 1 deletion payjoin-directory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bitcoin = { version = "0.32.2", features = ["base64"] }
bhttp = { version = "=0.5.1", features = ["http"] }
futures = "0.3.17"
http-body-util = "0.1.2"
hyper = { version = "1" }
hyper = { version = "1", features = ["http1", "server"] }
hyper-rustls = { version = "0.26", optional = true }
hyper-util = "0.1"
ohttp = "0.5.1"
Expand Down
63 changes: 53 additions & 10 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use bitcoin::base64::Engine;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty, Full};
use hyper::body::{Body, Bytes, Incoming};
use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE};
use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE, LOCATION};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode, Uri};
Expand All @@ -20,6 +20,7 @@ use tracing::{debug, error, info, trace};
pub const DEFAULT_DIR_PORT: u16 = 8080;
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_BASE_URL: &str = "https://localhost";

const MAX_BUFFER_SIZE: usize = 65536;

Expand All @@ -31,6 +32,7 @@ mod db;
use crate::db::DbPool;

pub async fn listen_tcp(
base_url: String,
port: u16,
db_host: String,
timeout: Duration,
Expand All @@ -42,13 +44,14 @@ pub async fn listen_tcp(
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let base_url = base_url.clone();
let io = TokioIo::new(stream);
tokio::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(
io,
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone())
}),
)
.with_upgrades()
Expand All @@ -64,6 +67,7 @@ pub async fn listen_tcp(

#[cfg(feature = "danger-local-https")]
pub async fn listen_tcp_with_tls(
base_url: String,
port: u16,
db_host: String,
timeout: Duration,
Expand All @@ -77,6 +81,7 @@ pub async fn listen_tcp_with_tls(
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let base_url = base_url.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Expand All @@ -90,7 +95,7 @@ pub async fn listen_tcp_with_tls(
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
serve_payjoin_directory(req, pool.clone(), ohttp.clone(), base_url.clone())
}),
)
.with_upgrades()
Expand Down Expand Up @@ -143,6 +148,7 @@ async fn serve_payjoin_directory(
req: Request<Incoming>,
pool: DbPool,
ohttp: Arc<Mutex<ohttp::Server>>,
base_url: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>> {
let path = req.uri().path().to_string();
let query = req.uri().query().unwrap_or_default().to_string();
Expand All @@ -151,7 +157,7 @@ async fn serve_payjoin_directory(
let path_segments: Vec<&str> = path.split('/').collect();
debug!("serve_payjoin_directory: {:?}", &path_segments);
let mut response = match (parts.method, path_segments.as_slice()) {
(Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp).await,
(Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp, base_url).await,
(Method::GET, ["", "ohttp-keys"]) => get_ohttp_keys(&ohttp).await,
(Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await,
(Method::GET, ["", "health"]) => health_check().await,
Expand All @@ -169,6 +175,7 @@ async fn handle_ohttp_gateway(
body: Incoming,
pool: DbPool,
ohttp: Arc<Mutex<ohttp::Server>>,
base_url: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
// decapsulate
let ohttp_body =
Expand All @@ -194,10 +201,13 @@ async fn handle_ohttp_gateway(
}
let request = http_req.body(full(body))?;

let response = handle_v2(pool, request).await?;
let response = handle_v2(pool, base_url, request).await?;

let (parts, body) = response.into_parts();
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
for (name, value) in parts.headers.iter() {
bhttp_res.put_header(name.as_str(), value.to_str().unwrap_or_default());
}
let full_body =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
bhttp_res.write_content(&full_body);
Expand All @@ -213,6 +223,7 @@ async fn handle_ohttp_gateway(

async fn handle_v2(
pool: DbPool,
base_url: String,
req: Request<BoxBody<Bytes, hyper::Error>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
let path = req.uri().path().to_string();
Expand All @@ -221,10 +232,10 @@ async fn handle_v2(
let path_segments: Vec<&str> = path.split('/').collect();
debug!("handle_v2: {:?}", &path_segments);
match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", ""]) => post_session(body).await,
(Method::POST, &["", ""]) => post_session(base_url, body).await,
(Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await,
(Method::GET, &["", id]) => get_fallback(id, pool).await,
(Method::POST, &["", id, "payjoin"]) => post_payjoin(id, body, pool).await,
(Method::PUT, &["", id]) => post_payjoin(id, body, pool).await,
_ => Ok(not_found()),
}
}
Expand All @@ -233,6 +244,7 @@ async fn health_check() -> Result<Response<BoxBody<Bytes, hyper::Error>>, Handle
Ok(Response::new(empty()))
}

#[derive(Debug)]
enum HandlerError {
PayloadTooLarge,
InternalServerError(anyhow::Error),
Expand Down Expand Up @@ -273,6 +285,7 @@ impl From<hyper::http::Error> for HandlerError {
}

async fn post_session(
base_url: String,
body: BoxBody<Bytes, hyper::Error>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
let bytes = body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes();
Expand All @@ -283,9 +296,10 @@ async fn post_session(
let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
.map_err(|e| HandlerError::BadRequest(e.into()))?;
tracing::info!("Initialized session with pubkey: {:?}", pubkey);
let mut res = Response::new(empty());
*res.status_mut() = StatusCode::NO_CONTENT;
Ok(res)
Ok(Response::builder()
.header(LOCATION, format!("{}/{}", base_url, pubkey))
.status(StatusCode::CREATED)
.body(empty())?)
}

async fn post_fallback_v1(
Expand Down Expand Up @@ -413,3 +427,32 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into()).map_err(|never| match never {}).boxed()
}

#[cfg(test)]
mod tests {
use hyper::Request;

use super::*;

/// Ensure that the POST / endpoint returns a 201 Created with a Location header
/// as is semantically correct when creating a resource.
///
/// https://datatracker.ietf.org/doc/html/rfc9110#name-post
#[tokio::test]
async fn test_post_session() -> Result<(), Box<dyn std::error::Error>> {
let base_url = "https://localhost".to_string();
let body = full("some_base64_encoded_pubkey");

let request = Request::builder().method(Method::POST).uri("/").body(body)?;

let response = post_session(base_url.clone(), request.into_body())
.await
.map_err(|e| format!("{:?}", e))?;

assert_eq!(response.status(), StatusCode::CREATED);
assert!(response.headers().contains_key(LOCATION));
let location_header = response.headers().get(LOCATION).ok_or("Missing LOCATION header")?;
assert!(location_header.to_str()?.starts_with(&base_url));
Ok(())
}
}
4 changes: 3 additions & 1 deletion payjoin-directory/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string());

payjoin_directory::listen_tcp(dir_port, db_host, timeout).await
let base_url = env::var("PJ_DIR_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());

payjoin_directory::listen_tcp(base_url, dir_port, db_host, timeout).await
}

fn init_logging() {
Expand Down
Loading

0 comments on commit ca3f400

Please sign in to comment.