diff --git a/Cargo.lock b/Cargo.lock index be46b7b90f..00d52cc7a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1786,6 +1786,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "linkerd-proxy-pool" +version = "0.1.0" +dependencies = [ + "futures", + "linkerd-error", + "linkerd-proxy-core", + "linkerd-stack", + "linkerd-tracing", + "parking_lot", + "pin-project", + "thiserror", + "tokio", + "tokio-stream", + "tokio-test", + "tokio-util", + "tower-test", + "tracing", +] + [[package]] name = "linkerd-proxy-resolve" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 118ce889c2..a00b13f1f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ members = [ "linkerd/proxy/dns-resolve", "linkerd/proxy/http", "linkerd/proxy/identity-client", + "linkerd/proxy/pool", "linkerd/proxy/resolve", "linkerd/proxy/server-policy", "linkerd/proxy/tap", diff --git a/linkerd/proxy/pool/Cargo.toml b/linkerd/proxy/pool/Cargo.toml new file mode 100644 index 0000000000..ab1f8ac7eb --- /dev/null +++ b/linkerd/proxy/pool/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "linkerd-proxy-pool" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +futures = { version = "0.3", default-features = false } +linkerd-error = { path = "../../error" } +# linkerd-metrics = { path = "../../metrics" } +linkerd-proxy-core = { path = "../core" } +linkerd-stack = { path = "../../stack" } +parking_lot = "0.12" +pin-project = "1" +# rand = "0.8" +thiserror = "1" +tokio = { version = "1", features = ["rt", "sync", "time"] } +# tokio-stream = { version = "0.1", features = ["sync"] } +tokio-util = "0.7" +tracing = "0.1" + +[dev-dependencies] +linkerd-tracing = { path = "../../tracing" } +tokio-stream = { version = "0.1", features = ["sync"] } +tokio-test = "0.4" +tower-test = "0.4" diff --git a/linkerd/proxy/pool/src/error.rs b/linkerd/proxy/pool/src/error.rs new file mode 100644 index 0000000000..c67ac5802a --- /dev/null +++ b/linkerd/proxy/pool/src/error.rs @@ -0,0 +1,29 @@ +//! Error types for the `PoolQueue` middleware. + +use linkerd_error::Error; +use std::{fmt, sync::Arc}; + +/// A shareable, terminal error produced by either a service or discovery +/// resolution. +#[derive(Clone, Debug)] +pub struct TerminalFailure(Arc); + +// === impl TerminalFailure === + +impl TerminalFailure { + pub(crate) fn new(inner: Error) -> TerminalFailure { + TerminalFailure(Arc::new(inner)) + } +} + +impl fmt::Display for TerminalFailure { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "pool failed: {}", self.0) + } +} + +impl std::error::Error for TerminalFailure { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&**self.0) + } +} diff --git a/linkerd/proxy/pool/src/failfast.rs b/linkerd/proxy/pool/src/failfast.rs new file mode 100644 index 0000000000..ec474a4d41 --- /dev/null +++ b/linkerd/proxy/pool/src/failfast.rs @@ -0,0 +1,130 @@ +use linkerd_stack::gate; +use std::pin::Pin; +use tokio::time; + +/// Manages the failfast state for a pool. +#[derive(Debug)] +pub(super) struct Failfast { + timeout: time::Duration, + sleep: Pin>, + state: Option, + gate: gate::Tx, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(super) enum State { + Waiting { since: time::Instant }, + Failfast { since: time::Instant }, +} + +// === impl Failfast === + +impl Failfast { + pub(super) fn new(timeout: time::Duration, gate: gate::Tx) -> Self { + Self { + timeout, + sleep: Box::pin(time::sleep(time::Duration::MAX)), + state: None, + gate, + } + } + + pub(super) fn duration(&self) -> time::Duration { + self.timeout + } + + /// Returns true if we are currently in a failfast state. + pub(super) fn is_active(&self) -> bool { + matches!(self.state, Some(State::Failfast { .. })) + } + + /// Clears any waiting or failfast state. + pub(super) fn set_ready(&mut self) -> Option { + let state = self.state.take()?; + if matches!(state, State::Failfast { .. }) { + tracing::trace!("Exiting failfast"); + let _ = self.gate.open(); + } + Some(state) + } + + /// Waits for the failfast timeout to expire and enters the failfast state. + pub(super) async fn entered(&mut self) { + let since = match self.state { + // If we're already in failfast, then we don't need to wait. + Some(State::Failfast { .. }) => { + return; + } + + // Ensure that the timer's been initialized. + Some(State::Waiting { since }) => since, + None => { + let now = time::Instant::now(); + self.sleep.as_mut().reset(now + self.timeout); + self.state = Some(State::Waiting { since: now }); + now + } + }; + + // Wait for the failfast timer to expire. + tracing::trace!("Waiting for failfast timeout"); + self.sleep.as_mut().await; + tracing::trace!("Entering failfast"); + + // Once we enter failfast, shut the upstream gate so that we can + // advertise backpressure past the queue. + self.state = Some(State::Failfast { since }); + let _ = self.gate.shut(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::prelude::*; + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn failfast() { + let (tx, gate_rx) = gate::channel(); + let dur = time::Duration::from_secs(1); + let mut failfast = Failfast::new(dur, tx); + + assert_eq!(dur, failfast.duration()); + assert!(gate_rx.is_open()); + + // The failfast timeout should not be initialized until the first + // request is received. + assert!(!failfast.is_active(), "failfast should be active"); + + failfast.entered().await; + assert!(failfast.is_active(), "failfast should be active"); + assert!(gate_rx.is_shut(), "gate should be shut"); + + failfast + .entered() + .now_or_never() + .expect("timeout must return immediately when in failfast"); + assert!(failfast.is_active(), "failfast should be active"); + assert!(gate_rx.is_shut(), "gate should be shut"); + + failfast.set_ready(); + assert!(!failfast.is_active(), "failfast should be inactive"); + assert!(gate_rx.is_open(), "gate should be open"); + + tokio::select! { + _ = time::sleep(time::Duration::from_millis(10)) => {} + _ = failfast.entered() => unreachable!("timed out too quick"), + } + assert!(!failfast.is_active(), "failfast should be inactive"); + assert!(gate_rx.is_open(), "gate should be open"); + + assert!( + matches!(failfast.state, Some(State::Waiting { .. })), + "failfast should be waiting" + ); + + failfast.entered().await; + assert!(failfast.is_active(), "failfast should be active"); + assert!(gate_rx.is_shut(), "gate should be shut"); + } +} diff --git a/linkerd/proxy/pool/src/future.rs b/linkerd/proxy/pool/src/future.rs new file mode 100644 index 0000000000..70881a42dd --- /dev/null +++ b/linkerd/proxy/pool/src/future.rs @@ -0,0 +1,73 @@ +use super::message; +use futures::ready; +use linkerd_error::Error; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +#[pin_project] +/// Future that completes when the buffered service eventually services the submitted request. +#[derive(Debug)] +pub struct ResponseFuture { + #[pin] + state: ResponseState, +} + +#[pin_project(project = ResponseStateProj)] +#[derive(Debug)] +enum ResponseState { + Failed { + error: Option, + }, + Rx { + #[pin] + rx: message::Rx, + }, + Poll { + #[pin] + fut: T, + }, +} + +impl ResponseFuture { + pub(crate) fn new(rx: message::Rx) -> Self { + ResponseFuture { + state: ResponseState::Rx { rx }, + } + } + + pub(crate) fn failed(err: Error) -> Self { + ResponseFuture { + state: ResponseState::Failed { error: Some(err) }, + } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + ResponseStateProj::Failed { error } => { + return Poll::Ready(Err(error.take().expect("polled after error"))); + } + ResponseStateProj::Rx { rx } => { + let fut = ready!(rx.poll(cx)) + .expect("worker must set a failure if it exits prematurely")?; + this.state.set(ResponseState::Poll { fut }); + } + ResponseStateProj::Poll { fut } => return fut.poll(cx).map_err(Into::into), + } + } + } +} diff --git a/linkerd/proxy/pool/src/lib.rs b/linkerd/proxy/pool/src/lib.rs new file mode 100644 index 0000000000..824a0cf002 --- /dev/null +++ b/linkerd/proxy/pool/src/lib.rs @@ -0,0 +1,39 @@ +//! Adapted from [`tower::buffer`][buffer]. +//! +//! [buffer]: https://github.com/tower-rs/tower/tree/bf4ea948346c59a5be03563425a7d9f04aadedf2/tower/src/buffer +// +// Copyright (c) 2019 Tower Contributors + +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +mod error; +mod failfast; +mod future; +mod message; +mod service; +#[cfg(test)] +mod tests; +mod worker; + +pub use self::service::PoolQueue; +pub use linkerd_proxy_core::Update; + +use linkerd_stack::Service; + +/// A collection of services updated from a resolution. +pub trait Pool: Service { + /// Updates the pool's endpoints. + fn update_pool(&mut self, update: Update); + + /// Polls to update the pool while the Service is ready. + /// + /// [`Service::poll_ready`] should do the same work, but will return ready + /// as soon as there at least one ready endpoint. This method will continue + /// to drive the pool until ready is returned (indicating that the pool need + /// not be updated before another request is processed). + fn poll_pool( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>; +} diff --git a/linkerd/proxy/pool/src/message.rs b/linkerd/proxy/pool/src/message.rs new file mode 100644 index 0000000000..697d511490 --- /dev/null +++ b/linkerd/proxy/pool/src/message.rs @@ -0,0 +1,37 @@ +use linkerd_error::{Error, Result}; +use tokio::{sync::oneshot, time}; + +/// Message sent over buffer +#[derive(Debug)] +pub(crate) struct Message { + pub(crate) req: Req, + pub(crate) tx: Tx, + pub(crate) span: tracing::Span, + pub(crate) t0: time::Instant, +} + +/// Response sender +type Tx = oneshot::Sender>; + +/// Response receiver +pub(crate) type Rx = oneshot::Receiver>; + +impl Message { + pub(crate) fn channel(req: Req) -> (Self, Rx) { + let (tx, rx) = oneshot::channel(); + let t0 = time::Instant::now(); + let span = tracing::Span::current(); + (Message { req, span, tx, t0 }, rx) + } + + pub(crate) fn fail(self, err: impl Into) { + if self.tx.send(Err(err.into())).is_ok() { + tracing::debug!( + latency = (time::Instant::now() - self.t0).as_secs_f64(), + "Failed due to pool error" + ); + } else { + tracing::debug!("Caller dropped"); + } + } +} diff --git a/linkerd/proxy/pool/src/service.rs b/linkerd/proxy/pool/src/service.rs new file mode 100644 index 0000000000..ec87e82d90 --- /dev/null +++ b/linkerd/proxy/pool/src/service.rs @@ -0,0 +1,103 @@ +use crate::{ + future::ResponseFuture, + message::Message, + worker::{self, Terminate}, + Pool, +}; +use futures::TryStream; +use linkerd_error::{Error, Result}; +use linkerd_proxy_core::Update; +use linkerd_stack::{gate, Service}; +use std::{ + future::Future, + task::{Context, Poll}, +}; +use tokio::{sync::mpsc, time}; +use tokio_util::sync::PollSender; + +/// A shareable service backed by a dynamic endpoint. +#[derive(Debug)] +pub struct PoolQueue { + tx: PollSender>, + terminal: Terminate, +} + +impl PoolQueue +where + Req: Send + 'static, + F: Send + 'static, +{ + pub fn spawn( + capacity: usize, + failfast: time::Duration, + resolution: R, + pool: P, + ) -> gate::Gate + where + T: Clone + Eq + std::fmt::Debug + Send, + R: TryStream> + Send + Unpin + 'static, + R::Error: Into + Send, + P: Pool + Send + 'static, + P::Error: Into + Send + Sync, + Req: Send + 'static, + { + let (gate_tx, gate_rx) = gate::channel(); + let (tx, rx) = mpsc::channel(capacity); + let terminal = Terminate::default(); + let inner = Self { + tx: PollSender::new(tx), + terminal: terminal.clone(), + }; + worker::spawn(rx, failfast, gate_tx, terminal, resolution, pool); + gate::Gate::new(gate_rx, inner) + } +} + +impl Service for PoolQueue +where + Req: Send + 'static, + F: Future> + Send + 'static, + E: Into, +{ + type Response = Rsp; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + let poll = self.tx.poll_reserve(cx).map_err(|_| { + self.terminal + .failure() + .expect("worker must set a failure if it exits prematurely") + }); + tracing::trace!(?poll); + poll + } + + fn call(&mut self, req: Req) -> Self::Future { + tracing::trace!("Sending request to worker"); + let (msg, rx) = Message::channel(req); + if self.tx.send_item(msg).is_err() { + // The channel closed since poll_ready was called, so propagate the + // failure in the response future. + return ResponseFuture::failed( + self.terminal + .failure() + .expect("worker must set a failure if it exits prematurely"), + ); + } + ResponseFuture::new(rx) + } +} + +impl Clone for PoolQueue +where + Req: Send + 'static, + F: Send + 'static, +{ + fn clone(&self) -> Self { + Self { + terminal: self.terminal.clone(), + tx: self.tx.clone(), + } + } +} diff --git a/linkerd/proxy/pool/src/tests.rs b/linkerd/proxy/pool/src/tests.rs new file mode 100644 index 0000000000..ad66c865fd --- /dev/null +++ b/linkerd/proxy/pool/src/tests.rs @@ -0,0 +1,385 @@ +#![allow(clippy::ok_expect)] + +use crate::PoolQueue; +use futures::prelude::*; +use linkerd_proxy_core::Update; +use linkerd_stack::{Service, ServiceExt}; +use tokio::{sync::mpsc, time}; +use tokio_stream::wrappers::ReceiverStream; +use tokio_test::{assert_pending, assert_ready}; + +mod mock; + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn processes_requests() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(1); + assert!(poolq.ready().now_or_never().expect("ready").is_ok()); + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + call.await.expect("response"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn processes_requests_cloned() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq0 = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + let mut poolq1 = poolq0.clone(); + + handle.svc.allow(2); + assert!(poolq0.ready().now_or_never().expect("ready").is_ok()); + assert!(poolq1.ready().now_or_never().expect("ready").is_ok()); + let call0 = poolq0.call(()); + let call1 = poolq1.call(()); + + let ((), respond0) = handle.svc.next_request().await.expect("request"); + respond0.send_response(()); + call0.await.expect("response"); + + let ((), respond1) = handle.svc.next_request().await.expect("request"); + respond1.send_response(()); + call1.await.expect("response"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn limits_request_capacity() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq0 = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + let mut poolq1 = poolq0.clone(); + + handle.svc.allow(0); + assert!(poolq0.ready().now_or_never().expect("ready").is_ok()); + let mut _call0 = poolq0.call(()); + + assert!( + poolq0.ready().now_or_never().is_none(), + "poolq must not be ready when at capacity" + ); + assert!( + poolq1.ready().now_or_never().is_none(), + "poolq must not be ready when at capacity" + ); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn updates_while_pending() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + tokio::task::yield_now().await; + + updates + .try_send(Ok(Update::Reset(vec![( + "192.168.1.44:80".parse().unwrap(), + (), + )]))) + .ok() + .expect("send update"); + handle.set_poll(std::task::Poll::Pending); + tokio::task::yield_now().await; + + handle.set_poll(std::task::Poll::Ready(Ok(()))); + handle.svc.allow(1); + tokio::task::yield_now().await; + + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + call.await.expect("response"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn updates_while_idle() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut _poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + updates + .try_send(Ok(Update::Reset(vec![( + "192.168.1.44:80".parse().unwrap(), + (), + )]))) + .ok() + .expect("send update"); + + tokio::task::yield_now().await; + assert_eq!( + handle.rx.try_recv().expect("must receive update"), + Update::Reset(vec![("192.168.1.44:80".parse().unwrap(), (),)]) + ); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn complete_resolution() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + // When we drop the update stream, everything continues to work as long as + // the pool is ready. + handle.svc.allow(1); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + drop(updates); + tokio::task::yield_now().await; + + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok()); + + handle.svc.allow(1); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok()); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn error_resolution() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call0 = poolq.call(()); + + updates + .try_send(Err(mock::ResolutionError)) + .ok() + .expect("send update"); + + call0.await.expect_err("response should fail"); + + assert!( + poolq.ready().await.is_err(), + "poolq must error after failed resolution" + ); + + poolq + .ready() + .await + .err() + .expect("poolq must error after failed resolution"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn error_pool_while_pending() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + handle.set_poll(std::task::Poll::Pending); + + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + tokio::task::yield_now().await; + + handle.set_poll(std::task::Poll::Ready(Err(mock::PoolError))); + tokio::task::yield_now().await; + call.now_or_never() + .expect("response should fail immediately") + .expect_err("response should fail"); + + tracing::info!("Awaiting readiness failure"); + tokio::task::yield_now().await; + poolq + .ready() + .now_or_never() + .expect("poolq readiness fail immediately") + .err() + .expect("poolq must error after pool error"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn error_after_ready() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + updates + .try_send(Err(mock::ResolutionError)) + .ok() + .expect("send update"); + tokio::task::yield_now().await; + poolq.call(()).await.expect_err("response should fail"); + + poolq + .ready() + .await + .err() + .expect("poolq must error after pool error"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn terminates() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + assert_pending!(handle.svc.poll_request()); + + drop(poolq); + + assert!( + call.await.is_err(), + "call should fail when queue is dropped" + ); + assert!(updates.is_closed()); + assert!( + assert_ready!(handle.svc.poll_request(), "poll_request should be ready").is_none(), + "poll_request should return None" + ); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn failfast() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + time::sleep(time::Duration::from_secs(1)).await; + assert!(call.await.is_err(), "call should failfast"); + if time::timeout(time::Duration::from_secs(1), poolq.ready()) + .await + .is_ok() + { + panic!("queue should not be ready while in failfast"); + } + + handle.svc.allow(1); + tokio::task::yield_now().await; + tracing::info!("Waiting for poolq to exit failfast"); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + // A delay doesn't impact failfast behavior when the pool is ready. + time::sleep(time::Duration::from_secs(1)).await; + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok(), "call should not failfast"); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn failfast_interrupted() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + // Wait for half a failfast timeout and then allow the request to be + // processed. + time::sleep(time::Duration::from_secs_f64(0.5)).await; + handle.svc.allow(1); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok(), "call should not failfast"); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); +} diff --git a/linkerd/proxy/pool/src/tests/mock.rs b/linkerd/proxy/pool/src/tests/mock.rs new file mode 100644 index 0000000000..17ba0a739a --- /dev/null +++ b/linkerd/proxy/pool/src/tests/mock.rs @@ -0,0 +1,98 @@ +use linkerd_error::Error; +use linkerd_proxy_core::Update; +use parking_lot::Mutex; +use std::{ + sync::Arc, + task::{Context, Poll, Waker}, +}; +use tokio::sync::mpsc; +use tower_test::mock; + +pub fn pool() -> (MockPool, PoolHandle) { + let state = Arc::new(Mutex::new(State { + poll: Poll::Ready(Ok(())), + waker: None, + })); + let (updates_tx, updates_rx) = mpsc::unbounded_channel(); + let (mock, svc) = mock::pair(); + let h = PoolHandle { + rx: updates_rx, + state: state.clone(), + svc, + }; + let p = MockPool { + tx: updates_tx, + state, + svc: mock, + }; + (p, h) +} + +pub struct MockPool { + tx: mpsc::UnboundedSender>, + state: Arc>, + svc: mock::Mock, +} + +pub struct PoolHandle { + state: Arc>, + pub rx: mpsc::UnboundedReceiver>, + pub svc: mock::Handle, +} + +struct State { + poll: Poll>, + waker: Option, +} + +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[error("mock pool error")] +pub struct PoolError; + +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[error("mock resolution error")] +pub struct ResolutionError; + +impl crate::Pool for MockPool { + fn update_pool(&mut self, update: Update) { + self.tx.send(update).ok().unwrap(); + } + + fn poll_pool(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut s = self.state.lock(); + s.waker.replace(cx.waker().clone()); + s.poll.map_err(Into::into) + } +} + +impl linkerd_stack::Service for MockPool { + type Response = Rsp; + type Error = Error; + type Future = mock::future::ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(res) = self.svc.poll_ready(cx) { + return Poll::Ready(res); + } + // Drive the pool when the service isn't ready. + let _ = crate::Pool::poll_pool(self, cx)?; + Poll::Pending + } + + fn call(&mut self, req: Req) -> Self::Future { + self.svc.call(req) + } +} + +impl PoolHandle { + pub fn set_poll(&self, poll: Poll>) { + let mut s = self.state.lock(); + s.poll = poll; + if let Some(w) = s.waker.take() { + tracing::trace!("Wake"); + w.wake(); + } else { + tracing::trace!("No waker"); + } + } +} diff --git a/linkerd/proxy/pool/src/worker.rs b/linkerd/proxy/pool/src/worker.rs new file mode 100644 index 0000000000..9533dc548f --- /dev/null +++ b/linkerd/proxy/pool/src/worker.rs @@ -0,0 +1,361 @@ +use std::future::poll_fn; + +use crate::{ + error, + failfast::{self, Failfast}, + message::Message, + Pool, +}; +use futures::{future, TryStream, TryStreamExt}; +use linkerd_error::{Error, Result}; +use linkerd_proxy_core::Update; +use linkerd_stack::{gate, FailFastError, ServiceExt}; +use parking_lot::RwLock; +use std::sync::Arc; +use tokio::{sync::mpsc, task::JoinHandle, time}; +use tracing::{debug_span, Instrument}; + +/// Provides a copy of the terminal failure error to all handles. +#[derive(Clone, Debug, Default)] +pub(crate) struct Terminate { + inner: Arc>>, +} + +#[derive(Debug)] +struct Worker { + pool: PoolDriver

, + discovery: Discovery, +} + +/// Manages the pool's readiness state, handling failfast timeouts. +#[derive(Debug)] +struct PoolDriver

