diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index efe95031e2..30e46ac278 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -30,6 +30,17 @@ impl Batch { } } + /// Creates an empty batch, with the configuration of existing batch. + pub(crate) fn new_from(batch: &Batch) -> Batch { + let batch_type = batch.get_type(); + let config = batch.config.clone(); + Batch { + batch_type, + config, + ..Default::default() + } + } + /// Creates a new, empty `Batch` of `batch_type` type with the provided statements. pub fn new_with_statements(batch_type: BatchType, statements: Vec) -> Self { Self { diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index cd225aae55..cdc6e730ea 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -914,8 +914,7 @@ impl Connection { prepared_queries.insert(query, prepared); } - let mut batch: Cow = Cow::Owned(Default::default()); - batch.to_mut().config = init_batch.config.clone(); + let mut batch: Cow = Cow::Owned(Batch::new_from(init_batch)); for stmt in &init_batch.statements { match stmt { BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str()) diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 159dc8840f..f4a49daf85 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -354,6 +354,57 @@ async fn test_prepared_statement() { } } +#[tokio::test] +async fn test_counter_batch() { + use crate::frame::value::Counter; + use scylla_cql::frame::request::batch::BatchType; + + setup_tracing(); + let session = Arc::new(create_new_session_builder().build().await.unwrap()); + let ks = unique_keyspace_name(); + + session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); + session + .query( + format!( + "CREATE TABLE IF NOT EXISTS {}.t_batch (key int PRIMARY KEY, value counter)", + ks + ), + &[], + ) + .await + .unwrap(); + + let statement_str = format!("UPDATE {}.t_batch SET value = value + ? WHERE key = ?", ks); + let query = Query::from(statement_str); + let prepared = session.prepare(query.clone()).await.unwrap(); + + let mut counter_batch = Batch::new(BatchType::Counter); + counter_batch.append_statement(query.clone()); + counter_batch.append_statement(prepared.clone()); + counter_batch.append_statement(query.clone()); + counter_batch.append_statement(prepared.clone()); + counter_batch.append_statement(query.clone()); + counter_batch.append_statement(prepared.clone()); + + // Check that we do not get a server error - the driver + // should send a COUNTER batch instead of a LOGGED (default) one. + session + .batch( + &counter_batch, + ( + (Counter(1), 1), + (Counter(2), 2), + (Counter(3), 3), + (Counter(4), 4), + (Counter(5), 5), + (Counter(6), 6), + ), + ) + .await + .unwrap(); +} + #[tokio::test] async fn test_batch() { setup_tracing();