diff --git a/scylla/tests/integration/shards.rs b/scylla/tests/integration/shards.rs index f8aa2335b6..185b1c5c68 100644 --- a/scylla/tests/integration/shards.rs +++ b/scylla/tests/integration/shards.rs @@ -1,21 +1,23 @@ +use std::collections::HashSet; use std::sync::Arc; use crate::utils::test_with_3_node_cluster; +use scylla::transport::Node; use scylla::{test_utils::unique_keyspace_name, SessionBuilder}; +use scylla::{IntoTypedRows, Session}; use tokio::sync::mpsc; -use scylla_proxy::TargetShard; use scylla_proxy::{ - Condition, Reaction, RequestOpcode, RequestReaction, RequestRule, ShardAwareness, + Condition, ProxyError, Reaction, RequestFrame, RequestOpcode, RequestReaction, RequestRule, + ResponseFrame, ResponseOpcode, ResponseReaction, ResponseRule, ShardAwareness, TargetShard, + WorkerError, }; -use scylla_proxy::{ProxyError, RequestFrame, WorkerError}; +use uuid::Uuid; #[tokio::test] #[ntest::timeout(30000)] #[cfg(not(scylla_cloud_tests))] async fn test_consistent_shard_awareness() { - use std::collections::HashSet; - let res = test_with_3_node_cluster(ShardAwareness::QueryNode, |proxy_uris, translation_map, mut running_proxy| async move { let (feedback_txs, mut feedback_rxs): (Vec<_>, Vec<_>) = (0..3).map(|_| { @@ -80,3 +82,218 @@ async fn test_consistent_shard_awareness() { Err(err) => panic!("{}", err), } } + +#[derive(scylla::FromRow)] +struct SelectedTablet { + last_token: i64, + replicas: Vec<(Uuid, i32)>, +} + +struct Tablet { + first_token: i64, + last_token: i64, + replicas: Vec<(Arc, i32)>, +} + +async fn get_tablets(session: &Session, ks: String, table: String) -> Vec { + let cluster_data = session.get_cluster_data(); + let endpoints = cluster_data.get_nodes_info(); + for endpoint in endpoints.iter() { + println!( + "id: {}, address: {}", + endpoint.host_id, + endpoint.address.ip() + ); + } + + let selected_tablets_rows = session.query( + "select last_token, replicas from system.tablets WHERE keyspace_name = ? and table_name = ? ALLOW FILTERING", + &(ks.as_str(), table.as_str())).await.unwrap().rows.unwrap(); + + let mut selected_tablets = selected_tablets_rows + .into_typed::() + .map(|x| x.unwrap()) + .collect::>(); + selected_tablets.sort_unstable_by(|a, b| a.last_token.cmp(&b.last_token)); + + let mut tablets = Vec::new(); + let mut first_token = i64::MIN; + for tablet in selected_tablets { + let replicas = tablet + .replicas + .iter() + .map(|(uuid, shard)| { + ( + Arc::clone( + endpoints + .get(endpoints.iter().position(|e| e.host_id == *uuid).unwrap()) + .unwrap(), + ), + *shard, + ) + }) + .collect(); + let raw_tablet = Tablet { + first_token, + last_token: tablet.last_token, + replicas, + }; + first_token = tablet.last_token.wrapping_add(1); + tablets.push(raw_tablet); + } + + tablets +} + +#[tokio::test] +#[ntest::timeout(30000)] +#[cfg(not(scylla_cloud_tests))] +async fn test_tablet_shard_awareness() { + use scylla::load_balancing::DefaultPolicy; + // use tracing_subscriber; + + // let filter = tracing_subscriber::EnvFilter::builder() + // .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) + // .from_env().unwrap() + // .add_directive("scylla_proxy::proxy=warn".parse().unwrap()); + + // tracing_subscriber::fmt().with_env_filter(filter).init(); + + const TABLET_COUNT: usize = 16; + + let res = test_with_3_node_cluster( + ShardAwareness::QueryNode, + |proxy_uris, translation_map, mut running_proxy| async move { + let (feedback_txs, mut feedback_rxs): (Vec<_>, Vec<_>) = (0..3) + .map(|_| mpsc::unbounded_channel::<(ResponseFrame, Option)>()) + .unzip(); + for (i, tx) in feedback_txs.iter().cloned().enumerate() { + running_proxy.running_nodes[i].change_response_rules(Some(vec![ResponseRule( + Condition::ResponseOpcode(ResponseOpcode::Result) + .and(Condition::not(Condition::ConnectionRegisteredAnyEvent)), + ResponseReaction::noop().with_feedback_when_performed(tx), + )])); + } + let lbp = DefaultPolicy::builder().build(); + let execution_profile = scylla::ExecutionProfile::builder() + .load_balancing_policy(lbp) + .build(); + let session = SessionBuilder::new() + .known_node(proxy_uris[0].as_str()) + .address_translator(Arc::new(translation_map)) + .default_execution_profile_handle(execution_profile.into_handle()) + .build() + .await + .unwrap(); + let ks = unique_keyspace_name(); + + /* Prepare schema */ + session + .query( + format!( + "CREATE KEYSPACE IF NOT EXISTS {} + WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3}} + AND tablets = {{ 'initial': {} }}", + ks, + TABLET_COUNT + ), + &[], + ) + .await + .unwrap(); + session + .query( + format!( + "CREATE TABLE IF NOT EXISTS {}.t (a int, b int, c text, primary key (a, b))", + ks + ), + &[], + ) + .await + .unwrap(); + + let tablets = get_tablets(&session, ks.clone(), "t".to_string()).await; + for tablet in tablets.iter() { + println!("[{}, {}]: {:?}", + tablet.first_token, + tablet.last_token, + tablet.replicas.iter().map(|(replica, shard)| { + (replica.address.ip(), shard) + }).collect::>()); + } + + let prepared = session + .prepare(format!( + "INSERT INTO {}.t (a, b, c) VALUES (?, ?, 'abc')", + ks + )) + .await + .unwrap(); + + let mut present_tablets = [false; TABLET_COUNT]; + let mut value_lists = vec![]; + for i in 0..1000 { + let token_value = prepared.calculate_token(&(i, 1)).unwrap().unwrap().value; + let tablet_idx = tablets.iter().position(|tablet| tablet.first_token <= token_value && token_value <= tablet.last_token).unwrap(); + if !present_tablets[tablet_idx] { + let values = (i, 1); + let tablet = &tablets[tablet_idx]; + println!("Values: {:?}, token: {}, tablet index: {}, tablet: [{}, {}]: {:?}", + values, + token_value, + tablet_idx, + tablet.first_token, + tablet.last_token, + tablet.replicas.iter().map(|(replica, shard)| { + (replica.address.ip(), shard) + }).collect::>() + ); + value_lists.push(values); + present_tablets[tablet_idx] = true; + } + } + + assert!(present_tablets.iter().all(|x| *x)); + + + fn count_tablet_feedbacks ( + rx: &mut mpsc::UnboundedReceiver<(ResponseFrame, Option)>, + ) -> usize { + std::iter::from_fn(|| rx.try_recv().ok()).map(|(frame, _shard)| { + let response = scylla_cql::frame::parse_response_body_extensions(frame.params.flags, None, frame.body).unwrap(); + match response.custom_payload { + Some(map) => map.contains_key("tablets-routing-v1"), + None => false + } + }).filter(|b| *b).count() + } + + for values in value_lists.iter() { + println!("{:?}, token: {}", values, prepared.calculate_token(&values).unwrap().unwrap().value); + for _ in 0..10 { + session.execute(&prepared, values).await.unwrap(); + } + + let feedbacks: usize = feedback_rxs.iter_mut().map(count_tablet_feedbacks).sum(); + assert!(feedbacks > 0); + } + + for values in value_lists.iter() { + println!("{:?}, token: {}", values, prepared.calculate_token(&values).unwrap().unwrap().value); + for _ in 0..10 { + session.execute(&prepared, values).await.unwrap(); + } + let feedbacks: usize = feedback_rxs.iter_mut().map(count_tablet_feedbacks).sum(); + assert_eq!(feedbacks, 0); + } + + running_proxy + }, + ) + .await; + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +}