Skip to content

Commit

Permalink
conn: copy a batch type in prepare_batch
Browse files Browse the repository at this point in the history
Previously, there was a bug in prepare_batch function.
We wouldn't inherit a BatchType from the `initial_batch` - in result,
we would always send LOGGED batches (since LOGGED is a default batch type).

This would result in server errors when including a counter statements
in a LOGGED batch. I added a test case which checks that it works
correctly now, and the driver sends a COUNTER batch when user asks to.
  • Loading branch information
muzarski committed Jul 10, 2024
1 parent a188ef8 commit f4962c6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
2 changes: 1 addition & 1 deletion scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ impl Connection {
prepared_queries.insert(query, prepared);
}

let mut batch: Cow<Batch> = Cow::Owned(Default::default());
let mut batch: Cow<Batch> = Cow::Owned(Batch::new(init_batch.get_type()));
batch.to_mut().config = init_batch.config.clone();
for stmt in &init_batch.statements {
match stmt {
Expand Down
51 changes: 51 additions & 0 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,57 @@ async fn test_prepared_statement() {
}
}

#[tokio::test]
async fn test_counter_batch() {
use crate::frame::value::Counter;
use scylla_cql::frame::request::batch::BatchType;

setup_tracing();
let session = Arc::new(create_new_session_builder().build().await.unwrap());
let ks = unique_keyspace_name();

session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query(
format!(
"CREATE TABLE IF NOT EXISTS {}.t_batch (key int PRIMARY KEY, value counter)",
ks
),
&[],
)
.await
.unwrap();

let statement_str = format!("UPDATE {}.t_batch SET value = value + ? WHERE key = ?", ks);
let query = Query::from(statement_str);
let prepared = session.prepare(query.clone()).await.unwrap();

let mut counter_batch = Batch::new(BatchType::Counter);
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());

// Check that we do not get a server error - the driver
// should send a COUNTER batch instead of a LOGGED (default) one.
session
.batch(
&counter_batch,
(
(Counter(1), 1),
(Counter(2), 2),
(Counter(3), 3),
(Counter(4), 4),
(Counter(5), 5),
(Counter(6), 6),
),
)
.await
.unwrap();
}

#[tokio::test]
async fn test_batch() {
setup_tracing();
Expand Down

0 comments on commit f4962c6

Please sign in to comment.