diff --git a/src/extensions/client/http.rs b/src/extensions/client/http.rs index f8497b2..582b7d8 100644 --- a/src/extensions/client/http.rs +++ b/src/extensions/client/http.rs @@ -16,32 +16,16 @@ pub struct HttpClient { } impl HttpClient { - pub fn new(endpoints: Vec) -> Result<(Option, Vec), Error> { - let mut other_urls = vec![]; + pub fn new(endpoints: &[String]) -> Result { let clients = endpoints - .into_iter() - .filter_map(|url| { - let t_url = url.to_lowercase(); - if t_url.starts_with("http://") || t_url.starts_with("https://") { - Some(RpcClient::builder().build(url)) - } else { - other_urls.push(url); - None - } - }) + .iter() + .map(|url| RpcClient::builder().build(url)) .collect::, _>>()?; - if clients.is_empty() { - Ok((None, other_urls)) - } else { - Ok(( - Some(Self { - clients, - last_sent: AtomicUsize::new(0), - }), - other_urls, - )) - } + Ok(Self { + clients, + last_sent: AtomicUsize::new(0), + }) } /// Sends a request to one of the clients diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index 063214c..a696a3b 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -1,57 +1,30 @@ -use std::{ - sync::{ - atomic::{AtomicU32, AtomicUsize}, - Arc, - }, - time::Duration, -}; +use std::time::Duration; use anyhow::anyhow; use async_trait::async_trait; -use futures::TryFutureExt; use garde::Validate; -use jsonrpsee::{ - core::{ - client::{ClientT, Error, Subscription, SubscriptionClientT}, - JsonValue, - }, - ws_client::{WsClient, WsClientBuilder}, +use jsonrpsee::core::{ + client::{Error, Subscription}, + JsonValue, }; -use opentelemetry::trace::FutureExt; use rand::{seq::SliceRandom, thread_rng}; use serde::Deserialize; -use tokio::sync::Notify; use super::ExtensionRegistry; -use crate::{ - extensions::Extension, - middlewares::CallResult, - utils::{self, errors}, -}; +use crate::{extensions::Extension, middlewares::CallResult, utils::errors}; mod http; #[cfg(test)] pub mod mock; #[cfg(test)] mod tests; - -const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client"); +#[allow(dead_code)] +mod ws; pub struct Client { endpoints: Vec, http_client: Option, - sender: Option>, - rotation_notify: Option>, - retries: u32, - background_task: Option>, -} - -impl Drop for Client { - fn drop(&mut self) { - if let Some(background_task) = self.background_task.take() { - background_task.abort(); - } - } + ws_client: Option, } #[derive(Deserialize, Validate, Debug)] @@ -59,7 +32,7 @@ impl Drop for Client { pub struct ClientConfig { #[garde(inner(custom(validate_endpoint)))] pub endpoints: Vec, - #[serde(default = "bool_true")] + #[serde(default = "ws::bool_true")] pub shuffle_endpoints: bool, } @@ -73,64 +46,19 @@ fn validate_endpoint(endpoint: &str, _context: &()) -> garde::Result { impl ClientConfig { pub async fn all_endpoints_can_be_connected(&self) -> bool { - let join_handles: Vec<_> = self - .endpoints - .iter() - .map(|endpoint| { - let endpoint = endpoint.clone(); - tokio::spawn(async move { - match check_endpoint_connection(&endpoint).await { - Ok(_) => { - tracing::info!("Connected to endpoint: {endpoint}"); - true - } - Err(err) => { - tracing::error!("Failed to connect to endpoint: {endpoint}, error: {err:?}",); - false - } - } - }) - }) - .collect(); - let mut ok_all = true; - for join_handle in join_handles { - let ok = join_handle.await.unwrap_or_else(|e| { - tracing::error!("Failed to join: {e:?}"); - false - }); - if !ok { - ok_all = false - } - } - ok_all - } -} -// simple connection check with default client params and no retries -async fn check_endpoint_connection(endpoint: &str) -> Result<(), anyhow::Error> { - let _ = WsClientBuilder::default().build(&endpoint).await?; - Ok(()) -} + let (ws_clients, _) = Client::get_urls(&self.endpoints); -pub fn bool_true() -> bool { - true -} + if ws_clients.is_empty() { + return true; + } -#[derive(Debug)] -enum Message { - Request { - method: String, - params: Vec, - response: tokio::sync::oneshot::Sender>, - retries: u32, - }, - Subscribe { - subscribe: String, - params: Vec, - unsubscribe: String, - response: tokio::sync::oneshot::Sender, Error>>, - retries: u32, - }, - RotateEndpoint, + ws::ClientConfig { + endpoints: ws_clients, + shuffle_endpoints: self.shuffle_endpoints, + } + .all_endpoints_can_be_connected() + .await + } } #[async_trait] @@ -155,291 +83,53 @@ impl Client { connection_timeout: Option, retries: Option, ) -> Result { - let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); - let endpoints_ = endpoints.clone(); - - if endpoints.is_empty() { + let endpoints = endpoints + .into_iter() + .map(|e| e.as_ref().to_string()) + .collect::>(); + let (ws_endpoints, http_endpoints) = Self::get_urls(&endpoints); + if ws_endpoints.is_empty() && http_endpoints.is_empty() { return Err(anyhow!("No endpoints provided")); } - let (http_client, ws_endpoints) = http::HttpClient::new(endpoints)?; - - if ws_endpoints.is_empty() { - return Ok(Self { - http_client, - endpoints: endpoints_, - sender: None, // No websocket - rotation_notify: None, - retries: retries.unwrap_or(3), - background_task: None, - }); - } - - tracing::debug!("New client with endpoints: {:?}", ws_endpoints); - - let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); - - let message_tx_bg = message_tx.clone(); - - let rotation_notify = Arc::new(Notify::new()); - let rotation_notify_bg = rotation_notify.clone(); - - let background_task = tokio::spawn(async move { - let connect_backoff_counter = Arc::new(AtomicU32::new(0)); - let request_backoff_counter = Arc::new(AtomicU32::new(0)); - - let current_endpoint = AtomicUsize::new(0); - - let connect_backoff_counter2 = connect_backoff_counter.clone(); - let build_ws = || async { - let build = || { - let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let url = &ws_endpoints[current_endpoint % ws_endpoints.len()]; - - tracing::info!("Connecting to endpoint: {}", url); - - // TODO: make those configurable - WsClientBuilder::default() - .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) - .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) - .max_buffer_capacity_per_subscription(2048) - .max_concurrent_requests(2048) - .max_response_size(20 * 1024 * 1024) - .build(url) - .map_err(|e| (e, url.to_string())) - }; - - loop { - match build().await { - Ok(ws) => { - let ws = Arc::new(ws); - tracing::info!("Endpoint connected"); - connect_backoff_counter2.store(0, std::sync::atomic::Ordering::Relaxed); - break ws; - } - Err((e, url)) => { - tracing::warn!("Unable to connect to endpoint: '{url}' error: {e}"); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter2)).await; - } - } - } - }; - - let mut ws = build_ws().await; - - let handle_message = |message: Message, ws: Arc| { - let tx = message_tx_bg.clone(); - let request_backoff_counter = request_backoff_counter.clone(); - - // total timeout for a request - let task_timeout = request_timeout - .unwrap_or(Duration::from_secs(30)) - // buffer 5 seconds for the request to be processed - .saturating_add(Duration::from_secs(5)); - - tokio::spawn(async move { - match message { - Message::Request { - method, - params, - response, - mut retries, - } => { - retries = retries.saturating_sub(1); - - // make sure it's still connected - if response.is_closed() { - return; - } - - if let Ok(result) = - tokio::time::timeout(task_timeout, ws.request(&method, params.clone())).await - { - match result { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(result); - } - Err(err) => { - tracing::debug!("Request failed: {:?}", err); - match err { - Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; - } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; - } - - if matches!(err, Error::RequestTimeout) { - tx.send(Message::RotateEndpoint) - .await - .expect("Failed to send rotate message"); - } - - tx.send(Message::Request { - method, - params, - response, - retries, - }) - .await - .expect("Failed to send request message"); - } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); - } - } - } - } - } else { - tracing::error!("request timed out method: {} params: {:?}", method, params); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(Err(Error::RequestTimeout)); - } - } - Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - mut retries, - } => { - retries = retries.saturating_sub(1); - - if let Ok(result) = tokio::time::timeout( - task_timeout, - ws.subscribe(&subscribe, params.clone(), &unsubscribe), - ) - .await - { - match result { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(result); - } - Err(err) => { - tracing::debug!("Subscribe failed: {:?}", err); - match err { - Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; - } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; - } - - if matches!(err, Error::RequestTimeout) { - tx.send(Message::RotateEndpoint) - .await - .expect("Failed to send rotate message"); - } - - tx.send(Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - retries, - }) - .await - .expect("Failed to send subscribe message") - } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); - } - } - } - } - } else { - tracing::error!("subscribe timed out subscribe: {} params: {:?}", subscribe, params); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(Err(Error::RequestTimeout)); - } - } - Message::RotateEndpoint => { - unreachable!() - } - } - }); - }; - - loop { - tokio::select! { - _ = ws.on_disconnect() => { - tracing::info!("Endpoint disconnected"); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; - ws = build_ws().await; - } - message = message_rx.recv() => { - tracing::trace!("Received message {message:?}"); - match message { - Some(Message::RotateEndpoint) => { - rotation_notify_bg.notify_waiters(); - tracing::info!("Rotate endpoint"); - ws = build_ws().await; - } - Some(message) => handle_message(message, ws.clone()), - None => { - tracing::debug!("Client dropped"); - break; - } - } - }, - }; - } - }); - - if let Some(0) = retries { - return Err(anyhow!("Retries need to be at least 1")); - } - Ok(Self { - http_client, - endpoints: endpoints_, - sender: Some(message_tx), - rotation_notify: Some(rotation_notify), - retries: retries.unwrap_or(3), - background_task: Some(background_task), + endpoints, + http_client: if http_endpoints.is_empty() { + None + } else { + Some(http::HttpClient::new(&http_endpoints)?) + }, + ws_client: if ws_endpoints.is_empty() { + None + } else { + Some(ws::Client::new( + &ws_endpoints, + request_timeout, + connection_timeout, + retries, + )?) + }, }) } + pub fn get_urls(endpoints: impl IntoIterator>) -> (Vec, Vec) { + let endpoints = endpoints + .into_iter() + .map(|e| e.as_ref().to_string()) + .collect::>(); + ( + endpoints + .iter() + .filter(|e| e.starts_with("ws://") || e.starts_with("wss://")) + .map(|c| c.to_string()) + .collect::>(), + endpoints + .into_iter() + .filter(|e| e.starts_with("http://") || e.starts_with("https://")) + .collect::>(), + ) + } + pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { Self::new(endpoints, None, None, None) } @@ -450,28 +140,11 @@ impl Client { pub async fn request(&self, method: &str, params: Vec) -> CallResult { if let Some(http_client) = &self.http_client { - return http_client.request(method, params).await; - } - - if let Some(sender) = self.sender.as_ref() { - async move { - let (tx, rx) = tokio::sync::oneshot::channel(); - sender - .send(Message::Request { - method: method.into(), - params, - response: tx, - retries: self.retries, - }) - .await - .map_err(errors::internal_error)?; - - rx.await.map_err(errors::internal_error)?.map_err(errors::map_error) - } - .with_context(TRACER.context(method.to_string())) - .await + http_client.request(method, params).await + } else if let Some(ws_client) = &self.ws_client { + ws_client.request(method, params).await } else { - Err(errors::internal_error("No sender")) + Err(errors::internal_error("No upstream client")) } } @@ -481,73 +154,23 @@ impl Client { params: Vec, unsubscribe: &str, ) -> Result, Error> { - if let Some(sender) = self.sender.as_ref() { - async move { - let (tx, rx) = tokio::sync::oneshot::channel(); - sender - .send(Message::Subscribe { - subscribe: subscribe.into(), - params, - unsubscribe: unsubscribe.into(), - response: tx, - retries: self.retries, - }) - .await - .map_err(errors::internal_error)?; - - rx.await.map_err(errors::internal_error)? - } - .with_context(TRACER.context(subscribe.to_string())) - .await + if let Some(ws_client) = &self.ws_client { + ws_client.subscribe(subscribe, params, unsubscribe).await } else { Err(Error::Call(errors::internal_error("No websocket connection"))) } } pub async fn rotate_endpoint(&self) { - if let Some(sender) = self.sender.as_ref() { - sender - .send(Message::RotateEndpoint) - .await - .expect("Failed to rotate endpoint"); + if let Some(ws_client) = &self.ws_client { + ws_client.rotate_endpoint().await; } } /// Returns a future that resolves when the endpoint is rotated. pub async fn on_rotation(&self) { - if let Some(rotation_notify) = self.rotation_notify.as_ref() { - rotation_notify.notified().await + if let Some(ws_client) = &self.ws_client { + ws_client.on_rotation().await; } } } - -fn get_backoff_time(counter: &Arc) -> Duration { - let min_time = 100u64; - let step = 100u64; - let max_count = 10u32; - - let backoff_count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - let backoff_count = backoff_count.min(max_count) as u64; - let backoff_time = backoff_count * backoff_count * step; - - Duration::from_millis(backoff_time + min_time) -} - -#[test] -fn test_get_backoff_time() { - let counter = Arc::new(AtomicU32::new(0)); - - let mut times = Vec::new(); - - for _ in 0..12 { - times.push(get_backoff_time(&counter)); - } - - let times = times.into_iter().map(|t| t.as_millis()).collect::>(); - - assert_eq!( - times, - vec![100, 200, 500, 1000, 1700, 2600, 3700, 5000, 6500, 8200, 10100, 10100] - ); -} diff --git a/src/extensions/client/ws.rs b/src/extensions/client/ws.rs new file mode 100644 index 0000000..b3eef7d --- /dev/null +++ b/src/extensions/client/ws.rs @@ -0,0 +1,515 @@ +use std::{ + sync::{ + atomic::{AtomicU32, AtomicUsize}, + Arc, + }, + time::Duration, +}; + +use anyhow::anyhow; +use async_trait::async_trait; +use futures::TryFutureExt; +use garde::Validate; +use jsonrpsee::{ + core::{ + client::{ClientT, Error, Subscription, SubscriptionClientT}, + JsonValue, + }, + ws_client::{WsClient, WsClientBuilder}, +}; +use opentelemetry::trace::FutureExt; +use rand::{seq::SliceRandom, thread_rng}; +use serde::Deserialize; +use tokio::sync::Notify; + +use super::ExtensionRegistry; +use crate::{ + extensions::Extension, + middlewares::CallResult, + utils::{self, errors}, +}; + +const TRACER: utils::telemetry::Tracer = utils::telemetry::Tracer::new("client"); + +#[derive(Debug)] +pub struct Client { + endpoints: Vec, + sender: tokio::sync::mpsc::Sender, + rotation_notify: Arc, + retries: u32, + background_task: tokio::task::JoinHandle<()>, +} + +impl Drop for Client { + fn drop(&mut self) { + self.background_task.abort(); + } +} + +#[derive(Deserialize, Validate, Debug)] +#[garde(allow_unvalidated)] +pub struct ClientConfig { + #[garde(inner(custom(validate_endpoint)))] + pub endpoints: Vec, + #[serde(default = "bool_true")] + pub shuffle_endpoints: bool, +} + +fn validate_endpoint(endpoint: &str, _context: &()) -> garde::Result { + endpoint + .parse::() + .map_err(|_| garde::Error::new(format!("Invalid endpoint format: {}", endpoint)))?; + + Ok(()) +} + +impl ClientConfig { + pub async fn all_endpoints_can_be_connected(&self) -> bool { + let join_handles: Vec<_> = self + .endpoints + .iter() + .map(|endpoint| { + let endpoint = endpoint.clone(); + tokio::spawn(async move { + match check_endpoint_connection(&endpoint).await { + Ok(_) => { + tracing::info!("Connected to endpoint: {endpoint}"); + true + } + Err(err) => { + tracing::error!("Failed to connect to endpoint: {endpoint}, error: {err:?}",); + false + } + } + }) + }) + .collect(); + let mut ok_all = true; + for join_handle in join_handles { + let ok = join_handle.await.unwrap_or_else(|e| { + tracing::error!("Failed to join: {e:?}"); + false + }); + if !ok { + ok_all = false + } + } + ok_all + } +} +// simple connection check with default client params and no retries +async fn check_endpoint_connection(endpoint: &str) -> Result<(), anyhow::Error> { + let _ = WsClientBuilder::default().build(&endpoint).await?; + Ok(()) +} + +pub fn bool_true() -> bool { + true +} + +#[derive(Debug)] +enum Message { + Request { + method: String, + params: Vec, + response: tokio::sync::oneshot::Sender>, + retries: u32, + }, + Subscribe { + subscribe: String, + params: Vec, + unsubscribe: String, + response: tokio::sync::oneshot::Sender, Error>>, + retries: u32, + }, + RotateEndpoint, +} + +#[async_trait] +impl Extension for Client { + type Config = ClientConfig; + + async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { + if config.shuffle_endpoints { + let mut endpoints = config.endpoints.clone(); + endpoints.shuffle(&mut thread_rng()); + Ok(Self::new(endpoints, None, None, None)?) + } else { + Ok(Self::new(config.endpoints.clone(), None, None, None)?) + } + } +} + +impl Client { + pub fn new( + endpoints: impl IntoIterator>, + request_timeout: Option, + connection_timeout: Option, + retries: Option, + ) -> Result { + let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); + + if endpoints.is_empty() { + return Err(anyhow!("No endpoints provided")); + } + + tracing::debug!("New client with endpoints: {:?}", endpoints); + + let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); + + let message_tx_bg = message_tx.clone(); + + let rotation_notify = Arc::new(Notify::new()); + let rotation_notify_bg = rotation_notify.clone(); + let endpoints_ = endpoints.clone(); + + let background_task = tokio::spawn(async move { + let connect_backoff_counter = Arc::new(AtomicU32::new(0)); + let request_backoff_counter = Arc::new(AtomicU32::new(0)); + + let current_endpoint = AtomicUsize::new(0); + + let connect_backoff_counter2 = connect_backoff_counter.clone(); + let build_ws = || async { + let build = || { + let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let url = &endpoints[current_endpoint % endpoints.len()]; + + tracing::info!("Connecting to endpoint: {}", url); + + // TODO: make those configurable + WsClientBuilder::default() + .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) + .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) + .max_buffer_capacity_per_subscription(2048) + .max_concurrent_requests(2048) + .max_response_size(20 * 1024 * 1024) + .build(url) + .map_err(|e| (e, url.to_string())) + }; + + loop { + match build().await { + Ok(ws) => { + let ws = Arc::new(ws); + tracing::info!("Endpoint connected"); + connect_backoff_counter2.store(0, std::sync::atomic::Ordering::Relaxed); + break ws; + } + Err((e, url)) => { + tracing::warn!("Unable to connect to endpoint: '{url}' error: {e}"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter2)).await; + } + } + } + }; + + let mut ws = build_ws().await; + + let handle_message = |message: Message, ws: Arc| { + let tx = message_tx_bg.clone(); + let request_backoff_counter = request_backoff_counter.clone(); + + // total timeout for a request + let task_timeout = request_timeout + .unwrap_or(Duration::from_secs(30)) + // buffer 5 seconds for the request to be processed + .saturating_add(Duration::from_secs(5)); + + tokio::spawn(async move { + match message { + Message::Request { + method, + params, + response, + mut retries, + } => { + retries = retries.saturating_sub(1); + + // make sure it's still connected + if response.is_closed() { + return; + } + + if let Ok(result) = + tokio::time::timeout(task_timeout, ws.request(&method, params.clone())).await + { + match result { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(result); + } + Err(err) => { + tracing::debug!("Request failed: {:?}", err); + match err { + Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + if matches!(err, Error::RequestTimeout) { + tx.send(Message::RotateEndpoint) + .await + .expect("Failed to send rotate message"); + } + + tx.send(Message::Request { + method, + params, + response, + retries, + }) + .await + .expect("Failed to send request message"); + } + err => { + // make sure it's still connected + if response.is_closed() { + return; + } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); + } + } + } + } + } else { + tracing::error!("request timed out method: {} params: {:?}", method, params); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(Err(Error::RequestTimeout)); + } + } + Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + mut retries, + } => { + retries = retries.saturating_sub(1); + + if let Ok(result) = tokio::time::timeout( + task_timeout, + ws.subscribe(&subscribe, params.clone(), &unsubscribe), + ) + .await + { + match result { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(result); + } + Err(err) => { + tracing::debug!("Subscribe failed: {:?}", err); + match err { + Error::RequestTimeout | Error::Transport(_) | Error::RestartNeeded(_) => { + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + if matches!(err, Error::RequestTimeout) { + tx.send(Message::RotateEndpoint) + .await + .expect("Failed to send rotate message"); + } + + tx.send(Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + retries, + }) + .await + .expect("Failed to send subscribe message") + } + err => { + // make sure it's still connected + if response.is_closed() { + return; + } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); + } + } + } + } + } else { + tracing::error!("subscribe timed out subscribe: {} params: {:?}", subscribe, params); + // make sure it's still connected + if response.is_closed() { + return; + } + let _ = response.send(Err(Error::RequestTimeout)); + } + } + Message::RotateEndpoint => { + unreachable!() + } + } + }); + }; + + loop { + tokio::select! { + _ = ws.on_disconnect() => { + tracing::info!("Endpoint disconnected"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; + ws = build_ws().await; + } + message = message_rx.recv() => { + tracing::trace!("Received message {message:?}"); + match message { + Some(Message::RotateEndpoint) => { + rotation_notify_bg.notify_waiters(); + tracing::info!("Rotate endpoint"); + ws = build_ws().await; + } + Some(message) => handle_message(message, ws.clone()), + None => { + tracing::debug!("Client dropped"); + break; + } + } + }, + }; + } + }); + + if let Some(0) = retries { + return Err(anyhow!("Retries need to be at least 1")); + } + + Ok(Self { + endpoints: endpoints_, + sender: message_tx, + rotation_notify, + retries: retries.unwrap_or(3), + background_task, + }) + } + + pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { + Self::new(endpoints, None, None, None) + } + + pub fn endpoints(&self) -> &Vec { + &self.endpoints + } + + pub async fn request(&self, method: &str, params: Vec) -> CallResult { + async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.sender + .send(Message::Request { + method: method.into(), + params, + response: tx, + retries: self.retries, + }) + .await + .map_err(errors::internal_error)?; + + rx.await.map_err(errors::internal_error)?.map_err(errors::map_error) + } + .with_context(TRACER.context(method.to_string())) + .await + } + + pub async fn subscribe( + &self, + subscribe: &str, + params: Vec, + unsubscribe: &str, + ) -> Result, Error> { + async move { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.sender + .send(Message::Subscribe { + subscribe: subscribe.into(), + params, + unsubscribe: unsubscribe.into(), + response: tx, + retries: self.retries, + }) + .await + .map_err(errors::internal_error)?; + + rx.await.map_err(errors::internal_error)? + } + .with_context(TRACER.context(subscribe.to_string())) + .await + } + + pub async fn rotate_endpoint(&self) { + self.sender + .send(Message::RotateEndpoint) + .await + .expect("Failed to rotate endpoint"); + } + + /// Returns a future that resolves when the endpoint is rotated. + pub async fn on_rotation(&self) { + self.rotation_notify.notified().await + } +} + +fn get_backoff_time(counter: &Arc) -> Duration { + let min_time = 100u64; + let step = 100u64; + let max_count = 10u32; + + let backoff_count = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + + let backoff_count = backoff_count.min(max_count) as u64; + let backoff_time = backoff_count * backoff_count * step; + + Duration::from_millis(backoff_time + min_time) +} + +#[test] +fn test_get_backoff_time() { + let counter = Arc::new(AtomicU32::new(0)); + + let mut times = Vec::new(); + + for _ in 0..12 { + times.push(get_backoff_time(&counter)); + } + + let times = times.into_iter().map(|t| t.as_millis()).collect::>(); + + assert_eq!( + times, + vec![100, 200, 500, 1000, 1700, 2600, 3700, 5000, 6500, 8200, 10100, 10100] + ); +}