diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs index c8c4ec104d..ac803ee49d 100644 --- a/scylla-cql/src/frame/response/mod.rs +++ b/scylla-cql/src/frame/response/mod.rs @@ -12,6 +12,8 @@ use crate::frame::protocol_features::ProtocolFeatures; pub use error::Error; pub use supported::Supported; +use self::result::ResultMetadata; + #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] #[repr(u8)] pub enum ResponseOpcode { @@ -42,6 +44,7 @@ impl Response { features: &ProtocolFeatures, opcode: ResponseOpcode, buf: &mut &[u8], + cached_metadata: Option<&ResultMetadata>, ) -> Result { let response = match opcode { ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?), @@ -50,7 +53,7 @@ impl Response { Response::Authenticate(authenticate::Authenticate::deserialize(buf)?) } ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?), - ResponseOpcode::Result => Response::Result(result::deserialize(buf)?), + ResponseOpcode::Result => Response::Result(result::deserialize(buf, cached_metadata)?), ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?), ResponseOpcode::AuthChallenge => { Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index ab5a588fba..95cc4bea17 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -886,17 +886,29 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult StdResult { - let metadata = deser_result_metadata(buf)?; +fn deser_rows( + buf: &mut &[u8], + cached_metadata: Option<&ResultMetadata>, +) -> StdResult { + let server_metadata = deser_result_metadata(buf)?; + + let metadata = match cached_metadata { + Some(metadata) => metadata.clone(), + None => { + // No cached_metadata provided. Server is supposed to provide the result metadata. + if server_metadata.col_count != server_metadata.col_specs.len() { + return Err(ParseError::BadIncomingData(format!( + "Bad result metadata provided in the response. Expected {} column specifications, received: {}", + server_metadata.col_count, + server_metadata.col_specs.len() + ))); + } + server_metadata + } + }; let original_size = buf.len(); - // TODO: the protocol allows an optimization (which must be explicitly requested on query by - // the driver) where the column metadata is not sent with the result. - // Implement this optimization. We'll then need to take the column types by a parameter. - // Beware of races; our column types may be outdated. - assert!(metadata.col_count == metadata.col_specs.len()); - let rows_count: usize = types::read_int(buf)?.try_into()?; let mut rows = Vec::with_capacity(rows_count); @@ -946,11 +958,14 @@ fn deser_schema_change(buf: &mut &[u8]) -> StdResult { }) } -pub fn deserialize(buf: &mut &[u8]) -> StdResult { +pub fn deserialize( + buf: &mut &[u8], + cached_metadata: Option<&ResultMetadata>, +) -> StdResult { use self::Result::*; Ok(match types::read_int(buf)? { 0x0001 => Void, - 0x0002 => Rows(deser_rows(buf)?), + 0x0002 => Rows(deser_rows(buf, cached_metadata)?), 0x0003 => SetKeyspace(deser_set_keyspace(buf)?), 0x0004 => Prepared(deser_prepared(buf)?), 0x0005 => SchemaChange(deser_schema_change(buf)?), diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index edaa1aad8e..86f039d1f9 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -2,6 +2,7 @@ use bytes::Bytes; use futures::{future::RemoteHandle, FutureExt}; use scylla_cql::errors::TranslationError; use scylla_cql::frame::request::options::Options; +use scylla_cql::frame::response::result::ResultMetadata; use scylla_cql::frame::response::Error; use scylla_cql::frame::types::SerialConsistency; use scylla_cql::types::serialize::batch::{BatchValues, BatchValuesIterator}; @@ -521,14 +522,14 @@ impl Connection { options: HashMap, ) -> Result { Ok(self - .send_request(&request::Startup { options }, false, false) + .send_request(&request::Startup { options }, false, false, None) .await? .response) } pub(crate) async fn get_options(&self) -> Result { Ok(self - .send_request(&request::Options {}, false, false) + .send_request(&request::Options {}, false, false, None) .await? .response) } @@ -541,6 +542,7 @@ impl Connection { }, true, query.config.tracing, + None, ) .await?; @@ -592,7 +594,7 @@ impl Connection { &self, response: Option>, ) -> Result { - self.send_request(&request::AuthResponse { response }, false, false) + self.send_request(&request::AuthResponse { response }, false, false, None) .await } @@ -661,7 +663,7 @@ impl Connection { }, }; - self.send_request(&query_frame, true, query.config.tracing) + self.send_request(&query_frame, true, query.config.tracing, None) .await } @@ -706,8 +708,17 @@ impl Connection { }, }; + let cached_metadata = prepared_statement + .get_skip_result_metadata() + .then(|| prepared_statement.get_result_metadata()); + let query_response = self - .send_request(&execute_frame, true, prepared_statement.config.tracing) + .send_request( + &execute_frame, + true, + prepared_statement.config.tracing, + cached_metadata, + ) .await?; match &query_response.response { @@ -719,8 +730,13 @@ impl Connection { // Repreparation of a statement is needed self.reprepare(prepared_statement.get_statement(), prepared_statement) .await?; - self.send_request(&execute_frame, true, prepared_statement.config.tracing) - .await + self.send_request( + &execute_frame, + true, + prepared_statement.config.tracing, + cached_metadata, + ) + .await } _ => Ok(query_response), } @@ -809,7 +825,7 @@ impl Connection { loop { let query_response = self - .send_request(&batch_frame, true, batch.config.tracing) + .send_request(&batch_frame, true, batch.config.tracing, None) .await?; return match query_response.response { @@ -931,7 +947,7 @@ impl Connection { }; match self - .send_request(®ister_frame, true, false) + .send_request(®ister_frame, true, false, None) .await? .response { @@ -961,6 +977,7 @@ impl Connection { request: &impl SerializableRequest, compress: bool, tracing: bool, + cached_metadata: Option<&ResultMetadata>, ) -> Result { let compression = if compress { self.config.compression @@ -977,6 +994,7 @@ impl Connection { task_response, self.config.compression, &self.features.protocol_features, + cached_metadata, ) } @@ -984,6 +1002,7 @@ impl Connection { task_response: TaskResponse, compression: Option, features: &ProtocolFeatures, + cached_metadata: Option<&ResultMetadata>, ) -> Result { let body_with_ext = frame::parse_response_body_extensions( task_response.params.flags, @@ -998,8 +1017,12 @@ impl Connection { ); } - let response = - Response::deserialize(features, task_response.opcode, &mut &*body_with_ext.body)?; + let response = Response::deserialize( + features, + task_response.opcode, + &mut &*body_with_ext.body, + cached_metadata, + )?; Ok(QueryResponse { response, @@ -1353,7 +1376,7 @@ impl Connection { // future implementers. let features = ProtocolFeatures::default(); // TODO: Use the right features - let response = Self::parse_response(task_response, compression, &features)?.response; + let response = Self::parse_response(task_response, compression, &features, None)?.response; let event = match response { Response::Event(e) => e, _ => {