Skip to content

Commit

Permalink
Send v1 to v1 receivers if v2 unsupported (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould authored Aug 9, 2024
2 parents ff350d5 + 3d29412 commit 908aa5c
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 198 deletions.
7 changes: 1 addition & 6 deletions payjoin-cli/src/app/v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl AppTrait for App {
println!("Sending fallback request to {}", &req.url);
let response = http
.post(req.url)
.header("Content-Type", payjoin::V1_REQ_CONTENT_TYPE)
.header("Content-Type", req.content_type)
.body(body.clone())
.send()
.await
Expand Down Expand Up @@ -329,11 +329,6 @@ impl App {
},
Some(bitcoin::FeeRate::MIN),
)?;
let payjoin_proposal_psbt = payjoin_proposal.psbt();
println!(
"Responded with Payjoin proposal {}",
payjoin_proposal_psbt.clone().extract_tx_unchecked_fee_rate().compute_txid()
);
Ok(payjoin_proposal)
}
}
8 changes: 4 additions & 4 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl AppTrait for App {
let http = http_agent()?;
let ohttp_response = http
.post(req.url)
.header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
Expand Down Expand Up @@ -156,7 +156,7 @@ impl App {
let http = http_agent()?;
let res = http
.post(req.url)
.header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
Expand Down Expand Up @@ -219,7 +219,7 @@ impl App {
let http = http_agent()?;
let response = http
.post(req.url)
.header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
Expand Down Expand Up @@ -251,7 +251,7 @@ impl App {
let http = http_agent()?;
let ohttp_response = http
.post(req.url)
.header("Content-Type", payjoin::V2_REQ_CONTENT_TYPE)
.header("Content-Type", req.content_type)
.body(req.body)
.send()
.await
Expand Down
6 changes: 3 additions & 3 deletions payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl SessionInitializer {
self.context.directory.as_str(),
Some(subdirectory.as_bytes()),
)?;
let req = Request { url, body };
let req = Request::new_v2(url, body);
Ok((req, ctx))
}

Expand Down Expand Up @@ -130,7 +130,7 @@ impl ActiveSession {
let (body, ohttp_ctx) =
self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?;
let url = self.context.ohttp_relay.clone();
let req = Request { url, body };
let req = Request::new_v2(url, body);
Ok((req, ohttp_ctx))
}

Expand Down Expand Up @@ -479,7 +479,7 @@ impl PayjoinProposal {
Some(&body),
)?;
let url = self.context.ohttp_relay.clone();
let req = Request { url, body };
let req = Request::new_v2(url, body);
Ok((req, ctx))
}

Expand Down
21 changes: 18 additions & 3 deletions payjoin/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub const V1_REQ_CONTENT_TYPE: &str = "text/plain";
pub const V2_REQ_CONTENT_TYPE: &str = "message/ohttp-req";

/// Represents data that needs to be transmitted to the receiver or payjoin directory.
/// Ensure the `Content-Length` is set to the length of `body`. (most libraries do this automatically)
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct Request {
Expand All @@ -14,10 +15,24 @@ pub struct Request {
/// This is full URL with scheme etc - you can pass it right to `reqwest` or a similar library.
pub url: Url,

/// The `Content-Type` header to use for the request.
///
/// `text/plain` for v1 requests and `message/ohttp-req` for v2 requests.
pub content_type: &'static str,

/// Bytes to be sent to the receiver.
///
/// This is properly encoded PSBT, already in base64. You only need to make sure `Content-Type`
/// is appropriate (`text/plain` for v1 requests and 'message/ohttp-req' for v2)
/// and `Content-Length` is `body.len()` (most libraries do the latter automatically).
/// This is properly encoded PSBT payload either in base64 in v1 or an OHTTP encapsulated payload in v2.
pub body: Vec<u8>,
}

impl Request {
pub fn new_v1(url: Url, body: Vec<u8>) -> Self {
Self { url, content_type: V1_REQ_CONTENT_TYPE, body }
}

#[cfg(feature = "v2")]
pub fn new_v2(url: Url, body: Vec<u8>) -> Self {
Self { url, content_type: V2_REQ_CONTENT_TYPE, body }
}
}
97 changes: 60 additions & 37 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ impl Eq for RequestContext {}

impl RequestContext {
/// Extract serialized V1 Request and Context froma Payjoin Proposal
pub fn extract_v1(self) -> Result<(Request, ContextV1), CreateRequestError> {
pub fn extract_v1(&self) -> Result<(Request, ContextV1), CreateRequestError> {
let url = serialize_url(
self.endpoint,
self.endpoint.clone(),
self.disable_output_substitution,
self.fee_contribution,
self.min_fee_rate,
Expand All @@ -290,20 +290,20 @@ impl RequestContext {
.map_err(InternalCreateRequestError::Url)?;
let body = self.psbt.to_string().as_bytes().to_vec();
Ok((
Request { url, body },
Request::new_v1(url, body),
ContextV1 {
original_psbt: self.psbt,
original_psbt: self.psbt.clone(),
disable_output_substitution: self.disable_output_substitution,
fee_contribution: self.fee_contribution,
payee: self.payee,
payee: self.payee.clone(),
input_type: self.input_type,
sequence: self.sequence,
min_fee_rate: self.min_fee_rate,
},
))
}

/// Extract serialized Request and Context from a Payjoin Proposal.
/// Extract serialized Request and Context from a Payjoin Proposal. Automatically selects the correct version.
///
/// In order to support polling, this may need to be called many times to be encrypted with
/// new unique nonces to make independent OHTTP requests.
Expand All @@ -321,7 +321,28 @@ impl RequestContext {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}
let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?;

match self.extract_rs_pubkey() {
Ok(rs) => self.extract_v2_strict(ohttp_relay, rs),
Err(e) => {
log::warn!("Failed to extract `rs` pubkey, falling back to v1: {}", e);
let (req, context_v1) = self.extract_v1()?;
Ok((req, ContextV2 { context_v1, e: None, ohttp_res: None }))
}
}
}

/// Extract serialized Request and Context from a Payjoin Proposal.
///
/// This method requires the `rs` pubkey to be extracted from the endpoint
/// and has no fallback to v1.
#[cfg(feature = "v2")]
fn extract_v2_strict(
&mut self,
ohttp_relay: Url,
rs: PublicKey,
) -> Result<(Request, ContextV2), CreateRequestError> {
use crate::uri::UrlExt;
let url = self.endpoint.clone();
let body = serialize_v2_body(
&self.psbt,
Expand All @@ -338,8 +359,7 @@ impl RequestContext {
.map_err(InternalCreateRequestError::OhttpEncapsulation)?;
log::debug!("ohttp_relay_url: {:?}", ohttp_relay);
Ok((
Request { url: ohttp_relay, body },
// this method may be called more than once to re-construct the ohttp, therefore we must clone (or TODO memoize)
Request::new_v2(ohttp_relay, body),
ContextV2 {
context_v1: ContextV1 {
original_psbt: self.psbt.clone(),
Expand All @@ -350,32 +370,30 @@ impl RequestContext {
sequence: self.sequence,
min_fee_rate: self.min_fee_rate,
},
e: self.e,
ohttp_res,
e: Some(self.e),
ohttp_res: Some(ohttp_res),
},
))
}

#[cfg(feature = "v2")]
fn rs_pubkey_from_dir_endpoint(endpoint: &Url) -> Result<PublicKey, CreateRequestError> {
fn extract_rs_pubkey(&self) -> Result<PublicKey, error::ParseSubdirectoryError> {
use bitcoin::base64::prelude::BASE64_URL_SAFE_NO_PAD;
use bitcoin::base64::Engine;
use error::ParseSubdirectoryError;

use crate::send::error::ParseSubdirectoryError;

let subdirectory = endpoint
let subdirectory = self
.endpoint
.path_segments()
.ok_or(ParseSubdirectoryError::MissingSubdirectory)?
.next()
.ok_or(ParseSubdirectoryError::MissingSubdirectory)?
.to_string();
.and_then(|mut segments| segments.next())
.ok_or(ParseSubdirectoryError::MissingSubdirectory)?;

let pubkey_bytes = BASE64_URL_SAFE_NO_PAD
.decode(subdirectory)
.map_err(ParseSubdirectoryError::SubdirectoryNotBase64)?;

bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes)
.map_err(ParseSubdirectoryError::SubdirectoryInvalidPubkey)
.map_err(CreateRequestError::from)
}

pub fn endpoint(&self) -> &Url { &self.endpoint }
Expand Down Expand Up @@ -517,8 +535,8 @@ pub struct ContextV1 {
#[cfg(feature = "v2")]
pub struct ContextV2 {
context_v1: ContextV1,
e: bitcoin::secp256k1::SecretKey,
ohttp_res: ohttp::ClientResponse,
e: Option<bitcoin::secp256k1::SecretKey>,
ohttp_res: Option<ohttp::ClientResponse>,
}

macro_rules! check_eq {
Expand Down Expand Up @@ -552,21 +570,26 @@ impl ContextV2 {
self,
response: &mut impl std::io::Read,
) -> Result<Option<Psbt>, ResponseError> {
let mut res_buf = Vec::new();
response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?;
let response = crate::v2::ohttp_decapsulate(self.ohttp_res, &res_buf)
.map_err(InternalValidationError::OhttpEncapsulation)?;
let mut body = match response.status() {
http::StatusCode::OK => response.body().to_vec(),
http::StatusCode::ACCEPTED => return Ok(None),
_ => return Err(InternalValidationError::UnexpectedStatusCode)?,
};
let psbt = crate::v2::decrypt_message_b(&mut body, self.e)
.map_err(InternalValidationError::HpkeError)?;

let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?;
let processed_proposal = self.context_v1.process_proposal(proposal)?;
Ok(Some(processed_proposal))
match (self.ohttp_res, self.e) {
(Some(ohttp_res), Some(e)) => {
let mut res_buf = Vec::new();
response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?;
let response = crate::v2::ohttp_decapsulate(ohttp_res, &res_buf)
.map_err(InternalValidationError::OhttpEncapsulation)?;
let mut body = match response.status() {
http::StatusCode::OK => response.body().to_vec(),
http::StatusCode::ACCEPTED => return Ok(None),
_ => return Err(InternalValidationError::UnexpectedStatusCode)?,
};
let psbt = crate::v2::decrypt_message_b(&mut body, e)
.map_err(InternalValidationError::HpkeError)?;

let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?;
let processed_proposal = self.context_v1.process_proposal(proposal)?;
Ok(Some(processed_proposal))
}
_ => self.context_v1.process_response(response).map(Some),
}
}
}

Expand Down
Loading

0 comments on commit 908aa5c

Please sign in to comment.