From 189c7770350cc16ebe30391e2e5c2f9d3ea0bd5c Mon Sep 17 00:00:00 2001 From: Zahari Dichev Date: Thu, 21 Dec 2023 11:01:14 +0000 Subject: [PATCH] alter tests Signed-off-by: Zahari Dichev --- Cargo.lock | 1 + linkerd/proxy/spire-client/Cargo.toml | 1 + linkerd/proxy/spire-client/src/client.rs | 1 - linkerd/proxy/spire-client/src/lib.rs | 137 ++++++++++++++++++----- 4 files changed, 109 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 644b3b1600..23ec884397 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1893,6 +1893,7 @@ dependencies = [ "simple_asn1", "spiffe-proto", "tokio", + "tokio-test", "tonic", "tower", "tracing", diff --git a/linkerd/proxy/spire-client/Cargo.toml b/linkerd/proxy/spire-client/Cargo.toml index 83ca98ddf2..bc4ded5662 100644 --- a/linkerd/proxy/spire-client/Cargo.toml +++ b/linkerd/proxy/spire-client/Cargo.toml @@ -25,3 +25,4 @@ asn1 = { version = "0.6", package = "simple_asn1" } [dev-dependencies] rcgen = "0.11.3" +tokio-test = "0.4" diff --git a/linkerd/proxy/spire-client/src/client.rs b/linkerd/proxy/spire-client/src/client.rs index 25563e3ec5..d294515e3f 100644 --- a/linkerd/proxy/spire-client/src/client.rs +++ b/linkerd/proxy/spire-client/src/client.rs @@ -39,7 +39,6 @@ impl tower::Service<()> for Client { let socket = self.socket.clone(); let backoff = self.backoff; Box::pin(async move { - //spiffe::workload_api::client::WorkloadApiClient // Strip the 'unix:' prefix for tonic compatibility. let stripped_path = socket .strip_prefix(UNIX_PREFIX) diff --git a/linkerd/proxy/spire-client/src/lib.rs b/linkerd/proxy/spire-client/src/lib.rs index fcb55c0146..5b8ed96189 100644 --- a/linkerd/proxy/spire-client/src/lib.rs +++ b/linkerd/proxy/spire-client/src/lib.rs @@ -91,17 +91,61 @@ mod tests { use std::collections::HashMap; use tokio::sync::watch; - fn gen_cert(subject_alt_names: Vec, serial: SerialNumber) -> DerX509 { + fn gen_svid(id: Id, subject_alt_names: Vec, serial: SerialNumber) -> Svid { let mut params = CertificateParams::default(); params.subject_alt_names = subject_alt_names; params.serial_number = Some(serial); - DerX509( - Certificate::from_params(params) - .expect("should generate cert") - .serialize_der() - .expect("should serialize"), - ) + Svid { + spiffe_id: id, + leaf: DerX509( + Certificate::from_params(params) + .expect("should generate cert") + .serialize_der() + .expect("should serialize"), + ), + private_key: Vec::default(), + intermediates: Vec::default(), + } + } + + fn svid_update(svids: Vec) -> SvidUpdate { + let mut svids_map = HashMap::default(); + for svid in svids.into_iter() { + svids_map.insert(svid.spiffe_id.clone(), svid); + } + + SvidUpdate { svids: svids_map } + } + + // TODO: use a service_fn for this mock + struct MockNewClient { + rx: watch::Receiver, + } + + impl MockNewClient { + fn new(init: SvidUpdate) -> (Self, watch::Sender) { + let (tx, rx) = watch::channel(init); + (Self { rx }, tx) + } + } + + impl tower::Service<()> for MockNewClient { + type Response = watch::Receiver; + type Error = Error; + // type Future = futures::future::BoxFuture<'static, Result>; + type Future = futures::future::BoxFuture<'static, Result>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: ()) -> Self::Future { + Box::pin(futures::future::ready(Ok(self.rx.clone()))) + } } struct MockCredentials { @@ -125,43 +169,76 @@ mod tests { } #[tokio::test(flavor = "current_thread")] - async fn valid_update() { - let serial = SerialNumber::from_slice("some-serial".as_bytes()); + async fn valid_updates() { let spiffe_san = "spiffe://some-domain/some-workload"; - let leaf = gen_cert(vec![SanType::URI(spiffe_san.into())], serial.clone()); let spiffe_id = Id::parse_uri("spiffe://some-domain/some-workload").expect("should parse"); - let (mut creds, mut rx) = MockCredentials::new(); - let svid = Svid { - spiffe_id: spiffe_id.clone(), - leaf, - private_key: Vec::default(), - intermediates: Vec::default(), - }; - let mut svids = HashMap::default(); - svids.insert(svid.spiffe_id.clone(), svid); - let update = SvidUpdate { svids }; + let (creds, mut creds_rx) = MockCredentials::new(); + + let spire = Spire::new(spiffe_id.clone(), Metrics::default()); + + let serial_1 = SerialNumber::from_slice("some-serial-1".as_bytes()); + let update_1 = svid_update(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_1.clone(), + )]); + + let (client, svid_tx) = MockNewClient::new(update_1); + tokio::spawn(spire.run(creds, client)); + + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_1)); + + let serial_2 = SerialNumber::from_slice("some-serial-2".as_bytes()); + let update_2 = svid_update(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_2.clone(), + )]); + + svid_tx.send(update_2).expect("should send"); - assert!(process_svid(&mut creds, update, &spiffe_id).is_ok()); - rx.changed().await.unwrap(); - assert!(*rx.borrow_and_update() == Some(serial)); + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_2)); } #[tokio::test(flavor = "current_thread")] async fn invalid_update() { + let spiffe_san = "spiffe://some-domain/some-workload"; let spiffe_id = Id::parse_uri("spiffe://some-domain/some-workload").expect("should parse"); - let (mut creds, mut rx) = MockCredentials::new(); - let svid = Svid { + + let (creds, mut creds_rx) = MockCredentials::new(); + + let spire = Spire::new(spiffe_id.clone(), Metrics::default()); + + let serial_1 = SerialNumber::from_slice("some-serial-1".as_bytes()); + let update_1 = svid_update(vec![gen_svid( + spiffe_id.clone(), + vec![SanType::URI(spiffe_san.into())], + serial_1.clone(), + )]); + + let (client, svid_tx) = MockNewClient::new(update_1); + tokio::spawn(spire.run(creds, client)); + + creds_rx.changed().await.unwrap(); + assert!(*creds_rx.borrow_and_update() == Some(serial_1.clone())); + + let invalid_svid = Svid { spiffe_id: spiffe_id.clone(), leaf: DerX509(Vec::default()), private_key: Vec::default(), intermediates: Vec::default(), }; - let mut svids = HashMap::default(); - svids.insert(svid.spiffe_id.clone(), svid); - let update = SvidUpdate { svids }; - assert!(process_svid(&mut creds, update, &spiffe_id).is_err()); - assert!(rx.borrow_and_update().is_none()); + let mut update_sent = svid_tx.subscribe(); + let update_2 = svid_update(vec![invalid_svid]); + svid_tx.send(update_2).expect("should send"); + + update_sent.changed().await.unwrap(); + + assert!(!creds_rx.has_changed().unwrap()); + assert!(*creds_rx.borrow_and_update() == Some(serial_1)); } }