Skip to content

Commit

Permalink
node: perform DNS lookup with timeout
Browse files Browse the repository at this point in the history
Passed the hostname_resolution_timeout down to the functions
responsible for DNS resolution logic.

Created a pub(crate) error type to distinguish between errors that
can occur during hostname resolution. Notice: it's pub(crate) since
the users of this API only emit logs in case of error. The errors are
not passed to public API.

Created a `lookup_host_with_timeout` function, and extracted
some logic here. The purpose is not to introduce complex branching
in original function.
  • Loading branch information
muzarski committed Dec 11, 2024
1 parent ee813e4 commit 86e69d8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
5 changes: 4 additions & 1 deletion scylla/src/transport/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,7 @@ impl PoolRefiller {
mut endpoint: UntranslatedEndpoint,
) -> impl Future<Output = UntranslatedEndpoint> {
let cloud_config = self.pool_config.connection_config.cloud_config.clone();
let hostname_resolution_timeout = self.pool_config.hostname_resolution_timeout;
async move {
if let Some(cloud_config) = cloud_config {
// If we operate in the serverless Cloud, then we substitute every node's address
Expand All @@ -881,7 +882,9 @@ impl PoolRefiller {
if let Some(dc) = datacenter.as_deref() {
if let Some(dc_config) = cloud_config.get_datacenters().get(dc) {
let hostname = dc_config.get_server();
if let Ok(resolved) = resolve_hostname(hostname).await {
if let Ok(resolved) =
resolve_hostname(hostname, hostname_resolution_timeout).await
{
*address = NodeAddr::Untranslatable(resolved)
} else {
warn!(
Expand Down
51 changes: 40 additions & 11 deletions scylla/src/transport/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use itertools::Itertools;
use tokio::net::lookup_host;
use thiserror::Error;
use tokio::net::{lookup_host, ToSocketAddrs};
use tracing::warn;
use uuid::Uuid;

Expand All @@ -13,6 +14,7 @@ use crate::transport::errors::{ConnectionPoolError, QueryError};
use std::fmt::Display;
use std::io;
use std::net::IpAddr;
use std::time::Duration;
use std::{
hash::{Hash, Hasher},
net::SocketAddr,
Expand Down Expand Up @@ -267,27 +269,53 @@ pub(crate) struct ResolvedContactPoint {
pub(crate) datacenter: Option<String>,
}

#[derive(Error, Debug)]
pub(crate) enum DnsLookupError {
#[error("Failed to perform DNS lookup within {0}ms")]
Timeout(u128),
#[error("Empty address list returned by DNS for {0}")]
EmptyAddressListForHost(String),
#[error(transparent)]
IoError(#[from] io::Error),
}

/// Performs a DNS lookup with provided optional timeout.
async fn lookup_host_with_timeout<T: ToSocketAddrs>(
host: T,
hostname_resolution_timeout: Option<Duration>,
) -> Result<impl Iterator<Item = SocketAddr>, DnsLookupError> {
if let Some(timeout) = hostname_resolution_timeout {
match tokio::time::timeout(timeout, lookup_host(host)).await {
Ok(res) => res.map_err(Into::into),
// Elapsed error from tokio library does not provide any context.
Err(_) => Err(DnsLookupError::Timeout(timeout.as_millis())),
}
} else {
lookup_host(host).await.map_err(Into::into)
}
}

// Resolve the given hostname using a DNS lookup if necessary.
// The resolution may return multiple IPs and the function returns one of them.
// It prefers to return IPv4s first, and only if there are none, IPv6s.
pub(crate) async fn resolve_hostname(hostname: &str) -> Result<SocketAddr, io::Error> {
let addrs = match lookup_host(hostname).await {
pub(crate) async fn resolve_hostname(
hostname: &str,
hostname_resolution_timeout: Option<Duration>,
) -> Result<SocketAddr, DnsLookupError> {
let addrs = match lookup_host_with_timeout(hostname, hostname_resolution_timeout).await {
Ok(addrs) => itertools::Either::Left(addrs),
// Use a default port in case of error, but propagate the original error on failure
Err(e) => {
let addrs = lookup_host((hostname, 9042)).await.or(Err(e))?;
let addrs = lookup_host_with_timeout((hostname, 9042), hostname_resolution_timeout)
.await
.or(Err(e))?;
itertools::Either::Right(addrs)
}
};

addrs
.find_or_last(|addr| matches!(addr, SocketAddr::V4(_)))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("Empty address list returned by DNS for {}", hostname),
)
})
.ok_or_else(|| DnsLookupError::EmptyAddressListForHost(hostname.to_owned()))
}

/// Transforms the given [`InternalKnownNode`]s into [`ContactPoint`]s.
Expand All @@ -296,6 +324,7 @@ pub(crate) async fn resolve_hostname(hostname: &str) -> Result<SocketAddr, io::E
/// In case of a plain IP address, parses it and uses straight.
pub(crate) async fn resolve_contact_points(
known_nodes: &[InternalKnownNode],
hostname_resolution_timeout: Option<Duration>,
) -> (Vec<ResolvedContactPoint>, Vec<String>) {
// Find IP addresses of all known nodes passed in the config
let mut initial_peers: Vec<ResolvedContactPoint> = Vec::with_capacity(known_nodes.len());
Expand Down Expand Up @@ -323,7 +352,7 @@ pub(crate) async fn resolve_contact_points(
let resolve_futures = to_resolve
.into_iter()
.map(|(hostname, datacenter)| async move {
match resolve_hostname(hostname).await {
match resolve_hostname(hostname, hostname_resolution_timeout).await {
Ok(address) => Some(ResolvedContactPoint {
address,
datacenter,
Expand Down
9 changes: 6 additions & 3 deletions scylla/src/transport/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ impl MetadataReader {
host_filter: &Option<Arc<dyn HostFilter>>,
) -> Result<Self, NewSessionError> {
let (initial_peers, resolved_hostnames) =
resolve_contact_points(&initial_known_nodes).await;
resolve_contact_points(&initial_known_nodes, hostname_resolution_timeout).await;
// Ensure there is at least one resolved node
if initial_peers.is_empty() {
return Err(NewSessionError::FailedToResolveAnyHostname(
Expand Down Expand Up @@ -574,8 +574,11 @@ impl MetadataReader {
// If no known peer is reachable, try falling back to initial contact points, in hope that
// there are some hostnames there which will resolve to reachable new addresses.
warn!("Failed to establish control connection and fetch metadata on all known peers. Falling back to initial contact points.");
let (initial_peers, _hostnames) =
resolve_contact_points(&self.initial_known_nodes).await;
let (initial_peers, _hostnames) = resolve_contact_points(
&self.initial_known_nodes,
self.hostname_resolution_timeout,
)
.await;
result = self
.retry_fetch_metadata_on_nodes(
initial,
Expand Down

0 comments on commit 86e69d8

Please sign in to comment.