diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index b77b84ef..e975dd13 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -184,29 +184,73 @@ mod integration { use reqwest::{Client, ClientBuilder, Error, Response}; use testcontainers_modules::redis::Redis; use testcontainers_modules::testcontainers::clients::Cli; + use tokio::sync::OnceCell as AsyncOnceCell; use super::*; static TESTS_TIMEOUT: Lazy = Lazy::new(|| Duration::from_secs(20)); static WAIT_SERVICE_INTERVAL: Lazy = Lazy::new(|| Duration::from_secs(3)); + static DIRECTORY_PORT: Lazy = Lazy::new(find_free_port); + static OHTTP_RELAY_PORT: Lazy = Lazy::new(find_free_port); + // Shared test infrastructure + static TEST_INFRASTRUCTURE: AsyncOnceCell = AsyncOnceCell::const_new(); + + struct TestInfrastructure { + directory: Url, + ohttp_relay: Url, + agent: Arc, + cert: Vec, + } + + impl TestInfrastructure { + async fn new() -> Result { + let (cert, key) = local_cert_key(); + let directory = Url::parse(&format!("https://localhost:{}", *DIRECTORY_PORT))?; + let ohttp_relay = Url::parse(&format!("http://localhost:{}", *OHTTP_RELAY_PORT))?; + let gateway_origin = http::Uri::from_str(directory.as_str())?; + + // Start services in background tasks + let _directory_handle = + tokio::spawn(init_directory(*DIRECTORY_PORT, (cert.clone(), key))); + let _relay_handle = + tokio::spawn(ohttp_relay::listen_tcp(*OHTTP_RELAY_PORT, gateway_origin)); + + // Create HTTP agent + let agent = Arc::new(http_agent(cert.clone())?); + + // Wait for services to be ready + wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await?; + wait_for_service_ready(directory.clone(), agent.clone()).await?; + + Ok(Self { directory, ohttp_relay, agent, cert }) + } + } + + async fn init_infrastructure() -> &'static TestInfrastructure { + TEST_INFRASTRUCTURE + .get_or_init(|| async { + TestInfrastructure::new() + .await + .expect("Failed to initialize test infrastructure") + }) + .await + } #[tokio::test] async fn test_bad_ohttp_keys() { let bad_ohttp_keys = OhttpKeys::from_str("AQO6SMScPUqSo60A7MY6Ak2hDO0CGAxz7BLYp60syRu0gw") .expect("Invalid OhttpKeys"); - - let (cert, key) = local_cert_key(); - let port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap(); - tokio::select!( - err = init_directory(port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), - res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => { - assert_eq!( - res.unwrap().headers().get("content-type").unwrap(), - "application/problem+json" - ); - } + let infra = init_infrastructure().await; + let res = try_request_with_bad_keys( + infra.directory.clone(), + bad_ohttp_keys, + infra.cert.clone(), + ) + .await; + assert_eq!( + res.unwrap().headers().get("content-type").unwrap(), + "application/problem+json" ); async fn try_request_with_bad_keys( @@ -214,8 +258,7 @@ mod integration { bad_ohttp_keys: OhttpKeys, cert_der: Vec, ) -> Result { - let agent = Arc::new(http_agent(cert_der.clone()).unwrap()); - wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); + let agent = Arc::new(http_agent(cert_der).unwrap()); let mock_ohttp_relay = directory.clone(); // pass through to directory let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4") .unwrap() @@ -230,18 +273,14 @@ mod integration { #[tokio::test] async fn test_session_expiration() { init_tracing(); - let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); - let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); - tokio::select!( - err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), - res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) - ); + let infra = init_infrastructure().await; + let res = do_expiration_tests( + infra.ohttp_relay.clone(), + infra.directory.clone(), + infra.cert.clone(), + ) + .await; + assert!(res.is_ok(), "v2 send receive failed: {:#?}", res); async fn do_expiration_tests( ohttp_relay: Url, @@ -250,8 +289,6 @@ mod integration { ) -> Result<(), BoxError> { let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?; let agent = Arc::new(http_agent(cert_der.clone())?); - wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap(); - wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); let ohttp_keys = payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) .await?; @@ -298,18 +335,14 @@ mod integration { #[tokio::test] async fn v2_to_v2() { init_tracing(); - let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); - let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); - tokio::select!( - err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), - res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) - ); + let infra = init_infrastructure().await; + let res = do_v2_send_receive( + infra.ohttp_relay.clone(), + infra.directory.clone(), + infra.cert.clone(), + ) + .await; + assert!(res.is_ok(), "v2 send receive failed: {:#?}", res); async fn do_v2_send_receive( ohttp_relay: Url, @@ -318,8 +351,6 @@ mod integration { ) -> Result<(), BoxError> { let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?; let agent = Arc::new(http_agent(cert_der.clone())?); - wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap(); - wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); let ohttp_keys = payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) .await?; @@ -430,18 +461,14 @@ mod integration { #[tokio::test] async fn v2_to_v2_mixed_input_script_types() { init_tracing(); - let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); - let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); - tokio::select!( - err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), - res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res) - ); + let infra = init_infrastructure().await; + let res = do_v2_send_receive( + infra.ohttp_relay.clone(), + infra.directory.clone(), + infra.cert.clone(), + ) + .await; + assert!(res.is_ok(), "v2 send receive failed: {:#?}", res); async fn do_v2_send_receive( ohttp_relay: Url, @@ -450,15 +477,12 @@ mod integration { ) -> Result<(), BoxError> { let (bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?; let agent = Arc::new(http_agent(cert_der.clone())?); - wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await.unwrap(); - wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap(); let ohttp_keys = payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) .await?; // ********************** // Inside the Receiver: // make utxos with different script types - let legacy_address = receiver.get_new_address(None, Some(AddressType::Legacy))?.assume_checked(); let nested_segwit_address = @@ -647,18 +671,11 @@ mod integration { #[tokio::test] async fn v1_to_v2() { init_tracing(); - let (cert, key) = local_cert_key(); - let ohttp_relay_port = find_free_port(); - let ohttp_relay = - Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap(); - let directory_port = find_free_port(); - let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap(); - let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap(); - tokio::select!( - err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err), - err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err), - res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()), - ); + let infra = init_infrastructure().await; + let res = + do_v1_to_v2(infra.ohttp_relay.clone(), infra.directory.clone(), infra.cert.clone()) + .await; + assert!(res.is_ok()); async fn do_v1_to_v2( ohttp_relay: Url, @@ -667,8 +684,6 @@ mod integration { ) -> Result<(), BoxError> { let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?; let agent: Arc = Arc::new(http_agent(cert_der.clone())?); - wait_for_service_ready(ohttp_relay.clone(), agent.clone()).await?; - wait_for_service_ready(directory.clone(), agent.clone()).await?; let ohttp_keys = payjoin::io::fetch_ohttp_keys(ohttp_relay, directory.clone(), cert_der.clone()) .await?; @@ -780,13 +795,19 @@ mod integration { async fn init_directory( port: u16, local_cert_key: (Vec, Vec), - ) -> Result<(), BoxError> { + ) -> Result<(), Box> { let docker: Cli = Cli::default(); let timeout = Duration::from_secs(2); let db = docker.run(Redis); let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379)); println!("Database running on {}", db.get_host_port_ipv4(6379)); - payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await + payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key) + .await + .map_err(|e| { + let err_string = e.to_string(); + Box::new(std::io::Error::new(std::io::ErrorKind::Other, err_string)) + as Box + }) } // generates or gets a DER encoded localhost cert and key.