diff --git a/.github/workflows/cassandra.yml b/.github/workflows/cassandra.yml index 4926ece5d6..6c8a71874d 100644 --- a/.github/workflows/cassandra.yml +++ b/.github/workflows/cassandra.yml @@ -31,7 +31,7 @@ jobs: run: cargo build --verbose --tests --features "full-serialization" - name: Run tests on cassandra run: | - CDC='disabled' RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements + CDC='disabled' RUSTFLAGS="--cfg cassandra_tests" RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements - name: Stop the cluster if: ${{ always() }} run: docker compose -f test/cluster/cassandra/docker-compose.yml stop diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index ab891c72f8..16c12ea358 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -2885,3 +2885,45 @@ async fn test_manual_primary_key_computation() { .await; } } + +#[cfg(cassandra_tests)] +#[tokio::test] +async fn test_vector_type() { + setup_tracing(); + let session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + + session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap(); + session + .query( + format!( + "CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector, c vector)", + ks + ), + &[], + ) + .await + .unwrap(); + + session + .query( + format!( + "INSERT INTO {}.t (a, b, c) VALUES (1, [1, 2, 3, 4], ['foo', 'bar'])", + ks + ), + &[], + ) + .await + .unwrap(); + + let prepared_statement = session + .prepare(format!( + "INSERT INTO {}.t (a, b, c) VALUES (?, [11, 12, 13, 14], ['afoo', 'abar'])", + ks + )) + .await + .unwrap(); + session.execute(&prepared_statement, &(2,)).await.unwrap(); + + // TODO: Implement and test SELECT statements and bind values (`?`) +} diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index b468050c0b..696bbe0c1a 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -184,6 +184,12 @@ enum PreCqlType { type_: PreCollectionType, }, Tuple(Vec), + Vector { + type_: Box, + /// matches the datatype used by the java driver: + /// + dimensions: i32, + }, UserDefinedType { frozen: bool, name: String, @@ -207,6 +213,10 @@ impl PreCqlType { .map(|t| t.into_cql_type(keyspace_name, udts)) .collect(), ), + PreCqlType::Vector { type_, dimensions } => CqlType::Vector { + type_: Box::new(type_.into_cql_type(keyspace_name, udts)), + dimensions, + }, PreCqlType::UserDefinedType { frozen, name } => { let definition = match udts .get(keyspace_name) @@ -232,6 +242,12 @@ pub enum CqlType { type_: CollectionType, }, Tuple(Vec), + Vector { + type_: Box, + /// matches the datatype used by the java driver: + /// + dimensions: i32, + }, UserDefinedType { frozen: bool, // Using Arc here in order not to have many copies of the same definition @@ -1093,6 +1109,7 @@ fn topo_sort_udts(udts: &mut Vec) -> Result<(), Quer PreCqlType::Tuple(types) => types .iter() .for_each(|type_| do_with_referenced_udts(what, type_)), + PreCqlType::Vector { type_, .. } => do_with_referenced_udts(what, type_), PreCqlType::UserDefinedType { name, .. } => what(name), } } @@ -1602,6 +1619,22 @@ fn parse_cql_type(p: ParserState<'_>) -> ParseResult<(PreCqlType, ParserState<'_ })?; Ok((PreCqlType::Tuple(types), p)) + } else if let Ok(p) = p.accept("vector<") { + let (inner_type, p) = parse_cql_type(p)?; + + let p = p.skip_white(); + let p = p.accept(",")?; + let p = p.skip_white(); + let (size, p) = p.parse_i32()?; + let p = p.skip_white(); + let p = p.accept(">")?; + + let typ = PreCqlType::Vector { + type_: Box::new(inner_type), + dimensions: size, + }; + + Ok((typ, p)) } else if let Ok((typ, p)) = parse_native_type(p) { Ok((PreCqlType::Native(typ), p)) } else if let Ok((name, p)) = parse_user_defined_type(p) { @@ -1787,6 +1820,20 @@ mod tests { PreCqlType::Native(NativeType::Varint), ]), ), + ( + "vector", + PreCqlType::Vector { + type_: Box::new(PreCqlType::Native(NativeType::Int)), + dimensions: 5, + }, + ), + ( + "vector", + PreCqlType::Vector { + type_: Box::new(PreCqlType::Native(NativeType::Text)), + dimensions: 1234, + }, + ), ( "com.scylladb.types.AwesomeType", PreCqlType::UserDefinedType { diff --git a/scylla/src/utils/parse.rs b/scylla/src/utils/parse.rs index 1c5e59ecb7..58c22b084a 100644 --- a/scylla/src/utils/parse.rs +++ b/scylla/src/utils/parse.rs @@ -87,6 +87,21 @@ impl<'s> ParserState<'s> { me } + /// Parses a sequence of digits and '-' as an integer. + /// Consumes characters until it finds a character that is not a digit or '-'. + /// + /// An error is returned if: + /// * The first character is not a digit or '-' + /// * The the integer is larger than i32 + pub(crate) fn parse_i32(self) -> ParseResult<(i32, Self)> { + let (digits, p) = self.take_while(|c| c.is_ascii_digit() || c == '-'); + if let Ok(value) = digits.parse() { + Ok((value, p)) + } else { + Err(p.error(ParseErrorCause::Expected("integer of max length 2**32"))) + } + } + /// Skips characters from the beginning while they satisfy given predicate /// and returns new parser state which pub(crate) fn take_while(self, mut pred: impl FnMut(char) -> bool) -> (&'s str, Self) { diff --git a/test/cluster/cassandra/docker-compose.yml b/test/cluster/cassandra/docker-compose.yml index aa46efd1f6..ac25c5f2f5 100644 --- a/test/cluster/cassandra/docker-compose.yml +++ b/test/cluster/cassandra/docker-compose.yml @@ -10,12 +10,12 @@ networks: - subnet: 172.42.0.0/16 services: cassandra1: - image: cassandra + image: cassandra:5.0-beta1 healthcheck: - test: ["CMD", "cqlsh", "-e", "describe keyspaces" ] - interval: 5s - timeout: 5s - retries: 60 + test: [ "CMD", "cqlsh", "-e", "describe keyspaces" ] + interval: 5s + timeout: 5s + retries: 60 networks: public: ipv4_address: 172.42.0.2 @@ -24,12 +24,12 @@ services: - HEAP_NEWSIZE=512M - MAX_HEAP_SIZE=2048M cassandra2: - image: cassandra + image: cassandra:5.0-beta1 healthcheck: - test: ["CMD", "cqlsh", "-e", "describe keyspaces" ] - interval: 5s - timeout: 5s - retries: 60 + test: [ "CMD", "cqlsh", "-e", "describe keyspaces" ] + interval: 5s + timeout: 5s + retries: 60 networks: public: ipv4_address: 172.42.0.3 @@ -42,12 +42,12 @@ services: cassandra1: condition: service_healthy cassandra3: - image: cassandra + image: cassandra:5.0-beta1 healthcheck: - test: ["CMD", "cqlsh", "-e", "describe keyspaces" ] - interval: 5s - timeout: 5s - retries: 60 + test: [ "CMD", "cqlsh", "-e", "describe keyspaces" ] + interval: 5s + timeout: 5s + retries: 60 networks: public: ipv4_address: 172.42.0.4