diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 309fe5815..9467c4717 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -167,6 +167,7 @@ impl Cluster { let mut metadata_reader = MetadataReader::new( known_nodes, + pool_config.hostname_resolution_timeout, control_connection_repair_sender, pool_config.connection_config.clone(), pool_config.keepalive_interval, diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs index 4b3de60c5..e3385333a 100644 --- a/scylla/src/transport/connection_pool.rs +++ b/scylla/src/transport/connection_pool.rs @@ -60,6 +60,7 @@ pub(crate) struct PoolConfig { pub(crate) pool_size: PoolSize, pub(crate) can_use_shard_aware_port: bool, pub(crate) keepalive_interval: Option, + pub(crate) hostname_resolution_timeout: Option, } impl Default for PoolConfig { @@ -69,6 +70,7 @@ impl Default for PoolConfig { pool_size: Default::default(), can_use_shard_aware_port: true, keepalive_interval: None, + hostname_resolution_timeout: None, } } } @@ -865,6 +867,7 @@ impl PoolRefiller { mut endpoint: UntranslatedEndpoint, ) -> impl Future { let cloud_config = self.pool_config.connection_config.cloud_config.clone(); + let hostname_resolution_timeout = self.pool_config.hostname_resolution_timeout; async move { if let Some(cloud_config) = cloud_config { // If we operate in the serverless Cloud, then we substitute every node's address @@ -879,7 +882,9 @@ impl PoolRefiller { if let Some(dc) = datacenter.as_deref() { if let Some(dc_config) = cloud_config.get_datacenters().get(dc) { let hostname = dc_config.get_server(); - if let Ok(resolved) = resolve_hostname(hostname).await { + if let Ok(resolved) = + resolve_hostname(hostname, hostname_resolution_timeout).await + { *address = NodeAddr::Untranslatable(resolved) } else { warn!( diff --git a/scylla/src/transport/node.rs b/scylla/src/transport/node.rs index ba2b3d9f4..20f9ade41 100644 --- a/scylla/src/transport/node.rs +++ b/scylla/src/transport/node.rs @@ -1,5 +1,6 @@ use itertools::Itertools; -use tokio::net::lookup_host; +use thiserror::Error; +use tokio::net::{lookup_host, ToSocketAddrs}; use tracing::warn; use uuid::Uuid; @@ -13,6 +14,7 @@ use crate::transport::errors::{ConnectionPoolError, QueryError}; use std::fmt::Display; use std::io; use std::net::IpAddr; +use std::time::Duration; use std::{ hash::{Hash, Hasher}, net::SocketAddr, @@ -267,27 +269,53 @@ pub(crate) struct ResolvedContactPoint { pub(crate) datacenter: Option, } +#[derive(Error, Debug)] +pub(crate) enum DnsLookupError { + #[error("Failed to perform DNS lookup within {0}ms")] + Timeout(u128), + #[error("Empty address list returned by DNS for {0}")] + EmptyAddressListForHost(String), + #[error(transparent)] + IoError(#[from] io::Error), +} + +/// Performs a DNS lookup with provided optional timeout. +async fn lookup_host_with_timeout( + host: T, + hostname_resolution_timeout: Option, +) -> Result, DnsLookupError> { + if let Some(timeout) = hostname_resolution_timeout { + match tokio::time::timeout(timeout, lookup_host(host)).await { + Ok(res) => res.map_err(Into::into), + // Elapsed error from tokio library does not provide any context. + Err(_) => Err(DnsLookupError::Timeout(timeout.as_millis())), + } + } else { + lookup_host(host).await.map_err(Into::into) + } +} + // Resolve the given hostname using a DNS lookup if necessary. // The resolution may return multiple IPs and the function returns one of them. // It prefers to return IPv4s first, and only if there are none, IPv6s. -pub(crate) async fn resolve_hostname(hostname: &str) -> Result { - let addrs = match lookup_host(hostname).await { +pub(crate) async fn resolve_hostname( + hostname: &str, + hostname_resolution_timeout: Option, +) -> Result { + let addrs = match lookup_host_with_timeout(hostname, hostname_resolution_timeout).await { Ok(addrs) => itertools::Either::Left(addrs), // Use a default port in case of error, but propagate the original error on failure Err(e) => { - let addrs = lookup_host((hostname, 9042)).await.or(Err(e))?; + let addrs = lookup_host_with_timeout((hostname, 9042), hostname_resolution_timeout) + .await + .or(Err(e))?; itertools::Either::Right(addrs) } }; addrs .find_or_last(|addr| matches!(addr, SocketAddr::V4(_))) - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - format!("Empty address list returned by DNS for {}", hostname), - ) - }) + .ok_or_else(|| DnsLookupError::EmptyAddressListForHost(hostname.to_owned())) } /// Transforms the given [`InternalKnownNode`]s into [`ContactPoint`]s. @@ -296,6 +324,7 @@ pub(crate) async fn resolve_hostname(hostname: &str) -> Result, ) -> (Vec, Vec) { // Find IP addresses of all known nodes passed in the config let mut initial_peers: Vec = Vec::with_capacity(known_nodes.len()); @@ -323,7 +352,7 @@ pub(crate) async fn resolve_contact_points( let resolve_futures = to_resolve .into_iter() .map(|(hostname, datacenter)| async move { - match resolve_hostname(hostname).await { + match resolve_hostname(hostname, hostname_resolution_timeout).await { Ok(address) => Some(ResolvedContactPoint { address, datacenter, diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 11b413dad..a1b61e6a8 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -300,6 +300,10 @@ pub struct SessionConfig { /// It is true by default but can be disabled if successive schema-altering statements should be performed. pub refresh_metadata_on_auto_schema_agreement: bool, + /// DNS hostname resolution timeout. + /// If `None`, the driver will wait for hostname resolution indefinitely. + pub hostname_resolution_timeout: Option, + /// The address translator is used to translate addresses received from ScyllaDB nodes /// (either with cluster metadata or with an event) to addresses that can be used to /// actually connect to those nodes. This may be needed e.g. when there is NAT @@ -380,6 +384,7 @@ impl SessionConfig { ssl_context: None, authenticator: None, connect_timeout: Duration::from_secs(5), + hostname_resolution_timeout: Some(Duration::from_secs(5)), connection_pool_size: Default::default(), disallow_shard_aware_port: false, keyspaces_to_fetch: Vec::new(), @@ -1095,6 +1100,7 @@ where pool_size: config.connection_pool_size, can_use_shard_aware_port: !config.disallow_shard_aware_port, keepalive_interval: config.keepalive_interval, + hostname_resolution_timeout: config.hostname_resolution_timeout, }; let cluster = Cluster::new( diff --git a/scylla/src/transport/session_builder.rs b/scylla/src/transport/session_builder.rs index 404e27733..c369baa76 100644 --- a/scylla/src/transport/session_builder.rs +++ b/scylla/src/transport/session_builder.rs @@ -806,6 +806,27 @@ impl GenericSessionBuilder { self } + /// Changes DNS hostname resolution timeout. + /// The default is 5 seconds. + /// + /// # Example + /// ``` + /// # use scylla::{Session, SessionBuilder}; + /// # use std::time::Duration; + /// # async fn example() -> Result<(), Box> { + /// let session: Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .hostname_resolution_timeout(Duration::from_secs(10)) + /// .build() // Turns SessionBuilder into Session + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn hostname_resolution_timeout(mut self, duration: Duration) -> Self { + self.config.hostname_resolution_timeout = Some(duration); + self + } + /// Sets the host filter. The host filter decides whether any connections /// should be opened to the node or not. The driver will also avoid /// those nodes when re-establishing the control connection. diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index 7bbfc6b2a..e123975df 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -41,6 +41,7 @@ use super::node::{InternalKnownNode, NodeAddr, ResolvedContactPoint}; pub(crate) struct MetadataReader { connection_config: ConnectionConfig, keepalive_interval: Option, + hostname_resolution_timeout: Option, control_connection_endpoint: UntranslatedEndpoint, control_connection: NodeConnectionPool, @@ -470,6 +471,7 @@ impl MetadataReader { #[allow(clippy::too_many_arguments)] pub(crate) async fn new( initial_known_nodes: Vec, + hostname_resolution_timeout: Option, control_connection_repair_requester: broadcast::Sender<()>, mut connection_config: ConnectionConfig, keepalive_interval: Option, @@ -479,7 +481,7 @@ impl MetadataReader { host_filter: &Option>, ) -> Result { let (initial_peers, resolved_hostnames) = - resolve_contact_points(&initial_known_nodes).await; + resolve_contact_points(&initial_known_nodes, hostname_resolution_timeout).await; // Ensure there is at least one resolved node if initial_peers.is_empty() { return Err(NewSessionError::FailedToResolveAnyHostname( @@ -503,6 +505,7 @@ impl MetadataReader { control_connection_endpoint.clone(), connection_config.clone(), keepalive_interval, + hostname_resolution_timeout, control_connection_repair_requester.clone(), ); @@ -510,6 +513,7 @@ impl MetadataReader { control_connection_endpoint, control_connection, keepalive_interval, + hostname_resolution_timeout, connection_config, known_peers: initial_peers .into_iter() @@ -570,8 +574,11 @@ impl MetadataReader { // If no known peer is reachable, try falling back to initial contact points, in hope that // there are some hostnames there which will resolve to reachable new addresses. warn!("Failed to establish control connection and fetch metadata on all known peers. Falling back to initial contact points."); - let (initial_peers, _hostnames) = - resolve_contact_points(&self.initial_known_nodes).await; + let (initial_peers, _hostnames) = resolve_contact_points( + &self.initial_known_nodes, + self.hostname_resolution_timeout, + ) + .await; result = self .retry_fetch_metadata_on_nodes( initial, @@ -630,6 +637,7 @@ impl MetadataReader { self.control_connection_endpoint.clone(), self.connection_config.clone(), self.keepalive_interval, + self.hostname_resolution_timeout, self.control_connection_repair_requester.clone(), ); @@ -730,6 +738,7 @@ impl MetadataReader { self.control_connection_endpoint.clone(), self.connection_config.clone(), self.keepalive_interval, + self.hostname_resolution_timeout, self.control_connection_repair_requester.clone(), ); } @@ -741,11 +750,13 @@ impl MetadataReader { endpoint: UntranslatedEndpoint, connection_config: ConnectionConfig, keepalive_interval: Option, + hostname_resolution_timeout: Option, refresh_requester: broadcast::Sender<()>, ) -> NodeConnectionPool { let pool_config = PoolConfig { connection_config, keepalive_interval, + hostname_resolution_timeout, // We want to have only one connection to receive events from pool_size: PoolSize::PerHost(NonZeroUsize::new(1).unwrap()),