diff --git a/payjoin-cli/src/app/mod.rs b/payjoin-cli/src/app/mod.rs index 73cb0dd2..c55b8105 100644 --- a/payjoin-cli/src/app/mod.rs +++ b/payjoin-cli/src/app/mod.rs @@ -5,7 +5,7 @@ use anyhow::{anyhow, Context, Result}; use bitcoincore_rpc::bitcoin::Amount; use bitcoincore_rpc::RpcApi; use payjoin::bitcoin::psbt::Psbt; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use payjoin::{bitcoin, PjUri}; pub mod config; @@ -28,7 +28,7 @@ pub trait App { async fn send_payjoin(&self, bip21: &str, fee_rate: &f32) -> Result<()>; async fn receive_payjoin(self, amount_arg: &str) -> Result<()>; - fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result { + fn create_pj_request(&self, uri: &PjUri, fee_rate: &f32) -> Result { let amount = uri.amount.ok_or_else(|| anyhow!("please specify the amount in the Uri"))?; // wallet_create_funded_psbt requires a HashMap @@ -64,7 +64,7 @@ pub trait App { .psbt; let psbt = Psbt::from_str(&psbt).with_context(|| "Failed to load PSBT from base64")?; log::debug!("Original psbt: {:#?}", psbt); - let req_ctx = payjoin::send::RequestBuilder::from_psbt_and_uri(psbt, uri.clone()) + let req_ctx = payjoin::send::SenderBuilder::from_psbt_and_uri(psbt, uri.clone()) .with_context(|| "Failed to build payjoin request")? .build_recommended(fee_rate) .with_context(|| "Failed to build payjoin request")?; diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index e90735bf..a1139b2e 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -8,7 +8,7 @@ use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::psbt::Psbt; use payjoin::bitcoin::{Amount, FeeRate}; use payjoin::receive::v2::ActiveSession; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use payjoin::{bitcoin, Error, Uri}; use tokio::signal; use tokio::sync::watch; @@ -91,7 +91,7 @@ impl AppTrait for App { } impl App { - async fn spawn_payjoin_sender(&self, mut req_ctx: RequestContext) -> Result<()> { + async fn spawn_payjoin_sender(&self, mut req_ctx: Sender) -> Result<()> { let mut interrupt = self.interrupt.clone(); tokio::select! { res = self.long_poll_post(&mut req_ctx) => { @@ -197,30 +197,57 @@ impl App { Ok(()) } - async fn long_poll_post(&self, req_ctx: &mut payjoin::send::RequestContext) -> Result { - loop { - let (req, ctx) = req_ctx.extract_v2(self.config.ohttp_relay.clone())?; - println!("Polling send request..."); - let http = http_agent()?; - let response = http - .post(req.url) - .header("Content-Type", req.content_type) - .body(req.body) - .send() - .await - .map_err(map_reqwest_err)?; - - println!("Sent fallback transaction"); - match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { - Ok(Some(psbt)) => return Ok(psbt), - Ok(None) => { - println!("No response yet."); - tokio::time::sleep(std::time::Duration::from_secs(5)).await; + async fn long_poll_post(&self, req_ctx: &mut payjoin::send::Sender) -> Result { + let (req, ctx) = req_ctx.extract_highest_version(self.config.ohttp_relay.clone())?; + println!("Posting Original PSBT Payload request..."); + let http = http_agent()?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + println!("Sent fallback transaction"); + match ctx { + payjoin::send::Context::V2(ctx) => { + let v2_ctx = Arc::new( + ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?, + ); + loop { + let (req, ohttp_ctx) = v2_ctx.extract_req(self.config.ohttp_relay.clone())?; + let response = http + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(map_reqwest_err)?; + match v2_ctx.process_response( + &mut response.bytes().await?.to_vec().as_slice(), + ohttp_ctx, + ) { + Ok(Some(psbt)) => return Ok(psbt), + Ok(None) => { + println!("No response yet."); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + Err(re) => { + println!("{}", re); + log::debug!("{:?}", re); + return Err(anyhow!("Response error").context(re)); + } + } } - Err(re) => { - println!("{}", re); - log::debug!("{:?}", re); - return Err(anyhow!("Response error").context(re)); + } + payjoin::send::Context::V1(ctx) => { + match ctx.process_response(&mut response.bytes().await?.to_vec().as_slice()) { + Ok(psbt) => Ok(psbt), + Err(re) => { + println!("{}", re); + log::debug!("{:?}", re); + Err(anyhow!("Response error").context(re)) + } } } } diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 8ec7250b..67137efc 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -1,6 +1,6 @@ use bitcoincore_rpc::jsonrpc::serde_json; use payjoin::receive::v2::ActiveSession; -use payjoin::send::RequestContext; +use payjoin::send::Sender; use sled::{IVec, Tree}; use url::Url; @@ -35,11 +35,7 @@ impl Database { Ok(()) } - pub(crate) fn insert_send_session( - &self, - session: &mut RequestContext, - pj_url: &Url, - ) -> Result<()> { + pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let value = serde_json::to_string(session).map_err(Error::Serialize)?; send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?; @@ -47,23 +43,21 @@ impl Database { Ok(()) } - pub(crate) fn get_send_sessions(&self) -> Result> { + pub(crate) fn get_send_sessions(&self) -> Result> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let mut sessions = Vec::new(); for item in send_tree.iter() { let (_, value) = item?; - let session: RequestContext = - serde_json::from_slice(&value).map_err(Error::Deserialize)?; + let session: Sender = serde_json::from_slice(&value).map_err(Error::Deserialize)?; sessions.push(session); } Ok(sessions) } - pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { + pub(crate) fn get_send_session(&self, pj_url: &Url) -> Result> { let send_tree = self.0.open_tree("send_sessions")?; if let Some(val) = send_tree.get(pj_url.to_string())? { - let session: RequestContext = - serde_json::from_slice(&val).map_err(Error::Deserialize)?; + let session: Sender = serde_json::from_slice(&val).map_err(Error::Deserialize)?; Ok(Some(session)) } else { Ok(None) diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 26b69864..679a0f40 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -4,8 +4,8 @@ use futures::StreamExt; use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult}; use tracing::debug; -const RES_COLUMN: &str = "res"; -const REQ_COLUMN: &str = "req"; +const DEFAULT_COLUMN: &str = ""; +const PJ_V1_COLUMN: &str = "pjv1"; #[derive(Debug, Clone)] pub(crate) struct DbPool { @@ -19,20 +19,20 @@ impl DbPool { Ok(Self { client, timeout }) } - pub async fn peek_req(&self, pubkey_id: &str) -> Option>> { - self.peek_with_timeout(pubkey_id, REQ_COLUMN).await + pub async fn push_default(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + self.push(pubkey_id, DEFAULT_COLUMN, data).await } - pub async fn peek_res(&self, pubkey_id: &str) -> Option>> { - self.peek_with_timeout(pubkey_id, RES_COLUMN).await + pub async fn peek_default(&self, pubkey_id: &str) -> Option>> { + self.peek_with_timeout(pubkey_id, DEFAULT_COLUMN).await } - pub async fn push_req(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, REQ_COLUMN, data).await + pub async fn push_v1(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { + self.push(pubkey_id, PJ_V1_COLUMN, data).await } - pub async fn push_res(&self, pubkey_id: &str, data: Vec) -> RedisResult<()> { - self.push(pubkey_id, RES_COLUMN, data).await + pub async fn peek_v1(&self, pubkey_id: &str) -> Option>> { + self.peek_with_timeout(pubkey_id, PJ_V1_COLUMN).await } async fn push(&self, pubkey_id: &str, channel_type: &str, data: Vec) -> RedisResult<()> { diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index eb1a2b65..194cf89a 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -220,9 +220,9 @@ async fn handle_v2( let path_segments: Vec<&str> = path.split('/').collect(); debug!("handle_v2: {:?}", &path_segments); match (parts.method, path_segments.as_slice()) { - (Method::POST, &["", id]) => post_fallback_v2(id, body, pool).await, - (Method::GET, &["", id]) => get_fallback(id, pool).await, - (Method::PUT, &["", id]) => post_payjoin(id, body, pool).await, + (Method::POST, &["", id]) => post_subdir(id, body, pool).await, + (Method::GET, &["", id]) => get_subdir(id, pool).await, + (Method::PUT, &["", id]) => put_payjoin_v1(id, body, pool).await, _ => Ok(not_found()), } } @@ -294,27 +294,49 @@ async fn post_fallback_v1( Err(_) => return Ok(bad_request_body_res), }; - let v2_compat_body = full(format!("{}\n{}", body_str, query)); - post_fallback(id, v2_compat_body, pool, none_response).await + let v2_compat_body = format!("{}\n{}", body_str, query); + let id = shorten_string(id); + pool.push_default(&id, v2_compat_body.into()) + .await + .map_err(|e| HandlerError::BadRequest(e.into()))?; + match pool.peek_v1(&id).await { + Some(result) => match result { + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => Err(HandlerError::BadRequest(e.into())), + }, + None => Ok(none_response), + } } -async fn post_fallback_v2( +async fn put_payjoin_v1( id: &str, body: BoxBody, pool: DbPool, ) -> Result>, HandlerError> { - trace!("Post fallback v2"); - let none_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; - post_fallback(id, body, pool, none_response).await + trace!("Put_payjoin_v1"); + let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; + + let id = shorten_string(id); + let req = + body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); + if req.len() > MAX_BUFFER_SIZE { + return Err(HandlerError::PayloadTooLarge); + } + + match pool.push_v1(&id, req.into()).await { + Ok(_) => Ok(ok_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), + } } -async fn post_fallback( +async fn post_subdir( id: &str, body: BoxBody, pool: DbPool, - none_response: Response>, ) -> Result>, HandlerError> { - tracing::trace!("Post fallback"); + let none_response = Response::builder().status(StatusCode::OK).body(empty())?; + tracing::trace!("Post subdir"); + let id = shorten_string(id); let req = body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); @@ -322,27 +344,19 @@ async fn post_fallback( return Err(HandlerError::PayloadTooLarge); } - match pool.push_req(&id, req.into()).await { - Ok(_) => (), - Err(e) => return Err(HandlerError::BadRequest(e.into())), - }; - - match pool.peek_res(&id).await { - Some(result) => match result { - Ok(buffered_res) => Ok(Response::new(full(buffered_res))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(none_response), + match pool.push_default(&id, req.into()).await { + Ok(_) => Ok(none_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), } } -async fn get_fallback( +async fn get_subdir( id: &str, pool: DbPool, ) -> Result>, HandlerError> { trace!("GET fallback"); let id = shorten_string(id); - match pool.peek_req(&id).await { + match pool.peek_default(&id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(full(buffered_req))), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -351,22 +365,6 @@ async fn get_fallback( } } -async fn post_payjoin( - id: &str, - body: BoxBody, - pool: DbPool, -) -> Result>, HandlerError> { - trace!("POST payjoin"); - let id = shorten_string(id); - let res = - body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); - - match pool.push_res(&id, res.into()).await { - Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(empty())?), - Err(e) => Err(HandlerError::BadRequest(e.into())), - } -} - fn not_found() -> Response> { let mut res = Response::default(); *res.status_mut() = StatusCode::NOT_FOUND; diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 1649a645..eca9acf8 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -93,7 +93,7 @@ impl ActiveSession { } } - // OHTTP Encapsulated HTTP GET request for the Original PSBT + /// Extratct an OHTTP Encapsulated HTTP GET request for the Original PSBT pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> { if SystemTime::now() > self.context.expiry { return Err(InternalSessionError::Expired(self.context.expiry).into()); @@ -482,22 +482,34 @@ impl PayjoinProposal { #[cfg(feature = "v2")] pub fn extract_v2_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> { - let body = match &self.context.e { - Some(e) => { - let payjoin_bytes = self.inner.payjoin_psbt.serialize(); - log::debug!("THERE IS AN e: {:?}", e); - crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e) - } - None => Ok(self.extract_v1_req().as_bytes().to_vec()), - }?; - let subdir_path = subdir_path_from_pubkey(self.context.s.public_key()); - let post_payjoin_target = - self.context.directory.join(&subdir_path).map_err(|e| Error::Server(e.into()))?; - log::debug!("Payjoin post target: {}", post_payjoin_target.as_str()); + let target_resource: Url; + let body: Vec; + let method: &str; + + if let Some(e) = &self.context.e { + // Prepare v2 payload + let payjoin_bytes = self.inner.payjoin_psbt.serialize(); + let sender_subdir = subdir_path_from_pubkey(e); + target_resource = + self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; + body = crate::v2::encrypt_message_b(payjoin_bytes, &self.context.s, e).unwrap(); + method = "POST"; + } else { + // Prepare v2 wrapped and backwards-compatible v1 payload + body = self.extract_v1_req().as_bytes().to_vec(); + let receiver_subdir = subdir_path_from_pubkey(self.context.s.public_key()); + target_resource = self + .context + .directory + .join(&receiver_subdir) + .map_err(|e| Error::Server(e.into()))?; + method = "PUT"; + } + log::debug!("Payjoin PSBT target: {}", target_resource.as_str()); let (body, ctx) = crate::v2::ohttp_encapsulate( &mut self.context.ohttp_keys, - "PUT", - post_payjoin_target.as_str(), + method, + target_resource.as_str(), Some(&body), )?; let url = self.context.ohttp_relay.clone(); diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 99f10584..3aa9175f 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -26,6 +26,8 @@ use std::str::FromStr; +#[cfg(feature = "v2")] +use bitcoin::base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; use bitcoin::psbt::Psbt; use bitcoin::{Amount, FeeRate, Script, ScriptBuf, TxOut, Weight}; pub use error::{CreateRequestError, ResponseError, ValidationError}; @@ -37,7 +39,7 @@ use url::Url; use crate::psbt::{InputPair, PsbtExt}; use crate::request::Request; #[cfg(feature = "v2")] -use crate::v2::{HpkePublicKey, HpkeSecretKey}; +use crate::v2::{HpkeKeyPair, HpkePublicKey}; use crate::PjUri; // See usize casts @@ -49,7 +51,7 @@ mod error; type InternalResult = Result; #[derive(Clone)] -pub struct RequestBuilder<'a> { +pub struct SenderBuilder<'a> { psbt: Psbt, uri: PjUri<'a>, disable_output_substitution: bool, @@ -63,7 +65,7 @@ pub struct RequestBuilder<'a> { min_fee_rate: FeeRate, } -impl<'a> RequestBuilder<'a> { +impl<'a> SenderBuilder<'a> { /// Prepare an HTTP request and request context to process the response /// /// An HTTP client will own the Request data while Context sticks around so @@ -98,10 +100,7 @@ impl<'a> RequestBuilder<'a> { // The minfeerate parameter is set if the contribution is available in change. // // This method fails if no recommendation can be made or if the PSBT is malformed. - pub fn build_recommended( - self, - min_fee_rate: FeeRate, - ) -> Result { + pub fn build_recommended(self, min_fee_rate: FeeRate) -> Result { // TODO support optional batched payout scripts. This would require a change to // build() which now checks for a single payee. let mut payout_scripts = std::iter::once(self.uri.address.script_pubkey()); @@ -179,7 +178,7 @@ impl<'a> RequestBuilder<'a> { change_index: Option, min_fee_rate: FeeRate, clamp_fee_contribution: bool, - ) -> Result { + ) -> Result { self.fee_contribution = Some((max_fee_contribution, change_index)); self.clamp_fee_contribution = clamp_fee_contribution; self.min_fee_rate = min_fee_rate; @@ -193,7 +192,7 @@ impl<'a> RequestBuilder<'a> { pub fn build_non_incentivizing( mut self, min_fee_rate: FeeRate, - ) -> Result { + ) -> Result { // since this is a builder, these should already be cleared // but we'll reset them to be sure self.fee_contribution = None; @@ -202,7 +201,7 @@ impl<'a> RequestBuilder<'a> { self.build() } - fn build(self) -> Result { + fn build(self) -> Result { let mut psbt = self.psbt.validate().map_err(InternalCreateRequestError::InconsistentOriginalPsbt)?; psbt.validate_input_utxos(true) @@ -221,7 +220,7 @@ impl<'a> RequestBuilder<'a> { )?; clear_unneeded_fields(&mut psbt); - Ok(RequestContext { + Ok(Sender { psbt, endpoint, disable_output_substitution, @@ -229,14 +228,14 @@ impl<'a> RequestBuilder<'a> { payee, min_fee_rate: self.min_fee_rate, #[cfg(feature = "v2")] - e: crate::v2::HpkeKeyPair::gen_keypair().secret_key().clone(), + e: crate::v2::HpkeKeyPair::gen_keypair(), }) } } #[derive(Clone, PartialEq, Eq)] #[cfg_attr(feature = "v2", derive(Serialize, Deserialize))] -pub struct RequestContext { +pub struct Sender { psbt: Psbt, endpoint: Url, disable_output_substitution: bool, @@ -244,12 +243,12 @@ pub struct RequestContext { min_fee_rate: FeeRate, payee: ScriptBuf, #[cfg(feature = "v2")] - e: crate::v2::HpkeSecretKey, + e: crate::v2::HpkeKeyPair, } -impl RequestContext { +impl Sender { /// Extract serialized V1 Request and Context froma Payjoin Proposal - pub fn extract_v1(&self) -> Result<(Request, ContextV1), CreateRequestError> { + pub fn extract_v1(&self) -> Result<(Request, V1Context), CreateRequestError> { let url = serialize_url( self.endpoint.clone(), self.disable_output_substitution, @@ -261,12 +260,14 @@ impl RequestContext { let body = self.psbt.to_string().as_bytes().to_vec(); Ok(( Request::new_v1(url, body), - ContextV1 { - original_psbt: self.psbt.clone(), - disable_output_substitution: self.disable_output_substitution, - fee_contribution: self.fee_contribution, - payee: self.payee.clone(), - min_fee_rate: self.min_fee_rate, + V1Context { + psbt_context: PsbtContext { + original_psbt: self.psbt.clone(), + disable_output_substitution: self.disable_output_substitution, + fee_contribution: self.fee_contribution, + payee: self.payee.clone(), + min_fee_rate: self.min_fee_rate, + }, }, )) } @@ -278,10 +279,10 @@ impl RequestContext { /// /// The `ohttp_relay` merely passes the encrypted payload to the ohttp gateway of the receiver #[cfg(feature = "v2")] - pub fn extract_v2( + pub fn extract_highest_version( &mut self, ohttp_relay: Url, - ) -> Result<(Request, ContextV2), CreateRequestError> { + ) -> Result<(Request, Context), CreateRequestError> { use crate::uri::UrlExt; if let Some(expiry) = self.endpoint.exp() { @@ -291,11 +292,11 @@ impl RequestContext { } match self.extract_rs_pubkey() { - Ok(rs) => self.extract_v2_strict(ohttp_relay, rs), + Ok(rs) => self.extract_v2(ohttp_relay, rs), Err(e) => { log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e); let (req, context_v1) = self.extract_v1()?; - Ok((req, ContextV2 { context_v1, rs: None, e: None, ohttp_res: None })) + Ok((req, Context::V1(context_v1))) } } } @@ -305,11 +306,11 @@ impl RequestContext { /// This method requires the `rs` pubkey to be extracted from the endpoint /// and has no fallback to v1. #[cfg(feature = "v2")] - fn extract_v2_strict( + fn extract_v2( &mut self, ohttp_relay: Url, rs: HpkePublicKey, - ) -> Result<(Request, ContextV2), CreateRequestError> { + ) -> Result<(Request, Context), CreateRequestError> { use crate::uri::UrlExt; let url = self.endpoint.clone(); let body = serialize_v2_body( @@ -318,35 +319,33 @@ impl RequestContext { self.fee_contribution, self.min_fee_rate, )?; - let body = crate::v2::encrypt_message_a(body, &self.e.clone(), &rs) + let body = crate::v2::encrypt_message_a(body, &self.e.secret_key().clone(), &rs) .map_err(InternalCreateRequestError::Hpke)?; let mut ohttp = self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; - let (body, ohttp_res) = + let (body, ohttp_ctx) = crate::v2::ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body)) .map_err(InternalCreateRequestError::OhttpEncapsulation)?; log::debug!("ohttp_relay_url: {:?}", ohttp_relay); Ok(( Request::new_v2(ohttp_relay, body), - ContextV2 { - context_v1: ContextV1 { + Context::V2(V2PostContext { + endpoint: self.endpoint.clone(), + psbt_ctx: PsbtContext { original_psbt: self.psbt.clone(), disable_output_substitution: self.disable_output_substitution, fee_contribution: self.fee_contribution, payee: self.payee.clone(), min_fee_rate: self.min_fee_rate, }, - rs: Some(self.extract_rs_pubkey()?), - e: Some(self.e.clone()), - ohttp_res: Some(ohttp_res), - }, + hpke_ctx: HpkeContext { rs, e: self.e.clone() }, + ohttp_ctx, + }), )) } #[cfg(feature = "v2")] fn extract_rs_pubkey(&self) -> Result { - use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD; - use bitcoin::base64::Engine; use error::ParseSubdirectoryError; let subdirectory = self @@ -366,25 +365,137 @@ impl RequestContext { pub fn endpoint(&self) -> &Url { &self.endpoint } } +pub enum Context { + V1(V1Context), + #[cfg(feature = "v2")] + V2(V2PostContext), +} + +pub struct V1Context { + psbt_context: PsbtContext, +} + +impl V1Context { + pub fn process_response( + self, + response: &mut impl std::io::Read, + ) -> Result { + self.psbt_context.process_response(response) + } +} + +#[cfg(feature = "v2")] +pub struct V2PostContext { + endpoint: Url, + psbt_ctx: PsbtContext, + hpke_ctx: HpkeContext, + ohttp_ctx: ohttp::ClientResponse, +} + +#[cfg(feature = "v2")] +impl V2PostContext { + pub fn process_response( + self, + response: &mut impl std::io::Read, + ) -> Result { + let mut res_buf = Vec::new(); + response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + let response = crate::v2::ohttp_decapsulate(self.ohttp_ctx, &res_buf) + .map_err(InternalValidationError::OhttpEncapsulation)?; + println!("post response status: {:?}", response.status()); + match response.status() { + http::StatusCode::OK => { + // return OK with new Typestate + Ok(V2GetContext { + endpoint: self.endpoint, + psbt_ctx: self.psbt_ctx, + hpke_ctx: self.hpke_ctx, + }) + } + _ => Err(InternalValidationError::UnexpectedStatusCode)?, + } + } +} + +#[cfg(feature = "v2")] +pub struct V2GetContext { + endpoint: Url, + psbt_ctx: PsbtContext, + hpke_ctx: HpkeContext, +} + +#[cfg(feature = "v2")] +impl V2GetContext { + pub fn extract_req( + &self, + ohttp_relay: Url, + ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { + use crate::uri::UrlExt; + let mut url = self.endpoint.clone(); + let subdir = + BASE64_URL_SAFE_NO_PAD.encode(self.hpke_ctx.e.public_key().to_compressed_bytes()); + url.set_path(&subdir); + println!("sender subdir from sender: {:?}", &url); + let body = crate::v2::encrypt_message_a( + Vec::new(), + &self.hpke_ctx.e.secret_key().clone(), + &self.hpke_ctx.rs.clone(), + ) + .map_err(InternalCreateRequestError::Hpke)?; + let mut ohttp = + self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?; + let (body, ohttp_ctx) = + crate::v2::ohttp_encapsulate(&mut ohttp, "GET", url.as_str(), Some(&body)) + .map_err(InternalCreateRequestError::OhttpEncapsulation)?; + + Ok((Request::new_v2(ohttp_relay, body), ohttp_ctx)) + } + + pub fn process_response( + &self, + response: &mut impl std::io::Read, + ohttp_ctx: ohttp::ClientResponse, + ) -> Result, ResponseError> { + let mut res_buf = Vec::new(); + response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; + println!("get response body length: {:?}", &res_buf.len()); + let response = crate::v2::ohttp_decapsulate(ohttp_ctx, &res_buf) + .map_err(InternalValidationError::OhttpEncapsulation)?; + println!("get response status: {:?}", &response.status()); + let body = match response.status() { + http::StatusCode::OK => response.body().to_vec(), + http::StatusCode::ACCEPTED => return Ok(None), + _ => return Err(InternalValidationError::UnexpectedStatusCode)?, + }; + let psbt = crate::v2::decrypt_message_b( + &body, + self.hpke_ctx.rs.clone(), + self.hpke_ctx.e.secret_key().clone(), + ) + .map_err(InternalValidationError::Hpke)?; + + let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; + let processed_proposal = self.psbt_ctx.clone().process_proposal(proposal)?; + Ok(Some(processed_proposal)) + } +} + /// Data required for validation of response. /// /// This type is used to process the response. Get it from [`RequestBuilder`](crate::send::RequestBuilder)'s build methods. /// Then you only need to call [`.process_response()`](crate::send::Context::process_response()) on it to continue BIP78 flow. #[derive(Debug, Clone)] -pub struct ContextV1 { +pub struct PsbtContext { original_psbt: Psbt, disable_output_substitution: bool, fee_contribution: Option<(bitcoin::Amount, usize)>, min_fee_rate: FeeRate, payee: ScriptBuf, } - #[cfg(feature = "v2")] -pub struct ContextV2 { - context_v1: ContextV1, - rs: Option, - e: Option, - ohttp_res: Option, +struct HpkeContext { + rs: HpkePublicKey, + e: HpkeKeyPair, } macro_rules! check_eq { @@ -405,43 +516,7 @@ macro_rules! ensure { }; } -#[cfg(feature = "v2")] -impl ContextV2 { - /// Decodes and validates the response. - /// - /// Call this method with response from receiver to continue BIP-??? flow. - /// A successful response can either be None if the directory has not response yet or Some(Psbt). - /// - /// If the response is some valid PSBT you should sign and broadcast. - #[inline] - pub fn process_response( - self, - response: &mut impl std::io::Read, - ) -> Result, ResponseError> { - match (self.ohttp_res, self.rs, self.e) { - (Some(ohttp_res), Some(rs), Some(e)) => { - let mut res_buf = Vec::new(); - response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let response = crate::v2::ohttp_decapsulate(ohttp_res, &res_buf) - .map_err(InternalValidationError::OhttpEncapsulation)?; - let body = match response.status() { - http::StatusCode::OK => response.body().to_vec(), - http::StatusCode::ACCEPTED => return Ok(None), - _ => return Err(InternalValidationError::UnexpectedStatusCode)?, - }; - let psbt = crate::v2::decrypt_message_b(&body, rs, e) - .map_err(InternalValidationError::Hpke)?; - - let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; - let processed_proposal = self.context_v1.process_proposal(proposal)?; - Ok(Some(processed_proposal)) - } - _ => self.context_v1.process_response(response).map(Some), - } - } -} - -impl ContextV1 { +impl PsbtContext { /// Decodes and validates the response. /// /// Call this method with response from receiver to continue BIP78 flow. If the response is @@ -826,11 +901,11 @@ mod test { const ORIGINAL_PSBT: &str = "cHNidP8BAHMCAAAAAY8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////AtyVuAUAAAAAF6kUHehJ8GnSdBUOOv6ujXLrWmsJRDCHgIQeAAAAAAAXqRR3QJbbz0hnQ8IvQ0fptGn+votneofTAAAAAAEBIKgb1wUAAAAAF6kU3k4ekGHKWRNbA1rV5tR5kEVDVNCHAQcXFgAUx4pFclNVgo1WWAdN1SYNX8tphTABCGsCRzBEAiB8Q+A6dep+Rz92vhy26lT0AjZn4PRLi8Bf9qoB/CMk0wIgP/Rj2PWZ3gEjUkTlhDRNAQ0gXwTO7t9n+V14pZ6oljUBIQMVmsAaoNWHVMS02LfTSe0e388LNitPa1UQZyOihY+FFgABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUAAA="; const PAYJOIN_PROPOSAL: &str = "cHNidP8BAJwCAAAAAo8nutGgJdyYGXWiBEb45Hoe9lWGbkxh/6bNiOJdCDuDAAAAAAD+////jye60aAl3JgZdaIERvjkeh72VYZuTGH/ps2I4l0IO4MBAAAAAP7///8CJpW4BQAAAAAXqRQd6EnwadJ0FQ46/q6NcutaawlEMIcACT0AAAAAABepFHdAltvPSGdDwi9DR+m0af6+i2d6h9MAAAAAAQEgqBvXBQAAAAAXqRTeTh6QYcpZE1sDWtXm1HmQRUNU0IcBBBYAFMeKRXJTVYKNVlgHTdUmDV/LaYUwIgYDFZrAGqDVh1TEtNi300ntHt/PCzYrT2tVEGcjooWPhRYYSFzWUDEAAIABAACAAAAAgAEAAAAAAAAAAAEBIICEHgAAAAAAF6kUyPLL+cphRyyI5GTUazV0hF2R2NWHAQcXFgAUX4BmVeWSTJIEwtUb5TlPS/ntohABCGsCRzBEAiBnu3tA3yWlT0WBClsXXS9j69Bt+waCs9JcjWtNjtv7VgIge2VYAaBeLPDB6HGFlpqOENXMldsJezF9Gs5amvDQRDQBIQJl1jz1tBt8hNx2owTm+4Du4isx0pmdKNMNIjjaMHFfrQABABYAFEb2Giu6c4KO5YW0pfw3lGp9jMUUIgICygvBWB5prpfx61y1HDAwo37kYP3YRJBvAjtunBAur3wYSFzWUDEAAIABAACAAAAAgAEAAAABAAAAAAA="; - fn create_v1_context() -> super::ContextV1 { + fn create_v1_context() -> super::PsbtContext { let original_psbt = Psbt::from_str(ORIGINAL_PSBT).unwrap(); eprintln!("original: {:#?}", original_psbt); let payee = original_psbt.unsigned_tx.output[1].script_pubkey.clone(); - let ctx = super::ContextV1 { + let ctx = super::PsbtContext { original_psbt, disable_output_substitution: false, fee_contribution: Some((bitcoin::Amount::from_sat(182), 0)), @@ -881,20 +956,15 @@ mod test { #[test] #[cfg(feature = "v2")] fn req_ctx_ser_de_roundtrip() { - use hpke::Deserializable; - use super::*; - let req_ctx = RequestContext { + let req_ctx = Sender { psbt: Psbt::from_str(ORIGINAL_PSBT).unwrap(), endpoint: Url::parse("http://localhost:1234").unwrap(), disable_output_substitution: false, fee_contribution: None, min_fee_rate: FeeRate::ZERO, payee: ScriptBuf::from(vec![0x00]), - e: HpkeSecretKey( - ::PrivateKey::from_bytes(&[0x01; 32]) - .unwrap(), - ), + e: HpkeKeyPair::gen_keypair(), }; let serialized = serde_json::to_string(&req_ctx).unwrap(); let deserialized = serde_json::from_str(&serialized).unwrap(); diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 9c6ce3b7..c3099e21 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -11,7 +11,7 @@ mod integration { use bitcoind::bitcoincore_rpc::{self, RpcApi}; use log::{log_enabled, Level}; use once_cell::sync::{Lazy, OnceCell}; - use payjoin::send::RequestBuilder; + use payjoin::send::SenderBuilder; use payjoin::{PjUri, PjUriBuilder, Request, Uri}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use url::Url; @@ -50,7 +50,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; debug!("Original psbt: {:#?}", psbt); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt, uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt, uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -90,6 +90,7 @@ mod integration { use bitcoin::Address; use http::StatusCode; use payjoin::receive::v2::{ActiveSession, PayjoinProposal, UncheckedProposal}; + use payjoin::send::Context; use payjoin::{OhttpKeys, PjUri, UriExt}; use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; @@ -200,9 +201,9 @@ mod integration { Some(std::time::SystemTime::now()), ) .build(); - let mut expired_req_ctx = RequestBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? + let mut expired_req_ctx = SenderBuilder::from_psbt_and_uri(psbt, expired_pj_uri)? .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; - match expired_req_ctx.extract_v2(directory.to_owned()) { + match expired_req_ctx.extract_highest_version(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), _ => assert!(false, "Expired send session should error"), @@ -271,10 +272,14 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_sweep_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; let (Request { url, body, content_type, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; + req_ctx.extract_highest_version(directory.to_owned())?; + let send_ctx = match send_ctx { + Context::V2(ctx) => ctx, + _ => panic!("V2 context expected"), + }; let response = agent .post(url.clone()) .header("Content-Type", content_type) @@ -284,10 +289,10 @@ mod integration { .unwrap(); log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); - let response_body = + let send_ctx = send_ctx.process_response(&mut response.bytes().await?.to_vec().as_slice())?; - // No response body yet since we are async and pushed fallback_psbt to the buffer - assert!(response_body.is_none()); + // POST Original PSBT + // FIXME should be none // ********************** // Inside the Receiver: @@ -301,7 +306,12 @@ mod integration { let mut payjoin_proposal = handle_directory_proposal(&receiver, proposal); assert!(!payjoin_proposal.is_output_substitution_disabled()); let (req, ctx) = payjoin_proposal.extract_v2_req()?; - let response = agent.post(req.url).body(req.body).send().await?; + let response = agent + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await?; let res = response.bytes().await?.to_vec(); payjoin_proposal.process_res(res, ctx)?; @@ -309,11 +319,18 @@ mod integration { // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts // Replay post fallback to get the response - let (Request { url, body, .. }, send_ctx) = - req_ctx.extract_v2(directory.to_owned())?; - let response = agent.post(url).body(body).send().await?; + let (Request { url, body, content_type, .. }, ohttp_ctx) = + send_ctx.extract_req(directory.to_owned())?; + let response = agent + .post(url.clone()) + .header("Content-Type", content_type) + .body(body.clone()) + .send() + .await + .unwrap(); + log::info!("Response: {:#?}", &response); let checked_payjoin_proposal_psbt = send_ctx - .process_response(&mut response.bytes().await?.to_vec().as_slice())? + .process_response(&mut response.bytes().await?.to_vec().as_slice(), ohttp_ctx)? .unwrap(); let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -352,9 +369,9 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? + let mut req_ctx = SenderBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(FeeRate::BROADCAST_MIN)?; - let (req, ctx) = req_ctx.extract_v2(EXAMPLE_URL.to_owned())?; + let (req, ctx) = req_ctx.extract_highest_version(EXAMPLE_URL.to_owned())?; let headers = HeaderMock::new(&req.body, req.content_type); // ********************** @@ -366,8 +383,11 @@ mod integration { // ********************** // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts - let checked_payjoin_proposal_psbt = - ctx.process_response(&mut response.as_bytes())?.unwrap(); + let ctx = match ctx { + Context::V1(ctx) => ctx, + _ => panic!("V1 context expected"), + }; + let checked_payjoin_proposal_psbt = ctx.process_response(&mut response.as_bytes())?; let payjoin_tx = extract_pj_tx(&sender, checked_payjoin_proposal_psbt)?; sender.send_raw_transaction(&payjoin_tx)?; @@ -428,7 +448,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; let (Request { url, body, content_type, .. }, send_ctx) = - RequestBuilder::from_psbt_and_uri(psbt, pj_uri)? + SenderBuilder::from_psbt_and_uri(psbt, pj_uri)? .build_with_additional_fee( Amount::from_sat(10000), None, @@ -757,7 +777,7 @@ mod integration { let psbt = build_original_psbt(&sender, &uri)?; log::debug!("Original psbt: {:#?}", psbt); let max_additional_fee = Amount::from_sat(1000); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt.clone(), uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt.clone(), uri)? .build_with_additional_fee(max_additional_fee, None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -844,7 +864,7 @@ mod integration { .unwrap(); let psbt = build_original_psbt(&sender, &uri)?; log::debug!("Original psbt: {:#?}", psbt); - let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt.clone(), uri)? + let (req, ctx) = SenderBuilder::from_psbt_and_uri(psbt.clone(), uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type);