diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 35f7a3a5..fd042df9 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -7,6 +7,19 @@ use tracing::debug; const DEFAULT_COLUMN: &str = ""; const PJ_V1_COLUMN: &str = "pjv1"; +// TODO move to payjoin crate as pub? +// TODO impl From for ShortId +// TODO impl Display for ShortId (Base64) +// TODO impl TryFrom<&str> for ShortId (Base64) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct ShortId(pub [u8; 8]); + +impl ShortId { + pub fn column_key(&self, column: &str) -> Vec { + self.0.iter().chain(column.as_bytes()).copied().collect() + } +} + #[derive(Debug, Clone)] pub(crate) struct DbPool { client: Client, @@ -19,23 +32,28 @@ impl DbPool { Ok(Self { client, timeout }) } - pub async fn push_default(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + pub async fn push_default(&self, pubkey_id: &ShortId, data: Vec) -> RedisResult<()> { self.push(pubkey_id, DEFAULT_COLUMN, data).await } - pub async fn peek_default(&self, pubkey_id: &str) -> Option>> { + pub async fn peek_default(&self, pubkey_id: &ShortId) -> Option>> { self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await } - pub async fn push_v1(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + pub async fn push_v1(&self, pubkey_id: &ShortId, data: Vec) -> RedisResult<()> { self.push(pubkey_id, PJ_V1_COLUMN, data).await } - pub async fn peek_v1(&self, pubkey_id: &str) -> Option>> { + pub async fn peek_v1(&self, pubkey_id: &ShortId) -> Option>> { self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await } - async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec) -> RedisResult<()> { + async fn push( + &self, + pubkey_id: &ShortId, + channel_type: &str, + data: Vec, + ) -> RedisResult<()> { let mut conn = self.client.get_async_connection().await?; let key = channel_name(pubkey_id, channel_type); () = conn.set(&key, data.clone()).await?; @@ -45,13 +63,13 @@ impl DbPool { async fn peek_with_timeout( &self, - pubkey_id: &str, + pubkey_id: &ShortId, channel_type: &str, ) -> Option>> { tokio::time::timeout(self.timeout, self.peek(pubkey_id, channel_type)).await.ok() } - async fn peek(&self, pubkey_id: &str, channel_type: &str) -> RedisResult> { + async fn peek(&self, pubkey_id: &ShortId, channel_type: &str) -> RedisResult> { let mut conn = self.client.get_async_connection().await?; let key = channel_name(pubkey_id, channel_type); @@ -99,6 +117,6 @@ impl DbPool { } } -fn channel_name(pubkey_id: &str, channel_type: &str) -> String { - format!("{}:{}", pubkey_id, channel_type) +fn channel_name(pubkey_id: &ShortId, channel_type: &str) -> Vec { + pubkey_id.column_key(channel_type) } diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index ef267c86..c673127d 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -3,6 +3,8 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; +use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; +use bitcoin::base64::Engine; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; @@ -15,6 +17,8 @@ use tokio::net::TcpListener; use tokio::sync::Mutex; use tracing::{debug, error, info, trace}; +use crate::db::ShortId; + pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; @@ -295,7 +299,7 @@ async fn post_fallback_v1( }; let v2_compat_body = format!("{}\n{}", body_str, query); - let id = shorten_string(id); + let id = decode_short_id(id)?; pool.push_default(&id, v2_compat_body.into()) .await .map_err(|e| HandlerError::BadRequest(e.into()))?; @@ -316,7 +320,7 @@ async fn put_payjoin_v1( trace!("Put_payjoin_v1"); let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; - let id = shorten_string(id); + let id = decode_short_id(id)?; let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); if req.len() > MAX_BUFFER_SIZE { @@ -337,7 +341,7 @@ async fn post_subdir( let none_response = Response::builder().status(StatusCode::OK).body(empty())?; trace!("post_subdir"); - let id = shorten_string(id); + let id = decode_short_id(id)?; let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); if req.len() > MAX_BUFFER_SIZE { @@ -355,7 +359,7 @@ async fn get_subdir( pool: DbPool, ) -> Result>, HandlerError> { trace!("get_subdir"); - let id = shorten_string(id); + let id = decode_short_id(id)?; match pool.peek_default(&id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), @@ -385,7 +389,15 @@ async fn get_ohttp_keys( Ok(res) } -fn shorten_string(input: &str) -> String { input.chars().take(8).collect() } +fn decode_short_id(input: &str) -> Result { + let decoded = + BASE64_URL_SAFE_NO_PAD.decode(input).map_err(|e| HandlerError::BadRequest(e.into()))?; + + decoded[..8] + .try_into() + .map_err(|_| HandlerError::BadRequest(anyhow::anyhow!("Invalid subdirectory ID"))) + .map(ShortId) +} fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 86de9380..c62bc8c4 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -3,6 +3,7 @@ use std::time::{Duration, SystemTime}; use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; use bitcoin::base64::Engine; +use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; use bitcoin::{Address, FeeRate, OutPoint, Script, TxOut}; use serde::de::Deserializer; @@ -48,7 +49,8 @@ where } fn subdir_path_from_pubkey(pubkey: &HpkePublicKey) -> String { - BASE64_URL_SAFE_NO_PAD.encode(pubkey.to_compressed_bytes()) + let hash = sha256::Hash::hash(&pubkey.to_compressed_bytes()); + BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]) } /// A payjoin V2 receiver, allowing for polled requests to the @@ -188,22 +190,26 @@ impl Receiver { ) } - // The contents of the `&pj=` query parameter including the base64url-encoded public key receiver subdirectory. + // The contents of the `&pj=` query parameter. // This identifies a session at the payjoin directory server. pub fn pj_url(&self) -> Url { - let pubkey = &self.id(); - let pubkey_base64 = BASE64_URL_SAFE_NO_PAD.encode(pubkey); + let id_base64 = BASE64_URL_SAFE_NO_PAD.encode(self.id()); let mut url = self.context.directory.clone(); { let mut path_segments = url.path_segments_mut().expect("Payjoin Directory URL cannot be a base"); - path_segments.push(&pubkey_base64); + path_segments.push(&id_base64); } url } - /// The per-session public key to use as an identifier - pub fn id(&self) -> [u8; 33] { self.context.s.public_key().to_compressed_bytes() } + /// The per-session identifier + pub fn id(&self) -> [u8; 8] { + let hash = sha256::Hash::hash(&self.context.s.public_key().to_compressed_bytes()); + hash.as_byte_array()[..8] + .try_into() + .expect("truncating SHA256 to 8 bytes should always succeed") + } } /// The sender's original PSBT and optional parameters diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index a0823a05..f31bb6f0 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -25,6 +25,8 @@ use std::str::FromStr; #[cfg(feature = "v2")] use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; +#[cfg(feature = "v2")] +use bitcoin::hashes::{sha256, Hash}; use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{CreateRequestError, ResponseError, ValidationError}; @@ -394,8 +396,10 @@ impl V2GetContext { ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { use crate::uri::UrlExt; let mut url = self.endpoint.clone(); - let subdir = BASE64_URL_SAFE_NO_PAD - .encode(self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); + + // TODO unify with receiver's fn subdir_path_from_pubkey + let hash = sha256::Hash::hash(&self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); + let subdir = BASE64_URL_SAFE_NO_PAD.encode(&hash.as_byte_array()[..8]); url.set_path(&subdir); let body = encrypt_message_a( Vec::new(),