{ + pool: P, + failfast: Failfast, +} + +/// Processes endpoint updates from service discovery. +#[derive(Debug)] +struct Discovery { + resolution: R, + closed: bool, +} + +/// Spawns a task that simultaneously updates a pool of services from a +/// discovery stream and dispatches requests to it. +/// +/// If the pool service does not become ready within the failfast timeout, then +/// request are failed with a FailFastError until the pool becomes ready. While +/// in the failfast state, the provided gate is shut so that the caller may +/// exert backpressure to eliminate requests from being added to the queue. +pub(crate) fn spawn( + mut reqs_rx: mpsc::Receiver>, + failfast: time::Duration, + gate: gate::Tx, + terminal: Terminate, + updates_rx: R, + pool: P, +) -> JoinHandle> +where + Req: Send + 'static, + T: Clone + Eq + std::fmt::Debug + Send, + R: TryStream> + Unpin + Send + 'static, + R::Error: Into + Send, + P: Pool + Send + 'static, + P::Future: Send + 'static, + P::Error: Into + Send, +{ + tokio::spawn( + async move { + let mut worker = Worker { + pool: PoolDriver::new(pool, Failfast::new(failfast, gate)), + discovery: Discovery::new(updates_rx), + }; + + loop { + // Drive the pool with discovery updates while waiting for a + // request. + // + // NOTE: We do NOT require that pool become ready before + // processing a request, so this technically means that the + // queue supports capacity + 1 items. This behavior is + // inherrited from tower::buffer. Correcting this is not worth + // the complexity. + let msg = tokio::select! { + biased; + + // If either the discovery stream or the pool fail, close + // the request stream and process any remaining requests. + e = worker.drive_pool() => { + terminal.close(reqs_rx, error::TerminalFailure::new(e)).await; + return Ok(()); + } + + msg = reqs_rx.recv() => match msg { + Some(msg) => msg, + None => { + tracing::debug!("Callers dropped"); + return Ok(()); + } + }, + }; + + // Wait for the pool to be ready to process a request. If this fails, we enter + tracing::trace!("Waiting for inner service readiness"); + if let Err(e) = worker.ready_pool_for_request().await { + let error = error::TerminalFailure::new(e); + msg.fail(error.clone()); + terminal.close(reqs_rx, error).await; + return Ok(()); + } + tracing::trace!("Pool ready"); + + // Process requests, either by dispatching them to the pool or + // by serving errors directly. + let Message { req, tx, span, t0 } = msg; + let call = { + // Preserve the original request's tracing context in + // the inner call. + let _enter = span.enter(); + worker.pool.call(req) + }; + + if tx.send(call).is_ok() { + // TODO(ver) track histogram from t0 until the request is dispatched. + tracing::trace!( + latency = (time::Instant::now() - t0).as_secs_f64(), + "Dispatched" + ); + } else { + tracing::debug!("Caller dropped"); + } + } + } + .instrument(debug_span!("pool")), + ) +} + +// === impl Worker === + +impl Worker +where + T: Clone + Eq + std::fmt::Debug, + R: TryStream> + Unpin, + R::Error: Into, +{ + /// Drives the pool, processing discovery updates. + /// + /// This never returns unless the pool or discovery stream fails. + async fn drive_pool(&mut self) -> Error + where + P: Pool, + P::Error: Into, + { + tracing::trace!("Discovering while awaiting requests"); + + loop { + let update = tokio::select! { + e = self.pool.drive() => return e, + res = self.discovery.discover() => match res { + Err(e) => return e, + Ok(up) => up, + }, + }; + + tracing::debug!(?update, "Discovered"); + self.pool.pool.update_pool(update); + } + } + + /// Waits for [`Service::poll_ready`], while also processing service + /// discovery updates (e.g. to provide new available endpoints). + async fn ready_pool_for_request(&mut self) -> Result<(), Error> + where + P: Pool, + P::Error: Into, + { + loop { + let update = tokio::select! { + // Tests, especially, depend on discovery updates being + // processed before ready returning. + biased; + res = self.discovery.discover() => res?, + res = self.pool.ready_or_failfast() => return res, + }; + + tracing::debug!(?update, "Discovered"); + self.pool.pool.update_pool(update); + } + } +} + +// === impl Discovery === + +impl Discovery +where + T: Clone + Eq + std::fmt::Debug, + R: TryStream> + Unpin, + R::Error: Into, +{ + fn new(resolution: R) -> Self { + Self { + resolution, + closed: false, + } + } + + /// Await the next service discovery update. + /// + /// If the discovery stream has closed, this never returns. + async fn discover(&mut self) -> Result, Error> { + if self.closed { + // Never returns. + return futures::future::pending().await; + } + + match self.resolution.try_next().await { + Ok(Some(up)) => Ok(up), + + Ok(None) => { + tracing::debug!("Resolution stream closed"); + self.closed = true; + // Never returns. + futures::future::pending().await + } + + Err(e) => { + let error = e.into(); + tracing::debug!(%error, "Resolution stream failed"); + self.closed = true; + Err(error) + } + } + } +} + +// === impl PoolDriver === + +impl

