Skip to content

Commit

Permalink
Have directory give US a port
Browse files Browse the repository at this point in the history
ensure that the db: testcontainer::Container variable does not go out of scope while the directory is running.

previously the directory task itself was awaited on by init_directory, whereas in the modified code it is instead returned as part of the result due to the different return value of listen_tcp_with_tls_on_free_port. this indirection de-coupled the db variable's lifetime from that of the directory, allowing it to go out of scope earlier than expected.

Co-authored-by: Yuval Kogman <[email protected]>
  • Loading branch information
2 people authored and spacebear21 committed Dec 3, 2024
1 parent 33e117b commit 834859c
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 62 deletions.
35 changes: 26 additions & 9 deletions payjoin-cli/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ mod e2e {
use url::Url;

type Error = Box<dyn std::error::Error + 'static>;
type BoxSendSyncError = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;

static INIT_TRACING: OnceCell<()> = OnceCell::new();
Expand All @@ -180,16 +181,28 @@ mod e2e {
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_relay = Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
let (port, directory_future) =
init_directory(db_host, (cert.clone(), key)).await.expect("Failed to init directory");
println!("Directory server started on port IN TEST FN {}", port);
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();

// Spawn the directory server task
let directory_task = tokio::spawn(async move {
if let Err(e) = directory_future.await {
eprintln!("Directory server error: {:?}", e);
}
});
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();

let temp_dir = env::temp_dir();
let receiver_db_path = temp_dir.join("receiver_db");
let sender_db_path = temp_dir.join("sender_db");
let result: Result<()> = tokio::select! {
res = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => Err(format!("Ohttp relay is long running: {:?}", res).into()),
res = init_directory(directory_port, (cert.clone(), key)) => Err(format!("Directory server is long running: {:?}", res).into()),
res = directory_task => Err(format!("Directory server is long running: {:?}", res).into()),
res = send_receive_cli_async(ohttp_relay, directory, cert, receiver_db_path.clone(), sender_db_path.clone()) => res.map_err(|e| format!("send_receive failed: {:?}", e).into()),
};

Expand Down Expand Up @@ -476,13 +489,17 @@ mod e2e {
Err("Timeout waiting for service to be ready".into())
}

async fn init_directory(port: u16, local_cert_key: (Vec<u8>, Vec<u8>)) -> Result<()> {
let docker: Cli = Cli::default();
async fn init_directory(
db_host: String,
local_cert_key: (Vec<u8>, Vec<u8>),
) -> std::result::Result<
(u16, tokio::task::JoinHandle<std::result::Result<(), BoxSendSyncError>>),
BoxSendSyncError,
> {
println!("Database running on {}", db_host);
let timeout = Duration::from_secs(2);
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key)
.await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down
99 changes: 64 additions & 35 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,65 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message"
mod db;
use crate::db::DbPool;

type BoxError = Box<dyn std::error::Error + Send + Sync>;

#[cfg(feature = "danger-local-https")]
pub async fn listen_tcp_with_tls_on_free_port(
db_host: String,
timeout: Duration,
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxError>>), BoxError> {
let listener = tokio::net::TcpListener::bind("[::]:0").await?;
let port = listener.local_addr()?.port();
println!("Directory server binding to port {}", listener.local_addr()?);
let handle = listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await?;
Ok((port, handle))
}

// Helper function to avoid code duplication
async fn listen_tcp_with_tls_on_listener(
listener: tokio::net::TcpListener,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let tls_acceptor = init_tls_acceptor(tls_config)?;
// Spawn the connection handling loop in a separate task
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}
Ok(())
});
Ok(handle)
}

// Modify existing listen_tcp_with_tls to use the new helper
pub async fn listen_tcp(
port: u16,
db_host: String,
Expand Down Expand Up @@ -74,41 +133,11 @@ pub async fn listen_tcp_with_tls(
port: u16,
db_host: String,
timeout: Duration,
tls_config: (Vec<u8>, Vec<u8>),
) -> Result<(), Box<dyn std::error::Error>> {
let pool = DbPool::new(timeout, db_host).await?;
let ohttp = Arc::new(Mutex::new(init_ohttp()?));
let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let tls_acceptor = init_tls_acceptor(tls_config)?;
let listener = TcpListener::bind(bind_addr).await?;
while let Ok((stream, _)) = listener.accept().await {
let pool = pool.clone();
let ohttp = ohttp.clone();
let tls_acceptor = tls_acceptor.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(stream).await {
Ok(tls_stream) => tls_stream,
Err(e) => {
error!("TLS accept error: {}", e);
return;
}
};
if let Err(err) = http1::Builder::new()
.serve_connection(
TokioIo::new(tls_stream),
service_fn(move |req| {
serve_payjoin_directory(req, pool.clone(), ohttp.clone())
}),
)
.with_upgrades()
.await
{
error!("Error serving connection: {:?}", err);
}
});
}

