Skip to content

Commit

Permalink
Merge pull request scylladb#1038 from muzarski/counter_batch_fix
Browse files Browse the repository at this point in the history
conn: copy a batch type in prepare_batch
  • Loading branch information
wprzytula authored Jul 10, 2024
2 parents a188ef8 + 5a0d6e7 commit ea0a4d0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
11 changes: 11 additions & 0 deletions scylla/src/statement/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ impl Batch {
}
}

/// Creates an empty batch, with the configuration of existing batch.
pub(crate) fn new_from(batch: &Batch) -> Batch {
let batch_type = batch.get_type();
let config = batch.config.clone();
Batch {
batch_type,
config,
..Default::default()
}
}

/// Creates a new, empty `Batch` of `batch_type` type with the provided statements.
pub fn new_with_statements(batch_type: BatchType, statements: Vec<BatchStatement>) -> Self {
Self {
Expand Down
3 changes: 1 addition & 2 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,7 @@ impl Connection {
prepared_queries.insert(query, prepared);
}

let mut batch: Cow<Batch> = Cow::Owned(Default::default());
batch.to_mut().config = init_batch.config.clone();
let mut batch: Cow<Batch> = Cow::Owned(Batch::new_from(init_batch));
for stmt in &init_batch.statements {
match stmt {
BatchStatement::Query(query) => match prepared_queries.get(query.contents.as_str())
Expand Down
51 changes: 51 additions & 0 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,57 @@ async fn test_prepared_statement() {
}
}

#[tokio::test]
async fn test_counter_batch() {
use crate::frame::value::Counter;
use scylla_cql::frame::request::batch::BatchType;

setup_tracing();
let session = Arc::new(create_new_session_builder().build().await.unwrap());
let ks = unique_keyspace_name();

session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query(
format!(
"CREATE TABLE IF NOT EXISTS {}.t_batch (key int PRIMARY KEY, value counter)",
ks
),
&[],
)
.await
.unwrap();

let statement_str = format!("UPDATE {}.t_batch SET value = value + ? WHERE key = ?", ks);
let query = Query::from(statement_str);
let prepared = session.prepare(query.clone()).await.unwrap();

let mut counter_batch = Batch::new(BatchType::Counter);
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());
counter_batch.append_statement(query.clone());
counter_batch.append_statement(prepared.clone());

// Check that we do not get a server error - the driver
// should send a COUNTER batch instead of a LOGGED (default) one.
session
.batch(
&counter_batch,
(
(Counter(1), 1),
(Counter(2), 2),
(Counter(3), 3),
(Counter(4), 4),
(Counter(5), 5),
(Counter(6), 6),
),
)
.await
.unwrap();
}

#[tokio::test]
async fn test_batch() {
setup_tracing();
Expand Down

0 comments on commit ea0a4d0

Please sign in to comment.