Skip to content

Commit

Permalink
Protect metadata with Oblivious HTTP
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Dec 11, 2023
1 parent a1d090c commit 5935775
Show file tree
Hide file tree
Showing 14 changed files with 1,592 additions and 477 deletions.
438 changes: 395 additions & 43 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions payjoin-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ path = "src/main.rs"
[features]
native-certs = ["ureq/native-certs"]
danger-local-https = ["rcgen", "rustls/dangerous_configuration", "hyper-rustls"]
v2 = ["payjoin/v2"]

[dependencies]
anyhow = "1.0.70"
Expand Down
4 changes: 3 additions & 1 deletion payjoin-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ impl App {
let (req, ctx) = payjoin::send::RequestBuilder::from_psbt_and_uri(psbt, uri)
.with_context(|| "Failed to build payjoin request")?
.build_recommended(fee_rate)
.with_context(|| "Failed to build payjoin request")?;
.with_context(|| "Failed to build payjoin request")?
.extract_v1()?;

let http = http_agent()?;
println!("Sending fallback request to {}", &req.url);
let response = spawn_blocking(move || {
Expand Down
1 change: 1 addition & 0 deletions payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#[cfg(feature = "danger-local-https")]
#[cfg(not(feature = "v2"))]
mod e2e {
use std::env;
use std::process::Stdio;
Expand Down
6 changes: 4 additions & 2 deletions payjoin-relay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ exclude = ["tests"]
danger-local-https = ["hyper-rustls", "rcgen", "rustls"]

[dependencies]
anyhow = "1.0.71"
bitcoin = { version = "0.30.0", features = ["base64"] }
bhttp = { version = "0.4.0", features = ["http"] }
hyper = { version = "0.14", features = ["full"] }
hyper-rustls = { version = "0.24", optional = true }
anyhow = "1.0.71"
payjoin = { path = "../payjoin", features = ["base64"] }
ohttp = "0.4.0"
rcgen = { version = "0.11", optional = true }
rustls = { version = "0.21", optional = true }
sqlx = { version = "0.7.1", features = ["postgres", "runtime-tokio"] }
Expand Down
195 changes: 163 additions & 32 deletions payjoin-relay/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::env;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;

use anyhow::Result;
use bitcoin::{self, base64};
use hyper::server::conn::AddrIncoming;
use hyper::server::Builder;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, HeaderMap, Method, Request, Response, Server, StatusCode};
use tracing::{debug, error, info};
use hyper::{Body, Method, Request, Response, Server, StatusCode, Uri};
use tokio::sync::Mutex;
use tracing::{debug, error, info, trace};
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::EnvFilter;

const DEFAULT_RELAY_PORT: &str = "8080";
const DEFAULT_DB_HOST: &str = "localhost:5432";
const DEFAULT_TIMEOUT_SECS: u64 = 30;
const MAX_BUFFER_SIZE: usize = 65536;
const V1_REJECT_RES_JSON: &str =
r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#;
const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#;

mod db;
Expand All @@ -32,10 +36,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string());

let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let make_svc = make_service_fn(|_| {
let pool = pool.clone();
let ohttp = ohttp.clone();
async move {
let handler = move |req| handle_web_req(pool.clone(), req);
let handler = move |req| handle_ohttp_gateway(req, pool.clone(), ohttp.clone());
Ok::<_, hyper::Error>(service_fn(handler))
}
});
Expand Down Expand Up @@ -88,17 +94,42 @@ fn init_server(bind_addr: &SocketAddr) -> Result<Builder<hyper_rustls::TlsAccept
Ok(Server::builder(acceptor))
}

async fn handle_web_req(pool: DbPool, req: Request<Body>) -> Result<Response<Body>> {
fn init_ohttp() -> Result<ohttp::Server> {
use ohttp::hpke::{Aead, Kdf, Kem};
use ohttp::{KeyId, SymmetricSuite};

const KEY_ID: KeyId = 1;
const KEM: Kem = Kem::X25519Sha256;
const SYMMETRIC: &[SymmetricSuite] =
&[SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)];

// create or read from file
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?;
let encoded_config = server_config.encode()?;
let b64_config = base64::encode_config(
encoded_config,
base64::Config::new(base64::CharacterSet::UrlSafe, false),
);
info!("ohttp server config base64 UrlSafe: {:?}", b64_config);
Ok(ohttp::Server::new(server_config)?)
}

async fn handle_ohttp_gateway(
req: Request<Body>,
pool: DbPool,
ohttp: Arc<Mutex<ohttp::Server>>,
) -> Result<Response<Body>> {
let path = req.uri().path().to_string();
let query = req.uri().query().unwrap_or_default().to_string();
let (parts, body) = req.into_parts();

let path_segments: Vec<&str> = path.split('/').collect();
debug!("{:?}", &path_segments);
debug!("handle_ohttp_gateway: {:?}", &path_segments);
let mut response = match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", ""]) => post_enroll(body).await,
(Method::POST, &["", id]) => post_fallback(id, body, parts.headers, pool).await,
(Method::GET, &["", id]) => get_fallback(id, pool).await,
(Method::POST, &["", id, "payjoin"]) => post_payjoin(id, body, pool).await,
(Method::POST, ["", ""]) => handle_ohttp(body, pool, ohttp).await,
(Method::GET, ["", "ohttp-config"]) =>
Ok(get_ohttp_config(ohttp_config(&ohttp).await?).await),
(Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await,
_ => Ok(not_found()),
}
.unwrap_or_else(|e| e.to_response());
Expand All @@ -111,30 +142,88 @@ async fn handle_web_req(pool: DbPool, req: Request<Body>) -> Result<Response<Bod
Ok(response)
}

async fn handle_ohttp(
body: Body,
pool: DbPool,
ohttp: Arc<Mutex<ohttp::Server>>,
) -> Result<Response<Body>, HandlerError> {
// decapsulate
let ohttp_body =
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
let mut ohttp_locked = ohttp.lock().await;
let (bhttp_req, res_ctx) =
ohttp_locked.decapsulate(&ohttp_body).map_err(|e| HandlerError::BadRequest(e.into()))?;
drop(ohttp_locked);
let mut cursor = std::io::Cursor::new(bhttp_req);
let req =
bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?;
let uri = Uri::builder()
.scheme(req.control().scheme().unwrap_or_default())
.authority(req.control().authority().unwrap_or_default())
.path_and_query(req.control().path().unwrap_or_default())
.build()?;
let body = req.content().to_vec();
let mut http_req =
Request::builder().uri(uri).method(req.control().method().unwrap_or_default());
for header in req.header().fields() {
http_req = http_req.header(header.name(), header.value())
}
let request = http_req.body(Body::from(body))?;

let response = handle_v2(pool, request).await?;

let (parts, body) = response.into_parts();
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
let full_body = hyper::body::to_bytes(body)
.await
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
bhttp_res.write_content(&full_body);
let mut bhttp_bytes = Vec::new();
bhttp_res
.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes)
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
let ohttp_res = res_ctx
.encapsulate(&bhttp_bytes)
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
Ok(Response::new(Body::from(ohttp_res)))
}

