Skip to content

Commit

Permalink
Fix/batch length (scylladb#824)
Browse files Browse the repository at this point in the history
* changed the type of the maximum number of statements in a batch query from an i16 to a u16 according to the CQL protocol spec

* add a guard when doing batch statements to prevent making calls to the server when the number of batch queries is greater than u16::MAX, as well as adding some tests
  • Loading branch information
samuelorji authored Oct 17, 2023
1 parent d931716 commit 82c1c99
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cassandra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: cargo build --verbose --tests
- name: Run tests on cassandra
run: |
CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info
CDC='disabled' SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose -- --skip test_views_in_schema_info --skip test_large_batch_statements
- name: Stop the cluster
if: ${{ always() }}
run: docker compose -f test/cluster/cassandra/docker-compose.yml stop
Expand Down
4 changes: 4 additions & 0 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ pub enum BadQuery {
#[error("Passed invalid keyspace name to use: {0}")]
BadKeyspaceName(#[from] BadKeyspaceName),

/// Too many queries in the batch statement
#[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")]
TooManyQueriesInBatchStatement(usize),

/// Other reasons of bad query
#[error("{0}")]
Other(String),
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum ParseError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("type not yet implemented, id: {0}")]
TypeNotImplemented(i16),
TypeNotImplemented(u16),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
#[error(transparent)]
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/request/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedV
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
let batch_type = buf.get_u8().try_into()?;

let statements_count: usize = types::read_short(buf)?.try_into()?;
let statements_count: usize = types::read_short(buf)?.into();
let statements_with_values = (0..statements_count)
.map(|_| {
let batch_statement = BatchStatement::deserialize(buf)?;
Expand Down
4 changes: 2 additions & 2 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult<ColumnType, ParseError> {
0x0030 => {
let keyspace_name: String = types::read_string(buf)?.to_string();
let type_name: String = types::read_string(buf)?.to_string();
let fields_size: usize = types::read_short(buf)?.try_into()?;
let fields_size: usize = types::read_short(buf)?.into();

let mut field_types: Vec<(String, ColumnType)> = Vec::with_capacity(fields_size);

Expand All @@ -455,7 +455,7 @@ fn deser_type(buf: &mut &[u8]) -> StdResult<ColumnType, ParseError> {
}
}
0x0031 => {
let len: usize = types::read_short(buf)?.try_into()?;
let len: usize = types::read_short(buf)?.into();
let mut types = Vec::with_capacity(len);
for _ in 0..len {
types.push(deser_type(buf)?);
Expand Down
20 changes: 10 additions & 10 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use uuid::Uuid;
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))]
#[repr(i16)]
#[repr(u16)]
pub enum Consistency {
Any = 0x0000,
One = 0x0001,
Expand Down Expand Up @@ -169,30 +169,30 @@ fn type_long() {
}
}

pub fn read_short(buf: &mut &[u8]) -> Result<i16, ParseError> {
let v = buf.read_i16::<BigEndian>()?;
pub fn read_short(buf: &mut &[u8]) -> Result<u16, ParseError> {
let v = buf.read_u16::<BigEndian>()?;
Ok(v)
}

pub fn write_short(v: i16, buf: &mut impl BufMut) {
buf.put_i16(v);
pub fn write_short(v: u16, buf: &mut impl BufMut) {
buf.put_u16(v);
}

pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result<usize, ParseError> {
let v = read_short(buf)?;
let v: usize = v.try_into()?;
let v: usize = v.into();
Ok(v)
}

fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> {
let v: i16 = v.try_into()?;
let v: u16 = v.try_into()?;
write_short(v, buf);
Ok(())
}

#[test]
fn type_short() {
let vals = [i16::MIN, -1, 0, 1, i16::MAX];
let vals: [u16; 3] = [0, 1, u16::MAX];
for val in vals.iter() {
let mut buf = Vec::new();
write_short(*val, &mut buf);
Expand Down Expand Up @@ -464,11 +464,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result<Consistency, ParseError> {
}

pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

#[test]
Expand Down
12 changes: 6 additions & 6 deletions scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Time(pub Duration);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SerializedValues {
serialized_values: Vec<u8>,
values_num: i16,
values_num: u16,
contains_names: bool,
}

Expand All @@ -77,7 +77,7 @@ pub struct CqlDuration {

#[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SerializeValuesError {
#[error("Too many values to add, max 32 767 values can be sent in a request")]
#[error("Too many values to add, max 65,535 values can be sent in a request")]
TooManyValues,
#[error("Mixing named and not named values is not allowed")]
MixingNamedAndNotNamedValues,
Expand Down Expand Up @@ -134,7 +134,7 @@ impl SerializedValues {
if self.contains_names {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}

Expand All @@ -158,7 +158,7 @@ impl SerializedValues {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
self.contains_names = true;
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}

Expand All @@ -184,15 +184,15 @@ impl SerializedValues {
}

pub fn write_to_request(&self, buf: &mut impl BufMut) {
buf.put_i16(self.values_num);
buf.put_u16(self.values_num);
buf.put(&self.serialized_values[..]);
}

pub fn is_empty(&self) -> bool {
self.values_num == 0
}

pub fn len(&self) -> i16 {
pub fn len(&self) -> u16 {
self.values_num
}

Expand Down
2 changes: 1 addition & 1 deletion scylla/src/statement/prepared_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ impl PreparedStatement {
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
pub enum PartitionKeyExtractionError {
#[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
NoPkIndexValue(u16, i16),
NoPkIndexValue(u16, u16),
}

#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
68 changes: 68 additions & 0 deletions scylla/src/transport/large_batch_statements_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use assert_matches::assert_matches;

use scylla_cql::errors::{BadQuery, QueryError};

use crate::batch::BatchType;
use crate::query::Query;
use crate::{
batch::Batch,
test_utils::{create_new_session_builder, unique_keyspace_name},
QueryResult, Session,
};

#[tokio::test]
async fn test_large_batch_statements() {
let mut session = create_new_session_builder().build().await.unwrap();

let ks = unique_keyspace_name();
session = create_test_session(session, &ks).await;

let max_queries = u16::MAX as usize;
let batch_insert_result = write_batch(&session, max_queries, &ks).await;

batch_insert_result.unwrap();

let too_many_queries = u16::MAX as usize + 1;
let batch_insert_result = write_batch(&session, too_many_queries, &ks).await;
assert_matches!(
batch_insert_result.unwrap_err(),
QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(_too_many_queries)) if _too_many_queries == too_many_queries
)
}

async fn create_test_session(session: Session, ks: &String) -> Session {
session
.query(
format!("CREATE KEYSPACE {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks),
&[],
)
.await.unwrap();
session
.query(
format!(
"CREATE TABLE {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))",
ks
),
&[],
)
.await
.unwrap();
session
}

async fn write_batch(session: &Session, n: usize, ks: &String) -> Result<QueryResult, QueryError> {
let mut batch_query = Batch::new(BatchType::Unlogged);
let mut batch_values = Vec::new();
let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks);
let query = Query::new(query);
let prepared_statement = session.prepare(query).await.unwrap();
for i in 0..n {
let mut key = vec![0];
key.extend(i.to_be_bytes().as_slice());
let value = key.clone();
let values = vec![key, value];
batch_values.push(values);
batch_query.append_statement(prepared_statement.clone());
}
session.batch(&batch_query, batch_values).await
}
2 changes: 2 additions & 0 deletions scylla/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod silent_prepare_batch_test;
mod cql_types_test;
#[cfg(test)]
mod cql_value_test;
#[cfg(test)]
mod large_batch_statements_test;

pub use cluster::ClusterData;
pub use node::{KnownNode, Node, NodeAddr, NodeRef};
8 changes: 8 additions & 0 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize;
use crate::authentication::AuthenticatorProvider;
#[cfg(feature = "ssl")]
use openssl::ssl::SslContext;
use scylla_cql::errors::BadQuery;

/// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses.
///
Expand Down Expand Up @@ -1143,6 +1144,13 @@ impl Session {
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
// If users batch statements by shard, they will be rewarded with full shard awareness

// check to ensure that we don't send a batch statement with more than u16::MAX queries
let batch_statements_length = batch.statements.len();
if batch_statements_length > u16::MAX as usize {
return Err(QueryError::BadQuery(
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
));
}
// Extract first serialized_value
let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?;
let first_serialized_value = first_serialized_value.as_deref();
Expand Down

0 comments on commit 82c1c99

Please sign in to comment.