Ok(())
cert_key: (Vec<u8>, Vec<u8>),
) -> Result<tokio::task::JoinHandle<Result<(), BoxError>>, BoxError> {
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key).await
}

#[cfg(feature = "danger-local-https")]
Expand Down
77 changes: 59 additions & 18 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ mod integration {

use super::*;

type BoxSendSyncError = Box<dyn std::error::Error + Send + Sync>;

static TESTS_TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(20));
static WAIT_SERVICE_INTERVAL: Lazy<Duration> = Lazy::new(|| Duration::from_secs(3));

Expand All @@ -197,10 +199,25 @@ mod integration {
.expect("Invalid OhttpKeys");

let (cert, key) = local_cert_key();
let port = find_free_port();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));

let (port, directory_future) = init_directory(db_host, (cert.clone(), key))
.await
.expect("Failed to init directory");
println!("Directory server started on port IN TEST FN {}", port);
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();

// Spawn the directory server task
let directory_task = tokio::spawn(async move {
if let Err(e) = directory_future.await {
eprintln!("Directory server error: {:?}", e);
}
});

tokio::select!(
err = init_directory(port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = directory_task => panic!("Directory server exited early: {:?}", err),
res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => {
assert_eq!(
res.unwrap().headers().get("content-type").unwrap(),
Expand All @@ -214,8 +231,10 @@ mod integration {
bad_ohttp_keys: OhttpKeys,
cert_der: Vec<u8>,
) -> Result<Response, Error> {
println!("Trying request with bad keys");
let agent = Arc::new(http_agent(cert_der.clone()).unwrap());
wait_for_service_ready(directory.clone(), agent.clone()).await.unwrap();
println!("Service ready");
let mock_ohttp_relay = directory.clone(); // pass through to directory
let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4")
.unwrap()
Expand All @@ -234,12 +253,18 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));

let (directory_port, directory_handle) = init_directory(db_host, (cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = directory_handle => panic!("Directory server exited early: {:?}", err),
res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -303,12 +328,18 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));

let (directory_port, directory_future) = init_directory(db_host, (cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = directory_future => panic!("Directory server exited early: {:?}", err),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -432,12 +463,18 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));

let (directory_port, directory_future) = init_directory(db_host, (cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = directory_future => panic!("Directory server exited early: {:?}", err),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -644,12 +681,17 @@ mod integration {
let ohttp_relay_port = find_free_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let docker: Cli = Cli::default();
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
let (directory_port, directory_future) = init_directory(db_host, (cert.clone(), key))
.await
.expect("Failed to init directory");
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = directory_future => panic!("Directory server exited early: {:?}", err),
res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()),
);

Expand Down Expand Up @@ -771,15 +813,14 @@ mod integration {
}

async fn init_directory(
port: u16,
db_host: String,
local_cert_key: (Vec<u8>, Vec<u8>),
) -> Result<(), BoxError> {
let docker: Cli = Cli::default();
) -> Result<(u16, tokio::task::JoinHandle<Result<(), BoxSendSyncError>>), BoxSendSyncError>
{
println!("Database running on {}", db_host);
let timeout = Duration::from_secs(2);
let db = docker.run(Redis);
let db_host = format!("127.0.0.1:{}", db.get_host_port_ipv4(6379));
println!("Database running on {}", db.get_host_port_ipv4(6379));
payjoin_directory::listen_tcp_with_tls(port, db_host, timeout, local_cert_key).await
payjoin_directory::listen_tcp_with_tls_on_free_port(db_host, timeout, local_cert_key)
.await
}

// generates or gets a DER encoded localhost cert and key.
Expand Down Expand Up @@ -920,7 +961,7 @@ mod integration {
while start.elapsed() < *TESTS_TIMEOUT {
let request_result =
agent.get(health_url.as_str()).send().await.map_err(|_| "Bad request")?;

println!("awaiting Service ready: {:?}", request_result.status());
match request_result.status() {
StatusCode::OK => return Ok(()),
StatusCode::NOT_FOUND => return Err("Endpoint not found"),
Expand Down

0 comments on commit 834859c

Please sign in to comment.