Skip to content

Commit

Permalink
cassandra 5.0 vector type CREATE/INSERT support
Browse files Browse the repository at this point in the history
makes progress towards: #1014

The vector type is introduced by the currently in beta cassandra 5.
See: https://cassandra.apache.org/doc/latest/cassandra/reference/vector-data-type.html
Scylla does not support vector types and so the tests are setup to only
compile/run with a new cassandra_tests config.

This commit does not add support for retrieving the data via a SELECT.
That was omitted to reduce scope and will be implemented in follow up
work.
  • Loading branch information
rukai committed Nov 19, 2024
1 parent f59908c commit 2685ab9
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cassandra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: cargo build --verbose --tests --features "full-serialization"
- name: Run tests on cassandra
run: |
CDC='disabled' RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements
CDC='disabled' RUSTFLAGS="--cfg cassandra_tests" RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements
- name: Stop the cluster
if: ${{ always() }}
run: docker compose -f test/cluster/cassandra/docker-compose.yml stop
Expand Down
2 changes: 1 addition & 1 deletion scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ harness = false
[lints.rust]
unnameable_types = "warn"
unreachable_pub = "warn"
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)'] }
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)', 'cfg(cassandra_tests)'] }
105 changes: 105 additions & 0 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3165,3 +3165,108 @@ async fn test_api_migration_session_sharing() {
assert!(matched);
}
}

#[cfg(cassandra_tests)]
#[tokio::test]
async fn test_vector_type_metadata() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();

session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query_unpaged(
format!(
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
ks
),
&[],
)
.await
.unwrap();

session.refresh_metadata().await.unwrap();
let metadata = session.get_cluster_data();
let columns = &metadata.keyspaces[&ks].tables["t"].columns;
assert_eq!(
columns["b"].type_,
CqlType::Vector {
type_: Box::new(CqlType::Native(NativeType::Int)),
dimensions: 4,
},
);
assert_eq!(
columns["c"].type_,
CqlType::Vector {
type_: Box::new(CqlType::Native(NativeType::Text)),
dimensions: 2,
},
);
}

#[cfg(cassandra_tests)]
#[tokio::test]
async fn test_vector_type_unprepared() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();

session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query_unpaged(
format!(
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
ks
),
&[],
)
.await
.unwrap();

session
.query_unpaged(
format!(
"INSERT INTO {}.t (a, b, c) VALUES (1, [1, 2, 3, 4], ['foo', 'bar'])",
ks
),
&[],
)
.await
.unwrap();

// TODO: Implement and test SELECT statements and bind values (`?`)
}

#[cfg(cassandra_tests)]
#[tokio::test]
async fn test_vector_type_prepared() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();

session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query_unpaged(
format!(
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
ks
),
&[],
)
.await
.unwrap();

let prepared_statement = session
.prepare(format!(
"INSERT INTO {}.t (a, b, c) VALUES (?, [11, 12, 13, 14], ['afoo', 'abar'])",
ks
))
.await
.unwrap();
session
.execute_unpaged(&prepared_statement, &(2,))
.await
.unwrap();

// TODO: Implement and test SELECT statements and bind values (`?`)
}
47 changes: 47 additions & 0 deletions scylla/src/transport/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ enum PreCqlType {
type_: PreCollectionType,
},
Tuple(Vec<PreCqlType>),
Vector {
type_: Box<PreCqlType>,
/// matches the datatype used by the java driver:
/// <https://github.com/apache/cassandra-java-driver/blob/85bb4065098b887d2dda26eb14423ce4fc687045/core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java#L77>
dimensions: i32,
},
UserDefinedType {
frozen: bool,
name: String,
Expand All @@ -211,6 +217,10 @@ impl PreCqlType {
.map(|t| t.into_cql_type(keyspace_name, udts))
.collect(),
),
PreCqlType::Vector { type_, dimensions } => CqlType::Vector {
type_: Box::new(type_.into_cql_type(keyspace_name, udts)),
dimensions,
},
PreCqlType::UserDefinedType { frozen, name } => {
let definition = match udts
.get(keyspace_name)
Expand All @@ -236,6 +246,12 @@ pub enum CqlType {
type_: CollectionType,
},
Tuple(Vec<CqlType>),
Vector {
type_: Box<CqlType>,
/// matches the datatype used by the java driver:
/// <https://github.com/apache/cassandra-java-driver/blob/85bb4065098b887d2dda26eb14423ce4fc687045/core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java#L77>
dimensions: i32,
},
UserDefinedType {
frozen: bool,
// Using Arc here in order not to have many copies of the same definition
Expand Down Expand Up @@ -1137,6 +1153,7 @@ fn topo_sort_udts(udts: &mut Vec<UdtRowWithParsedFieldTypes>) -> Result<(), Quer
PreCqlType::Tuple(types) => types
.iter()
.for_each(|type_| do_with_referenced_udts(what, type_)),
PreCqlType::Vector { type_, .. } => do_with_referenced_udts(what, type_),
PreCqlType::UserDefinedType { name, .. } => what(name),
}
}
Expand Down Expand Up @@ -1637,6 +1654,22 @@ fn parse_cql_type(p: ParserState<'_>) -> ParseResult<(PreCqlType, ParserState<'_
})?;

Ok((PreCqlType::Tuple(types), p))
} else if let Ok(p) = p.accept("vector<") {
let (inner_type, p) = parse_cql_type(p)?;

let p = p.skip_white();
let p = p.accept(",")?;
let p = p.skip_white();
let (size, p) = p.parse_i32()?;
let p = p.skip_white();
let p = p.accept(">")?;

let typ = PreCqlType::Vector {
type_: Box::new(inner_type),
dimensions: size,
};

Ok((typ, p))
} else if let Ok((typ, p)) = parse_native_type(p) {
Ok((PreCqlType::Native(typ), p))
} else if let Ok((name, p)) = parse_user_defined_type(p) {
Expand Down Expand Up @@ -1827,6 +1860,20 @@ mod tests {
PreCqlType::Native(NativeType::Varint),
]),
),
(
"vector<int, 5>",
PreCqlType::Vector {
type_: Box::new(PreCqlType::Native(NativeType::Int)),
dimensions: 5,
},
),
(
"vector<text, 1234>",
PreCqlType::Vector {
type_: Box::new(PreCqlType::Native(NativeType::Text)),
dimensions: 1234,
},
),
(
"com.scylladb.types.AwesomeType",
PreCqlType::UserDefinedType {
Expand Down
15 changes: 15 additions & 0 deletions scylla/src/utils/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ impl<'s> ParserState<'s> {
me
}

/// Parses a sequence of digits and '-' as an integer.
/// Consumes characters until it finds a character that is not a digit or '-'.
///
/// An error is returned if:
/// * The first character is not a digit or '-'
/// * The integer is larger than i32
pub(crate) fn parse_i32(self) -> ParseResult<(i32, Self)> {
let (digits, p) = self.take_while(|c| c.is_ascii_digit() || c == '-');
if let Ok(value) = digits.parse() {
Ok((value, p))
} else {
Err(p.error(ParseErrorCause::Other("Expected 32-bit signed integer")))
}
}

/// Skips characters from the beginning while they satisfy given predicate
/// and returns new parser state which
pub(crate) fn take_while(self, mut pred: impl FnMut(char) -> bool) -> (&'s str, Self) {
Expand Down

0 comments on commit 2685ab9

Please sign in to comment.