diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index a9ace5c31a..624b848f20 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -29,7 +29,8 @@ use std::{ net::{Ipv4Addr, Ipv6Addr}, }; -use super::errors::{BadKeyspaceName, BadQuery, DbError, QueryError}; +use super::errors::{BadKeyspaceName, DbError, QueryError}; +use super::iterator::RowIterator; use crate::batch::{Batch, BatchStatement}; use crate::frame::protocol_features::ProtocolFeatures; @@ -457,98 +458,6 @@ impl Connection { .await } - /// Performs query_single_page multiple times to query all available pages - pub async fn query_all( - &self, - query: &Query, - values: impl ValueList, - ) -> Result { - // This method is used only for driver internal queries, so no need to consult execution profile here. - self.query_all_with_consistency( - query, - values, - query - .config - .determine_consistency(self.config.default_consistency), - query.get_serial_consistency(), - ) - .await - } - - pub async fn query_all_with_consistency( - &self, - query: &Query, - values: impl ValueList, - consistency: Consistency, - serial_consistency: Option, - ) -> Result { - if query.get_page_size().is_none() { - // Page size should be set when someone wants to use paging - return Err(QueryError::BadQuery(BadQuery::Other( - "Called Connection::query_all without page size set!".to_string(), - ))); - } - - let mut final_result = QueryResult::default(); - - let serialized_values = values.serialized()?; - let mut paging_state: Option = None; - - loop { - // Send next paged query - let mut cur_result: QueryResult = self - .query_with_consistency( - query, - &serialized_values, - consistency, - serial_consistency, - paging_state, - ) - .await? - .into_query_result()?; - - // Set paging_state for the next query - paging_state = cur_result.paging_state.take(); - - // Add current query results to the final_result - final_result.merge_with_next_page_res(cur_result); - - if paging_state.is_none() { - // No more pages to query, we can return the final result - return Ok(final_result); - } - } - } - - pub async fn execute_single_page( - &self, - prepared_statement: &PreparedStatement, - values: impl ValueList, - paging_state: Option, - ) -> Result { - self.execute(prepared_statement, values, paging_state) - .await? - .into_query_result() - } - - pub async fn execute( - &self, - prepared_statement: &PreparedStatement, - values: impl ValueList, - paging_state: Option, - ) -> Result { - self.execute_with_consistency( - prepared_statement, - values, - prepared_statement - .config - .determine_consistency(self.config.default_consistency), - prepared_statement.config.serial_consistency.flatten(), - paging_state, - ) - .await - } - pub async fn execute_with_consistency( &self, prepared_statement: &PreparedStatement, @@ -591,41 +500,28 @@ impl Connection { } } - /// Performs execute_single_page multiple times to fetch all available pages - #[allow(dead_code)] - pub async fn execute_all( - &self, - prepared_statement: &PreparedStatement, + /// Executes a query and fetches its results over multiple pages, using + /// the asynchronous iterator interface. + pub(crate) async fn query_iter( + self: Arc, + query: Query, values: impl ValueList, - ) -> Result { - if prepared_statement.get_page_size().is_none() { - return Err(QueryError::BadQuery(BadQuery::Other( - "Called Connection::execute_all without page size set!".to_string(), - ))); - } - - let mut final_result = QueryResult::default(); - - let serialized_values = values.serialized()?; - let mut paging_state: Option = None; - - loop { - // Send next paged query - let mut cur_result: QueryResult = self - .execute_single_page(prepared_statement, &serialized_values, paging_state) - .await?; - - // Set paging_state for the next query - paging_state = cur_result.paging_state.take(); + ) -> Result { + let serialized_values = values.serialized()?.into_owned(); - // Add current query results to the final_result - final_result.merge_with_next_page_res(cur_result); + let consistency = query + .config + .determine_consistency(self.config.default_consistency); + let serial_consistency = query.config.serial_consistency.flatten(); - if paging_state.is_none() { - // No more pages to query, we can return the final result - return Ok(final_result); - } - } + RowIterator::new_for_connection_query_iter( + query, + self, + serialized_values, + consistency, + serial_consistency, + ) + .await } #[allow(dead_code)] @@ -1559,7 +1455,6 @@ impl VerifiedKeyspaceName { #[cfg(test)] mod tests { - use scylla_cql::errors::BadQuery; use scylla_cql::frame::protocol_features::{ LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION, }; @@ -1572,12 +1467,12 @@ mod tests { use tokio::select; use tokio::sync::mpsc; - use super::super::errors::QueryError; use super::ConnectionConfig; use crate::query::Query; use crate::transport::connection::open_connection; use crate::utils::test_utils::unique_keyspace_name; - use crate::{IntoTypedRows, SessionBuilder}; + use crate::SessionBuilder; + use futures::{StreamExt, TryStreamExt}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -1596,20 +1491,20 @@ mod tests { } } - /// Tests for Connection::query_all and Connection::execute_all + /// Tests for Connection::query_iter /// 1. SELECT from an empty table. /// 2. Create table and insert ints 0..100. - /// Then use query_all and execute_all with page_size set to 7 to select all 100 rows. - /// 3. INSERT query_all should have None in result rows. - /// 4. Calling query_all with a Query that doesn't have page_size set should result in an error. + /// Then use query_iter with page_size set to 7 to select all 100 rows. + /// 3. INSERT query_iter should work and not return any rows. #[tokio::test] - async fn connection_query_all_execute_all_test() { + async fn connection_query_iter_test() { let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string()); let addr: SocketAddr = resolve_hostname(&uri).await; let (connection, _) = super::open_connection(addr, None, ConnectionConfig::default()) .await .unwrap(); + let connection = Arc::new(connection); let ks = unique_keyspace_name(); @@ -1623,12 +1518,12 @@ mod tests { session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'SimpleStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap(); session.use_keyspace(ks.clone(), false).await.unwrap(); session - .query("DROP TABLE IF EXISTS connection_query_all_tab", &[]) + .query("DROP TABLE IF EXISTS connection_query_iter_tab", &[]) .await .unwrap(); session .query( - "CREATE TABLE IF NOT EXISTS connection_query_all_tab (p int primary key)", + "CREATE TABLE IF NOT EXISTS connection_query_iter_tab (p int primary key)", &[], ) .await @@ -1641,20 +1536,22 @@ mod tests { .unwrap(); // 1. SELECT from an empty table returns query result where rows are Some(Vec::new()) - let select_query = Query::new("SELECT p FROM connection_query_all_tab").with_page_size(7); - let empty_res = connection.query_all(&select_query, &[]).await.unwrap(); - assert!(empty_res.rows.unwrap().is_empty()); - - let mut prepared_select = connection.prepare(&select_query).await.unwrap(); - prepared_select.set_page_size(7); - let empty_res_prepared = connection.execute_all(&prepared_select, &[]).await.unwrap(); - assert!(empty_res_prepared.rows.unwrap().is_empty()); + let select_query = Query::new("SELECT p FROM connection_query_iter_tab").with_page_size(7); + let empty_res = connection + .clone() + .query_iter(select_query.clone(), &[]) + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + assert!(empty_res.is_empty()); - // 2. Insert 100 and select using query_all with page_size 7 + // 2. Insert 100 and select using query_iter with page_size 7 let values: Vec = (0..100).collect(); let mut insert_futures = Vec::new(); let insert_query = - Query::new("INSERT INTO connection_query_all_tab (p) VALUES (?)").with_page_size(7); + Query::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)").with_page_size(7); for v in &values { insert_futures.push(connection.query_single_page(insert_query.clone(), (v,))); } @@ -1662,56 +1559,26 @@ mod tests { futures::future::try_join_all(insert_futures).await.unwrap(); let mut results: Vec = connection - .query_all(&select_query, &[]) + .clone() + .query_iter(select_query.clone(), &[]) .await .unwrap() - .rows - .unwrap() .into_typed::<(i32,)>() - .map(|r| r.unwrap().0) - .collect(); + .map(|ret| ret.unwrap().0) + .collect::>() + .await; results.sort_unstable(); // Clippy recommended to use sort_unstable instead of sort() assert_eq!(results, values); - let mut results2: Vec = connection - .execute_all(&prepared_select, &[]) + // 3. INSERT query_iter should work and not return any rows. + let insert_res1 = connection + .query_iter(insert_query, (0,)) .await .unwrap() - .rows - .unwrap() - .into_typed::<(i32,)>() - .map(|r| r.unwrap().0) - .collect(); - results2.sort_unstable(); - assert_eq!(results2, values); - - // 3. INSERT query_all should have None in result rows. - let insert_res1 = connection.query_all(&insert_query, (0,)).await.unwrap(); - assert!(insert_res1.rows.is_none()); - - let prepared_insert = connection.prepare(&insert_query).await.unwrap(); - let insert_res2 = connection - .execute_all(&prepared_insert, (0,)) + .try_collect::>() .await .unwrap(); - assert!(insert_res2.rows.is_none(),); - - // 4. Calling query_all with a Query that doesn't have page_size set should result in an error. - let no_page_size_query = Query::new("SELECT p FROM connection_query_all_tab"); - let no_page_res = connection.query_all(&no_page_size_query, &[]).await; - assert!(matches!( - no_page_res, - Err(QueryError::BadQuery(BadQuery::Other(_))) - )); - - let prepared_no_page_size_query = connection.prepare(&no_page_size_query).await.unwrap(); - let prepared_no_page_res = connection - .execute_all(&prepared_no_page_size_query, &[]) - .await; - assert!(matches!( - prepared_no_page_res, - Err(QueryError::BadQuery(BadQuery::Other(_))) - )); + assert!(insert_res1.is_empty()); } #[tokio::test] diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index 5db3c3fd7f..c94a3235c6 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -9,6 +9,7 @@ use std::task::{Context, Poll}; use bytes::Bytes; use futures::Stream; +use scylla_cql::frame::types::SerialConsistency; use std::result::Result; use thiserror::Error; use tokio::sync::mpsc; @@ -136,7 +137,7 @@ impl RowIterator { if query.get_page_size().is_none() { query.set_page_size(DEFAULT_ITER_PAGE_SIZE); } - let (sender, mut receiver) = mpsc::channel(1); + let (sender, receiver) = mpsc::channel(1); let consistency = query .config @@ -185,28 +186,10 @@ impl RowIterator { current_attempt_id: None, }; - let _: PageSendAttemptedProof = worker.work(cluster_data).await; + worker.work(cluster_data).await }; - tokio::task::spawn(worker_task); - - // This unwrap is safe because: - // - The future returned by worker.work sends at least one item - // to the channel (the PageSendAttemptedProof helps enforce this) - // - That future is polled in a tokio::task which isn't going to be - // cancelled - let pages_received = receiver.recv().await.unwrap()?; - - Ok(RowIterator { - current_row_idx: 0, - current_page: pages_received.rows, - page_receiver: receiver, - tracing_ids: if let Some(tracing_id) = pages_received.tracing_id { - vec![tracing_id] - } else { - Vec::new() - }, - }) + Self::new_from_worker_future(worker_task, receiver).await } pub(crate) async fn new_for_prepared_statement( @@ -215,7 +198,7 @@ impl RowIterator { if config.prepared.get_page_size().is_none() { config.prepared.set_page_size(DEFAULT_ITER_PAGE_SIZE); } - let (sender, mut receiver) = mpsc::channel(1); + let (sender, receiver) = mpsc::channel(1); let consistency = config .prepared @@ -277,10 +260,50 @@ impl RowIterator { current_attempt_id: None, }; - let _: PageSendAttemptedProof = worker.work(config.cluster_data).await; + worker.work(config.cluster_data).await + }; + + Self::new_from_worker_future(worker_task, receiver).await + } + + pub(crate) async fn new_for_connection_query_iter( + mut query: Query, + connection: Arc, + values: SerializedValues, + consistency: Consistency, + serial_consistency: Option, + ) -> Result { + if query.get_page_size().is_none() { + query.set_page_size(DEFAULT_ITER_PAGE_SIZE); + } + let (sender, receiver) = mpsc::channel::>(1); + + let worker_task = async move { + let worker = SingleConnectionRowIteratorWorker { + sender: sender.into(), + fetcher: |paging_state| { + connection.query_with_consistency( + &query, + &values, + consistency, + serial_consistency, + paging_state, + ) + }, + }; + worker.work().await }; - tokio::task::spawn(worker_task); + Self::new_from_worker_future(worker_task, receiver).await + } + + async fn new_from_worker_future( + worker_task: impl Future + Send + 'static, + mut receiver: mpsc::Receiver>, + ) -> Result { + tokio::task::spawn(async move { + worker_task.await; + }); // This unwrap is safe because: // - The future returned by worker.work sends at least one item @@ -319,8 +342,12 @@ impl RowIterator { // A separate module is used here so that the parent module cannot construct // SendAttemptedProof directly. mod checked_channel_sender { + use scylla_cql::{errors::QueryError, frame::response::result::Rows}; use std::marker::PhantomData; use tokio::sync::mpsc; + use uuid::Uuid; + + use super::ReceivedPage; /// A value whose existence proves that there was an attempt /// to send an item of type T through a channel. @@ -344,6 +371,28 @@ mod checked_channel_sender { (SendAttemptedProof(PhantomData), self.0.send(value).await) } } + + type ResultPage = Result; + + impl ProvingSender { + pub(crate) async fn send_empty_page( + &self, + tracing_id: Option, + ) -> ( + SendAttemptedProof, + Result<(), mpsc::error::SendError>, + ) { + let empty_page = ReceivedPage { + rows: Rows { + metadata: Default::default(), + rows_count: 0, + rows: Vec::new(), + }, + tracing_id, + }; + self.send(Ok(empty_page)).await + } + } } use checked_channel_sender::{ProvingSender, SendAttemptedProof}; @@ -481,17 +530,7 @@ where // interface isn't meant for sending writes), // we must attempt to send something because // the iterator expects it. - let (proof, _) = self - .sender - .send(Ok(ReceivedPage { - rows: Rows { - metadata: Default::default(), - rows_count: 0, - rows: Vec::new(), - }, - tracing_id: None, - })) - .await; + let (proof, _) = self.sender.send_empty_page(None).await; return proof; } }; @@ -581,17 +620,7 @@ where // so let's return an empty iterator as suggested in #631. // We must attempt to send something because the iterator expects it. - let (proof, _) = self - .sender - .send(Ok(ReceivedPage { - rows: Rows { - metadata: Default::default(), - rows_count: 0, - rows: Vec::new(), - }, - tracing_id, - })) - .await; + let (proof, _) = self.sender.send_empty_page(tracing_id).await; return Ok(proof); } Ok(_) => { @@ -686,6 +715,66 @@ where } } +/// A massively simplified version of the RowIteratorWorker. It does not have +/// any complicated logic related to retries, it just fetches pages from +/// a single connection. +struct SingleConnectionRowIteratorWorker { + sender: ProvingSender>, + fetcher: Fetcher, +} + +impl SingleConnectionRowIteratorWorker +where + Fetcher: Fn(Option) -> FetchFut + Send + Sync, + FetchFut: Future> + Send, +{ + async fn work(mut self) -> PageSendAttemptedProof { + match self.do_work().await { + Ok(proof) => proof, + Err(err) => { + let (proof, _) = self.sender.send(Err(err)).await; + proof + } + } + } + + async fn do_work(&mut self) -> Result { + let mut paging_state = None; + loop { + let result = (self.fetcher)(paging_state).await?; + let response = result.into_non_error_query_response()?; + match response.response { + NonErrorResponse::Result(result::Result::Rows(mut rows)) => { + paging_state = rows.metadata.paging_state.take(); + let (proof, send_result) = self + .sender + .send(Ok(ReceivedPage { + rows, + tracing_id: response.tracing_id, + })) + .await; + if paging_state.is_none() || send_result.is_err() { + return Ok(proof); + } + } + NonErrorResponse::Result(_) => { + // We have most probably sent a modification statement (e.g. INSERT or UPDATE), + // so let's return an empty iterator as suggested in #631. + + // We must attempt to send something because the iterator expects it. + let (proof, _) = self.sender.send_empty_page(response.tracing_id).await; + return Ok(proof); + } + _ => { + return Err(QueryError::ProtocolError( + "Unexpected response to next page query", + )); + } + } + } + } +} + /// Iterator over rows returned by paged queries /// where each row is parsed as the given type\ /// Returned by `RowIterator::into_typed` diff --git a/scylla/src/transport/query_result.rs b/scylla/src/transport/query_result.rs index 426553cbaf..ca861e787b 100644 --- a/scylla/src/transport/query_result.rs +++ b/scylla/src/transport/query_result.rs @@ -131,23 +131,6 @@ impl QueryResult { .enumerate() .find(|(_id, spec)| spec.name == name) } - - /// This function is used to merge results of multiple paged queries into one.\ - /// other is the result of a new paged query.\ - /// It is merged with current result kept in self.\ - pub(crate) fn merge_with_next_page_res(&mut self, other: QueryResult) { - if let Some(other_rows) = other.rows { - match &mut self.rows { - Some(self_rows) => self_rows.extend(other_rows), - None => self.rows = Some(other_rows), - } - }; - - self.warnings.extend(other.warnings); - self.tracing_id = other.tracing_id; - self.paging_state = other.paging_state; - self.col_specs = other.col_specs; - } } /// [`QueryResult::rows()`](QueryResult::rows) or a similar function called on a bad QueryResult.\ diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index f73b957af7..6e3388d3c3 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -5,14 +5,17 @@ use crate::transport::connection::{Connection, ConnectionConfig}; use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig, PoolSize}; use crate::transport::errors::{DbError, QueryError}; use crate::transport::host_filter::HostFilter; -use crate::transport::session::{AddressTranslator, IntoTypedRows}; +use crate::transport::session::AddressTranslator; use crate::utils::parse::{ParseErrorCause, ParseResult, ParserState}; -use crate::QueryResult; -use futures::future::try_join_all; -use itertools::Itertools; +use futures::future::{self, FutureExt}; +use futures::stream::{self, StreamExt, TryStreamExt}; +use futures::Stream; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; +use scylla_cql::frame::response::result::Row; +use scylla_cql::frame::value::ValueList; +use scylla_macros::FromRow; use std::borrow::BorrowMut; use std::collections::HashMap; use std::fmt; @@ -335,7 +338,7 @@ impl MetadataReader { async fn fetch_metadata(&self, initial: bool) -> Result { // TODO: Timeouts? self.control_connection.wait_until_initialized().await; - let conn = &*self.control_connection.random_connection()?; + let conn = &self.control_connection.random_connection()?; let res = query_metadata( conn, @@ -448,7 +451,7 @@ impl MetadataReader { } async fn query_metadata( - conn: &Connection, + conn: &Arc, connect_port: u16, address_translator: Option<&dyn AddressTranslator>, keyspace_to_fetch: &[String], @@ -476,149 +479,191 @@ async fn query_metadata( Ok(Metadata { peers, keyspaces }) } +#[derive(FromRow)] +#[scylla_crate = "scylla_cql"] +struct NodeInfoRow { + host_id: Option, + untranslated_ip_addr: IpAddr, + datacenter: Option, + rack: Option, + tokens: Option>, +} + +#[derive(Clone, Copy)] +enum NodeInfoSource { + Local, + Peer, +} + +impl NodeInfoSource { + fn describe(&self) -> &'static str { + match self { + Self::Local => "local node", + Self::Peer => "peer", + } + } +} + async fn query_peers( - conn: &Connection, + conn: &Arc, connect_port: u16, address_translator: Option<&dyn AddressTranslator>, ) -> Result, QueryError> { let mut peers_query = Query::new("select host_id, rpc_address, data_center, rack, tokens from system.peers"); peers_query.set_page_size(1024); - let peers_query_future = conn.query_all(&peers_query, &[]); + let peers_query_stream = conn + .clone() + .query_iter(peers_query, &[]) + .into_stream() + .try_flatten() + .and_then(|row_result| future::ok((NodeInfoSource::Peer, row_result))); let mut local_query = Query::new("select host_id, rpc_address, data_center, rack, tokens from system.local"); local_query.set_page_size(1024); - let local_query_future = conn.query_all(&local_query, &[]); + let local_query_stream = conn + .clone() + .query_iter(local_query, &[]) + .into_stream() + .try_flatten() + .and_then(|row_result| future::ok((NodeInfoSource::Local, row_result))); - let (peers_res, local_res) = tokio::try_join!(peers_query_future, local_query_future)?; + let untranslated_rows = stream::select(peers_query_stream, local_query_stream); - let peers_rows = peers_res.rows.ok_or(QueryError::ProtocolError( - "system.peers query response was not Rows", - ))?; + let local_ip: IpAddr = conn.get_connect_address().ip(); + let local_address = SocketAddr::new(local_ip, connect_port); - let local_rows = local_res.rows.ok_or(QueryError::ProtocolError( - "system.local query response was not Rows", - ))?; + let translated_peers_futures = untranslated_rows.map(|row_result| async { + let (source, raw_row) = row_result?; + let row = raw_row.into_typed().map_err(|_| { + QueryError::ProtocolError("system.peers or system.local has invalid column type") + })?; + create_peer_from_row(source, row, local_address, address_translator).await + }); - let typed_peers_rows = peers_rows.into_typed::<( - Option, - IpAddr, - Option, - Option, - Option>, - )>(); + let peers = translated_peers_futures + .buffer_unordered(256) + .try_collect::>() + .await?; + Ok(peers.into_iter().flatten().collect()) +} - let local_ip: IpAddr = conn.get_connect_address().ip(); - let local_address = SocketAddr::new(local_ip, connect_port); +async fn create_peer_from_row( + source: NodeInfoSource, + row: NodeInfoRow, + local_address: SocketAddr, + address_translator: Option<&dyn AddressTranslator>, +) -> Result, QueryError> { + let NodeInfoRow { + host_id, + untranslated_ip_addr, + datacenter, + rack, + tokens, + } = row; + + let host_id = match host_id { + Some(host_id) => host_id, + None => { + warn!("{} (untranslated ip: {}, dc: {:?}, rack: {:?}) has Host ID set to null; skipping node.", source.describe(), untranslated_ip_addr, datacenter, rack); + return Ok(None); + } + }; - let typed_local_rows = local_rows.into_typed::<( - Option, - IpAddr, - Option, - Option, - Option>, - )>(); - - let untranslated_rows = typed_peers_rows - .map(|res| res.map(|peer_row| (false, peer_row))) - .chain(typed_local_rows.map(|res| res.map(|local_row| (true, local_row)))); - - let translated_peers_futures = untranslated_rows - .filter_map_ok(|(is_local, (host_id, ip, dc, rack, tokens))| if let Some(host_id) = host_id { - Some((is_local, (host_id, ip, dc, rack, tokens))) - } else { - let who = if is_local { "Local node" } else { "Peer" }; - warn!("{} (untranslated ip: {}, dc: {:?}, rack: {:?}) has Host ID set to null; skipping node.", who, ip, dc, rack); - None - }) - .map(|untranslated_row| async { - let (is_local, (host_id, untranslated_ip_addr, datacenter, rack, tokens)) = untranslated_row.map_err( - |_| QueryError::ProtocolError("system.peers or system.local has invalid column type") - )?; - let untranslated_address = SocketAddr::new(untranslated_ip_addr, connect_port); - - let (untranslated_address, address) = match (is_local, address_translator) { - (true, None) => { - // We need to replace rpc_address with control connection address. - (Some(untranslated_address), local_address) - }, - (true, Some(_)) => { - // The address we used to connect is most likely different and we just don't know. - (None, local_address) - }, - (false, None) => { - // The usual case - no translation. - (Some(untranslated_address), untranslated_address) - }, - (false, Some(translator)) => { - // We use the provided translator and skip the peer if there is no rule for translating it. - (Some(untranslated_address), - match translator.translate_address(&UntranslatedPeer {host_id, untranslated_address}).await { - Ok(address) => address, - Err(err) => { - warn!("Could not translate address {}; TranslationError: {:?}; node therefore skipped.", - untranslated_address, err); - return Ok::, QueryError>(None); - } - } - ) - } - }; + let connect_port = local_address.port(); + let untranslated_address = SocketAddr::new(untranslated_ip_addr, connect_port); - let tokens_str: Vec = tokens.unwrap_or_default(); + let (untranslated_address, address) = match (source, address_translator) { + (NodeInfoSource::Local, None) => { + // We need to replace rpc_address with control connection address. + (Some(untranslated_address), local_address) + } + (NodeInfoSource::Local, Some(_)) => { + // The address we used to connect is most likely different and we just don't know. + (None, local_address) + } + (NodeInfoSource::Peer, None) => { + // The usual case - no translation. + (Some(untranslated_address), untranslated_address) + } + (NodeInfoSource::Peer, Some(translator)) => { + // We use the provided translator and skip the peer if there is no rule for translating it. + ( + Some(untranslated_address), + match translator + .translate_address(&UntranslatedPeer { + host_id, + untranslated_address, + }) + .await + { + Ok(address) => address, + Err(err) => { + warn!("Could not translate address {}; TranslationError: {:?}; node therefore skipped.", + untranslated_address, err); + return Ok::, QueryError>(None); + } + }, + ) + } + }; - // Parse string representation of tokens as integer values - let tokens: Vec = match tokens_str - .iter() - .map(|s| Token::from_str(s)) - .collect::, _>>() - { - Ok(parsed) => parsed, - Err(e) => { - // FIXME: we could allow the users to provide custom partitioning information - // in order for it to work with non-standard token sizes. - // Also, we could implement support for Cassandra's other standard partitioners - // like RandomPartitioner or ByteOrderedPartitioner. - trace!("Couldn't parse tokens as 64-bit integers: {}, proceeding with a dummy token. If you're using a partitioner with different token size, consider migrating to murmur3", e); - vec![Token { - value: rand::thread_rng().gen::(), - }] - } - }; + let tokens_str: Vec = tokens.unwrap_or_default(); - Ok(Some(Peer { - host_id, - untranslated_address, - address, - tokens, - datacenter, - rack, - })) - }); + // Parse string representation of tokens as integer values + let tokens: Vec = match tokens_str + .iter() + .map(|s| Token::from_str(s)) + .collect::, _>>() + { + Ok(parsed) => parsed, + Err(e) => { + // FIXME: we could allow the users to provide custom partitioning information + // in order for it to work with non-standard token sizes. + // Also, we could implement support for Cassandra's other standard partitioners + // like RandomPartitioner or ByteOrderedPartitioner. + trace!("Couldn't parse tokens as 64-bit integers: {}, proceeding with a dummy token. If you're using a partitioner with different token size, consider migrating to murmur3", e); + vec![Token { + value: rand::thread_rng().gen::(), + }] + } + }; - let peers = try_join_all(translated_peers_futures).await?; - Ok(peers.into_iter().flatten().collect()) + Ok(Some(Peer { + host_id, + untranslated_address, + address, + tokens, + datacenter, + rack, + })) } -async fn query_filter_keyspace_name( - conn: &Connection, +fn query_filter_keyspace_name( + conn: &Arc, query_str: &str, keyspaces_to_fetch: &[String], -) -> Result { +) -> impl Stream> { let keyspaces = &[keyspaces_to_fetch] as &[&[String]]; let (query_str, query_values) = if !keyspaces_to_fetch.is_empty() { (format!("{query_str} where keyspace_name in ?"), keyspaces) } else { (query_str.into(), &[] as &[&[String]]) }; + let query_values = query_values.serialized().map(|sv| sv.into_owned()); let mut query = Query::new(query_str); + let conn = conn.clone(); query.set_page_size(1024); - conn.query_all(&query, query_values).await + let fut = async move { + let query_values = query_values?; + conn.query_iter(query, query_values).await + }; + fut.into_stream().try_flatten() } async fn query_keyspaces( - conn: &Connection, + conn: &Arc, keyspaces_to_fetch: &[String], fetch_schema: bool, ) -> Result, QueryError> { @@ -626,14 +671,8 @@ async fn query_keyspaces( conn, "select keyspace_name, replication from system_schema.keyspaces", keyspaces_to_fetch, - ) - .await? - .rows - .ok_or(QueryError::ProtocolError( - "system_schema.keyspaces query response was not Rows", - ))?; - - let mut result = HashMap::with_capacity(rows.len()); + ); + let (mut all_tables, mut all_views, mut all_user_defined_types) = if fetch_schema { ( query_tables(conn, keyspaces_to_fetch).await?, @@ -644,8 +683,9 @@ async fn query_keyspaces( (HashMap::new(), HashMap::new(), HashMap::new()) }; - for row in rows.into_typed::<(String, HashMap)>() { - let (keyspace_name, strategy_map) = row.map_err(|_| { + rows.map(|row_result| { + let row = row_result?; + let (keyspace_name, strategy_map) = row.into_typed().map_err(|_| { QueryError::ProtocolError("system_schema.keyspaces has invalid column type") })?; @@ -656,39 +696,39 @@ async fn query_keyspaces( .remove(&keyspace_name) .unwrap_or_default(); - result.insert( - keyspace_name, - Keyspace { - strategy, - tables, - views, - user_defined_types, - }, - ); - } + let keyspace = Keyspace { + strategy, + tables, + views, + user_defined_types, + }; - Ok(result) + Ok((keyspace_name, keyspace)) + }) + .try_collect() + .await } async fn query_user_defined_types( - conn: &Connection, + conn: &Arc, keyspaces_to_fetch: &[String], ) -> Result>>, QueryError> { let rows = query_filter_keyspace_name( conn, "select keyspace_name, type_name, field_names, field_types from system_schema.types", keyspaces_to_fetch, - ) - .await? - .rows - .ok_or(QueryError::ProtocolError( - "system_schema.types query response was not Rows", - ))?; + ); - let mut result = HashMap::with_capacity(rows.len()); + let mut result = HashMap::new(); - for row in rows.into_typed::<(String, String, Vec, Vec)>() { - let (keyspace_name, type_name, field_names, field_types) = row.map_err(|_| { + rows.map(|row_result| { + let row = row_result?; + let (keyspace_name, type_name, field_names, field_types): ( + String, + String, + Vec, + Vec, + ) = row.into_typed().map_err(|_| { QueryError::ProtocolError("system_schema.types has invalid column type") })?; @@ -702,31 +742,30 @@ async fn query_user_defined_types( .entry(keyspace_name) .or_insert_with(HashMap::new) .insert(type_name, fields); - } + + Ok::<_, QueryError>(()) + }) + .try_for_each(|_| future::ok(())) + .await?; Ok(result) } async fn query_tables( - conn: &Connection, + conn: &Arc, keyspaces_to_fetch: &[String], ) -> Result>, QueryError> { let rows = query_filter_keyspace_name( conn, "SELECT keyspace_name, table_name FROM system_schema.tables", keyspaces_to_fetch, - ) - .await? - .rows - .ok_or(QueryError::ProtocolError( - "system_schema.tables query response was not Rows", - ))?; - - let mut result = HashMap::with_capacity(rows.len()); + ); + let mut result = HashMap::new(); let mut tables = query_tables_schema(conn, keyspaces_to_fetch).await?; - for row in rows.into_typed::<(String, String)>() { - let (keyspace_name, table_name) = row.map_err(|_| { + rows.map(|row_result| { + let row = row_result?; + let (keyspace_name, table_name) = row.into_typed().map_err(|_| { QueryError::ProtocolError("system_schema.tables has invalid column type") })?; @@ -743,31 +782,31 @@ async fn query_tables( .entry(keyspace_and_table_name.0) .or_insert_with(HashMap::new) .insert(keyspace_and_table_name.1, table); - } + + Ok::<_, QueryError>(()) + }) + .try_for_each(|_| future::ok(())) + .await?; Ok(result) } async fn query_views( - conn: &Connection, + conn: &Arc, keyspaces_to_fetch: &[String], ) -> Result>, QueryError> { let rows = query_filter_keyspace_name( conn, "SELECT keyspace_name, view_name, base_table_name FROM system_schema.views", keyspaces_to_fetch, - ) - .await? - .rows - .ok_or(QueryError::ProtocolError( - "system_schema.views query response was not Rows", - ))?; - - let mut result = HashMap::with_capacity(rows.len()); + ); + + let mut result = HashMap::new(); let mut tables = query_tables_schema(conn, keyspaces_to_fetch).await?; - for row in rows.into_typed::<(String, String, String)>() { - let (keyspace_name, view_name, base_table_name) = row.map_err(|_| { + rows.map(|row_result| { + let row = row_result?; + let (keyspace_name, view_name, base_table_name) = row.into_typed().map_err(|_| { QueryError::ProtocolError("system_schema.views has invalid column type") })?; @@ -788,13 +827,17 @@ async fn query_views( .entry(keyspace_and_view_name.0) .or_insert_with(HashMap::new) .insert(keyspace_and_view_name.1, materialized_view); - } + + Ok::<_, QueryError>(()) + }) + .try_for_each(|_| future::ok(())) + .await?; Ok(result) } async fn query_tables_schema( - conn: &Connection, + conn: &Arc, keyspaces_to_fetch: &[String], ) -> Result, QueryError> { // Upon migration from thrift to CQL, Cassandra internally creates a surrogate column "value" of @@ -804,23 +847,25 @@ async fn query_tables_schema( let rows = query_filter_keyspace_name(conn, "select keyspace_name, table_name, column_name, kind, position, type from system_schema.columns", keyspaces_to_fetch - ) - .await? - .rows - .ok_or(QueryError::ProtocolError( - "system_schema.columns query response was not Rows", - ))?; - - let mut tables_schema = HashMap::with_capacity(rows.len()); + ); - for row in rows.into_typed::<(String, String, String, String, i32, String)>() { - let (keyspace_name, table_name, column_name, kind, position, type_) = - row.map_err(|_| { - QueryError::ProtocolError("system_schema.columns has invalid column type") - })?; + let mut tables_schema = HashMap::new(); + + rows.map(|row_result| { + let row = row_result?; + let (keyspace_name, table_name, column_name, kind, position, type_): ( + String, + String, + String, + String, + i32, + String, + ) = row.into_typed().map_err(|_| { + QueryError::ProtocolError("system_schema.columns has invalid column type") + })?; if type_ == THRIFT_EMPTY_TYPE { - continue; + return Ok::<_, QueryError>(()); } let entry = tables_schema.entry((keyspace_name, table_name)).or_insert(( @@ -851,7 +896,11 @@ async fn query_tables_schema( kind, }, ); - } + + Ok::<_, QueryError>(()) + }) + .try_for_each(|_| future::ok(())) + .await?; let mut all_partitioners = query_table_partitioners(conn).await?; let mut result = HashMap::new(); @@ -1005,33 +1054,38 @@ fn freeze_type(type_: CqlType) -> CqlType { } async fn query_table_partitioners( - conn: &Connection, + conn: &Arc, ) -> Result>, QueryError> { let mut partitioner_query = Query::new( "select keyspace_name, table_name, partitioner from system_schema.scylla_tables", ); partitioner_query.set_page_size(1024); - let rows = match conn.query_all(&partitioner_query, &[]).await { + let rows = conn + .clone() + .query_iter(partitioner_query, &[]) + .into_stream() + .try_flatten(); + + let result = rows + .map(|row_result| { + let (keyspace_name, table_name, partitioner) = + row_result?.into_typed().map_err(|_| { + QueryError::ProtocolError("system_schema.tables has invalid column type") + })?; + Ok::<_, QueryError>(((keyspace_name, table_name), partitioner)) + }) + .try_collect::>() + .await; + + match result { // FIXME: This match catches all database errors with this error code despite the fact // that we are only interested in the ones resulting from non-existent table // system_schema.scylla_tables. // For more information please refer to https://github.com/scylladb/scylla-rust-driver/pull/349#discussion_r762050262 - Err(QueryError::DbError(DbError::Invalid, _)) => return Ok(HashMap::new()), - query_result => query_result?.rows.ok_or(QueryError::ProtocolError( - "system_schema.scylla_tables query response was not Rows", - ))?, - }; - - let mut result = HashMap::with_capacity(rows.len()); - - for row in rows.into_typed::<(String, String, Option)>() { - let (keyspace_name, table_name, partitioner) = row.map_err(|_| { - QueryError::ProtocolError("system_schema.tables has invalid column type") - })?; - result.insert((keyspace_name, table_name), partitioner); + Err(QueryError::DbError(DbError::Invalid, _)) => Ok(HashMap::new()), + result => result, } - Ok(result) } fn strategy_from_string_map(