From 58d6f212a3c763d3f8e6232fd6c6cd49502c4336 Mon Sep 17 00:00:00 2001 From: DanGould Date: Tue, 6 Aug 2024 13:42:32 -0400 Subject: [PATCH] Include content-type header in Request The extract method may either produce a v1 or v2 type depending on its state. --- payjoin-cli/src/app/v1.rs | 2 +- payjoin-cli/src/app/v2.rs | 8 ++++---- payjoin/src/receive/v2/mod.rs | 6 +++--- payjoin/src/request.rs | 21 ++++++++++++++++++--- payjoin/src/send/mod.rs | 4 ++-- payjoin/tests/integration.rs | 24 ++++++++++-------------- 6 files changed, 38 insertions(+), 27 deletions(-) diff --git a/payjoin-cli/src/app/v1.rs b/payjoin-cli/src/app/v1.rs index da73cadb..05ccec68 100644 --- a/payjoin-cli/src/app/v1.rs +++ b/payjoin-cli/src/app/v1.rs @@ -65,7 +65,7 @@ impl AppTrait for App { println!("Sending fallback request to {}", &req.url); let response = http .post(req.url) - .header("Content-Type", payjoin::V1_REQ_CONTENT_TYPE) + .header("Content-Type", req.content_type) .body(body.clone()) .send() .await diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index ea7639e4..2aba3f8f 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -92,7 +92,7 @@ impl AppTrait for App { let http = http_agent()?; let ohttp_response = http .post(req.url) - .header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE) + .header("Content-Type", req.content_type) .body(req.body) .send() .await @@ -156,7 +156,7 @@ impl App { let http = http_agent()?; let res = http .post(req.url) - .header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE) + .header("Content-Type", req.content_type) .body(req.body) .send() .await @@ -219,7 +219,7 @@ impl App { let http = http_agent()?; let response = http .post(req.url) - .header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE) + .header("Content-Type", req.content_type) .body(req.body) .send() .await @@ -251,7 +251,7 @@ impl App { let http = http_agent()?; let ohttp_response = http .post(req.url) - .header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE) + .header("Content-Type", req.content_type) .body(req.body) .send() .await diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 25cc6b4c..78a4fe06 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -91,7 +91,7 @@ impl SessionInitializer { self.context.directory.as_str(), Some(subdirectory.as_bytes()), )?; - let req = Request { url, body }; + let req = Request::new_v2(url, body); Ok((req, ctx)) } @@ -130,7 +130,7 @@ impl ActiveSession { let (body, ohttp_ctx) = self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?; let url = self.context.ohttp_relay.clone(); - let req = Request { url, body }; + let req = Request::new_v2(url, body); Ok((req, ohttp_ctx)) } @@ -479,7 +479,7 @@ impl PayjoinProposal { Some(&body), )?; let url = self.context.ohttp_relay.clone(); - let req = Request { url, body }; + let req = Request::new_v2(url, body); Ok((req, ctx)) } diff --git a/payjoin/src/request.rs b/payjoin/src/request.rs index 9e0d5571..a093aa10 100644 --- a/payjoin/src/request.rs +++ b/payjoin/src/request.rs @@ -6,6 +6,7 @@ pub const V1_REQ_CONTENT_TYPE: &str = "text/plain"; pub const V2_REQ_CONTENT_TYPE: &str = "message/ohttp-req"; /// Represents data that needs to be transmitted to the receiver or payjoin directory. +/// Ensure the `Content-Length` is set to the length of `body`. (most libraries do this automatically) #[non_exhaustive] #[derive(Debug, Clone)] pub struct Request { @@ -14,10 +15,24 @@ pub struct Request { /// This is full URL with scheme etc - you can pass it right to `reqwest` or a similar library. pub url: Url, + /// The `Content-Type` header to use for the request. + /// + /// `text/plain` for v1 requests and `message/ohttp-req` for v2 requests. + pub content_type: &'static str, + /// Bytes to be sent to the receiver. /// - /// This is properly encoded PSBT, already in base64. You only need to make sure `Content-Type` - /// is appropriate (`text/plain` for v1 requests and 'message/ohttp-req' for v2) - /// and `Content-Length` is `body.len()` (most libraries do the latter automatically). + /// This is properly encoded PSBT payload either in base64 in v1 or an OHTTP encapsulated payload in v2. pub body: Vec, } + +impl Request { + pub fn new_v1(url: Url, body: Vec) -> Self { + Self { url, content_type: V1_REQ_CONTENT_TYPE, body } + } + + #[cfg(feature = "v2")] + pub fn new_v2(url: Url, body: Vec) -> Self { + Self { url, content_type: V2_REQ_CONTENT_TYPE, body } + } +} diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 13c32c4c..bf1e2d53 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -312,7 +312,7 @@ impl RequestContext { .map_err(InternalCreateRequestError::Url)?; let body = self.psbt.to_string().as_bytes().to_vec(); Ok(( - Request { url, body }, + Request::new_v1(url, body), ContextV1 { original_psbt: self.psbt.clone(), disable_output_substitution: self.disable_output_substitution, @@ -358,7 +358,7 @@ impl RequestContext { .map_err(InternalCreateRequestError::OhttpEncapsulation)?; log::debug!("ohttp_relay_url: {:?}", ohttp_relay); return Ok(( - Request { url: ohttp_relay, body }, + Request::new_v2(ohttp_relay, body), // this method may be called more than once to re-construct the ohttp, therefore we must clone (or TODO memoize) ContextV2 { context_v1: ContextV1 { diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index e5c96d39..92e41697 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -49,7 +49,7 @@ mod integration { let (req, ctx) = RequestBuilder::from_psbt_and_uri(psbt, uri)? .build_with_additional_fee(Amount::from_sat(10000), None, FeeRate::ZERO, false)? .extract_v1()?; - let headers = HeaderMock::from_vec(&req.body); + let headers = HeaderMock::new(&req.body, req.content_type); // ********************** // Inside the Receiver: @@ -347,11 +347,11 @@ mod integration { let psbt = build_sweep_psbt(&sender, &pj_uri)?; let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(payjoin::bitcoin::FeeRate::BROADCAST_MIN)?; - let (Request { url, body, .. }, send_ctx) = + let (Request { url, body, content_type, .. }, send_ctx) = req_ctx.extract_v2(directory.to_owned())?; let response = agent .post(url.clone()) - .header("Content-Type", payjoin::V1_REQ_CONTENT_TYPE) + .header("Content-Type", content_type) .body(body.clone()) .send() .await @@ -419,7 +419,7 @@ mod integration { let mut req_ctx = RequestBuilder::from_psbt_and_uri(psbt.clone(), pj_uri.clone())? .build_recommended(payjoin::bitcoin::FeeRate::BROADCAST_MIN)?; let (req, ctx) = req_ctx.extract_v2(EXAMPLE_URL.to_owned())?; - let headers = HeaderMock::from_vec(&req.body); + let headers = HeaderMock::new(&req.body, req.content_type); // ********************** // Inside the Receiver: @@ -488,7 +488,7 @@ mod integration { .check_pj_supported() .unwrap(); let psbt = build_original_psbt(&sender, &pj_uri)?; - let (Request { url, body, .. }, send_ctx) = + let (Request { url, body, content_type, .. }, send_ctx) = RequestBuilder::from_psbt_and_uri(psbt, pj_uri)? .build_with_additional_fee( Amount::from_sat(10000), @@ -500,7 +500,7 @@ mod integration { log::info!("send fallback v1 to offline receiver fail"); let res = agent .post(url.clone()) - .header("Content-Type", payjoin::V1_REQ_CONTENT_TYPE) + .header("Content-Type", content_type) .body(body.clone()) .send() .await; @@ -543,12 +543,8 @@ mod integration { // ********************** // send fallback v1 to online receiver log::info!("send fallback v1 to online receiver should succeed"); - let response = agent - .post(url) - .header("Content-Type", payjoin::V1_REQ_CONTENT_TYPE) - .body(body) - .send() - .await?; + let response = + agent.post(url).header("Content-Type", content_type).body(body).send().await?; log::info!("Response: {:#?}", &response); assert!(response.status().is_success()); @@ -966,9 +962,9 @@ mod integration { } impl HeaderMock { - fn from_vec(body: &[u8]) -> HeaderMock { + fn new(body: &[u8], content_type: &str) -> HeaderMock { let mut h = HashMap::new(); - h.insert("content-type".to_string(), payjoin::V1_REQ_CONTENT_TYPE.to_string()); + h.insert("content-type".to_string(), content_type.to_string()); h.insert("content-length".to_string(), body.len().to_string()); HeaderMock(h) }