diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index ce719b9ed..3c423c33c 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -33,10 +33,10 @@ use crate::{ }, executor::IpaRuntime, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, TransportIdentity, }, - net::{http_serde, Error, CRYPTO_PROVIDER}, + net::{error::ShardQueryStatusMismatchError, http_serde, Error, CRYPTO_PROVIDER}, protocol::{Gate, QueryId}, }; @@ -509,6 +509,31 @@ impl IpaHttpClient { }) .collect() } + + /// This API is used by leader shards in MPC to request query status information on peers. + /// If a given peer has status that doesn't match the one provided by the leader, it responds + /// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned. + /// + /// # Errors + /// If the request has illegal arguments, or fails to be delivered + pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> { + let req = http_serde::query::status_match::try_into_http_request( + &data, + self.scheme.clone(), + self.authority.clone(), + )?; + let resp = self.request(req).await?; + + match resp.status() { + StatusCode::OK => Ok(()), + StatusCode::PRECONDITION_FAILED => { + let bytes = response_to_bytes(resp).await?; + let err = serde_json::from_slice::(&bytes)?; + Err(err.into()) + } + _ => Err(Error::from_failed_resp(resp).await), + } + } } fn make_http_connector() -> HttpConnector { @@ -537,7 +562,7 @@ pub(crate) mod tests { ff::{FieldType, Fp31}, helpers::{ make_owned_handler, query::QueryType::TestMultiply, BytesStream, HelperIdentity, - HelperResponse, RequestHandler, RoleAssignment, Transport, MESSAGE_PAYLOAD_SIZE_BYTES, + HelperResponse, RequestHandler, RoleAssignment, MESSAGE_PAYLOAD_SIZE_BYTES, }, net::test::TestServer, protocol::step::TestExecutionStep, @@ -734,7 +759,7 @@ pub(crate) mod tests { resp_ok(resp).await.unwrap(); let mut stream = transport - .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) + .receive(HelperIdentity::ONE, &(QueryId, expected_step.clone())) .into_bytes_stream(); assert_eq!( diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index 6a04e8282..e5f188158 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -4,7 +4,8 @@ use axum::{ }; use crate::{ - error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, sharding::ShardIndex, + error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, query::QueryStatus, + sharding::ShardIndex, }; #[derive(thiserror::Error, Debug)] @@ -59,8 +60,13 @@ pub enum Error { #[source] inner: hyper_util::client::legacy::Error, }, - #[error("{error}")] + #[error("{code}: {error}")] Application { code: StatusCode, error: BoxError }, + #[error(transparent)] + ShardQueryStatusMismatch { + #[from] + error: ShardQueryStatusMismatchError, + }, } impl Error { @@ -142,6 +148,12 @@ pub struct ShardError { pub source: Error, } +#[derive(Debug, thiserror::Error, serde::Deserialize, serde::Serialize)] +#[error("Query status mismatch. Actual status: {actual}")] +pub struct ShardQueryStatusMismatchError { + pub actual: QueryStatus, +} + impl IntoResponse for Error { fn into_response(self) -> Response { let status_code = match self { @@ -165,6 +177,13 @@ impl IntoResponse for Error { | Self::MissingExtension(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::Application { code, .. } => code, + Self::ShardQueryStatusMismatch { error } => { + return ( + StatusCode::PRECONDITION_FAILED, + serde_json::to_string(&error).unwrap(), + ) + .into_response() + } }; (status_code, self.to_string()).into_response() } diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index d49df1b19..2b9dad085 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -608,4 +608,48 @@ pub mod query { pub const AXUM_PATH: &str = "/:query_id/kill"; } + + pub mod status_match { + use serde::{Deserialize, Serialize}; + + use crate::{helpers::query::CompareStatusRequest, query::QueryStatus}; + + #[derive(Serialize, Deserialize)] + pub struct StatusQueryString { + pub status: QueryStatus, + } + + impl StatusQueryString { + fn url_encode(&self) -> String { + // todo: serde urlencoded + format!("status={}", self.status) + } + } + + impl From for StatusQueryString { + fn from(value: QueryStatus) -> Self { + Self { status: value } + } + } + + pub fn try_into_http_request( + req: &CompareStatusRequest, + scheme: axum::http::uri::Scheme, + authority: axum::http::uri::Authority, + ) -> crate::net::http_serde::OutgoingRequest { + let uri = axum::http::uri::Uri::builder() + .scheme(scheme) + .authority(authority) + .path_and_query(format!( + "{}/{}/status-match?{}", + crate::net::http_serde::query::BASE_AXUM_PATH, + req.query_id.as_ref(), + StatusQueryString::from(req.status).url_encode(), + )) + .build()?; + Ok(hyper::Request::get(uri).body(axum::body::Body::empty())?) + } + + pub const AXUM_PATH: &str = "/:query_id/status-match"; + } } diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index dc99ebff5..c8ab75875 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -3,18 +3,21 @@ mod query; use axum::Router; -use crate::net::{http_serde, transport::MpcHttpTransport, ShardHttpTransport}; +use crate::{ + net::{http_serde, transport::MpcHttpTransport, HttpTransport, Shard}, + sync::Arc, +}; pub fn mpc_router(transport: MpcHttpTransport) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, Router::new() .merge(query::query_router(transport.clone())) - .merge(query::h2h_router(transport)), + .merge(query::h2h_router(transport.inner_transport)), ) } -pub fn shard_router(transport: ShardHttpTransport) -> Router { +pub fn shard_router(transport: Arc>) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, Router::new().merge(query::s2s_router(transport)), diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 55fe5c054..8a9881bb7 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -4,6 +4,7 @@ mod kill; mod prepare; mod results; mod status; +mod status_match; mod step; use std::marker::PhantomData; @@ -21,8 +22,8 @@ use tower::{layer::layer_fn, Service}; use crate::{ net::{ - server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, Shard, - ShardHttpTransport, + server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, + HttpTransport, Shard, }, sync::Arc, }; @@ -48,19 +49,20 @@ pub fn query_router(transport: MpcHttpTransport) -> Router { /// particular query, to coordinate servicing that query. // // It might make sense to split the query and h2h handlers into two modules. -pub fn h2h_router(transport: MpcHttpTransport) -> Router { +pub fn h2h_router(transport: Arc>) -> Router { Router::new() - .merge(step::router(Arc::clone(&transport.inner_transport))) - .merge(prepare::router(transport.inner_transport)) + .merge(step::router(Arc::clone(&transport))) + .merge(prepare::router(transport)) .layer(layer_fn(HelperAuthentication::<_, Helper>::new)) } /// Construct router for shard-to-shard communications similar to [`h2h_router`]. -pub fn s2s_router(transport: ShardHttpTransport) -> Router { +pub fn s2s_router(transport: Arc>) -> Router { Router::new() - .merge(step::router(Arc::clone(&transport.inner_transport))) - .merge(prepare::router(Arc::clone(&transport.inner_transport))) - .merge(results::router(transport.inner_transport)) + .merge(step::router(Arc::clone(&transport))) + .merge(prepare::router(Arc::clone(&transport))) + .merge(results::router(Arc::clone(&transport))) + .merge(status_match::router(transport)) .layer(layer_fn(HelperAuthentication::<_, Shard>::new)) } @@ -125,12 +127,11 @@ pub mod test_helpers { use std::{any::Any, sync::Arc}; use axum::body::Body; - use http_body_util::BodyExt; use hyper::{http::request, StatusCode}; use crate::{ helpers::{HelperIdentity, RequestHandler}, - net::test::TestServer, + net::{test::TestServer, Helper}, }; /// Helper trait for optionally adding an extension to a request. @@ -178,14 +179,6 @@ pub mod test_helpers { req: hyper::Request, handler: Arc>, ) -> bytes::Bytes { - let test_server = TestServer::builder() - .with_request_handler(handler) - .build() - .await; - let resp = test_server.server.handle_req(req).await; - let status = resp.status(); - assert_eq!(StatusCode::OK, status); - - resp.into_body().collect().await.unwrap().to_bytes() + TestServer::::oneshot_success(req, handler).await } } diff --git a/ipa-core/src/net/server/handlers/query/status_match.rs b/ipa-core/src/net/server/handlers/query/status_match.rs new file mode 100644 index 000000000..5b2081c5e --- /dev/null +++ b/ipa-core/src/net/server/handlers/query/status_match.rs @@ -0,0 +1,227 @@ +use axum::{ + extract::{Path, Query}, + routing::get, + Extension, Router, +}; +use hyper::StatusCode; + +use crate::{ + helpers::{query::CompareStatusRequest, ApiError, BodyStream}, + net::{ + http_serde::query::status_match::{ + StatusQueryString, {self}, + }, + server::Error, + HttpTransport, Shard, + }, + protocol::QueryId, + query::QueryStatusError, + sync::Arc, +}; + +async fn handler( + transport: Extension>>, + Path(query_id): Path, + Query(StatusQueryString { status }): Query, +) -> Result<(), Error> { + let req = CompareStatusRequest { query_id, status }; + match Arc::clone(&transport) + .dispatch(req, BodyStream::empty()) + .await + { + Ok(_) => Ok(()), + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { my_status, .. })) => { + Err(crate::net::error::ShardQueryStatusMismatchError { actual: my_status }.into()) + } + Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), + } +} + +pub fn router(transport: Arc>) -> Router { + Router::new() + .route(status_match::AXUM_PATH, get(handler)) + .layer(Extension(transport)) +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{borrow::Borrow, sync::Arc}; + + use axum::{ + body::Body, + http::uri::{Authority, Scheme}, + }; + use hyper::StatusCode; + + use crate::{ + helpers::{ + make_owned_handler, + query::CompareStatusRequest, + routing::{Addr, RouteId}, + ApiError, BodyStream, HelperResponse, RequestHandler, + }, + net::{ + error::ShardQueryStatusMismatchError, + http_serde::query::status_match::try_into_http_request, + server::ClientIdentity, + test::{TestServer, TestServerBuilder}, + Error, Shard, + }, + protocol::QueryId, + query::{QueryStatus, QueryStatusError}, + sharding::ShardIndex, + }; + + fn for_status(status: QueryStatus) -> CompareStatusRequest { + CompareStatusRequest { + query_id: QueryId, + status, + } + } + + fn http_request>(req: B) -> hyper::Request { + try_into_http_request( + req.borrow(), + Scheme::HTTP, + Authority::from_static("localhost"), + ) + .unwrap() + } + + fn authenticated(mut req: hyper::Request) -> hyper::Request { + req.extensions_mut() + .insert(ClientIdentity(ShardIndex::from(2))); + req + } + + fn handler_status_match(expected_status: QueryStatus) -> Arc> { + make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { + let RouteId::QueryStatus = addr.route else { + panic!("unexpected call"); + }; + let req = addr.into::().unwrap(); + assert_eq!(req.query_id, QueryId); + assert_eq!(req.status, expected_status); + Ok(HelperResponse::ok()) + }, + ) + } + + fn handler_status_mismatch( + expected_status: QueryStatus, + ) -> Arc> { + assert_ne!(expected_status, QueryStatus::Running); + + make_owned_handler( + move |addr: Addr, _data: BodyStream| async move { + let RouteId::QueryStatus = addr.route else { + panic!("unexpected call"); + }; + let req = addr.into::().unwrap(); + assert_eq!(req.query_id, QueryId); + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + query_id: QueryId, + my_status: QueryStatus::Running, + other_status: expected_status, + })) + }, + ) + } + + #[tokio::test] + async fn status_success() { + let expected_status = QueryStatus::Running; + let req = authenticated(http_request(for_status(expected_status))); + + TestServer::::oneshot_success(req, handler_status_match(expected_status)).await; + } + + #[tokio::test] + async fn status_client_success() { + let expected_status = QueryStatus::Running; + let test_server = TestServerBuilder::::default() + .with_request_handler(handler_status_match(expected_status)) + .build() + .await; + + test_server + .client + .status_match(for_status(expected_status)) + .await + .unwrap(); + } + + #[tokio::test] + async fn status_client_mismatch() { + let diff_status = QueryStatus::Preparing; + let test_server = TestServerBuilder::::default() + .with_request_handler(handler_status_mismatch(diff_status)) + .build() + .await; + let e = test_server + .client + .status_match(for_status(diff_status)) + .await + .unwrap_err(); + assert!(matches!( + e, + Error::ShardQueryStatusMismatch { + error: ShardQueryStatusMismatchError { + actual: QueryStatus::Running + }, + } + )); + } + + #[tokio::test] + async fn status_mismatch() { + let req_status = QueryStatus::Completed; + let handler = handler_status_mismatch(req_status); + let req = authenticated(http_request(for_status(req_status))); + + let resp = TestServer::::oneshot(req, handler).await; + assert_eq!(StatusCode::PRECONDITION_FAILED, resp.status()); + } + + #[tokio::test] + async fn other_query_error() { + let handler = make_owned_handler( + move |_addr: Addr, _data: BodyStream| async move { + Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( + QueryId, + ))) + }, + ); + let req = authenticated(http_request(for_status(QueryStatus::Running))); + + let resp = TestServer::::oneshot(req, handler).await; + assert_eq!(StatusCode::INTERNAL_SERVER_ERROR, resp.status()); + } + + #[tokio::test] + async fn unauthenticated() { + assert_eq!( + StatusCode::UNAUTHORIZED, + TestServer::::oneshot( + http_request(for_status(QueryStatus::Running)), + make_owned_handler(|_, _| async move { unimplemented!() }), + ) + .await + .status() + ); + } + + #[tokio::test] + async fn server_error() { + assert_eq!( + StatusCode::INTERNAL_SERVER_ERROR, + TestServer::::oneshot( + authenticated(http_request(for_status(QueryStatus::Running))), + make_owned_handler(|_, _| async move { Err(ApiError::BadRequest("".into())) }), + ) + .await + .status() + ); + } +} diff --git a/ipa-core/src/net/server/handlers/query/step.rs b/ipa-core/src/net/server/handlers/query/step.rs index f0e537acd..4128e933a 100644 --- a/ipa-core/src/net/server/handlers/query/step.rs +++ b/ipa-core/src/net/server/handlers/query/step.rs @@ -40,7 +40,7 @@ mod tests { use super::*; use crate::{ - helpers::{HelperIdentity, Transport, MESSAGE_PAYLOAD_SIZE_BYTES}, + helpers::{HelperIdentity, MESSAGE_PAYLOAD_SIZE_BYTES}, net::{ server::handlers::query::test_helpers::{assert_fails_with, MaybeExtensionExt}, test::TestServer, @@ -66,7 +66,7 @@ mod tests { let mut stream = test_server .transport - .receive(HelperIdentity::TWO, (QueryId, step)) + .receive(HelperIdentity::TWO, &(QueryId, step)) .into_bytes_stream(); assert_eq!( diff --git a/ipa-core/src/net/server/mod.rs b/ipa-core/src/net/server/mod.rs index 88df2cd60..bcf855d5e 100644 --- a/ipa-core/src/net/server/mod.rs +++ b/ipa-core/src/net/server/mod.rs @@ -39,10 +39,7 @@ use tower::{layer::layer_fn, Service}; use tower_http::trace::TraceLayer; use tracing::{error, Span}; -use super::{ - transport::{MpcHttpTransport, ShardHttpTransport}, - Shard, -}; +use super::{transport::MpcHttpTransport, HttpTransport, Shard}; use crate::{ config::{ NetworkConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig, ServerConfig, TlsConfig, @@ -95,11 +92,13 @@ pub struct IpaHttpServer { impl IpaHttpServer { #[must_use] pub fn new_mpc( - transport: &MpcHttpTransport, + transport: Arc>, config: ServerConfig, network_config: NetworkConfig, ) -> Self { - let router = handlers::mpc_router(transport.clone()); + let router = handlers::mpc_router(MpcHttpTransport { + inner_transport: transport, + }); IpaHttpServer { config, network_config, @@ -111,11 +110,11 @@ impl IpaHttpServer { impl IpaHttpServer { #[must_use] pub fn new_shards( - transport: &ShardHttpTransport, + transport: Arc>, config: ServerConfig, network_config: NetworkConfig, ) -> Self { - let router = handlers::shard_router(transport.clone()); + let router = handlers::shard_router(transport); IpaHttpServer { config, network_config, @@ -126,7 +125,10 @@ impl IpaHttpServer { impl IpaHttpServer { #[cfg(all(test, unit_test))] - async fn handle_req(&self, req: hyper::Request) -> axum::response::Response { + pub(crate) async fn handle_req( + &self, + req: hyper::Request, + ) -> axum::response::Response { use tower::ServiceExt; self.router.clone().oneshot(req).await.unwrap() } diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index c74c25610..ecb654de6 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -15,17 +15,21 @@ use std::{ ops::Index, }; +#[cfg(all(test, unit_test))] +use http_body_util::BodyExt; +#[cfg(all(test, unit_test))] +use hyper::StatusCode; use once_cell::sync::Lazy; use rustls_pki_types::CertificateDer; -use super::{transport::MpcHttpTransport, ConnectionFlavor, Shard}; +use super::{ConnectionFlavor, HttpTransport, Shard}; use crate::{ config::{ ClientConfig, HpkeClientConfig, HpkeServerConfig, NetworkConfig, PeerConfig, ServerConfig, TlsConfig, }, - executor::{IpaJoinHandle, IpaRuntime}, - helpers::{HandlerBox, HelperIdentity, RequestHandler, TransportIdentity}, + executor::IpaRuntime, + helpers::{HandlerBox, HelperIdentity, RequestHandler, StreamCollection, TransportIdentity}, hpke::{Deserializable as _, IpaPublicKey}, net::{ClientIdentity, Helper, IpaHttpClient, IpaHttpServer}, sharding::{ShardIndex, ShardedHelperIdentity}, @@ -229,7 +233,7 @@ impl TestApp { &self.mpc_network_config, &identities.helper, ); - let (transport, server) = MpcHttpTransport::new( + let (transport, server) = crate::net::MpcHttpTransport::new( IpaRuntime::current(), sid.helper_identity, self.mpc_server.config, @@ -465,16 +469,34 @@ impl TestConfigBuilder { TestConfig::new(self) } } -pub struct TestServer { +pub struct TestServer { pub addr: SocketAddr, - pub handle: IpaJoinHandle<()>, - pub transport: MpcHttpTransport, - pub server: IpaHttpServer, - pub client: IpaHttpClient, - pub request_handler: Option>>, + pub transport: Arc>, + pub server: IpaHttpServer, + pub client: IpaHttpClient, + pub request_handler: Option>>, } -impl TestServer { +impl TestServer { + fn new( + addr: SocketAddr, + transport: Arc>, + server: IpaHttpServer, + request_handler: Option>>, + ) -> Self { + // pick the first client because it is the one that will be used to talk to this server + let client = transport.clients.first().unwrap().clone(); + Self { + addr, + transport, + server, + client, + request_handler, + } + } +} + +impl TestServer { /// Build default set of test clients /// /// All three clients will be configured with the same default server URL, thus, @@ -488,23 +510,72 @@ impl TestServer { pub fn builder() -> TestServerBuilder { TestServerBuilder::default() } + + #[cfg(all(test, unit_test))] + pub async fn oneshot_success( + req: hyper::Request, + handler: Arc>, + ) -> bytes::Bytes { + let test_server = TestServerBuilder::::default() + .with_request_handler(handler) + .build() + .await; + let resp = test_server.server.handle_req(req).await; + let status = resp.status(); + assert_eq!(StatusCode::OK, status); + + resp.into_body().collect().await.unwrap().to_bytes() + } } -#[derive(Default)] -pub struct TestServerBuilder { - handler: Option>>, +impl TestServer { + #[cfg(all(test, unit_test))] + pub async fn oneshot( + req: hyper::Request, + handler: Arc>, + ) -> hyper::Response { + let test_server = TestServerBuilder::::default() + .with_request_handler(handler) + .build() + .await; + test_server.server.handle_req(req).await + } + + #[cfg(all(test, unit_test))] + pub async fn oneshot_success( + req: hyper::Request, + handler: Arc>, + ) -> bytes::Bytes { + let resp = Self::oneshot(req, handler).await; + let status = resp.status(); + assert_eq!(StatusCode::OK, status); + + resp.into_body().collect().await.unwrap().to_bytes() + } +} +pub struct TestServerBuilder { + handler: Option>>, metrics: Option, disable_https: bool, use_http1: bool, disable_matchkey_encryption: bool, } -impl TestServerBuilder { +impl Default for TestServerBuilder { + fn default() -> Self { + Self { + handler: None, + metrics: None, + disable_https: false, + use_http1: false, + disable_matchkey_encryption: false, + } + } +} + +impl TestServerBuilder { #[must_use] - pub fn with_request_handler( - mut self, - handler: Arc>, - ) -> Self { + pub fn with_request_handler(mut self, handler: Arc>) -> Self { self.handler = Some(handler); self } @@ -537,42 +608,157 @@ impl TestServerBuilder { self } - pub async fn build(self) -> TestServer { - let identities = - ClientIdentities::new(self.disable_https, ShardedHelperIdentity::ONE_FIRST); - let mut test_config = TestConfig::builder() + fn test_config(&self) -> TestConfig { + TestConfig::builder() .with_disable_https_option(self.disable_https) .with_use_http1_option(self.use_http1) // TODO: add disble_matchkey here - .build(); - let leaders_ring = test_config.rings.pop().unwrap(); - let first_server = leaders_ring.servers.into_iter().next().unwrap(); - let clients = IpaHttpClient::from_conf( - &IpaRuntime::current(), - &leaders_ring.network, - &identities.helper.clone_with_key(), - ); - let handler = self.handler.as_ref().map(HandlerBox::owning_ref); - let client = clients[0].clone(); - let (transport, server) = MpcHttpTransport::new( - IpaRuntime::current(), - HelperIdentity::ONE, - first_server.config, - leaders_ring.network, - &clients, + .build() + } +} + +trait TestTransportConfigurator { + type Connection: ConnectionFlavor; + const IDENTITY: ::Identity; + + fn client_identity(&self) -> ClientIdentity; + + fn make_transport( + &self, + handler: Option::Identity>>>, + test_network: &TestNetwork, + ) -> Arc> { + let handler = handler.as_ref().map(HandlerBox::owning_ref); + + let clients = test_network + .network + .peers + .iter() + .map(|peer| { + IpaHttpClient::new( + IpaRuntime::current(), + &test_network.network.client, + peer.clone(), + self.client_identity(), + ) + }) + .collect::>(); + + let transport = HttpTransport { + http_runtime: IpaRuntime::current(), + identity: Self::IDENTITY, + clients, + record_streams: StreamCollection::default(), handler, + }; + + Arc::new(transport) + } +} + +/// Pick the first helper to serve as test server +impl TestTransportConfigurator for TestServerBuilder { + type Connection = Helper; + const IDENTITY: HelperIdentity = HelperIdentity::ONE; + + fn client_identity(&self) -> ClientIdentity { + ClientIdentities::new(self.disable_https, ShardedHelperIdentity::ONE_FIRST).helper + } +} + +/// Pick the first shard to serve as test server +impl TestTransportConfigurator for TestServerBuilder { + type Connection = Shard; + const IDENTITY: ShardIndex = ShardIndex::FIRST; + + fn client_identity(&self) -> ClientIdentity { + ClientIdentities::new(self.disable_https, ShardedHelperIdentity::ONE_FIRST).shard + } +} + +trait TestServerConfigurator { + type Connection: ConnectionFlavor; + + fn configure( + transport: &Arc>, + test_config: TestConfig, + ) -> (IpaHttpServer, AddressableTestServer); +} + +impl TestServerConfigurator for IpaHttpServer { + type Connection = Shard; + + fn configure( + transport: &Arc>, + test_config: TestConfig, + ) -> (IpaHttpServer, AddressableTestServer) { + let [test_network, ..] = test_config.shards; + let first_server = test_network.servers.into_iter().next().unwrap(); + let http_server = IpaHttpServer::new_shards( + Arc::clone(transport), + first_server.config.clone(), + test_network.network, ); - let (addr, handle) = server - .start_on(&IpaRuntime::current(), first_server.socket, self.metrics) + + (http_server, first_server) + } +} + +impl TestServerConfigurator for IpaHttpServer { + type Connection = Helper; + + fn configure( + transport: &Arc>, + mut test_config: TestConfig, + ) -> (IpaHttpServer, AddressableTestServer) { + let test_network = test_config.rings.pop().unwrap(); + let first_server = test_network.servers.into_iter().next().unwrap(); + let http_server = IpaHttpServer::new_mpc( + Arc::clone(transport), + first_server.config.clone(), + test_network.network, + ); + + (http_server, first_server) + } +} + +impl TestServerBuilder { + pub async fn build(self) -> TestServer { + let test_config = self.test_config(); + + let transport = self.make_transport(self.handler.clone(), &test_config.shards[0]); + let (http_server, test_server_conf) = + IpaHttpServer::::configure(&transport, test_config); + let (addr, _handle) = http_server + .start_on( + &IpaRuntime::current(), + test_server_conf.socket, + self.metrics, + ) .await; - TestServer { - addr, - handle, - transport, - server, - client, - request_handler: self.handler, - } + + TestServer::new(addr, transport, http_server, self.handler) + } +} + +impl TestServerBuilder { + pub async fn build(self) -> TestServer { + let test_config = self.test_config(); + + let transport = + self.make_transport(self.handler.clone(), test_config.rings.first().unwrap()); + let (http_server, test_server_conf) = + IpaHttpServer::::configure(&transport, test_config); + let (addr, _handle) = http_server + .start_on( + &IpaRuntime::current(), + test_server_conf.socket, + self.metrics, + ) + .await; + + TestServer::new(addr, transport, http_server, self.handler) } } diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index de980861f..173c08831 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -28,11 +28,11 @@ use crate::{ /// Shared implementation used by [`MpcHttpTransport`] and [`ShardHttpTransport`] pub struct HttpTransport { - http_runtime: IpaRuntime, - identity: F::Identity, - clients: Vec>, - record_streams: StreamCollection, - handler: Option>, + pub(super) http_runtime: IpaRuntime, + pub(super) identity: F::Identity, + pub(super) clients: Vec>, + pub(super) record_streams: StreamCollection, + pub(super) handler: Option>, } /// HTTP transport for helper to helper traffic. @@ -122,7 +122,7 @@ impl HttpTransport { } } - fn receive>( + pub(crate) fn receive>( &self, from: F::Identity, route: &R, @@ -218,18 +218,17 @@ impl MpcHttpTransport { clients: &[IpaHttpClient; 3], handler: Option>, ) -> (Self, IpaHttpServer) { - let transport = Self { - inner_transport: Arc::new(HttpTransport { - http_runtime, - identity, - clients: clients.to_vec(), - handler, - record_streams: StreamCollection::default(), - }), - }; + let inner_transport = Arc::new(HttpTransport { + http_runtime, + identity, + clients: clients.to_vec(), + handler, + record_streams: StreamCollection::default(), + }); - let server = IpaHttpServer::new_mpc(&transport, server_config, network_config); - (transport, server) + let server = + IpaHttpServer::new_mpc(Arc::clone(&inner_transport), server_config, network_config); + (Self { inner_transport }, server) } /// Connect an inbound stream of record data. @@ -326,19 +325,23 @@ impl ShardHttpTransport { clients: Vec>, handler: Option>, ) -> (Self, IpaHttpServer) { - let transport = Self { - inner_transport: Arc::new(HttpTransport { - http_runtime, - identity: shard_id, - clients, - handler, - record_streams: StreamCollection::default(), - }), - shard_count, - }; + let inner_transport = Arc::new(HttpTransport { + http_runtime, + identity: shard_id, + clients, + handler, + record_streams: StreamCollection::default(), + }); - let server = IpaHttpServer::new_shards(&transport, server_config, network_config); - (transport, server) + let server = + IpaHttpServer::new_shards(Arc::clone(&inner_transport), server_config, network_config); + ( + Self { + inner_transport, + shard_count, + }, + server, + ) } } @@ -439,18 +442,18 @@ mod tests { .build() .await; - transport.inner_transport.record_streams.add_stream( + transport.record_streams.add_stream( (QueryId, HelperIdentity::ONE, Gate::default()), BodyStream::empty(), ); - assert_eq!(1, transport.inner_transport.record_streams.len()); + assert_eq!(1, transport.record_streams.len()); - Transport::clone_ref(&transport) + Arc::clone(&transport) .dispatch((RouteId::KillQuery, QueryId), BodyStream::empty()) .await .unwrap(); - assert!(transport.inner_transport.record_streams.is_empty()); + assert!(transport.record_streams.is_empty()); } #[tokio::test] @@ -468,7 +471,7 @@ mod tests { // Request step data reception (normally called by protocol) let mut stream = transport - .receive(HelperIdentity::TWO, (QueryId, STEP.clone())) + .receive(HelperIdentity::TWO, &(QueryId, STEP.clone())) .into_bytes_stream(); // make sure it is not ready as it hasn't received any data yet. diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 28f981222..bca5c7e1d 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -1,6 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, - fmt::{Debug, Formatter}, + fmt::{Debug, Display, Formatter}, future::Future, task::Poll, }; @@ -35,6 +35,12 @@ pub enum QueryStatus { Completed, } +impl Display for QueryStatus { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + impl From<&QueryState> for QueryStatus { fn from(source: &QueryState) -> Self { match source {