diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index d4789b198..3c423c33c 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -384,31 +384,6 @@ impl IpaHttpClient { let resp = self.request(req).await?; resp_ok(resp).await } - - /// This API is used by leader shards in MPC to request query status information on peers. - /// If a given peer has status that doesn't match the one provided by the leader, it responds - /// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned. - /// - /// # Errors - /// If the request has illegal arguments, or fails to be delivered - pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> { - let req = http_serde::query::status_match::try_into_http_request( - &data, - self.scheme.clone(), - self.authority.clone(), - )?; - let resp = self.request(req).await?; - - match resp.status() { - StatusCode::OK => Ok(()), - StatusCode::PRECONDITION_FAILED => { - let bytes = response_to_bytes(resp).await?; - let err = serde_json::from_slice::(&bytes)?; - Err(err.into()) - } - _ => Err(Error::from_failed_resp(resp).await), - } - } } impl IpaHttpClient { @@ -534,6 +509,31 @@ impl IpaHttpClient { }) .collect() } + + /// This API is used by leader shards in MPC to request query status information on peers. + /// If a given peer has status that doesn't match the one provided by the leader, it responds + /// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned. + /// + /// # Errors + /// If the request has illegal arguments, or fails to be delivered + pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> { + let req = http_serde::query::status_match::try_into_http_request( + &data, + self.scheme.clone(), + self.authority.clone(), + )?; + let resp = self.request(req).await?; + + match resp.status() { + StatusCode::OK => Ok(()), + StatusCode::PRECONDITION_FAILED => { + let bytes = response_to_bytes(resp).await?; + let err = serde_json::from_slice::(&bytes)?; + Err(err.into()) + } + _ => Err(Error::from_failed_resp(resp).await), + } + } } fn make_http_connector() -> HttpConnector { diff --git a/ipa-core/src/net/server/handlers/query/status_match.rs b/ipa-core/src/net/server/handlers/query/status_match.rs index 716743bba..5b2081c5e 100644 --- a/ipa-core/src/net/server/handlers/query/status_match.rs +++ b/ipa-core/src/net/server/handlers/query/status_match.rs @@ -45,7 +45,7 @@ pub fn router(transport: Arc>) -> Router { #[cfg(all(test, unit_test))] mod tests { - use std::borrow::Borrow; + use std::{borrow::Borrow, sync::Arc}; use axum::{ body::Body, @@ -58,11 +58,14 @@ mod tests { make_owned_handler, query::CompareStatusRequest, routing::{Addr, RouteId}, - ApiError, BodyStream, HelperResponse, + ApiError, BodyStream, HelperResponse, RequestHandler, }, net::{ - http_serde::query::status_match::try_into_http_request, server::ClientIdentity, - test::TestServer, Shard, + error::ShardQueryStatusMismatchError, + http_serde::query::status_match::try_into_http_request, + server::ClientIdentity, + test::{TestServer, TestServerBuilder}, + Error, Shard, }, protocol::QueryId, query::{QueryStatus, QueryStatusError}, @@ -91,42 +94,90 @@ mod tests { req } - #[tokio::test] - async fn status_success() { - let expected_status = QueryStatus::Running; - let expected_query_id = QueryId; - - let handler = make_owned_handler( + fn handler_status_match(expected_status: QueryStatus) -> Arc> { + make_owned_handler( move |addr: Addr, _data: BodyStream| async move { let RouteId::QueryStatus = addr.route else { panic!("unexpected call"); }; let req = addr.into::().unwrap(); - assert_eq!(req.query_id, expected_query_id); + assert_eq!(req.query_id, QueryId); assert_eq!(req.status, expected_status); Ok(HelperResponse::ok()) }, - ); - let req = authenticated(http_request(for_status(expected_status))); - - TestServer::::oneshot_success(req, handler).await; + ) } - #[tokio::test] - async fn status_mismatch() { - let req_status = QueryStatus::Completed; - let handler = make_owned_handler( + fn handler_status_mismatch( + expected_status: QueryStatus, + ) -> Arc> { + assert_ne!(expected_status, QueryStatus::Running); + + make_owned_handler( move |addr: Addr, _data: BodyStream| async move { let RouteId::QueryStatus = addr.route else { panic!("unexpected call"); }; + let req = addr.into::().unwrap(); + assert_eq!(req.query_id, QueryId); Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { query_id: QueryId, my_status: QueryStatus::Running, - other_status: req_status, + other_status: expected_status, })) }, - ); + ) + } + + #[tokio::test] + async fn status_success() { + let expected_status = QueryStatus::Running; + let req = authenticated(http_request(for_status(expected_status))); + + TestServer::::oneshot_success(req, handler_status_match(expected_status)).await; + } + + #[tokio::test] + async fn status_client_success() { + let expected_status = QueryStatus::Running; + let test_server = TestServerBuilder::::default() + .with_request_handler(handler_status_match(expected_status)) + .build() + .await; + + test_server + .client + .status_match(for_status(expected_status)) + .await + .unwrap(); + } + + #[tokio::test] + async fn status_client_mismatch() { + let diff_status = QueryStatus::Preparing; + let test_server = TestServerBuilder::::default() + .with_request_handler(handler_status_mismatch(diff_status)) + .build() + .await; + let e = test_server + .client + .status_match(for_status(diff_status)) + .await + .unwrap_err(); + assert!(matches!( + e, + Error::ShardQueryStatusMismatch { + error: ShardQueryStatusMismatchError { + actual: QueryStatus::Running + }, + } + )); + } + + #[tokio::test] + async fn status_mismatch() { + let req_status = QueryStatus::Completed; + let handler = handler_status_mismatch(req_status); let req = authenticated(http_request(for_status(req_status))); let resp = TestServer::::oneshot(req, handler).await;