async fn handle_v2(pool: DbPool, req: Request<Body>) -> Result<Response<Body>, HandlerError> {
let path = req.uri().path().to_string();
let (parts, body) = req.into_parts();

let path_segments: Vec<&str> = path.split('/').collect();
debug!("handle_v2: {:?}", &path_segments);
match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", ""]) => post_enroll(body).await,
(Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await,
(Method::GET, &["", id]) => get_fallback(id, pool).await,
(Method::POST, &["", id, "payjoin"]) => post_payjoin(id, body, pool).await,
_ => Ok(not_found()),
}
}

enum HandlerError {
PayloadTooLarge,
ReceiverOffline,
InternalServerError(anyhow::Error),
BadRequest(anyhow::Error),
}

impl HandlerError {
fn to_response(&self) -> Response<Body> {
let (status, body) = match self {
HandlerError::PayloadTooLarge => (StatusCode::PAYLOAD_TOO_LARGE, Body::empty()),
HandlerError::ReceiverOffline =>
(StatusCode::SERVICE_UNAVAILABLE, Body::from(V1_UNAVAILABLE_RES_JSON)),
let status = match self {
HandlerError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE,
HandlerError::BadRequest(e) => {
error!("Bad request: {}", e);
(StatusCode::BAD_REQUEST, Body::empty())
StatusCode::BAD_REQUEST
}
HandlerError::InternalServerError(e) => {
error!("Internal server error: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Body::empty())
StatusCode::INTERNAL_SERVER_ERROR
}
};

let mut res = Response::new(body);
let mut res = Response::new(Body::empty());
*res.status_mut() = status;
res
}
Expand All @@ -145,7 +234,6 @@ impl From<hyper::http::Error> for HandlerError {
}

async fn post_enroll(body: Body) -> Result<Response<Body>, HandlerError> {
use payjoin::{base64, bitcoin};
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
let bytes =
hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?;
Expand All @@ -159,16 +247,51 @@ async fn post_enroll(body: Body) -> Result<Response<Body>, HandlerError> {
Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?)
}

async fn post_fallback(
async fn post_fallback_v1(
id: &str,
query: String,
body: Body,
pool: DbPool,
) -> Result<Response<Body>, HandlerError> {
trace!("Post fallback v1");
let none_response = Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Body::from(V1_UNAVAILABLE_RES_JSON))?;
let bad_request_body_res =
Response::builder().status(StatusCode::BAD_REQUEST).body(Body::from(V1_REJECT_RES_JSON))?;

let body_bytes = match hyper::body::to_bytes(body).await {
Ok(bytes) => bytes.to_vec(),
Err(_) => return Ok(bad_request_body_res),
};

let body_str = match String::from_utf8(body_bytes) {
Ok(body_str) => body_str,
Err(_) => return Ok(bad_request_body_res),
};

let v2_compat_body = Body::from(format!("{}\n{}", body_str, query));
post_fallback(id, v2_compat_body, pool, none_response).await
}

async fn post_fallback_v2(
id: &str,
body: Body,
headers: HeaderMap,
pool: DbPool,
) -> Result<Response<Body>, HandlerError> {
use hyper::header::HeaderValue;
trace!("Post fallback v2");
let none_response = Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?;
post_fallback(id, body, pool, none_response).await
}

async fn post_fallback(
id: &str,
body: Body,
pool: DbPool,
none_response: Response<Body>,
) -> Result<Response<Body>, HandlerError> {
tracing::trace!("Post fallback");
let id = shorten_string(id);
let is_async = headers.get("Async") == Some(&HeaderValue::from_static("true"));
let req = hyper::body::to_bytes(body)
.await
.map_err(|e| HandlerError::InternalServerError(e.into()))?;
Expand All @@ -186,19 +309,12 @@ async fn post_fallback(
Ok(buffered_res) => Ok(Response::new(Body::from(buffered_res))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => fallback_timeout_response(is_async),
}
}

fn fallback_timeout_response(is_req_async: bool) -> Result<Response<Body>, HandlerError> {
if is_req_async {
Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?)
} else {
Err(HandlerError::ReceiverOffline)
None => Ok(none_response),
}
}

async fn get_fallback(id: &str, pool: DbPool) -> Result<Response<Body>, HandlerError> {
trace!("GET fallback");
let id = shorten_string(id);
match pool.peek_req(&id).await {
Some(result) => match result {
Expand All @@ -210,6 +326,7 @@ async fn get_fallback(id: &str, pool: DbPool) -> Result<Response<Body>, HandlerE
}

async fn post_payjoin(id: &str, body: Body, pool: DbPool) -> Result<Response<Body>, HandlerError> {
trace!("POST payjoin");
let id = shorten_string(id);
let res = hyper::body::to_bytes(body)
.await
Expand All @@ -227,4 +344,18 @@ fn not_found() -> Response<Body> {
res
}

async fn get_ohttp_config(config: String) -> Response<Body> {
trace!("GET ohttp config: {:?}", config);
let mut res = Response::default();
*res.body_mut() = Body::from(config);
res
}

fn shorten_string(input: &str) -> String { input.chars().take(8).collect() }

async fn ohttp_config(server: &Arc<Mutex<ohttp::Server>>) -> Result<String> {
let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
let server = server.lock().await;
let encoded_config = server.config().encode()?;
Ok(base64::encode_config(encoded_config, b64_config))
}
4 changes: 3 additions & 1 deletion payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ exclude = ["tests"]
send = []
receive = ["rand"]
base64 = ["bitcoin/base64"]
v2 = ["bitcoin/rand-std", "chacha20poly1305"]
v2 = ["bitcoin/rand-std", "chacha20poly1305", "ohttp", "bhttp"]

[dependencies]
bitcoin = { version = "0.30.0", features = ["base64"] }
bip21 = "0.3.1"
chacha20poly1305 = { version = "0.10.1", optional = true }
log = { version = "0.4.14"}
ohttp = { version = "0.4.0", optional = true }
bhttp = { version = "0.4.0", optional = true }
rand = { version = "0.8.4", optional = true }
url = "2.2.2"

Expand Down
Loading

0 comments on commit 5935775

Please sign in to comment.