diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index bd16171f..dcf5554b 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -362,14 +362,15 @@ async fn unwrap_ohttp_keys_or_else_fetch(config: &AppConfig) -> Result, ) -> Result { - use reqwest::{Client, Proxy}; - let ohttp_keys_url = payjoin_directory.join("/ohttp-keys")?; let proxy = Proxy::all(ohttp_relay.as_str())?; - #[cfg(not(feature = "_danger-local-https"))] let client = Client::builder().proxy(proxy).build()?; - #[cfg(feature = "_danger-local-https")] + let res = client.get(ohttp_keys_url).send().await?; + let body = res.bytes().await?.to_vec(); + OhttpKeys::decode(&body).map_err(|e| Error(InternalError::InvalidOhttpKeys(e.to_string()))) +} + +/// Fetch the ohttp keys from the specified payjoin directory via proxy. +/// +/// * `ohttp_relay`: The http CONNNECT method proxy to request the ohttp keys from a payjoin +/// directory. Proxying requests for ohttp keys ensures a client IP address is never revealed to +/// the payjoin directory. +/// +/// * `payjoin_directory`: The payjoin directory from which to fetch the ohttp keys. This +/// directory stores and forwards payjoin client payloads. +/// +/// * `cert_der`: The DER-encoded certificate to use for local HTTPS connections. +#[cfg(feature = "_danger-local-https")] +pub async fn fetch_ohttp_keys_with_cert( + ohttp_relay: Url, + payjoin_directory: Url, + cert_der: Vec, +) -> Result { + let ohttp_keys_url = payjoin_directory.join("/ohttp-keys")?; + let proxy = Proxy::all(ohttp_relay.as_str())?; let client = Client::builder() .danger_accept_invalid_certs(true) .use_rustls_tls() @@ -46,7 +61,6 @@ enum InternalError { Io(std::io::Error), #[cfg(feature = "_danger-local-https")] Rustls(rustls::Error), - #[cfg(feature = "v2")] InvalidOhttpKeys(String), } @@ -72,7 +86,6 @@ impl std::fmt::Display for Error { Reqwest(e) => e.fmt(f), ParseUrl(e) => e.fmt(f), Io(e) => e.fmt(f), - #[cfg(feature = "v2")] InvalidOhttpKeys(e) => { write!(f, "Invalid ohttp keys returned from payjoin directory: {}", e) } @@ -90,7 +103,6 @@ impl std::error::Error for Error { Reqwest(e) => Some(e), ParseUrl(e) => Some(e), Io(e) => Some(e), - #[cfg(feature = "v2")] InvalidOhttpKeys(_) => None, #[cfg(feature = "_danger-local-https")] Rustls(e) => Some(e), diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index a82e581a..dda37056 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -1,4 +1,4 @@ -#[cfg(all(feature = "send", feature = "receive"))] +#[cfg(all(feature = "send", feature = "receive", feature = "_danger-local-https"))] mod integration { use std::collections::HashMap; use std::env; @@ -171,8 +171,7 @@ mod integration { } } - #[cfg(feature = "_danger-local-https")] - #[cfg(feature = "v2")] + #[cfg(all(feature = "io", feature = "v2"))] mod v2 { use std::sync::Arc; use std::time::Duration; @@ -252,9 +251,12 @@ mod integration { let agent = Arc::new(http_agent(cert_der.clone())?); wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap(); wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); - let ohttp_keys = - payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) - .await?; + let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert( + ohttp_relay, + directory.clone(), + cert_der, + ) + .await?; // ********************** // Inside the Receiver: @@ -321,9 +323,12 @@ mod integration { let agent = Arc::new(http_agent(cert_der.clone())?); wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap(); wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); - let ohttp_keys = - payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) - .await?; + let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert( + ohttp_relay, + directory.clone(), + cert_der.clone(), + ) + .await?; // ********************** // Inside the Receiver: let address = receiver.get_new_address(None, None)?.assume_checked(); @@ -450,9 +455,12 @@ mod integration { let agent = Arc::new(http_agent(cert_der.clone())?); wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap(); wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); - let ohttp_keys = - payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) - .await?; + let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert( + ohttp_relay, + directory.clone(), + cert_der, + ) + .await?; // ********************** // Inside the Receiver: // make utxos with different script types @@ -662,9 +670,12 @@ mod integration { let agent: Arc = Arc::new(http_agent(cert_der.clone())?); wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await?; wait_for_service_ready(directory.clone(), agent.clone()).await?; - let ohttp_keys = - payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) - .await?; + let ohttp_keys = payjoin::io::fetch_ohttp_keys_with_cert( + ohttp_relay, + directory.clone(), + cert_der.clone(), + ) + .await?; let address = receiver.get_new_address(None, None)?.assume_checked(); let mut session = initialize_session(address, directory, ohttp_keys.clone(), None);