Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Connection::query_all with Connection::query_iter #645

Merged
merged 13 commits into from
Mar 2, 2023
Merged
237 changes: 52 additions & 185 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use std::{
net::{Ipv4Addr, Ipv6Addr},
};

use super::errors::{BadKeyspaceName, BadQuery, DbError, QueryError};
use super::errors::{BadKeyspaceName, DbError, QueryError};
use super::iterator::RowIterator;

use crate::batch::{Batch, BatchStatement};
use crate::frame::protocol_features::ProtocolFeatures;
Expand Down Expand Up @@ -457,98 +458,6 @@ impl Connection {
.await
}

/// Performs query_single_page multiple times to query all available pages
pub async fn query_all(
&self,
query: &Query,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
// This method is used only for driver internal queries, so no need to consult execution profile here.
self.query_all_with_consistency(
query,
values,
query
.config
.determine_consistency(self.config.default_consistency),
query.get_serial_consistency(),
)
.await
}

pub async fn query_all_with_consistency(
&self,
query: &Query,
values: impl ValueList,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<QueryResult, QueryError> {
if query.get_page_size().is_none() {
// Page size should be set when someone wants to use paging
return Err(QueryError::BadQuery(BadQuery::Other(
"Called Connection::query_all without page size set!".to_string(),
)));
}

let mut final_result = QueryResult::default();

let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;

loop {
// Send next paged query
let mut cur_result: QueryResult = self
.query_with_consistency(
query,
&serialized_values,
consistency,
serial_consistency,
paging_state,
)
.await?
.into_query_result()?;

// Set paging_state for the next query
paging_state = cur_result.paging_state.take();

// Add current query results to the final_result
final_result.merge_with_next_page_res(cur_result);

if paging_state.is_none() {
// No more pages to query, we can return the final result
return Ok(final_result);
}
}
}

pub async fn execute_single_page(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResult, QueryError> {
self.execute(prepared_statement, values, paging_state)
.await?
.into_query_result()
}

pub async fn execute(
&self,
prepared_statement: &PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
self.execute_with_consistency(
prepared_statement,
values,
prepared_statement
.config
.determine_consistency(self.config.default_consistency),
prepared_statement.config.serial_consistency.flatten(),
paging_state,
)
.await
}

pub async fn execute_with_consistency(
&self,
prepared_statement: &PreparedStatement,
Expand Down Expand Up @@ -591,41 +500,28 @@ impl Connection {
}
}

/// Performs execute_single_page multiple times to fetch all available pages
#[allow(dead_code)]
pub async fn execute_all(
&self,
prepared_statement: &PreparedStatement,
/// Executes a query and fetches its results over multiple pages, using
/// the asynchronous iterator interface.
pub(crate) async fn query_iter(
self: Arc<Self>,
query: Query,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
if prepared_statement.get_page_size().is_none() {
return Err(QueryError::BadQuery(BadQuery::Other(
"Called Connection::execute_all without page size set!".to_string(),
)));
}

let mut final_result = QueryResult::default();

let serialized_values = values.serialized()?;
let mut paging_state: Option<Bytes> = None;

loop {
// Send next paged query
let mut cur_result: QueryResult = self
.execute_single_page(prepared_statement, &serialized_values, paging_state)
.await?;

// Set paging_state for the next query
paging_state = cur_result.paging_state.take();
) -> Result<RowIterator, QueryError> {
let serialized_values = values.serialized()?.into_owned();

// Add current query results to the final_result
final_result.merge_with_next_page_res(cur_result);
let consistency = query
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = query.config.serial_consistency.flatten();

if paging_state.is_none() {
// No more pages to query, we can return the final result
return Ok(final_result);
}
}
RowIterator::new_for_connection_query_iter(
query,
self,
serialized_values,
consistency,
serial_consistency,
)
.await
}

#[allow(dead_code)]
Expand Down Expand Up @@ -1559,7 +1455,6 @@ impl VerifiedKeyspaceName {

#[cfg(test)]
mod tests {
use scylla_cql::errors::BadQuery;
use scylla_cql::frame::protocol_features::{
LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION,
};
Expand All @@ -1572,12 +1467,12 @@ mod tests {
use tokio::select;
use tokio::sync::mpsc;

use super::super::errors::QueryError;
use super::ConnectionConfig;
use crate::query::Query;
use crate::transport::connection::open_connection;
use crate::utils::test_utils::unique_keyspace_name;
use crate::{IntoTypedRows, SessionBuilder};
use crate::SessionBuilder;
use futures::{StreamExt, TryStreamExt};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
Expand All @@ -1596,20 +1491,20 @@ mod tests {
}
}

/// Tests for Connection::query_all and Connection::execute_all
/// Tests for Connection::query_iter
/// 1. SELECT from an empty table.
/// 2. Create table and insert ints 0..100.
/// Then use query_all and execute_all with page_size set to 7 to select all 100 rows.
/// 3. INSERT query_all should have None in result rows.
/// 4. Calling query_all with a Query that doesn't have page_size set should result in an error.
/// Then use query_iter with page_size set to 7 to select all 100 rows.
/// 3. INSERT query_iter should work and not return any rows.
#[tokio::test]
async fn connection_query_all_execute_all_test() {
async fn connection_query_iter_test() {
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
let addr: SocketAddr = resolve_hostname(&uri).await;

let (connection, _) = super::open_connection(addr, None, ConnectionConfig::default())
.await
.unwrap();
let connection = Arc::new(connection);

let ks = unique_keyspace_name();

Expand All @@ -1623,12 +1518,12 @@ mod tests {
session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'SimpleStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap();
session.use_keyspace(ks.clone(), false).await.unwrap();
session
.query("DROP TABLE IF EXISTS connection_query_all_tab", &[])
.query("DROP TABLE IF EXISTS connection_query_iter_tab", &[])
.await
.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS connection_query_all_tab (p int primary key)",
"CREATE TABLE IF NOT EXISTS connection_query_iter_tab (p int primary key)",
&[],
)
.await
Expand All @@ -1641,77 +1536,49 @@ mod tests {
.unwrap();

// 1. SELECT from an empty table returns query result where rows are Some(Vec::new())
let select_query = Query::new("SELECT p FROM connection_query_all_tab").with_page_size(7);
let empty_res = connection.query_all(&select_query, &[]).await.unwrap();
assert!(empty_res.rows.unwrap().is_empty());

let mut prepared_select = connection.prepare(&select_query).await.unwrap();
prepared_select.set_page_size(7);
let empty_res_prepared = connection.execute_all(&prepared_select, &[]).await.unwrap();
assert!(empty_res_prepared.rows.unwrap().is_empty());
let select_query = Query::new("SELECT p FROM connection_query_iter_tab").with_page_size(7);
let empty_res = connection
.clone()
.query_iter(select_query.clone(), &[])
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(empty_res.is_empty());

// 2. Insert 100 and select using query_all with page_size 7
// 2. Insert 100 and select using query_iter with page_size 7
let values: Vec<i32> = (0..100).collect();
let mut insert_futures = Vec::new();
let insert_query =
Query::new("INSERT INTO connection_query_all_tab (p) VALUES (?)").with_page_size(7);
Query::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)").with_page_size(7);
for v in &values {
insert_futures.push(connection.query_single_page(insert_query.clone(), (v,)));
}

futures::future::try_join_all(insert_futures).await.unwrap();

let mut results: Vec<i32> = connection
.query_all(&select_query, &[])
.clone()
.query_iter(select_query.clone(), &[])
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
.map(|ret| ret.unwrap().0)
.collect::<Vec<_>>()
.await;
results.sort_unstable(); // Clippy recommended to use sort_unstable instead of sort()
assert_eq!(results, values);

let mut results2: Vec<i32> = connection
.execute_all(&prepared_select, &[])
// 3. INSERT query_iter should work and not return any rows.
let insert_res1 = connection
.query_iter(insert_query, (0,))
.await
.unwrap()
.rows
.unwrap()
.into_typed::<(i32,)>()
.map(|r| r.unwrap().0)
.collect();
results2.sort_unstable();
assert_eq!(results2, values);

// 3. INSERT query_all should have None in result rows.
let insert_res1 = connection.query_all(&insert_query, (0,)).await.unwrap();
assert!(insert_res1.rows.is_none());

let prepared_insert = connection.prepare(&insert_query).await.unwrap();
let insert_res2 = connection
.execute_all(&prepared_insert, (0,))
.try_collect::<Vec<_>>()
.await
.unwrap();
assert!(insert_res2.rows.is_none(),);

// 4. Calling query_all with a Query that doesn't have page_size set should result in an error.
let no_page_size_query = Query::new("SELECT p FROM connection_query_all_tab");
let no_page_res = connection.query_all(&no_page_size_query, &[]).await;
assert!(matches!(
no_page_res,
Err(QueryError::BadQuery(BadQuery::Other(_)))
));

let prepared_no_page_size_query = connection.prepare(&no_page_size_query).await.unwrap();
let prepared_no_page_res = connection
.execute_all(&prepared_no_page_size_query, &[])
.await;
assert!(matches!(
prepared_no_page_res,
Err(QueryError::BadQuery(BadQuery::Other(_)))
));
assert!(insert_res1.is_empty());
}

#[tokio::test]
Expand Down
Loading