From 792dea6c9b86baae6cf7d66c91d49ed88b4a30a0 Mon Sep 17 00:00:00 2001 From: Bryan Chen Date: Sat, 18 May 2024 21:59:52 +1200 Subject: [PATCH] Revert "Refactor endpoint (#178)" This reverts commit 7fa313276c5dbe46f038a6c752f08ad77b7fc50e. --- src/extensions/api/tests.rs | 51 +-- src/extensions/client/endpoint.rs | 481 ++++++----------------- src/extensions/client/health.rs | 101 ++++- src/extensions/client/mod.rs | 87 ++-- src/extensions/client/tests.rs | 56 +-- src/middlewares/methods/block_tag.rs | 9 +- src/middlewares/methods/inject_params.rs | 9 +- src/tests/merge_subscription.rs | 11 +- src/tests/upstream.rs | 11 +- 9 files changed, 271 insertions(+), 545 deletions(-) diff --git a/src/extensions/api/tests.rs b/src/extensions/api/tests.rs index a82b730..1c0ff0c 100644 --- a/src/extensions/api/tests.rs +++ b/src/extensions/api/tests.rs @@ -1,6 +1,6 @@ use jsonrpsee::server::ServerHandle; use serde_json::json; -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use std::{net::SocketAddr, sync::Arc}; use tokio::sync::mpsc; use super::eth::EthApi; @@ -61,14 +61,7 @@ async fn create_client() -> ( ) { let (addr, server, head_rx, finalized_head_rx, block_hash_rx) = create_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); (client, server, head_rx, finalized_head_rx, block_hash_rx) } @@ -175,14 +168,7 @@ async fn rotate_endpoint_on_stale() { let (addr, server, mut head_rx, _, mut block_rx) = create_server().await; let (addr2, server2, mut head_rx2, _, mut block_rx2) = create_server().await; - let client = Client::new( - [format!("ws://{addr}"), format!("ws://{addr2}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}"), format!("ws://{addr2}")]).unwrap(); let api = SubstrateApi::new(Arc::new(client), std::time::Duration::from_millis(100)); let head = api.get_head(); @@ -245,14 +231,7 @@ async fn rotate_endpoint_on_head_mismatch() { let (addr1, server1, mut head_rx1, mut finalized_head_rx1, mut block_rx1) = create_server().await; let (addr2, server2, mut head_rx2, mut finalized_head_rx2, mut block_rx2) = create_server().await; - let client = Client::new( - [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr1}"), format!("ws://{addr2}")]).unwrap(); let client = Arc::new(client); let api = SubstrateApi::new(client.clone(), std::time::Duration::from_millis(100)); @@ -353,16 +332,7 @@ async fn rotate_endpoint_on_head_mismatch() { #[tokio::test] async fn substrate_background_tasks_abort_on_drop() { let (addr, _server, mut head_rx, mut finalized_head_rx, _) = create_server().await; - let client = Arc::new( - Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(), - ); + let client = Arc::new(Client::with_endpoints([format!("ws://{addr}")]).unwrap()); let api = SubstrateApi::new(client, std::time::Duration::from_millis(100)); // background tasks started @@ -382,16 +352,7 @@ async fn substrate_background_tasks_abort_on_drop() { #[tokio::test] async fn eth_background_tasks_abort_on_drop() { let (addr, _server, mut subscription_rx, mut block_rx) = create_eth_server().await; - let client = Arc::new( - Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(), - ); + let client = Arc::new(Client::with_endpoints([format!("ws://{addr}")]).unwrap()); let api = EthApi::new(client, std::time::Duration::from_millis(100)); diff --git a/src/extensions/client/endpoint.rs b/src/extensions/client/endpoint.rs index 5d6118b..f266d29 100644 --- a/src/extensions/client/endpoint.rs +++ b/src/extensions/client/endpoint.rs @@ -1,13 +1,14 @@ -use super::health::{self, Event, Health}; -use crate::extensions::client::{get_backoff_time, HealthCheckConfig}; +use super::health::{Event, Health}; +use crate::{ + extensions::client::{get_backoff_time, HealthCheckConfig}, + utils::errors, +}; use jsonrpsee::{ async_client::Client, core::client::{ClientT, Subscription, SubscriptionClientT}, - core::JsonValue, ws_client::WsClientBuilder, }; use std::{ - fmt::{Debug, Formatter}, sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -15,68 +16,12 @@ use std::{ time::Duration, }; -enum Message { - Request { - method: String, - params: Vec, - response: tokio::sync::oneshot::Sender>, - timeout: Duration, - }, - Subscribe { - subscribe: String, - params: Vec, - unsubscribe: String, - response: tokio::sync::oneshot::Sender, jsonrpsee::core::client::Error>>, - timeout: Duration, - }, - Reconnect, -} - -impl Debug for Message { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Message::Request { - method, - params, - response: _, - timeout, - } => write!(f, "Request({method}, {params:?}, _, {timeout:?})"), - Message::Subscribe { - subscribe, - params, - unsubscribe, - response: _, - timeout, - } => write!(f, "Subscribe({subscribe}, {params:?}, {unsubscribe}, _, {timeout:?})"), - Message::Reconnect => write!(f, "Reconnect"), - } - } -} - -enum State { - Initial, - OnError(health::Event), - Connect(Option), - HandleMessage(Arc, Message), - WaitForMessage(Arc), -} - -impl Debug for State { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - State::Initial => write!(f, "Initial"), - State::OnError(e) => write!(f, "OnError({e:?})"), - State::Connect(m) => write!(f, "Connect({m:?})"), - State::HandleMessage(_c, m) => write!(f, "HandleMessage(_, {m:?})"), - State::WaitForMessage(_c) => write!(f, "WaitForMessage(_)"), - } - } -} - pub struct Endpoint { url: String, health: Arc, - message_tx: tokio::sync::mpsc::Sender, + client_rx: tokio::sync::watch::Receiver>>, + reconnect_tx: tokio::sync::mpsc::Sender<()>, + on_client_ready: Arc, background_tasks: Vec>, connect_counter: Arc, } @@ -90,279 +35,78 @@ impl Drop for Endpoint { impl Endpoint { pub fn new( url: String, - request_timeout: Duration, - connection_timeout: Duration, + request_timeout: Option, + connection_timeout: Option, health_config: Option, ) -> Self { - tracing::info!("New endpoint: {url}"); - - let health = Arc::new(Health::new(url.clone())); + let (client_tx, client_rx) = tokio::sync::watch::channel(None); + let (reconnect_tx, mut reconnect_rx) = tokio::sync::mpsc::channel(1); + let on_client_ready = Arc::new(tokio::sync::Notify::new()); + let health = Arc::new(Health::new(url.clone(), health_config)); let connect_counter = Arc::new(AtomicU32::new(0)); - let (message_tx, message_rx) = tokio::sync::mpsc::channel::(4096); - let mut endpoint = Self { - url: url.clone(), - health: health.clone(), - message_tx, - background_tasks: vec![], - connect_counter: connect_counter.clone(), - }; + let url_ = url.clone(); + let health_ = health.clone(); + let on_client_ready_ = on_client_ready.clone(); + let connect_counter_ = connect_counter.clone(); - endpoint.start_background_task( - url, - request_timeout, - connection_timeout, - connect_counter, - message_rx, - health, - ); - if let Some(config) = health_config { - endpoint.start_health_monitor_task(config); - } - - endpoint - } - - fn start_background_task( - &mut self, - url: String, - request_timeout: Duration, - connection_timeout: Duration, - connect_counter: Arc, - mut message_rx: tokio::sync::mpsc::Receiver, - health: Arc, - ) { - let handler = tokio::spawn(async move { + // This task will try to connect to the endpoint and keep the connection alive + let connection_task = tokio::spawn(async move { let connect_backoff_counter = Arc::new(AtomicU32::new(0)); - let mut state = State::Initial; - loop { - tracing::trace!("{url} {state:?}"); - - let new_state = match state { - State::Initial => { - connect_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + tracing::info!("Connecting endpoint: {url_}"); + connect_counter_.fetch_add(1, Ordering::Relaxed); + + let client = WsClientBuilder::default() + .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) + .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) + .max_buffer_capacity_per_subscription(2048) + .max_concurrent_requests(2048) + .max_response_size(20 * 1024 * 1024) + .build(&url_); + + match client.await { + Ok(client) => { + let client = Arc::new(client); + health_.update(Event::ConnectionSuccessful); + _ = client_tx.send(Some(client.clone())); + on_client_ready_.notify_waiters(); + tracing::info!("Endpoint connected: {url_}"); + connect_backoff_counter.store(0, Ordering::Relaxed); - // wait for messages before connecting - let msg = match message_rx.recv().await { - Some(Message::Reconnect) => None, - Some(msg @ Message::Request { .. } | msg @ Message::Subscribe { .. }) => Some(msg), - None => { - let url = url.clone(); - // channel is closed? exit - tracing::debug!("Endpoint {url} channel closed"); - return; - } - }; - State::Connect(msg) - } - State::OnError(evt) => { - health.update(evt); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; - State::Initial - } - State::Connect(msg) => { - // TODO: make the params configurable - let client = WsClientBuilder::default() - .request_timeout(request_timeout) - .connection_timeout(connection_timeout) - .max_buffer_capacity_per_subscription(2048) - .max_concurrent_requests(2048) - .max_response_size(20 * 1024 * 1024) - .build(url.clone()) - .await; - - match client { - Ok(client) => { - connect_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - health.update(Event::ConnectionSuccessful); - if let Some(msg) = msg { - State::HandleMessage(Arc::new(client), msg) - } else { - State::WaitForMessage(Arc::new(client)) - } - } - Err(err) => { - tracing::debug!("Endpoint {url} connection error: {err}"); - State::OnError(health::Event::ConnectionClosed) - } - } - } - State::HandleMessage(client, msg) => match msg { - Message::Request { - method, - params, - response, - timeout, - } => { - // don't block on making the request - let url = url.clone(); - let health = health.clone(); - let client2 = client.clone(); - tokio::spawn(async move { - let resp = match tokio::time::timeout( - timeout, - client2.request::>(&method, params), - ) - .await - { - Ok(resp) => resp, - Err(_) => { - tracing::warn!("Endpoint {url} request timeout: {method} timeout: {timeout:?}"); - health.update(Event::RequestTimeout); - Err(jsonrpsee::core::client::Error::RequestTimeout) - } - }; - if let Err(err) = &resp { - health.on_error(err); - } - - if response.send(resp).is_err() { - tracing::error!("Unable to send response to message channel"); - } - }); - - State::WaitForMessage(client) - } - Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - timeout, - } => { - // don't block on making the request - let url = url.clone(); - let health = health.clone(); - let client2 = client.clone(); - tokio::spawn(async move { - let resp = match tokio::time::timeout( - timeout, - client2.subscribe::>( - &subscribe, - params, - &unsubscribe, - ), - ) - .await - { - Ok(resp) => resp, - Err(_) => { - tracing::warn!("Endpoint {url} subscription timeout: {subscribe}"); - health.update(Event::RequestTimeout); - Err(jsonrpsee::core::client::Error::RequestTimeout) - } - }; - if let Err(err) = &resp { - health.on_error(err); - } - - if response.send(resp).is_err() { - tracing::error!("Unable to send response to message channel"); - } - }); - - State::WaitForMessage(client) - } - Message::Reconnect => State::Initial, - }, - State::WaitForMessage(client) => { tokio::select! { - msg = message_rx.recv() => { - match msg { - Some(msg) => State::HandleMessage(client, msg), - None => { - // channel is closed? exit - tracing::debug!("Endpoint {url} channel closed"); - return - } - } - + _ = reconnect_rx.recv() => { + tracing::debug!("Endpoint reconnect requested: {url_}"); }, - () = client.on_disconnect() => { - tracing::debug!("Endpoint {url} disconnected"); - State::OnError(health::Event::ConnectionClosed) + _ = client.on_disconnect() => { + tracing::debug!("Endpoint disconnected: {url_}"); } } } - }; - - state = new_state; - } - }); - - self.background_tasks.push(handler); - } - - fn start_health_monitor_task(&mut self, config: HealthCheckConfig) { - let message_tx = self.message_tx.clone(); - let health = self.health.clone(); - let url = self.url.clone(); - - let handler = tokio::spawn(async move { - let health_response = config.response.clone(); - let interval = Duration::from_secs(config.interval_sec); - let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms); - let max_response_time: Duration = Duration::from_millis(config.healthy_response_time_ms * 2); - - loop { - // Wait for the next interval - tokio::time::sleep(interval).await; - - let request_start = std::time::Instant::now(); - - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let res = message_tx - .send(Message::Request { - method: config.health_method.clone(), - params: vec![], - response: response_tx, - timeout: max_response_time, - }) - .await; - - if let Err(err) = res { - tracing::error!("{url} Unexpected error in message channel: {err}"); - } - - let res = match response_rx.await { - Ok(resp) => resp, Err(err) => { - tracing::error!("{url} Unexpected error in response channel: {err}"); - Err(jsonrpsee::core::client::Error::Custom("Internal server error".into())) - } - }; - - match res { - Ok(response) => { - let duration = request_start.elapsed(); - - // Check response - if let Some(ref health_response) = health_response { - if !health_response.validate(&response) { - health.update(Event::Unhealthy); - continue; - } - } - - // Check response time - if duration > healthy_response_time { - tracing::warn!("{url} response time is too long: {duration:?}"); - health.update(Event::SlowResponse); - continue; - } - - health.update(Event::ResponseOk); - } - Err(err) => { - health.on_error(&err); + health_.on_error(&err); + _ = client_tx.send(None); + tracing::warn!("Unable to connect to endpoint: {url_} error: {err}"); } } + // Wait a bit before trying to reconnect + tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; } }); - self.background_tasks.push(handler); + // This task will check the health of the endpoint and update the health score + let health_checker = Health::monitor(health.clone(), client_rx.clone(), on_client_ready.clone()); + + Self { + url, + health, + client_rx, + reconnect_tx, + on_client_ready, + background_tasks: vec![connection_task, health_checker], + connect_counter, + } } pub fn url(&self) -> &str { @@ -373,6 +117,13 @@ impl Endpoint { self.health.as_ref() } + pub async fn connected(&self) { + if self.client_rx.borrow().is_some() { + return; + } + self.on_client_ready.notified().await; + } + pub fn connect_counter(&self) -> u32 { self.connect_counter.load(Ordering::Relaxed) } @@ -382,27 +133,29 @@ impl Endpoint { method: &str, params: Vec, timeout: Duration, - ) -> Result { - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let res = self - .message_tx - .send(Message::Request { - method: method.into(), - params, - response: response_tx, - timeout, - }) - .await; - - if let Err(err) = res { - tracing::error!("Unexpected error in message channel: {err}"); - } - - match response_rx.await { - Ok(resp) => resp, - Err(err) => { - tracing::error!("Unexpected error in response channel: {err}"); - Err(jsonrpsee::core::client::Error::Custom("Internal server error".into())) + ) -> Result { + match tokio::time::timeout(timeout, async { + self.connected().await; + let client = self + .client_rx + .borrow() + .clone() + .ok_or(errors::failed("client not connected"))?; + match client.request(method, params.clone()).await { + Ok(resp) => Ok(resp), + Err(err) => { + self.health.on_error(&err); + Err(err) + } + } + }) + .await + { + Ok(res) => res, + Err(_) => { + tracing::error!("request timed out method: {method} params: {params:?}"); + self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); + Err(jsonrpsee::core::Error::RequestTimeout) } } } @@ -413,36 +166,38 @@ impl Endpoint { params: Vec, unsubscribe_method: &str, timeout: Duration, - ) -> Result, jsonrpsee::core::client::Error> { - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let res = self - .message_tx - .send(Message::Subscribe { - subscribe: subscribe_method.into(), - params, - unsubscribe: unsubscribe_method.into(), - response: response_tx, - timeout, - }) - .await; - - if let Err(err) = res { - tracing::error!("Unexpected error in message channel: {err}"); - } - - match response_rx.await { - Ok(resp) => resp, - Err(err) => { - tracing::error!("Unexpected error in response channel: {err}"); - Err(jsonrpsee::core::client::Error::Custom("Internal server error".into())) + ) -> Result, jsonrpsee::core::Error> { + match tokio::time::timeout(timeout, async { + self.connected().await; + let client = self + .client_rx + .borrow() + .clone() + .ok_or(errors::failed("client not connected"))?; + match client + .subscribe(subscribe_method, params.clone(), unsubscribe_method) + .await + { + Ok(resp) => Ok(resp), + Err(err) => { + self.health.on_error(&err); + Err(err) + } + } + }) + .await + { + Ok(res) => res, + Err(_) => { + tracing::error!("subscribe timed out subscribe: {subscribe_method} params: {params:?}"); + self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); + Err(jsonrpsee::core::Error::RequestTimeout) } } } pub async fn reconnect(&self) { - let res = self.message_tx.send(Message::Reconnect).await; - if let Err(err) = res { - tracing::error!("Unexpected error in message channel: {err}"); - } + // notify the client to reconnect + self.reconnect_tx.send(()).await.unwrap(); } } diff --git a/src/extensions/client/health.rs b/src/extensions/client/health.rs index c69adc1..a6425f2 100644 --- a/src/extensions/client/health.rs +++ b/src/extensions/client/health.rs @@ -1,29 +1,31 @@ -use std::sync::atomic::{AtomicU32, Ordering}; - -const MAX_SCORE: u32 = 100; -const THRESHOLD: u32 = 50; +use crate::extensions::client::HealthCheckConfig; +use jsonrpsee::{async_client::Client, core::client::ClientT}; +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; #[derive(Debug)] pub enum Event { ResponseOk, - ConnectionSuccessful, SlowResponse, RequestTimeout, + ConnectionSuccessful, ServerError, - Unhealthy, - ConnectionClosed, + StaleChain, } impl Event { pub fn update_score(&self, current: u32) -> u32 { u32::min( match self { - Event::ConnectionSuccessful => current.saturating_add(60), Event::ResponseOk => current.saturating_add(2), - Event::SlowResponse => current.saturating_sub(20), - Event::RequestTimeout => current.saturating_sub(40), - Event::ConnectionClosed => current.saturating_sub(30), - Event::ServerError | Event::Unhealthy => 0, + Event::SlowResponse => current.saturating_sub(5), + Event::RequestTimeout | Event::ServerError | Event::StaleChain => 0, + Event::ConnectionSuccessful => MAX_SCORE / 5 * 4, // 80% of max score }, MAX_SCORE, ) @@ -33,14 +35,19 @@ impl Event { #[derive(Debug, Default)] pub struct Health { url: String, + config: Option, score: AtomicU32, unhealthy: tokio::sync::Notify, } +const MAX_SCORE: u32 = 100; +const THRESHOLD: u32 = MAX_SCORE / 2; + impl Health { - pub fn new(url: String) -> Self { + pub fn new(url: String, config: Option) -> Self { Self { url, + config, score: AtomicU32::new(0), unhealthy: tokio::sync::Notify::new(), } @@ -58,13 +65,13 @@ impl Health { } self.score.store(new_score, Ordering::Relaxed); tracing::trace!( - "{:?} score updated from: {current_score} to: {new_score} because {event:?}", + "Endpoint {:?} score updated from: {current_score} to: {new_score}", self.url ); // Notify waiters if the score has dropped below the threshold if current_score >= THRESHOLD && new_score < THRESHOLD { - tracing::warn!("{:?} became unhealthy", self.url); + tracing::warn!("Endpoint {:?} became unhealthy", self.url); self.unhealthy.notify_waiters(); } } @@ -79,7 +86,7 @@ impl Health { self.update(Event::RequestTimeout); } _ => { - tracing::warn!("{:?} responded with error: {err:?}", self.url); + tracing::warn!("Endpoint {:?} responded with error: {err:?}", self.url); self.update(Event::ServerError); } }; @@ -89,3 +96,65 @@ impl Health { self.unhealthy.notified().await; } } + +impl Health { + pub fn monitor( + health: Arc, + client_rx_: tokio::sync::watch::Receiver>>, + on_client_ready: Arc, + ) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let config = match health.config { + Some(ref config) => config, + None => return, + }; + + // Wait for the client to be ready before starting the health check + on_client_ready.notified().await; + + let method_name = config.health_method.as_ref().expect("Invalid health config"); + let health_response = config.response.clone(); + let interval = Duration::from_secs(config.interval_sec); + let healthy_response_time = Duration::from_millis(config.healthy_response_time_ms); + + let client = match client_rx_.borrow().clone() { + Some(client) => client, + None => return, + }; + + loop { + // Wait for the next interval + tokio::time::sleep(interval).await; + + let request_start = std::time::Instant::now(); + match client + .request::>(method_name, vec![]) + .await + { + Ok(response) => { + let duration = request_start.elapsed(); + + // Check response + if let Some(ref health_response) = health_response { + if !health_response.validate(&response) { + health.update(Event::StaleChain); + continue; + } + } + + // Check response time + if duration > healthy_response_time { + health.update(Event::SlowResponse); + continue; + } + + health.update(Event::ResponseOk); + } + Err(err) => { + health.on_error(&err); + } + } + } + }) + } +} diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 865be01..87e0495 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -1,5 +1,4 @@ use std::{ - fmt::{Debug, Formatter}, sync::{atomic::AtomicU32, Arc}, time::Duration, }; @@ -118,12 +117,23 @@ pub struct HealthCheckConfig { pub interval_sec: u64, #[serde(default = "healthy_response_time_ms")] pub healthy_response_time_ms: u64, - pub health_method: String, + pub health_method: Option, pub response: Option, } +impl Default for HealthCheckConfig { + fn default() -> Self { + Self { + interval_sec: interval_sec(), + healthy_response_time_ms: healthy_response_time_ms(), + health_method: None, + response: None, + } + } +} + pub fn interval_sec() -> u64 { - 300 + 10 } pub fn healthy_response_time_ms() -> u64 { @@ -160,6 +170,7 @@ impl HealthResponse { } } +#[derive(Debug)] enum Message { Request { method: String, @@ -177,57 +188,27 @@ enum Message { RotateEndpoint, } -impl Debug for Message { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Message::Request { - method, - params, - response: _, - retries, - } => write!(f, "Request({method}, {params:?}, _, {retries})"), - Message::Subscribe { - subscribe, - params, - unsubscribe, - response: _, - retries, - } => write!(f, "Subscribe({subscribe}, {params:?}, {unsubscribe}, _, {retries})"), - Message::RotateEndpoint => write!(f, "RotateEndpoint"), - } - } -} - #[async_trait] impl Extension for Client { type Config = ClientConfig; async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { let health_check = config.health_check.clone(); - let endpoints = if config.shuffle_endpoints { + if config.shuffle_endpoints { let mut endpoints = config.endpoints.clone(); endpoints.shuffle(&mut thread_rng()); - endpoints + Ok(Self::new(endpoints, None, None, None, health_check)?) } else { - config.endpoints.clone() - }; - - // TODO: make the params configurable - Ok(Self::new( - endpoints, - Duration::from_secs(30), - Duration::from_secs(30), - None, - health_check, - )?) + Ok(Self::new(config.endpoints.clone(), None, None, None, health_check)?) + } } } impl Client { pub fn new( endpoints: impl IntoIterator>, - request_timeout: Duration, - connection_timeout: Duration, + request_timeout: Option, + connection_timeout: Option, retries: Option, health_config: Option, ) -> Result { @@ -274,9 +255,14 @@ impl Client { let select_healtiest = |endpoints: Vec>, current_idx: usize| async move { if endpoints.len() == 1 { let selected_endpoint = endpoints[0].clone(); + // Ensure it's connected + selected_endpoint.connected().await; return (selected_endpoint, 0); } + // wait for at least one endpoint to connect + futures::future::select_all(endpoints.iter().map(|x| x.connected().boxed())).await; + let (idx, endpoint) = endpoints .iter() .enumerate() @@ -299,10 +285,13 @@ impl Client { } }; - let handle_message = |message: Message, endpoint: Arc| { + let handle_message = |message: Message, endpoint: Arc, rotation_notify: Arc| { let tx = message_tx_bg.clone(); let request_backoff_counter = request_backoff_counter.clone(); + // total timeout for a request + let task_timeout = request_timeout.unwrap_or(Duration::from_secs(30)); + tokio::spawn(async move { match message { Message::Request { @@ -318,7 +307,7 @@ impl Client { return; } - match endpoint.request(&method, params.clone(), request_timeout).await { + match endpoint.request(&method, params.clone(), task_timeout).await { result @ Ok(_) => { request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); // make sure it's still connected @@ -334,6 +323,9 @@ impl Client { | Error::Transport(_) | Error::RestartNeeded(_) | Error::MaxSlotsExceeded => { + // Make sure endpoint is rotated + rotation_notify.notified().await; + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; // make sure it's still connected @@ -378,7 +370,7 @@ impl Client { retries = retries.saturating_sub(1); match endpoint - .subscribe(&subscribe, params.clone(), &unsubscribe, request_timeout) + .subscribe(&subscribe, params.clone(), &unsubscribe, task_timeout) .await { result @ Ok(_) => { @@ -396,6 +388,9 @@ impl Client { | Error::Transport(_) | Error::RestartNeeded(_) | Error::MaxSlotsExceeded => { + // Make sure endpoint is rotated + rotation_notify.notified().await; + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; // make sure it's still connected @@ -452,8 +447,8 @@ impl Client { tracing::warn!("Switch to endpoint: {new_url}", new_url=new_selected_endpoint.url()); selected_endpoint = new_selected_endpoint; current_endpoint_idx = new_current_endpoint_idx; + rotation_notify_bg.notify_waiters(); } - rotation_notify_bg.notify_waiters(); } message = message_rx.recv() => { tracing::trace!("Received message {message:?}"); @@ -463,7 +458,7 @@ impl Client { (selected_endpoint, current_endpoint_idx) = next_endpoint(current_endpoint_idx).await; rotation_notify_bg.notify_waiters(); } - Some(message) => handle_message(message, selected_endpoint.clone()), + Some(message) => handle_message(message, selected_endpoint.clone(), rotation_notify_bg.clone()), None => { tracing::debug!("Client dropped"); break; @@ -483,6 +478,10 @@ impl Client { }) } + pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { + Self::new(endpoints, None, None, None, None) + } + pub fn endpoints(&self) -> &Vec> { self.endpoints.as_ref() } diff --git a/src/extensions/client/tests.rs b/src/extensions/client/tests.rs index cf229a8..4a7f126 100644 --- a/src/extensions/client/tests.rs +++ b/src/extensions/client/tests.rs @@ -11,14 +11,7 @@ use tokio::sync::mpsc; async fn basic_request() { let (addr, handle, mut rx, _) = dummy_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let task = tokio::spawn(async move { let req = rx.recv().await.unwrap(); @@ -38,14 +31,7 @@ async fn basic_request() { async fn basic_subscription() { let (addr, handle, _, mut rx) = dummy_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let task = tokio::spawn(async move { let sub = rx.recv().await.unwrap(); @@ -81,15 +67,10 @@ async fn multiple_endpoints() { format!("ws://{addr2}"), format!("ws://{addr3}"), ], - Duration::from_secs(1), - Duration::from_secs(1), None, - Some(HealthCheckConfig { - interval_sec: 1, - healthy_response_time_ms: 250, - health_method: "mock_rpc".into(), - response: None, - }), + None, + None, + Some(Default::default()), ) .unwrap(); @@ -141,14 +122,7 @@ async fn multiple_endpoints() { async fn concurrent_requests() { let (addr, handle, mut rx, _) = dummy_server().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let task = tokio::spawn(async move { let req1 = rx.recv().await.unwrap(); @@ -184,8 +158,8 @@ async fn retry_requests_successful() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_millis(100), - Duration::from_millis(100), + Some(Duration::from_millis(100)), + None, Some(2), None, ) @@ -222,8 +196,8 @@ async fn retry_requests_out_of_retries() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_millis(100), - Duration::from_millis(100), + Some(Duration::from_millis(100)), + None, Some(2), None, ) @@ -286,13 +260,13 @@ async fn health_check_works() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_secs(1), - Duration::from_secs(1), + None, + None, None, Some(HealthCheckConfig { interval_sec: 1, healthy_response_time_ms: 250, - health_method: "system_health".into(), + health_method: Some("system_health".into()), response: Some(HealthResponse::Contains(vec![( "isSyncing".to_string(), Box::new(HealthResponse::Eq(false.into())), @@ -333,8 +307,8 @@ async fn reconnect_on_disconnect() { let client = Client::new( [format!("ws://{addr1}"), format!("ws://{addr2}")], - Duration::from_millis(100), - Duration::from_millis(100), + Some(Duration::from_millis(100)), + None, Some(2), None, ) diff --git a/src/middlewares/methods/block_tag.rs b/src/middlewares/methods/block_tag.rs index 00aecd7..1e0c6cf 100644 --- a/src/middlewares/methods/block_tag.rs +++ b/src/middlewares/methods/block_tag.rs @@ -165,14 +165,7 @@ mod tests { let (addr, _server) = builder.build().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let api = EthApi::new(Arc::new(client), Duration::from_secs(100)); ( diff --git a/src/middlewares/methods/inject_params.rs b/src/middlewares/methods/inject_params.rs index bbfc0f1..bcbca9f 100644 --- a/src/middlewares/methods/inject_params.rs +++ b/src/middlewares/methods/inject_params.rs @@ -211,14 +211,7 @@ mod tests { let (addr, _server) = builder.build().await; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let api = SubstrateApi::new(Arc::new(client), Duration::from_secs(100)); ExecutionContext { diff --git a/src/tests/merge_subscription.rs b/src/tests/merge_subscription.rs index 89054b0..37a7c53 100644 --- a/src/tests/merge_subscription.rs +++ b/src/tests/merge_subscription.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use serde_json::json; use crate::{ @@ -99,14 +97,7 @@ async fn merge_subscription_works() { let subway_server = server::build(config).await.unwrap(); let addr = subway_server.addr; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let mut first_sub = client .subscribe(subscribe_mock, vec![], unsubscribe_mock) .await diff --git a/src/tests/upstream.rs b/src/tests/upstream.rs index ab63169..9c730da 100644 --- a/src/tests/upstream.rs +++ b/src/tests/upstream.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use crate::{ config::{Config, MergeStrategy, MiddlewaresConfig, RpcDefinitions, RpcSubscription}, extensions::{ @@ -75,14 +73,7 @@ async fn upstream_error_propagate() { let subway_server = server::build(config).await.unwrap(); let addr = subway_server.addr; - let client = Client::new( - [format!("ws://{addr}")], - Duration::from_secs(1), - Duration::from_secs(1), - None, - None, - ) - .unwrap(); + let client = Client::with_endpoints([format!("ws://{addr}")]).unwrap(); let result = client.subscribe(subscribe_mock, vec![], unsubscribe_mock).await; assert!(result