PoolDriver

{ + fn new(pool: P, failfast: Failfast) -> Self { + Self { pool, failfast } + } + + /// Drives the inner pool, ensuring that the failfast state is cleared if appropriate. + /// [`Pool::poll_pool``]. This allows the pool to + /// + /// If the service is in failfast, this clears the failfast state on readiness. + /// + /// This only returns if the pool fails. + async fn drive(&mut self) -> Error + where + P: Pool, + P::Error: Into, + { + if self.failfast.is_active() { + tracing::trace!("Waiting to leave failfast"); + let res = self.pool.ready().await; + match self.failfast.set_ready() { + Some(failfast::State::Failfast { since }) => { + tracing::info!( + elapsed = (time::Instant::now() - since).as_secs_f64(), + "Available; exited failfast" + ); + } + _ => unreachable!("must be in failfast"), + } + if let Err(e) = res { + return e.into(); + } + } + + tracing::trace!("Driving pool"); + if let Err(e) = poll_fn(|cx| self.pool.poll_pool(cx)).await { + return e.into(); + } + + tracing::trace!("Pool driven"); + future::pending().await + } + + /// Waits for the inner pool's [`Service::poll_ready`] to be ready, while + async fn ready_or_failfast(&mut self) -> Result<(), Error> + where + P: Pool, + P::Error: Into, + { + tokio::select! { + biased; + + res = self.pool.ready() => { + match self.failfast.set_ready() { + None => tracing::trace!("Ready"), + Some(failfast::State::Waiting { since }) => { + tracing::trace!( + elapsed = (time::Instant::now() - since).as_secs_f64(), + "Available" + ); + } + Some(failfast::State::Failfast { since }) => { + // Note: It is exceptionally unlikely that we will exit + // failfast here, since the below `failfaast.entered()` + // will return immediately when in the failfast state. + tracing::info!( + elapsed = (time::Instant::now() - since).as_secs_f64(), + "Available; exiting failfast" + ); + } + } + if let Err(e) = res { + return Err(e.into()); + } + } + + () = self.failfast.entered() => { + tracing::info!( + timeout = self.failfast.duration().as_secs_f64(), "Unavailable; entering failfast", + ); + } + } + + Ok(()) + } + + fn call(&mut self, req: Req) -> Result + where + P: Pool, + P::Error: Into, + { + // If we've tripped failfast, fail the request. + if self.failfast.is_active() { + return Err(FailFastError::default().into()); + } + + // Otherwise dispatch the request to the pool. + Ok(self.pool.call(req)) + } +} + +// === impl Terminate === + +impl Terminate { + #[inline] + pub(super) fn failure(&self) -> Option { + (*self.inner.read()).clone().map(Into::into) + } + + async fn close( + self, + mut reqs_rx: mpsc::Receiver>, + error: error::TerminalFailure, + ) { + tracing::debug!(%error, "Closing pool"); + *self.inner.write() = Some(error.clone()); + reqs_rx.close(); + + while let Some(msg) = reqs_rx.recv().await { + msg.fail(error.clone()); + } + + tracing::debug!("Closed"); + } +} diff --git a/linkerd/stack/src/failfast.rs b/linkerd/stack/src/failfast.rs index 0f26dc02f2..fb96bd7dcd 100644 --- a/linkerd/stack/src/failfast.rs +++ b/linkerd/stack/src/failfast.rs @@ -44,7 +44,7 @@ pub struct FailFast { } /// An error representing that an operation timed out. -#[derive(Debug, Error)] +#[derive(Debug, Default, Error)] #[error("service in fail-fast")] pub struct FailFastError(());