diff --git a/scylla-cql/benches/benchmark.rs b/scylla-cql/benches/benchmark.rs index 2ab15f5051..ec0a26213b 100644 --- a/scylla-cql/benches/benchmark.rs +++ b/scylla-cql/benches/benchmark.rs @@ -14,6 +14,7 @@ fn make_query(contents: &str, values: SerializedValues) -> query::Query<'_> { consistency: scylla_cql::Consistency::LocalQuorum, serial_consistency: None, values: Cow::Owned(values), + skip_metadata: false, page_size: None, paging_state: None, timestamp: None, diff --git a/scylla-cql/src/frame/request/mod.rs b/scylla-cql/src/frame/request/mod.rs index 37549513e1..1a5d3511f2 100644 --- a/scylla-cql/src/frame/request/mod.rs +++ b/scylla-cql/src/frame/request/mod.rs @@ -150,6 +150,7 @@ mod tests { timestamp: None, page_size: Some(323), paging_state: Some(vec![2, 1, 3, 7].into()), + skip_metadata: false, values: { let mut vals = SerializedValues::new(); vals.add_value(&2137, &ColumnType::Int).unwrap(); @@ -177,6 +178,7 @@ mod tests { timestamp: Some(3423434), page_size: None, paging_state: None, + skip_metadata: false, values: { let mut vals = SerializedValues::new(); vals.add_value(&42, &ColumnType::Int).unwrap(); @@ -234,6 +236,7 @@ mod tests { timestamp: None, page_size: None, paging_state: None, + skip_metadata: false, values: Cow::Borrowed(SerializedValues::EMPTY), }; let query = Query { diff --git a/scylla-cql/src/frame/request/query.rs b/scylla-cql/src/frame/request/query.rs index 164118f081..31a281f512 100644 --- a/scylla-cql/src/frame/request/query.rs +++ b/scylla-cql/src/frame/request/query.rs @@ -63,6 +63,7 @@ pub struct QueryParameters<'a> { pub timestamp: Option, pub page_size: Option, pub paging_state: Option, + pub skip_metadata: bool, pub values: Cow<'a, SerializedValues>, } @@ -74,6 +75,7 @@ impl Default for QueryParameters<'_> { timestamp: None, page_size: None, paging_state: None, + skip_metadata: false, values: Cow::Borrowed(SerializedValues::EMPTY), } } @@ -88,6 +90,10 @@ impl QueryParameters<'_> { flags |= FLAG_VALUES; } + if self.skip_metadata { + flags |= FLAG_SKIP_METADATA; + } + if self.page_size.is_some() { flags |= FLAG_PAGE_SIZE; } @@ -143,6 +149,7 @@ impl<'q> QueryParameters<'q> { ))); } let values_flag = (flags & FLAG_VALUES) != 0; + let skip_metadata = (flags & FLAG_SKIP_METADATA) != 0; let page_size_flag = (flags & FLAG_PAGE_SIZE) != 0; let paging_state_flag = (flags & FLAG_WITH_PAGING_STATE) != 0; let serial_consistency_flag = (flags & FLAG_WITH_SERIAL_CONSISTENCY) != 0; @@ -192,6 +199,7 @@ impl<'q> QueryParameters<'q> { timestamp, page_size, paging_state, + skip_metadata, values, }) } diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs index 5acec7f34b..569a14f401 100644 --- a/scylla-cql/src/frame/response/mod.rs +++ b/scylla-cql/src/frame/response/mod.rs @@ -5,13 +5,13 @@ pub mod event; pub mod result; pub mod supported; -use crate::{errors::QueryError, frame::frame_errors::ParseError}; - -use crate::frame::protocol_features::ProtocolFeatures; pub use error::Error; pub use supported::Supported; -use super::TryFromPrimitiveError; +use crate::frame::protocol_features::ProtocolFeatures; +use crate::frame::response::result::ResultMetadata; +use crate::frame::TryFromPrimitiveError; +use crate::{errors::QueryError, frame::frame_errors::ParseError}; #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[repr(u8)] @@ -64,6 +64,7 @@ impl Response { features: &ProtocolFeatures, opcode: ResponseOpcode, buf: &mut &[u8], + cached_metadata: Option<&ResultMetadata>, ) -> Result { let response = match opcode { ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?), @@ -72,7 +73,7 @@ impl Response { Response::Authenticate(authenticate::Authenticate::deserialize(buf)?) } ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?), - ResponseOpcode::Result => Response::Result(result::deserialize(buf)?), + ResponseOpcode::Result => Response::Result(result::deserialize(buf, cached_metadata)?), ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?), ResponseOpcode::AuthChallenge => { Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 745b925c00..95cc4bea17 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -386,7 +386,7 @@ pub struct ColumnSpec { pub typ: ColumnType, } -#[derive(Debug, Default)] +#[derive(Debug, Clone, Default)] pub struct ResultMetadata { col_count: usize, pub paging_state: Option, @@ -886,17 +886,29 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult StdResult { - let metadata = deser_result_metadata(buf)?; +fn deser_rows( + buf: &mut &[u8], + cached_metadata: Option<&ResultMetadata>, +) -> StdResult { + let server_metadata = deser_result_metadata(buf)?; + + let metadata = match cached_metadata { + Some(metadata) => metadata.clone(), + None => { + // No cached_metadata provided. Server is supposed to provide the result metadata. + if server_metadata.col_count != server_metadata.col_specs.len() { + return Err(ParseError::BadIncomingData(format!( + "Bad result metadata provided in the response. Expected {} column specifications, received: {}", + server_metadata.col_count, + server_metadata.col_specs.len() + ))); + } + server_metadata + } + }; let original_size = buf.len(); - // TODO: the protocol allows an optimization (which must be explicitly requested on query by - // the driver) where the column metadata is not sent with the result. - // Implement this optimization. We'll then need to take the column types by a parameter. - // Beware of races; our column types may be outdated. - assert!(metadata.col_count == metadata.col_specs.len()); - let rows_count: usize = types::read_int(buf)?.try_into()?; let mut rows = Vec::with_capacity(rows_count); @@ -946,11 +958,14 @@ fn deser_schema_change(buf: &mut &[u8]) -> StdResult { }) } -pub fn deserialize(buf: &mut &[u8]) -> StdResult { +pub fn deserialize( + buf: &mut &[u8], + cached_metadata: Option<&ResultMetadata>, +) -> StdResult { use self::Result::*; Ok(match types::read_int(buf)? { 0x0001 => Void, - 0x0002 => Rows(deser_rows(buf)?), + 0x0002 => Rows(deser_rows(buf, cached_metadata)?), 0x0003 => SetKeyspace(deser_set_keyspace(buf)?), 0x0004 => Prepared(deser_prepared(buf)?), 0x0005 => SchemaChange(deser_schema_change(buf)?), diff --git a/scylla/src/statement/mod.rs b/scylla/src/statement/mod.rs index a8d034615d..642ea06ad3 100644 --- a/scylla/src/statement/mod.rs +++ b/scylla/src/statement/mod.rs @@ -16,6 +16,7 @@ pub(crate) struct StatementConfig { pub(crate) is_idempotent: bool, + pub(crate) skip_result_metadata: bool, pub(crate) tracing: bool, pub(crate) timestamp: Option, pub(crate) request_timeout: Option, diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index a3cd155e7c..ceb7712cb7 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -10,7 +10,7 @@ use std::time::Duration; use thiserror::Error; use uuid::Uuid; -use scylla_cql::frame::response::result::ColumnSpec; +use scylla_cql::frame::response::result::{ColumnSpec, PartitionKeyIndex, ResultMetadata}; use super::StatementConfig; use crate::frame::response::result::PreparedMetadata; @@ -37,6 +37,7 @@ pub struct PreparedStatement { #[derive(Debug)] struct PreparedStatementSharedData { metadata: PreparedMetadata, + result_metadata: ResultMetadata, statement: String, } @@ -59,6 +60,7 @@ impl PreparedStatement { id: Bytes, is_lwt: bool, metadata: PreparedMetadata, + result_metadata: ResultMetadata, statement: String, page_size: Option, config: StatementConfig, @@ -67,6 +69,7 @@ impl PreparedStatement { id, shared: Arc::new(PreparedStatementSharedData { metadata, + result_metadata, statement, }), prepare_tracing_ids: Vec::new(), @@ -270,6 +273,27 @@ impl PreparedStatement { self.config.tracing } + /// Make use of cached metadata to decode results + /// of the statement's execution. + /// + /// If true, the driver will request the server not to + /// attach the result metadata in response to the statement execution. + /// + /// The driver will cache the result metadata received from the server + /// after statement preparation and will use it + /// to deserialize the results of statement execution. + /// + /// This option is false by default. + pub fn set_use_cached_result_metadata(&mut self, use_cached_metadata: bool) { + self.config.skip_result_metadata = use_cached_metadata; + } + + /// Gets the information whether the driver uses cached metadata + /// to decode the results of the statement's execution. + pub fn get_use_cached_result_metadata(&self) -> bool { + self.config.skip_result_metadata + } + /// Sets the default timestamp for this statement in microseconds. /// If not None, it will replace the server side assigned timestamp as default timestamp /// If a statement contains a `USING TIMESTAMP` clause, calling this method won't change @@ -301,11 +325,31 @@ impl PreparedStatement { self.partitioner_name = partitioner_name; } - /// Access metadata about this prepared statement as returned by the database - pub fn get_prepared_metadata(&self) -> &PreparedMetadata { + /// Access metadata about the bind variables of this statement as returned by the database + pub(crate) fn get_prepared_metadata(&self) -> &PreparedMetadata { &self.shared.metadata } + /// Access column specifications of the bind variables of this statement + pub fn get_variable_col_specs(&self) -> &[ColumnSpec] { + &self.shared.metadata.col_specs + } + + /// Access info about partition key indexes of the bind variables of this statement + pub fn get_variable_pk_indexes(&self) -> &[PartitionKeyIndex] { + &self.shared.metadata.pk_indexes + } + + /// Access metadata about the result of prepared statement returned by the database + pub(crate) fn get_result_metadata(&self) -> &ResultMetadata { + &self.shared.result_metadata + } + + /// Access column specifications of the result set returned after the execution of this statement + pub fn get_result_set_col_specs(&self) -> &[ColumnSpec] { + &self.shared.result_metadata.col_specs + } + /// Get the name of the partitioner used for this statement. pub(crate) fn get_partitioner_name(&self) -> &PartitionerName { &self.partitioner_name diff --git a/scylla/src/transport/caching_session.rs b/scylla/src/transport/caching_session.rs index f3d0d4db88..464286602b 100644 --- a/scylla/src/transport/caching_session.rs +++ b/scylla/src/transport/caching_session.rs @@ -8,7 +8,7 @@ use crate::{QueryResult, Session}; use bytes::Bytes; use dashmap::DashMap; use futures::future::try_join_all; -use scylla_cql::frame::response::result::PreparedMetadata; +use scylla_cql::frame::response::result::{PreparedMetadata, ResultMetadata}; use scylla_cql::types::serialize::batch::BatchValues; use scylla_cql::types::serialize::row::SerializeRow; use std::collections::hash_map::RandomState; @@ -23,6 +23,7 @@ struct RawPreparedStatementData { id: Bytes, is_confirmed_lwt: bool, metadata: PreparedMetadata, + result_metadata: ResultMetadata, partitioner_name: PartitionerName, } @@ -168,6 +169,7 @@ where raw.id.clone(), raw.is_confirmed_lwt, raw.metadata.clone(), + raw.result_metadata.clone(), query.contents, page_size, query.config, @@ -195,6 +197,7 @@ where id: prepared.get_id().clone(), is_confirmed_lwt: prepared.is_confirmed_lwt(), metadata: prepared.get_prepared_metadata().clone(), + result_metadata: prepared.get_result_metadata().clone(), partitioner_name: prepared.get_partitioner_name().clone(), }; self.cache.insert(query_contents, raw); diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index b6b91b69db..6d8af2aeb5 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -2,6 +2,7 @@ use bytes::Bytes; use futures::{future::RemoteHandle, FutureExt}; use scylla_cql::errors::TranslationError; use scylla_cql::frame::request::options::Options; +use scylla_cql::frame::response::result::ResultMetadata; use scylla_cql::frame::response::Error; use scylla_cql::frame::types::SerialConsistency; use scylla_cql::types::serialize::batch::{BatchValues, BatchValuesIterator}; @@ -521,14 +522,14 @@ impl Connection { options: HashMap, ) -> Result { Ok(self - .send_request(&request::Startup { options }, false, false) + .send_request(&request::Startup { options }, false, false, None) .await? .response) } pub(crate) async fn get_options(&self) -> Result { Ok(self - .send_request(&request::Options {}, false, false) + .send_request(&request::Options {}, false, false, None) .await? .response) } @@ -541,6 +542,7 @@ impl Connection { }, true, query.config.tracing, + None, ) .await?; @@ -552,6 +554,7 @@ impl Connection { .protocol_features .prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32), p.prepared_metadata, + p.result_metadata, query.contents.clone(), query.get_page_size(), query.config.clone(), @@ -591,7 +594,7 @@ impl Connection { &self, response: Option>, ) -> Result { - self.send_request(&request::AuthResponse { response }, false, false) + self.send_request(&request::AuthResponse { response }, false, false, None) .await } @@ -655,11 +658,12 @@ impl Connection { values: Cow::Borrowed(SerializedValues::EMPTY), page_size: query.get_page_size(), paging_state, + skip_metadata: false, timestamp: query.get_timestamp(), }, }; - self.send_request(&query_frame, true, query.config.tracing) + self.send_request(&query_frame, true, query.config.tracing, None) .await } @@ -699,12 +703,22 @@ impl Connection { values: Cow::Borrowed(values), page_size: prepared_statement.get_page_size(), timestamp: prepared_statement.get_timestamp(), + skip_metadata: prepared_statement.get_use_cached_result_metadata(), paging_state, }, }; + let cached_metadata = prepared_statement + .get_use_cached_result_metadata() + .then(|| prepared_statement.get_result_metadata()); + let query_response = self - .send_request(&execute_frame, true, prepared_statement.config.tracing) + .send_request( + &execute_frame, + true, + prepared_statement.config.tracing, + cached_metadata, + ) .await?; match &query_response.response { @@ -716,8 +730,13 @@ impl Connection { // Repreparation of a statement is needed self.reprepare(prepared_statement.get_statement(), prepared_statement) .await?; - self.send_request(&execute_frame, true, prepared_statement.config.tracing) - .await + self.send_request( + &execute_frame, + true, + prepared_statement.config.tracing, + cached_metadata, + ) + .await } _ => Ok(query_response), } @@ -806,7 +825,7 @@ impl Connection { loop { let query_response = self - .send_request(&batch_frame, true, batch.config.tracing) + .send_request(&batch_frame, true, batch.config.tracing, None) .await?; return match query_response.response { @@ -928,7 +947,7 @@ impl Connection { }; match self - .send_request(®ister_frame, true, false) + .send_request(®ister_frame, true, false, None) .await? .response { @@ -958,6 +977,7 @@ impl Connection { request: &impl SerializableRequest, compress: bool, tracing: bool, + cached_metadata: Option<&ResultMetadata>, ) -> Result { let compression = if compress { self.config.compression @@ -974,6 +994,7 @@ impl Connection { task_response, self.config.compression, &self.features.protocol_features, + cached_metadata, ) } @@ -981,6 +1002,7 @@ impl Connection { task_response: TaskResponse, compression: Option, features: &ProtocolFeatures, + cached_metadata: Option<&ResultMetadata>, ) -> Result { let body_with_ext = frame::parse_response_body_extensions( task_response.params.flags, @@ -995,8 +1017,12 @@ impl Connection { ); } - let response = - Response::deserialize(features, task_response.opcode, &mut &*body_with_ext.body)?; + let response = Response::deserialize( + features, + task_response.opcode, + &mut &*body_with_ext.body, + cached_metadata, + )?; Ok(QueryResponse { response, @@ -1350,7 +1376,7 @@ impl Connection { // future implementers. let features = ProtocolFeatures::default(); // TODO: Use the right features - let response = Self::parse_response(task_response, compression, &features)?.response; + let response = Self::parse_response(task_response, compression, &features, None)?.response; let event = match response { Response::Event(e) => e, _ => { diff --git a/scylla/tests/integration/main.rs b/scylla/tests/integration/main.rs index 4b8920309a..7f09ae2c5a 100644 --- a/scylla/tests/integration/main.rs +++ b/scylla/tests/integration/main.rs @@ -6,4 +6,5 @@ mod new_session; mod retries; mod shards; mod silent_prepare_query; +mod skip_metadata_optimization; pub(crate) mod utils; diff --git a/scylla/tests/integration/skip_metadata_optimization.rs b/scylla/tests/integration/skip_metadata_optimization.rs new file mode 100644 index 0000000000..9523a88136 --- /dev/null +++ b/scylla/tests/integration/skip_metadata_optimization.rs @@ -0,0 +1,89 @@ +use crate::utils::test_with_3_node_cluster; +use scylla::transport::session::Session; +use scylla::SessionBuilder; +use scylla::{prepared_statement::PreparedStatement, test_utils::unique_keyspace_name}; +use scylla_cql::frame::types; +use scylla_proxy::{ + Condition, ProxyError, Reaction, ResponseFrame, ResponseReaction, ShardAwareness, TargetShard, + WorkerError, +}; +use std::sync::Arc; + +#[tokio::test] +#[ntest::timeout(20000)] +#[cfg(not(scylla_cloud_tests))] +async fn test_skip_result_metadata() { + use scylla_proxy::{ResponseOpcode, ResponseRule}; + + const NO_METADATA_FLAG: i32 = 0x0004; + + let res = test_with_3_node_cluster(ShardAwareness::QueryNode, |proxy_uris, translation_map, mut running_proxy| async move { + // DB preparation phase + let session: Session = SessionBuilder::new() + .known_node(proxy_uris[0].as_str()) + .address_translator(Arc::new(translation_map)) + .build() + .await + .unwrap(); + + let ks = unique_keyspace_name(); + session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3}}", ks), &[]).await.unwrap(); + session.use_keyspace(ks, false).await.unwrap(); + session + .query("CREATE TABLE t (a int primary key, b int, c text)", &[]) + .await + .unwrap(); + session.query("INSERT INTO t (a, b, c) VALUES (1, 2, 'foo_filter_data')", &[]).await.unwrap(); + + let mut prepared = session.prepare("SELECT a, b, c FROM t").await.unwrap(); + + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + + // We inserted this string to filter responses + let body_rows = b"foo_filter_data"; + for node in running_proxy.running_nodes.iter_mut() { + let rule = ResponseRule( + Condition::ResponseOpcode(ResponseOpcode::Result).and(Condition::BodyContainsCaseSensitive(Box::new(*body_rows))), + ResponseReaction::noop().with_feedback_when_performed(tx.clone()) + ); + node.change_response_rules(Some(vec![rule])); + } + + async fn test_with_flags_predicate( + session: &Session, + prepared: &PreparedStatement, + rx: &mut tokio::sync::mpsc::UnboundedReceiver<(ResponseFrame, Option)>, + predicate: impl FnOnce(i32) -> bool + ) { + session.execute(prepared, &[]).await.unwrap(); + + let (frame, _shard) = rx.recv().await.unwrap(); + let mut buf = &*frame.body; + + // FIXME: make use of scylla_cql::frame utilities, instead of deserializing frame manually. + // This will probably be possible once https://github.com/scylladb/scylla-rust-driver/issues/462 is fixed. + match types::read_int(&mut buf).unwrap() { + 0x0002 => (), + _ => panic!("Invalid result type"), + } + let result_metadata_flags = types::read_int(&mut buf).unwrap(); + assert!(predicate(result_metadata_flags)); + } + + // Verify that server sends metadata when driver doesn't send SKIP_METADATA flag. + prepared.set_use_cached_result_metadata(false); + test_with_flags_predicate(&session, &prepared, &mut rx, |flags| flags & NO_METADATA_FLAG == 0).await; + + // Verify that server doesn't send metadata when driver sends SKIP_METADATA flag. + prepared.set_use_cached_result_metadata(true); + test_with_flags_predicate(&session, &prepared, &mut rx, |flags| flags & NO_METADATA_FLAG != 0).await; + + running_proxy + }).await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +}