diff --git a/payjoin-cli/src/app.rs b/payjoin-cli/src/app.rs index 3f1626b7..27bc8f1a 100644 --- a/payjoin-cli/src/app.rs +++ b/payjoin-cli/src/app.rs @@ -92,13 +92,14 @@ impl App { &self, client: &reqwest::blocking::Client, enroll_context: &mut EnrollContext, - ) -> Result { + ) -> Result { loop { - let (enroll_body, context) = enroll_context.enroll_body(); + let (enroll_body, context) = enroll_context.enroll_body()?; let ohttp_response = client.post(&self.config.ohttp_proxy).body(enroll_body).send()?; let ohttp_response = ohttp_response.bytes()?; - let proposal = - enroll_context.parse_relay_response(ohttp_response.as_ref(), context).unwrap(); + let proposal = enroll_context + .parse_relay_response(ohttp_response.as_ref(), context) + .map_err(|e| anyhow!("parse error {}", e))?; match proposal { Some(proposal) => return Ok(proposal), None => std::thread::sleep(std::time::Duration::from_secs(5)), @@ -236,8 +237,11 @@ impl App { .map_err(|e| anyhow!("Failed to parse into UncheckedProposal {}", e))?; let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.receive_subdir()); - let (body, ohttp_ctx) = - payjoin_proposal.extract_v2_req(&self.config.ohttp_config, &receive_endpoint); + let ohttp_config = + bitcoin::base64::decode_config(&self.config.ohttp_config, base64::URL_SAFE)?; + let (body, ohttp_ctx) = payjoin_proposal + .extract_v2_req(&ohttp_config, &receive_endpoint) + .map_err(|e| anyhow!("v2 req extraction failed {}", e))?; let res = client .post(&self.config.ohttp_proxy) .body(body) diff --git a/payjoin-relay/src/main.rs b/payjoin-relay/src/main.rs index f73482f0..c2c4f6ce 100644 --- a/payjoin-relay/src/main.rs +++ b/payjoin-relay/src/main.rs @@ -69,8 +69,11 @@ fn init_ohttp() -> Result { let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))?; let encoded_config = server_config.encode()?; let b64_config = base64::encode_config( - encoded_config, - base64::Config::new(base64::CharacterSet::UrlSafe, false), + &encoded_config, + base64::Config::new( + base64::CharacterSet::UrlSafe, + false, + ), ); tracing::info!("ohttp server config base64 UrlSafe: {:?}", b64_config); Ok(ohttp::Server::new(server_config)?) @@ -103,33 +106,32 @@ async fn handle_ohttp( ) -> Result, HandlerError> { // decapsulate let ohttp_body = - hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?; + hyper::body::to_bytes(body).await.map_err(|e| HandlerError::BadRequest(e.into()))?; - let (bhttp_req, res_ctx) = ohttp.decapsulate(&ohttp_body).unwrap(); + let (bhttp_req, res_ctx) = ohttp.decapsulate(&ohttp_body).map_err(|e| HandlerError::BadRequest(e.into()))?; let mut cursor = std::io::Cursor::new(bhttp_req); - let req = bhttp::Message::read_bhttp(&mut cursor).unwrap(); + let req = bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?; let uri = Uri::builder() - .scheme(req.control().scheme().unwrap()) - .authority(req.control().authority().unwrap()) - .path_and_query(req.control().path().unwrap()) - .build() - .unwrap(); + .scheme(req.control().scheme().unwrap_or_default()) + .authority(req.control().authority().unwrap_or_default()) + .path_and_query(req.control().path().unwrap_or_default()) + .build()?; let body = req.content().to_vec(); - let mut http_req = Request::builder().uri(uri).method(req.control().method().unwrap()); + let mut http_req = Request::builder().uri(uri).method(req.control().method().unwrap_or_default()); for header in req.header().fields() { http_req = http_req.header(header.name(), header.value()) } - let request = http_req.body(Body::from(body)).unwrap(); + let request = http_req.body(Body::from(body))?; let response = handle_http(pool, request).await?; let (parts, body) = response.into_parts(); let mut bhttp_res = bhttp::Message::response(parts.status.as_u16()); - let full_body = hyper::body::to_bytes(body).await.unwrap(); + let full_body = hyper::body::to_bytes(body).await.map_err(|e| HandlerError::InternalServerError(e.into()))?; bhttp_res.write_content(&full_body); let mut bhttp_bytes = Vec::new(); - bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap(); - let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap(); + bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).map_err(|e| HandlerError::InternalServerError(e.into()))?; + let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).map_err(|e| HandlerError::InternalServerError(e.into()))?; Ok(Response::new(Body::from(ohttp_res))) } @@ -149,16 +151,22 @@ async fn handle_http(pool: DbPool, req: Request) -> Result, enum HandlerError { PayloadTooLarge, - InternalServerError, - BadRequest, + InternalServerError(Box), + BadRequest(Box), } impl HandlerError { fn to_response(&self) -> Response { let status = match self { HandlerError::PayloadTooLarge => StatusCode::PAYLOAD_TOO_LARGE, - Self::BadRequest => StatusCode::BAD_REQUEST, - _ => StatusCode::INTERNAL_SERVER_ERROR, + Self::InternalServerError(e) => { + tracing::error!("Internal server error: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + } + Self::BadRequest(e) => { + tracing::error!("Bad request: {}", e); + StatusCode::BAD_REQUEST + }, }; let mut res = Response::default(); @@ -168,13 +176,13 @@ impl HandlerError { } impl From for HandlerError { - fn from(_: hyper::http::Error) -> Self { HandlerError::InternalServerError } + fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) } } async fn post_fallback(id: &str, body: Body, pool: DbPool) -> Result, HandlerError> { tracing::debug!("Post fallback"); let id = shorten_string(id); - let req = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?; + let req = hyper::body::to_bytes(body).await.map_err(|e| HandlerError::InternalServerError(e.into()))?; if req.len() > MAX_BUFFER_SIZE { return Err(HandlerError::PayloadTooLarge); @@ -182,13 +190,13 @@ async fn post_fallback(id: &str, body: Body, pool: DbPool) -> Result (), - Err(_) => return Err(HandlerError::BadRequest), + Err(e) => return Err(HandlerError::BadRequest(e.into())), }; match pool.peek_res(&id).await { Some(result) => match result { Ok(buffered_res) => Ok(Response::new(Body::from(buffered_res))), - Err(_) => Err(HandlerError::BadRequest), + Err(e) => Err(HandlerError::BadRequest(e.into())), }, // TODO return v2 response, v1 gets its own None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?), @@ -200,7 +208,7 @@ async fn get_request(id: &str, pool: DbPool) -> Result, HandlerEr match pool.peek_req(&id).await { Some(result) => match result { Ok(buffered_req) => Ok(Response::new(Body::from(buffered_req))), - Err(_) => Err(HandlerError::BadRequest), + Err(e) => Err(HandlerError::BadRequest(e.into())), }, None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?), } @@ -208,11 +216,11 @@ async fn get_request(id: &str, pool: DbPool) -> Result, HandlerEr async fn post_payjoin(id: &str, body: Body, pool: DbPool) -> Result, HandlerError> { let id = shorten_string(id); - let res = hyper::body::to_bytes(body).await.map_err(|_| HandlerError::InternalServerError)?; + let res = hyper::body::to_bytes(body).await.map_err(|e| HandlerError::InternalServerError(e.into()))?; match pool.push_res(&id, res.into()).await { Ok(_) => Ok(Response::builder().status(StatusCode::NO_CONTENT).body(Body::empty())?), - Err(_) => Err(HandlerError::BadRequest), + Err(e) => Err(HandlerError::BadRequest(e.into())), } } diff --git a/payjoin/src/receive/error.rs b/payjoin/src/receive/error.rs index d9dbc9e3..dcbe6ec8 100644 --- a/payjoin/src/receive/error.rs +++ b/payjoin/src/receive/error.rs @@ -7,6 +7,9 @@ pub enum Error { BadRequest(RequestError), // To be returned as HTTP 500 Server(Box), + // V2 d/encapsulation failed + #[cfg(feature = "v2")] + V2(crate::v2::Error), } impl fmt::Display for Error { @@ -14,6 +17,8 @@ impl fmt::Display for Error { match &self { Self::BadRequest(e) => e.fmt(f), Self::Server(e) => write!(f, "Internal Server Error: {}", e), + #[cfg(feature = "v2")] + Self::V2(e) => e.fmt(f), } } } @@ -23,6 +28,8 @@ impl error::Error for Error { match &self { Self::BadRequest(_) => None, Self::Server(e) => Some(e.as_ref()), + #[cfg(feature = "v2")] + Self::V2(e) => Some(e), } } } @@ -31,6 +38,15 @@ impl From for Error { fn from(e: RequestError) -> Self { Error::BadRequest(e) } } +impl From for Error { + fn from(e: InternalRequestError) -> Self { Error::BadRequest(e.into()) } +} + +impl From for Error { + #[cfg(feature = "v2")] + fn from(e: crate::v2::Error) -> Self { Error::V2(e) } +} + /// Error that may occur when the request from sender is malformed. /// /// This is currently opaque type because we aren't sure which variants will stay. diff --git a/payjoin/src/receive/mod.rs b/payjoin/src/receive/mod.rs index 3cef0d7f..6422ccb7 100644 --- a/payjoin/src/receive/mod.rs +++ b/payjoin/src/receive/mod.rs @@ -282,6 +282,7 @@ use rand::Rng; use crate::input_type::InputType; use crate::optional_parameters::Params; use crate::psbt::PsbtExt; +use crate::v2; pub trait Headers { fn get_header(&self, key: &str) -> Option<&str>; @@ -326,32 +327,30 @@ impl EnrollContext { format!("{}/{}", self.subdirectory(), crate::v2::RECEIVE) } - pub fn enroll_body(&mut self) -> (Vec, ohttp::ClientResponse) { + pub fn enroll_body(&mut self) -> Result<(Vec, ohttp::ClientResponse), crate::v2::Error> { let receive_endpoint = self.receive_subdir(); log::debug!("{}{}", self.relay_url.as_str(), receive_endpoint); - let (ohttp_req, ctx) = crate::v2::ohttp_encapsulate( + crate::v2::ohttp_encapsulate( &self.ohttp_config, "GET", format!("{}{}", self.relay_url.as_str(), receive_endpoint).as_str(), None, - ); - - (ohttp_req, ctx) + ) } pub fn parse_relay_response( &self, mut body: impl std::io::Read, context: ohttp::ClientResponse, - ) -> Result, RequestError> { + ) -> Result, Error> { let mut buf = Vec::new(); let _ = body.read_to_end(&mut buf); - let response = crate::v2::ohttp_decapsulate(context, &buf); + let response = crate::v2::ohttp_decapsulate(context, &buf)?; if response.is_empty() { log::debug!("response is empty"); return Ok(None); } - let (proposal, e) = crate::v2::decrypt_message_a(&response, self.s.secret_key()); + let (proposal, e) = crate::v2::decrypt_message_a(&response, self.s.secret_key())?; let mut proposal = serde_json::from_slice::(&proposal) .map_err(InternalRequestError::Json)?; proposal.psbt = proposal.psbt.validate().map_err(InternalRequestError::InconsistentPsbt)?; @@ -931,21 +930,26 @@ impl PayjoinProposal { #[cfg(feature = "v2")] pub fn extract_v2_req( &self, - ohttp_config: &str, + ohttp_config: &Vec, receive_endpoint: &str, - ) -> (Vec, ohttp::ClientResponse) { + ) -> Result<(Vec, ohttp::ClientResponse), Error> { let e = self.v2_context.unwrap(); // TODO make v2 only let mut payjoin_bytes = self.payjoin_psbt.serialize(); - let body = crate::v2::encrypt_message_b(&mut payjoin_bytes, e); - let ohttp_config = base64::decode_config(ohttp_config, base64::URL_SAFE).unwrap(); + let body = crate::v2::encrypt_message_b(&mut payjoin_bytes, e)?; dbg!(receive_endpoint); - crate::v2::ohttp_encapsulate(&ohttp_config, "POST", receive_endpoint, Some(&body)) + let (req, ctx) = + crate::v2::ohttp_encapsulate(&ohttp_config, "POST", receive_endpoint, Some(&body))?; + Ok((req, ctx)) } - #[cfg(feature = "v2")] - pub fn deserialize_res(&self, res: Vec, ohttp_context: ohttp::ClientResponse) -> Vec { + pub fn deserialize_res( + &self, + res: Vec, + ohttp_context: ohttp::ClientResponse, + ) -> Result, Error> { // display success or failure - crate::v2::ohttp_decapsulate(ohttp_context, &res) + let res = crate::v2::ohttp_decapsulate(ohttp_context, &res)?; + Ok(res) } } diff --git a/payjoin/src/send/error.rs b/payjoin/src/send/error.rs index 33e7ea50..6e70a822 100644 --- a/payjoin/src/send/error.rs +++ b/payjoin/src/send/error.rs @@ -16,13 +16,22 @@ pub struct ValidationError { #[derive(Debug)] pub(crate) enum InternalValidationError { - Psbt(bitcoin::psbt::PsbtParseError), + PsbtParse(bitcoin::psbt::PsbtParseError), Io(std::io::Error), InvalidInputType(InputTypeError), InvalidProposedInput(crate::psbt::PrevTxOutError), - VersionsDontMatch { proposed: i32, original: i32 }, - LockTimesDontMatch { proposed: LockTime, original: LockTime }, - SenderTxinSequenceChanged { proposed: Sequence, original: Sequence }, + VersionsDontMatch { + proposed: i32, + original: i32, + }, + LockTimesDontMatch { + proposed: LockTime, + original: LockTime, + }, + SenderTxinSequenceChanged { + proposed: Sequence, + original: Sequence, + }, SenderTxinContainsNonWitnessUtxo, SenderTxinContainsWitnessUtxo, SenderTxinContainsFinalScriptSig, @@ -32,7 +41,10 @@ pub(crate) enum InternalValidationError { ReceiverTxinNotFinalized, ReceiverTxinMissingUtxoInfo, MixedSequence, - MixedInputTypes { proposed: InputType, original: InputType }, + MixedInputTypes { + proposed: InputType, + original: InputType, + }, MissingOrShuffledInputs, TxOutContainsKeyPaths, FeeContributionExceedsMaximum, @@ -44,6 +56,10 @@ pub(crate) enum InternalValidationError { PayeeTookContributedFee, FeeContributionPaysOutputSizeIncrease, FeeRateBelowMinimum, + #[cfg(feature = "v2")] + V2(crate::v2::Error), + #[cfg(feature = "v2")] + Psbt(bitcoin::psbt::Error), } impl From for ValidationError { @@ -58,7 +74,7 @@ impl fmt::Display for ValidationError { use InternalValidationError::*; match &self.internal { - Psbt(e) => write!(f, "couldn't decode PSBT: {}", e), + PsbtParse(e) => write!(f, "couldn't decode PSBT: {}", e), Io(e) => write!(f, "couldn't read PSBT: {}", e), InvalidInputType(e) => write!(f, "invalid transaction input type: {}", e), InvalidProposedInput(e) => write!(f, "invalid proposed transaction input: {}", e), @@ -86,6 +102,10 @@ impl fmt::Display for ValidationError { PayeeTookContributedFee => write!(f, "payee tried to take fee contribution for himself"), FeeContributionPaysOutputSizeIncrease => write!(f, "fee contribution pays for additional outputs"), FeeRateBelowMinimum => write!(f, "the fee rate of proposed transaction is below minimum"), + #[cfg(feature = "v2")] + V2(e) => write!(f, "v2 error: {}", e), + #[cfg(feature = "v2")] + Psbt(e) => write!(f, "psbt error: {}", e), } } } @@ -95,7 +115,7 @@ impl std::error::Error for ValidationError { use InternalValidationError::*; match &self.internal { - Psbt(error) => Some(error), + PsbtParse(error) => Some(error), Io(error) => Some(error), InvalidInputType(error) => Some(error), InvalidProposedInput(error) => Some(error), @@ -123,6 +143,10 @@ impl std::error::Error for ValidationError { PayeeTookContributedFee => None, FeeContributionPaysOutputSizeIncrease => None, FeeRateBelowMinimum => None, + #[cfg(feature = "v2")] + V2(error) => Some(error), + #[cfg(feature = "v2")] + Psbt(error) => Some(error), } } } @@ -152,6 +176,8 @@ pub(crate) enum InternalCreateRequestError { UriDoesNotSupportPayjoin, PrevTxOut(crate::psbt::PrevTxOutError), InputType(crate::input_type::InputTypeError), + #[cfg(feature = "v2")] + V2(crate::v2::Error), } impl fmt::Display for CreateRequestError { @@ -174,6 +200,8 @@ impl fmt::Display for CreateRequestError { UriDoesNotSupportPayjoin => write!(f, "the URI does not support payjoin"), PrevTxOut(e) => write!(f, "invalid previous transaction output: {}", e), InputType(e) => write!(f, "invalid input type: {}", e), + #[cfg(feature = "v2")] + V2(e) => write!(f, "v2 error: {}", e), } } } @@ -198,6 +226,8 @@ impl std::error::Error for CreateRequestError { UriDoesNotSupportPayjoin => None, PrevTxOut(error) => Some(error), InputType(error) => Some(error), + #[cfg(feature = "v2")] + V2(error) => Some(error), } } } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index ac869237..f579f2d3 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -402,13 +402,15 @@ impl<'a> RequestContext<'a> { self.fee_contribution, self.min_fee_rate, ); - let (body, e) = crate::v2::encrypt_message_a(&body, rs); + let (body, e) = + crate::v2::encrypt_message_a(&body, rs).map_err(InternalCreateRequestError::V2)?; let (body, ohttp_res) = crate::v2::ohttp_encapsulate( &self.uri.extras.ohttp_config.as_ref().unwrap().encode().unwrap(), "POST", url.as_str(), Some(&body), - ); + ) + .map_err(InternalCreateRequestError::V2)?; log::debug!("ohttp_proxy_url: {:?}", ohttp_proxy_url); let url = Url::parse(ohttp_proxy_url).map_err(InternalCreateRequestError::Url)?; Ok(( @@ -505,12 +507,14 @@ impl ContextV2 { ) -> Result, ValidationError> { let mut res_buf = Vec::new(); response.read_to_end(&mut res_buf).map_err(InternalValidationError::Io)?; - let mut res_buf = crate::v2::ohttp_decapsulate(self.ohttp_res, &res_buf); - let psbt = crate::v2::decrypt_message_b(&mut res_buf, self.e); + let mut res_buf = crate::v2::ohttp_decapsulate(self.ohttp_res, &res_buf) + .map_err(InternalValidationError::V2)?; + let psbt = crate::v2::decrypt_message_b(&mut res_buf, self.e) + .map_err(InternalValidationError::V2)?; if psbt.is_empty() { return Ok(None); } - let proposal = Psbt::deserialize(&psbt).expect("PSBT deserialization failed"); + let proposal = Psbt::deserialize(&psbt).map_err(InternalValidationError::Psbt)?; let processed_proposal = self.context_v1.process_proposal(proposal)?; Ok(Some(processed_proposal)) } @@ -528,7 +532,7 @@ impl ContextV1 { ) -> Result { let mut res_str = String::new(); response.read_to_string(&mut res_str).map_err(InternalValidationError::Io)?; - let proposal = Psbt::from_str(&res_str).map_err(InternalValidationError::Psbt)?; + let proposal = Psbt::from_str(&res_str).map_err(InternalValidationError::PsbtParse)?; // process in non-generic function self.process_proposal(proposal).map(Into::into).map_err(Into::into) diff --git a/payjoin/src/v2.rs b/payjoin/src/v2.rs index 0d807e39..46ddce29 100644 --- a/payjoin/src/v2.rs +++ b/payjoin/src/v2.rs @@ -20,6 +20,8 @@ pub fn subdir(path: &str) -> String { pubkey_id } +use std::{error, fmt}; + use bitcoin::secp256k1::ecdh::SharedSecret; use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; use chacha20poly1305::aead::{Aead, KeyInit, OsRng, Payload}; @@ -30,18 +32,18 @@ use chacha20poly1305::{AeadCore, ChaCha20Poly1305, Nonce}; /// <- Receiver S /// -> Sender E, ES(payload), payload protected by knowledge of receiver key /// <- Receiver E, EE(payload), payload protected by knowledge of sender & receiver key -pub fn encrypt_message_a(msg: &[u8], s: PublicKey) -> (Vec, SecretKey) { +pub fn encrypt_message_a(msg: &[u8], s: PublicKey) -> Result<(Vec, SecretKey), Error> { let secp = Secp256k1::new(); let (e_sec, e_pub) = secp.generate_keypair(&mut OsRng); let es = SharedSecret::new(&s, &e_sec); - let cipher = - ChaCha20Poly1305::new_from_slice(&es.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng); // key es encrypts only 1 message so 0 is unique let aad = &e_pub.serialize(); let payload = Payload { msg, aad }; log::debug!("payload.msg: {:?}", payload.msg); log::debug!("payload.aad: {:?}", payload.aad); - let c_t: Vec = cipher.encrypt(&nonce, payload).expect("encryption failed"); + let c_t: Vec = cipher.encrypt(&nonce, payload)?; log::debug!("c_t: {:?}", c_t); // let ct_payload = Payload { // msg: &c_t[..], @@ -55,19 +57,19 @@ pub fn encrypt_message_a(msg: &[u8], s: PublicKey) -> (Vec, SecretKey) { message_a.extend(&nonce[..]); log::debug!("nonce: {:?}", nonce); message_a.extend(&c_t[..]); - (message_a, e_sec) + Ok((message_a, e_sec)) } -pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> (Vec, PublicKey) { +pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> Result<(Vec, PublicKey), Error> { // let message a = [pubkey/AD][nonce][authentication tag][ciphertext] - let e = PublicKey::from_slice(&message_a[..33]).expect("invalid public key"); + let e = PublicKey::from_slice(&message_a[..33])?; log::debug!("e: {:?}", e); let nonce = Nonce::from_slice(&message_a[33..45]); log::debug!("nonce: {:?}", nonce); let es = SharedSecret::new(&e, &s); log::debug!("es: {:?}", es); - let cipher = - ChaCha20Poly1305::new_from_slice(&es.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&es.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let c_t = &message_a[45..]; let aad = &e.serialize(); log::debug!("c_t: {:?}", c_t); @@ -75,37 +77,37 @@ pub fn decrypt_message_a(message_a: &[u8], s: SecretKey) -> (Vec, PublicKey) let payload = Payload { msg: &c_t, aad }; log::debug!("payload.msg: {:?}", payload.msg); log::debug!("payload.aad: {:?}", payload.aad); - let buffer = cipher.decrypt(&nonce, payload).expect("decryption failed"); - (buffer, e) + let buffer = cipher.decrypt(&nonce, payload)?; + Ok((buffer, e)) } -pub fn encrypt_message_b(msg: &mut Vec, re_pub: PublicKey) -> Vec { +pub fn encrypt_message_b(msg: &mut Vec, re_pub: PublicKey) -> Result, Error> { // let message b = [pubkey/AD][nonce][authentication tag][ciphertext] let secp = Secp256k1::new(); let (e_sec, e_pub) = secp.generate_keypair(&mut OsRng); let ee = SharedSecret::new(&re_pub, &e_sec); - let cipher = - ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let nonce = Nonce::from_slice(&[0u8; 12]); // key es encrypts only 1 message so 0 is unique let aad = &e_pub.serialize(); let payload = Payload { msg, aad }; - let c_t = cipher.encrypt(nonce, payload).expect("encryption failed"); + let c_t = cipher.encrypt(nonce, payload)?; let mut message_b = e_pub.serialize().to_vec(); message_b.extend(&nonce[..]); message_b.extend(&c_t[..]); - message_b + Ok(message_b) } -pub fn decrypt_message_b(message_b: &mut Vec, e: SecretKey) -> Vec { +pub fn decrypt_message_b(message_b: &mut Vec, e: SecretKey) -> Result, Error> { // let message b = [pubkey/AD][nonce][authentication tag][ciphertext] - let re = PublicKey::from_slice(&message_b[..33]).expect("invalid public key"); + let re = PublicKey::from_slice(&message_b[..33])?; let nonce = Nonce::from_slice(&message_b[33..45]); let ee = SharedSecret::new(&re, &e); - let cipher = - ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()).expect("cipher creation failed"); + let cipher = ChaCha20Poly1305::new_from_slice(&ee.secret_bytes()) + .map_err(|_| InternalError::InvalidKeyLength)?; let payload = Payload { msg: &message_b[45..], aad: &re.serialize() }; - let buffer = cipher.decrypt(&nonce, payload).expect("decryption failed"); - buffer + let buffer = cipher.decrypt(&nonce, payload)?; + Ok(buffer) } pub fn ohttp_encapsulate( @@ -113,9 +115,9 @@ pub fn ohttp_encapsulate( method: &str, url: &str, body: Option<&[u8]>, -) -> (Vec, ohttp::ClientResponse) { - let ctx = ohttp::ClientRequest::from_encoded_config(ohttp_config).unwrap(); - let url = url::Url::parse(url).expect("invalid url"); +) -> Result<(Vec, ohttp::ClientResponse), Error> { + let ctx = ohttp::ClientRequest::from_encoded_config(ohttp_config)?; + let url = url::Url::parse(url)?; let mut bhttp_message = bhttp::Message::request( method.as_bytes().to_vec(), url.scheme().as_bytes().to_vec(), @@ -127,13 +129,89 @@ pub fn ohttp_encapsulate( } let mut bhttp_req = Vec::new(); let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req); - ctx.encapsulate(&bhttp_req).expect("encapsulation failed") + let encapsulated = ctx.encapsulate(&bhttp_req)?; + Ok(encapsulated) } /// decapsulate ohttp, bhttp response and return http response body and status code -pub fn ohttp_decapsulate(res_ctx: ohttp::ClientResponse, ohttp_body: &[u8]) -> Vec { - let bhttp_body = res_ctx.decapsulate(ohttp_body).expect("decapsulation failed"); +pub fn ohttp_decapsulate( + res_ctx: ohttp::ClientResponse, + ohttp_body: &[u8], +) -> Result, Error> { + let mut bhttp_body = res_ctx.decapsulate(ohttp_body)?; let mut r = std::io::Cursor::new(bhttp_body); - let response = bhttp::Message::read_bhttp(&mut r).expect("read bhttp failed"); - response.content().to_vec() + let response = bhttp::Message::read_bhttp(&mut r)?; + Ok(response.content().to_vec()) +} + +/// Error that may occur when de/encrypting or de/capsulating a v2 message. +/// +/// This is currently opaque type because we aren't sure which variants will stay. +/// You can only display it. +#[derive(Debug)] +pub struct Error(InternalError); + +#[derive(Debug)] +pub(crate) enum InternalError { + Ohttp(ohttp::Error), + Bhttp(bhttp::Error), + ParseUrl(url::ParseError), + Secp256k1(bitcoin::secp256k1::Error), + ChaCha20Poly1305(chacha20poly1305::aead::Error), + InvalidKeyLength, +} + +impl From for Error { + fn from(value: ohttp::Error) -> Self { Self(InternalError::Ohttp(value)) } +} + +impl From for Error { + fn from(value: bhttp::Error) -> Self { Self(InternalError::Bhttp(value)) } +} + +impl From for Error { + fn from(value: url::ParseError) -> Self { Self(InternalError::ParseUrl(value)) } +} + +impl From for Error { + fn from(value: bitcoin::secp256k1::Error) -> Self { Self(InternalError::Secp256k1(value)) } +} + +impl From for Error { + fn from(value: chacha20poly1305::aead::Error) -> Self { + Self(InternalError::ChaCha20Poly1305(value)) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use InternalError::*; + + match &self.0 { + Ohttp(e) => e.fmt(f), + Bhttp(e) => e.fmt(f), + ParseUrl(e) => e.fmt(f), + Secp256k1(e) => e.fmt(f), + ChaCha20Poly1305(e) => e.fmt(f), + InvalidKeyLength => write!(f, "Invalid Length"), + } + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + use InternalError::*; + + match &self.0 { + Ohttp(e) => Some(e), + Bhttp(e) => Some(e), + ParseUrl(e) => Some(e), + Secp256k1(e) => Some(e), + ChaCha20Poly1305(_) | InvalidKeyLength => None, + } + } +} + +impl From for Error { + fn from(value: InternalError) -> Self { Self(value) } }