diff --git a/payjoin-cli/seen_inputs.json b/payjoin-cli/seen_inputs.json new file mode 100644 index 00000000..7928b45b --- /dev/null +++ b/payjoin-cli/seen_inputs.json @@ -0,0 +1 @@ +["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"][["c93eb8f0c617f1150bdf311f594774c7c50a9518e954b83b5424753426d91a5e:1"] \ No newline at end of file diff --git a/payjoin-cli/src/app.rs b/payjoin-cli/src/app.rs index dcada8c0..248582b2 100644 --- a/payjoin-cli/src/app.rs +++ b/payjoin-cli/src/app.rs @@ -91,16 +91,20 @@ impl App { &self, client: &reqwest::blocking::Client, enroll_context: &mut EnrollContext, - ) -> Result { + ) -> Result { loop { - let (payjoin_get_body, context) = enroll_context.payjoin_get_body(); + let (payjoin_get_body, context) = enroll_context + .payjoin_get_body() + .map_err(|e| anyhow!("Failed to create payjoin GET body: {}", e))?; let ohttp_response = client.post(&self.config.ohttp_proxy).body(payjoin_get_body).send()?; let ohttp_response = ohttp_response.bytes()?; - let proposal = - enroll_context.parse_relay_response(ohttp_response.as_ref(), context).unwrap(); + let proposal = enroll_context + .parse_relay_response(ohttp_response.as_ref(), context) + .map_err(|e| anyhow!("parse error {}", e))?; + log::debug!("got response"); match proposal { - Some(proposal) => return Ok(proposal), + Some(proposal) => break Ok(proposal), None => std::thread::sleep(std::time::Duration::from_secs(5)), } } @@ -229,17 +233,19 @@ impl App { .build() .with_context(|| "Failed to build reqwest http client")?; log::debug!("Awaiting request"); - let _enroll = client.post(&self.config.pj_endpoint).body(context.enroll_body()).send()?; + let (body, _) = context.enroll_body().unwrap(); + let _enroll = client.post(&self.config.pj_endpoint).body(body).send()?; log::debug!("Awaiting proposal"); let res = self.long_poll_get(&client, &mut context)?; log::debug!("Received request"); - let payjoin_proposal = self - .process_proposal(proposal) - .map_err(|e| anyhow!("Failed to process UncheckedProposal {}", e))?; - let payjoin_endpoint = format!("{}/{}/receive", self.config.pj_endpoint, pubkey_base64); - let (body, ohttp_ctx) = - payjoin_proposal.extract_v2_req(&self.config.ohttp_config, &payjoin_endpoint); + let payjoin_proposal = + self.process_proposal(res).map_err(|e| anyhow!("Failed to process proposal {}", e))?; + log::debug!("Posting payjoin back"); + let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.payjoin_subdir()); + let (body, ohttp_ctx) = payjoin_proposal + .extract_v2_req(&self.config.ohttp_config, &receive_endpoint) + .map_err(|e| anyhow!("v2 req extraction failed {}", e))?; let res = client .post(&self.config.ohttp_proxy) .body(body) diff --git a/payjoin-relay/src/main.rs b/payjoin-relay/src/main.rs index de347917..1d069784 100644 --- a/payjoin-relay/src/main.rs +++ b/payjoin-relay/src/main.rs @@ -1,6 +1,5 @@ use std::env; use std::net::SocketAddr; -use std::str::FromStr; use std::sync::Arc; use anyhow::Result; @@ -73,7 +72,7 @@ fn init_ohttp() -> Result { let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?; let encoded_config = server_config.encode()?; let b64_config = base64::encode_config( - encoded_config, + &encoded_config, base64::Config::new(base64::CharacterSet::UrlSafe, false), ); info!("ohttp server config base64 UrlSafe: {:?}", b64_config); @@ -119,29 +118,36 @@ async fn handle_ohttp( let (bhttp_req, res_ctx) = ohttp_locked.decapsulate(&ohttp_body).map_err(|e| HandlerError::BadRequest(e.into()))?; let mut cursor = std::io::Cursor::new(bhttp_req); - let req = bhttp::Message::read_bhttp(&mut cursor).unwrap(); + let req = + bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?; let uri = Uri::builder() - .scheme(req.control().scheme().unwrap()) - .authority(req.control().authority().unwrap()) - .path_and_query(req.control().path().unwrap()) - .build() - .unwrap(); + .scheme(req.control().scheme().unwrap_or_default()) + .authority(req.control().authority().unwrap_or_default()) + .path_and_query(req.control().path().unwrap_or_default()) + .build()?; let body = req.content().to_vec(); - let mut http_req = Request::builder().uri(uri).method(req.control().method().unwrap()); + let mut http_req = + Request::builder().uri(uri).method(req.control().method().unwrap_or_default()); for header in req.header().fields() { http_req = http_req.header(header.name(), header.value()) } - let request = http_req.body(Body::from(body)).unwrap(); + let request = http_req.body(Body::from(body))?; let response = handle_v2(pool, request).await?; let (parts, body) = response.into_parts(); let mut bhttp_res = bhttp::Message::response(parts.status.as_u16()); - let full_body = hyper::body::to_bytes(body).await.unwrap(); + let full_body = hyper::body::to_bytes(body) + .await + .map_err(|e| HandlerError::InternalServerError(e.into()))?; bhttp_res.write_content(&full_body); let mut bhttp_bytes = Vec::new(); - bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap(); - let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap(); + bhttp_res + .write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes) + .map_err(|e| HandlerError::InternalServerError(e.into()))?; + let ohttp_res = res_ctx + .encapsulate(&bhttp_bytes) + .map_err(|e| HandlerError::InternalServerError(e.into()))?; Ok(Response::new(Body::from(ohttp_res))) } @@ -162,16 +168,22 @@ async fn handle_v2(pool: DbPool, req: Request) -> Result, H enum HandlerError { PayloadTooLarge, - InternalServerError, - BadRequest, + InternalServerError(Box), + BadRequest(Box), } impl HandlerError { fn to_response(&self) -> Response { let status = match self { HandlerError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE, - HandlerError::BadRequest => StatusCode::BAD_REQUEST, - _ => StatusCode::INTERNAL_SERVER_ERROR, + Self::InternalServerError(e) => { + error!("Internal server error: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + } + Self::BadRequest(e) => { + error!("Bad request: {}", e); + StatusCode::BAD_REQUEST + } }; let mut res = Response::new(Body::empty()); @@ -181,17 +193,19 @@ impl HandlerError { } impl From for HandlerError { - fn from(_: hyper::http::Error) -> Self { HandlerError::InternalServerError } + fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) } } async fn post_enroll(body: Body) -> Result, HandlerError> { let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false); - let bytes = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::BadRequest)?; - let base64_id = String::from_utf8(bytes.to_vec()).map_err(|_| HandlerError::BadRequest)?; - let pubkey_bytes: Vec = - base64::decode_config(base64_id, b64_config).map_err(|_| HandlerError::BadRequest)?; + let bytes = + hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?; + let base64_id = + String::from_utf8(bytes.to_vec()).map_err(|e| HandlerError::BadRequest(e.into()))?; + let pubkey_bytes: Vec = base64::decode_config(base64_id, b64_config) + .map_err(|e| HandlerError::BadRequest(e.into()))?; let pubkey = bitcoin::secp256k1::PublicKey::from_slice(&pubkey_bytes) - .map_err(|_| HandlerError::BadRequest)?; + .map_err(|e| HandlerError::BadRequest(e.into()))?; tracing::info!("Enrolled valid pubkey: {:?}", pubkey); Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?) } @@ -226,20 +240,23 @@ async fn post_fallback( ) -> Result, HandlerError> { tracing::trace!("Post fallback"); let id = shorten_string(id); - let req = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?; + let req = hyper::body::to_bytes(body) + .await + .map_err(|e| HandlerError::InternalServerError(e.into()))?; + if req.len() > MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); } match pool.push_req(&id, req.into()).await { Ok(_) => (), - Err(_) => return Err(HandlerError::BadRequest), + Err(e) => return Err(HandlerError::BadRequest(e.into())), }; match pool.peek_res(&id).await { Some(result) => match result { Ok(buffered_res) => Ok(Response::new(Body::from(buffered_res))), - Err(_) => Err(HandlerError::BadRequest), + Err(e) => Err(HandlerError::BadRequest(e.into())), }, None => Ok(none_response), } @@ -250,7 +267,7 @@ async fn get_fallback(id: &str, pool: DbPool) -> Result, HandlerE match pool.peek_req(&id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(Body::from(buffered_req))), - Err(_) => Err(HandlerError::BadRequest), + Err(e) => Err(HandlerError::BadRequest(e.into())), }, None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?), } @@ -258,11 +275,13 @@ async fn get_fallback(id: &str, pool: DbPool) -> Result, HandlerE async fn post_payjoin(id: &str, body: Body, pool: DbPool) -> Result, HandlerError> { let id = shorten_string(id); - let res = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?; + let res = hyper::body::to_bytes(body) + .await + .map_err(|e| HandlerError::InternalServerError(e.into()))?; match pool.push_res(&id, res.into()).await { Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?), - Err(_) => Err(HandlerError::BadRequest), + Err(e) => Err(HandlerError::BadRequest(e.into())), } } diff --git a/payjoin-relay/tests/integration.rs b/payjoin-relay/tests/integration.rs index f79c3a1f..c3fc6e09 100644 --- a/payjoin-relay/tests/integration.rs +++ b/payjoin-relay/tests/integration.rs @@ -82,7 +82,7 @@ mod integration { // Enroll with relay let mut enroll_ctx = EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL); - let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body(); + let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().expect("Failed to enroll"); let _ohttp_response = http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request"); log::debug!("Enrolled receiver"); @@ -150,7 +150,8 @@ mod integration { // ********************** // Inside the Receiver: // GET fallback_psbt - let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body(); + let (payjoin_get_body, ohttp_req_ctx) = + enroll_ctx.payjoin_get_body().expect("Failed to get fallback"); let ohttp_response = http .post(RELAY_URL) .body(payjoin_get_body) @@ -162,11 +163,12 @@ mod integration { ); let proposal = enroll_ctx.parse_relay_response(reader, ohttp_req_ctx).unwrap().unwrap(); let payjoin_proposal = handle_proposal(proposal, receiver); - - let (body, _ohttp_ctx) = payjoin_proposal.extract_v2_req( - &ohttp_config_base64, - &format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()), - ); + let (body, _ohttp_ctx) = payjoin_proposal + .extract_v2_req( + &ohttp_config_base64, + &format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()), + ) + .expect("Failed to extract v2 req"); let _ohttp_response = http.post(RELAY_URL).body(body).send().await.expect("Failed to post payjoin_psbt"); @@ -174,6 +176,7 @@ mod integration { // Inside the Sender: // Sender checks, signs, finalizes, extracts, and broadcasts log::info!("replay POST fallback psbt for payjoin_psbt response"); + log::info!("Req body {:#?}", &req.body); let response = http .post(req.url.as_str()) .body(req.body.clone()) @@ -256,7 +259,7 @@ mod integration { // Enroll with relay let mut enroll_ctx = EnrollContext::from_relay_config(&RELAY_URL, &ohttp_config_base64, &RELAY_URL); - let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body(); + let (enroll_body, _ohttp_req_ctx) = enroll_ctx.enroll_body().unwrap(); let enroll = http.post(RELAY_URL).body(enroll_body).send().await.expect("Failed to send request"); @@ -331,7 +334,7 @@ mod integration { .expect("Failed to build reqwest http client"); let proposal = loop { - let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body(); + let (payjoin_get_body, ohttp_req_ctx) = enroll_ctx.payjoin_get_body().unwrap(); let enc_response = http .post(RELAY_URL) .body(payjoin_get_body) @@ -355,10 +358,12 @@ mod integration { debug!("handle relay response"); let response = handle_proposal(proposal, receiver); debug!("Post payjoin_psbt to relay"); - let (body, _ohttp_ctx) = response.extract_v2_req( - &ohttp_config_base64, - &format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()), - ); + let (body, _ohttp_ctx) = response + .extract_v2_req( + &ohttp_config_base64, + &format!("{}/{}", &RELAY_URL, &enroll_ctx.payjoin_subdir()), + ) + .unwrap(); // Respond with payjoin psbt within the time window the sender is willing to wait let response = http.post(RELAY_URL).body(body).send().await; debug!("POSTed with payjoin_psbt"); diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index 380f5339..4b7d42a8 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -7,6 +7,9 @@ pub enum Error { BadRequest(RequestError), // To be returned as HTTP 500 Server(Box), + // V2 d/encapsulation failed + #[cfg(feature = "v2")] + V2(crate::v2::Error), } impl fmt::Display for Error { @@ -14,6 +17,8 @@ impl fmt::Display for Error { match &self { Self::BadRequest(e) => e.fmt(f), Self::Server(e) => write!(f, "Internal Server Error: {}", e), + #[cfg(feature = "v2")] + Self::V2(e) => e.fmt(f), } } } @@ -23,6 +28,8 @@ impl error::Error for Error { match &self { Self::BadRequest(_) => None, Self::Server(e) => Some(e.as_ref()), + #[cfg(feature = "v2")] + Self::V2(e) => Some(e), } } } @@ -31,6 +38,15 @@ impl From for Error { fn from(e: RequestError) -> Self { Error::BadRequest(e) } } +impl From for Error { + fn from(e: InternalRequestError) -> Self { Error::BadRequest(e.into()) } +} + +impl From for Error { + #[cfg(feature = "v2")] + fn from(e: crate::v2::Error) -> Self { Error::V2(e) } +} + /// Error that may occur when the request from sender is malformed. /// /// This is currently opaque type because we aren't sure which variants will stay. diff --git a/payjoin/src/receive/mod.rs b/payjoin/src/receive/mod.rs index 2a8658b8..6ede74ff 100644 --- a/payjoin/src/receive/mod.rs +++ b/payjoin/src/receive/mod.rs @@ -324,35 +324,32 @@ impl EnrollContext { pub fn payjoin_subdir(&self) -> String { format!("{}/{}", self.subdirectory(), "payjoin") } - pub fn enroll_body(&mut self) -> (Vec, ohttp::ClientResponse) { - let (ohttp_req, ctx) = crate::v2::ohttp_encapsulate( + pub fn enroll_body(&mut self) -> Result<(Vec, ohttp::ClientResponse), crate::v2::Error> { + crate::v2::ohttp_encapsulate( &self.ohttp_config, "POST", self.relay_url.as_str(), Some(&self.subdirectory().as_bytes()), - ); - - (ohttp_req, ctx) + ) } - pub fn payjoin_get_body(&mut self) -> (Vec, ohttp::ClientResponse) { + pub fn payjoin_get_body( + &mut self, + ) -> Result<(Vec, ohttp::ClientResponse), crate::v2::Error> { let fallback_endpoint = format!("{}{}", &self.relay_url, self.subdirectory()); log::debug!("{}", fallback_endpoint.as_str()); - let (ohttp_req, ctx) = - crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &fallback_endpoint, None); - - (ohttp_req, ctx) + crate::v2::ohttp_encapsulate(&self.ohttp_config, "GET", &fallback_endpoint, None) } pub fn parse_relay_response( &self, mut body: impl std::io::Read, context: ohttp::ClientResponse, - ) -> Result, RequestError> { + ) -> Result, Error> { let mut buf = Vec::new(); let _ = body.read_to_end(&mut buf); log::trace!("decapsulating relay response"); - let response = crate::v2::ohttp_decapsulate(context, &buf); + let response = crate::v2::ohttp_decapsulate(context, &buf)?; if response.is_empty() { log::debug!("response is empty"); return Ok(None); @@ -369,7 +366,7 @@ impl EnrollContext { })) } Err(_) => { - let (proposal, e) = crate::v2::decrypt_message_a(&response, self.s.secret_key()); + let (proposal, e) = crate::v2::decrypt_message_a(&response, self.s.secret_key())?; let mut proposal = serde_json::from_slice::(&proposal) .map_err(InternalRequestError::Json)?; proposal.psbt = @@ -953,23 +950,30 @@ impl PayjoinProposal { &self, ohttp_config: &str, receive_endpoint: &str, - ) -> (Vec, ohttp::ClientResponse) { + ) -> Result<(Vec, ohttp::ClientResponse), Error> { let body = match self.v2_context { Some(e) => { let mut payjoin_bytes = self.payjoin_psbt.serialize(); crate::v2::encrypt_message_b(&mut payjoin_bytes, e) } - None => self.extract_v1_req().as_bytes().to_vec(), - }; - let ohttp_config = base64::decode_config(ohttp_config, base64::URL_SAFE).unwrap(); + None => Ok(self.extract_v1_req().as_bytes().to_vec()), + }?; + let ohttp_config = base64::decode_config(ohttp_config, base64::URL_SAFE) + .map_err(InternalRequestError::Base64)?; dbg!(receive_endpoint); - crate::v2::ohttp_encapsulate(&ohttp_config, "POST", receive_endpoint, Some(&body)) + let (req, ctx) = + crate::v2::ohttp_encapsulate(&ohttp_config, "POST", receive_endpoint, Some(&body))?; + Ok((req, ctx)) } - #[cfg(feature = "v2")] - pub fn deserialize_res(&self, res: Vec, ohttp_context: ohttp::ClientResponse) -> Vec { + pub fn deserialize_res( + &self, + res: Vec, + ohttp_context: ohttp::ClientResponse, + ) -> Result, Error> { // display success or failure - crate::v2::ohttp_decapsulate(ohttp_context, &res) + let res = crate::v2::ohttp_decapsulate(ohttp_context, &res)?; + Ok(res) } } diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 33e7ea50..6e70a822 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -16,13 +16,22 @@ pub struct ValidationError { #[derive(Debug)] pub(crate) enum InternalValidationError { - Psbt(bitcoin::psbt::PsbtParseError), + PsbtParse(bitcoin::psbt::PsbtParseError), Io(std::io::Error), InvalidInputType(InputTypeError), InvalidProposedInput(crate::psbt::PrevTxOutError), - VersionsDontMatch { proposed: i32, original: i32 }, - LockTimesDontMatch { proposed: LockTime, original: LockTime }, - SenderTxinSequenceChanged { proposed: Sequence, original: Sequence }, + VersionsDontMatch { + proposed: i32, + original: i32, + }, + LockTimesDontMatch { + proposed: LockTime, + original: LockTime, + }, + SenderTxinSequenceChanged { + proposed: Sequence, + original: Sequence, + }, SenderTxinContainsNonWitnessUtxo, SenderTxinContainsWitnessUtxo, SenderTxinContainsFinalScriptSig, @@ -32,7 +41,10 @@ pub(crate) enum InternalValidationError { ReceiverTxinNotFinalized, ReceiverTxinMissingUtxoInfo, MixedSequence, - MixedInputTypes { proposed: InputType, original: InputType }, + MixedInputTypes { + proposed: InputType, + original: InputType, + }, MissingOrShuffledInputs, TxOutContainsKeyPaths, FeeContributionExceedsMaximum, @@ -44,6 +56,10 @@ pub(crate) enum InternalValidationError { PayeeTookContributedFee, FeeContributionPaysOutputSizeIncrease, FeeRateBelowMinimum, + #[cfg(feature = "v2")] + V2(crate::v2::Error), + #[cfg(feature = "v2")] + Psbt(bitcoin::psbt::Error), } impl From for ValidationError { @@ -58,7 +74,7 @@ impl fmt::Display for ValidationError { use InternalValidationError::*; match &self.internal { - Psbt(e) => write!(f, "couldn't decode PSBT: {}", e), + PsbtParse(e) => write!(f, "couldn't decode PSBT: {}", e), Io(e) => write!(f, "couldn't read PSBT: {}", e), InvalidInputType(e) => write!(f, "invalid transaction input type: {}", e), InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e), @@ -86,6 +102,10 @@ impl fmt::Display for ValidationError { PayeeTookContributedFee => write!(f, "payee tried to take fee contribution for himself"), FeeContributionPaysOutputSizeIncrease => write!(f, "fee contribution pays for additional outputs"), FeeRateBelowMinimum => write!(f, "the fee rate of proposed transaction is below minimum"), + #[cfg(feature = "v2")] + V2(e) => write!(f, "v2 error: {}", e), + #[cfg(feature = "v2")] + Psbt(e) => write!(f, "psbt error: {}", e), } } } @@ -95,7 +115,7 @@ impl std::error::Error for ValidationError { use InternalValidationError::*; match &self.internal { - Psbt(error) => Some(error), + PsbtParse(error) => Some(error), Io(error) => Some(error), InvalidInputType(error) => Some(error), InvalidProposedInput(error) => Some(error), @@ -123,6 +143,10 @@ impl std::error::Error for ValidationError { PayeeTookContributedFee => None, FeeContributionPaysOutputSizeIncrease => None, FeeRateBelowMinimum => None, + #[cfg(feature = "v2")] + V2(error) => Some(error), + #[cfg(feature = "v2")] + Psbt(error) => Some(error), } } } @@ -152,6 +176,8 @@ pub(crate) enum InternalCreateRequestError { UriDoesNotSupportPayjoin, PrevTxOut(crate::psbt::PrevTxOutError), InputType(crate::input_type::InputTypeError), + #[cfg(feature = "v2")] + V2(crate::v2::Error), } impl fmt::Display for CreateRequestError { @@ -174,6 +200,8 @@ impl fmt::Display for CreateRequestError { UriDoesNotSupportPayjoin => write!(f, "the URI does not support payjoin"), PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e), InputType(e) => write!(f, "invalid input type: {}", e), + #[cfg(feature = "v2")] + V2(e) => write!(f, "v2 error: {}", e), } } } @@ -198,6 +226,8 @@ impl std::error::Error for CreateRequestError { UriDoesNotSupportPayjoin => None, PrevTxOut(error) => Some(error), InputType(error) => Some(error), + #[cfg(feature = "v2")] + V2(error) => Some(error), } } } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index ac869237..f579f2d3 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -402,13 +402,15 @@ impl<'a> RequestContext<'a> { self.fee_contribution, self.min_fee_rate, ); - let (body, e) = crate::v2::encrypt_message_a(&body, rs); + let (body, e) = + crate::v2::encrypt_message_a(&body, rs).map_err(InternalCreateRequestError::V2)?; let (body, ohttp_res) = crate::v2::ohttp_encapsulate( &self.uri.extras.ohttp_config.as_ref().unwrap().encode().unwrap(), "POST", url.as_str(), Some(&body), - ); + ) + .map_err(InternalCreateRequestError::V2)?; log::debug!("ohttp_proxy_url: {:?}", ohttp_proxy_url); let url = Url::parse(ohttp_proxy_url).map_err(InternalCreateRequestError::Url)?; Ok(( @@ -505,12 +507,14 @@ impl ContextV2 { ) -> Result, ValidationError> { let mut res_buf = Vec::new(); response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let mut res_buf = crate::v2::ohttp_decapsulate(self.ohttp_res, &res_buf); - let psbt = crate::v2::decrypt_message_b(&mut res_buf, self.e); + let mut res_buf = crate::v2::ohttp_decapsulate(self.ohttp_res, &res_buf) + .map_err(InternalValidationError::V2)?; + let psbt = crate::v2::decrypt_message_b(&mut res_buf, self.e) + .map_err(InternalValidationError::V2)?; if psbt.is_empty() { return Ok(None); } - let proposal = Psbt::deserialize(&psbt).expect("PSBT deserialization failed"); + let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; let processed_proposal = self.context_v1.process_proposal(proposal)?; Ok(Some(processed_proposal)) } @@ -528,7 +532,7 @@ impl ContextV1 { ) -> Result { let mut res_str = String::new(); response.read_to_string(&mut res_str).map_err(InternalValidationError::Io)?; - let proposal = Psbt::from_str(&res_str).map_err(InternalValidationError::Psbt)?; + let proposal = Psbt::from_str(&res_str).map_err(InternalValidationError::PsbtParse)?; // process in non-generic function self.process_proposal(proposal).map(Into::into).map_err(Into::into) diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index ae163642..ef32637e 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -19,6 +19,8 @@ pub fn subdir(path: &str) -> String { pubkey_id } +use std::{error, fmt}; + use bitcoin::secp256k1::ecdh::SharedSecret; use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; use chacha20poly1305::aead::{Aead, KeyInit, OsRng, Payload}; @@ -29,18 +31,18 @@ use chacha20poly1305::{AeadCore, ChaCha20Poly1305, Nonce}; /// <- Receiver S /// -> Sender E, ES(payload), payload protected by knowledge of receiver key /// <- Receiver E, EE(payload), payload protected by knowledge of sender & receiver key -pub fn encrypt_message_a(msg: &[u8], s: PublicKey) -> (Vec, SecretKey) { +pub fn encrypt_message_a(msg: &[u8], s: PublicKey) -> Result<(Vec, SecretKey), Error> { let secp = Secp256k1::new(); let (e_sec, e_pub) = secp.generate_keypair(&mut OsRng); let es = SharedSecret::new(&s, &e_sec); - let cipher = - ChaCha20Poly1305::new_from_slice(&es.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng); // key es encrypts only 1 message so 0 is unique let aad = &e_pub.serialize(); let payload = Payload { msg, aad }; log::debug!("payload.msg: {:?}", payload.msg); log::debug!("payload.aad: {:?}", payload.aad); - let c_t: Vec = cipher.encrypt(&nonce, payload).expect("encryption failed"); + let c_t: Vec = cipher.encrypt(&nonce, payload)?; log::debug!("c_t: {:?}", c_t); // let ct_payload = Payload { // msg: &c_t[..], @@ -54,19 +56,19 @@ pub fn encrypt_message_a(msg: &[u8], s: PublicKey) -> (Vec, SecretKey) { message_a.extend(&nonce[..]); log::debug!("nonce: {:?}", nonce); message_a.extend(&c_t[..]); - (message_a, e_sec) + Ok((message_a, e_sec)) } -pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> (Vec, PublicKey) { +pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> Result<(Vec, PublicKey), Error> { // let message a = [pubkey/AD][nonce][authentication tag][ciphertext] - let e = PublicKey::from_slice(&message_a[..33]).expect("invalid public key"); + let e = PublicKey::from_slice(&message_a[..33])?; log::debug!("e: {:?}", e); let nonce = Nonce::from_slice(&message_a[33..45]); log::debug!("nonce: {:?}", nonce); let es = SharedSecret::new(&e, &s); log::debug!("es: {:?}", es); - let cipher = - ChaCha20Poly1305::new_from_slice(&es.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let c_t = &message_a[45..]; let aad = &e.serialize(); log::debug!("c_t: {:?}", c_t); @@ -74,37 +76,37 @@ pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> (Vec, PublicKey) let payload = Payload { msg: &c_t, aad }; log::debug!("payload.msg: {:?}", payload.msg); log::debug!("payload.aad: {:?}", payload.aad); - let buffer = cipher.decrypt(&nonce, payload).expect("decryption failed"); - (buffer, e) + let buffer = cipher.decrypt(&nonce, payload)?; + Ok((buffer, e)) } -pub fn encrypt_message_b(msg: &mut Vec, re_pub: PublicKey) -> Vec { +pub fn encrypt_message_b(msg: &mut Vec, re_pub: PublicKey) -> Result, Error> { // let message b = [pubkey/AD][nonce][authentication tag][ciphertext] let secp = Secp256k1::new(); let (e_sec, e_pub) = secp.generate_keypair(&mut OsRng); let ee = SharedSecret::new(&re_pub, &e_sec); - let cipher = - ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let nonce = Nonce::from_slice(&[0u8; 12]); // key es encrypts only 1 message so 0 is unique let aad = &e_pub.serialize(); let payload = Payload { msg, aad }; - let c_t = cipher.encrypt(nonce, payload).expect("encryption failed"); + let c_t = cipher.encrypt(nonce, payload)?; let mut message_b = e_pub.serialize().to_vec(); message_b.extend(&nonce[..]); message_b.extend(&c_t[..]); - message_b + Ok(message_b) } -pub fn decrypt_message_b(message_b: &mut Vec, e: SecretKey) -> Vec { +pub fn decrypt_message_b(message_b: &mut Vec, e: SecretKey) -> Result, Error> { // let message b = [pubkey/AD][nonce][authentication tag][ciphertext] - let re = PublicKey::from_slice(&message_b[..33]).expect("invalid public key"); + let re = PublicKey::from_slice(&message_b[..33])?; let nonce = Nonce::from_slice(&message_b[33..45]); let ee = SharedSecret::new(&re, &e); - let cipher = - ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let payload = Payload { msg: &message_b[45..], aad: &re.serialize() }; - let buffer = cipher.decrypt(&nonce, payload).expect("decryption failed"); - buffer + let buffer = cipher.decrypt(&nonce, payload)?; + Ok(buffer) } pub fn ohttp_encapsulate( @@ -112,9 +114,9 @@ pub fn ohttp_encapsulate( method: &str, url: &str, body: Option<&[u8]>, -) -> (Vec, ohttp::ClientResponse) { - let ctx = ohttp::ClientRequest::from_encoded_config(ohttp_config).unwrap(); - let url = url::Url::parse(url).expect("invalid url"); +) -> Result<(Vec, ohttp::ClientResponse), Error> { + let ctx = ohttp::ClientRequest::from_encoded_config(ohttp_config)?; + let url = url::Url::parse(url)?; let mut bhttp_message = bhttp::Message::request( method.as_bytes().to_vec(), url.scheme().as_bytes().to_vec(), @@ -126,13 +128,89 @@ pub fn ohttp_encapsulate( } let mut bhttp_req = Vec::new(); let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); - ctx.encapsulate(&bhttp_req).expect("encapsulation failed") + let encapsulated = ctx.encapsulate(&bhttp_req)?; + Ok(encapsulated) } /// decapsulate ohttp, bhttp response and return http response body and status code -pub fn ohttp_decapsulate(res_ctx: ohttp::ClientResponse, ohttp_body: &[u8]) -> Vec { - let bhttp_body = res_ctx.decapsulate(ohttp_body).expect("decapsulation failed"); +pub fn ohttp_decapsulate( + res_ctx: ohttp::ClientResponse, + ohttp_body: &[u8], +) -> Result, Error> { + let mut bhttp_body = res_ctx.decapsulate(ohttp_body)?; let mut r = std::io::Cursor::new(bhttp_body); - let response = bhttp::Message::read_bhttp(&mut r).expect("read bhttp failed"); - response.content().to_vec() + let response = bhttp::Message::read_bhttp(&mut r)?; + Ok(response.content().to_vec()) +} + +/// Error that may occur when de/encrypting or de/capsulating a v2 message. +/// +/// This is currently opaque type because we aren't sure which variants will stay. +/// You can only display it. +#[derive(Debug)] +pub struct Error(InternalError); + +#[derive(Debug)] +pub(crate) enum InternalError { + Ohttp(ohttp::Error), + Bhttp(bhttp::Error), + ParseUrl(url::ParseError), + Secp256k1(bitcoin::secp256k1::Error), + ChaCha20Poly1305(chacha20poly1305::aead::Error), + InvalidKeyLength, +} + +impl From for Error { + fn from(value: ohttp::Error) -> Self { Self(InternalError::Ohttp(value)) } +} + +impl From for Error { + fn from(value: bhttp::Error) -> Self { Self(InternalError::Bhttp(value)) } +} + +impl From for Error { + fn from(value: url::ParseError) -> Self { Self(InternalError::ParseUrl(value)) } +} + +impl From for Error { + fn from(value: bitcoin::secp256k1::Error) -> Self { Self(InternalError::Secp256k1(value)) } +} + +impl From for Error { + fn from(value: chacha20poly1305::aead::Error) -> Self { + Self(InternalError::ChaCha20Poly1305(value)) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use InternalError::*; + + match &self.0 { + Ohttp(e) => e.fmt(f), + Bhttp(e) => e.fmt(f), + ParseUrl(e) => e.fmt(f), + Secp256k1(e) => e.fmt(f), + ChaCha20Poly1305(e) => e.fmt(f), + InvalidKeyLength => write!(f, "Invalid Length"), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use InternalError::*; + + match &self.0 { + Ohttp(e) => Some(e), + Bhttp(e) => Some(e), + ParseUrl(e) => Some(e), + Secp256k1(e) => Some(e), + ChaCha20Poly1305(_) | InvalidKeyLength => None, + } + } +} + +impl From for Error { + fn from(value: InternalError) -> Self { Self(value) } }