diff --git a/README.md b/README.md index c9682925bb..15d82dec7f 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,10 @@ let uri = "127.0.0.1:9042"; let session: Session = SessionBuilder::new().known_node(uri).build().await?; -if let Some(rows) = session.query("SELECT a, b, c FROM ks.t", &[]).await?.rows { - for row in rows.into_typed::<(i32, i32, String)>() { - let (a, b, c) = row?; - println!("a, b, c: {}, {}, {}", a, b, c); - } +let result = session.query("SELECT a, b, c FROM ks.t", &[]).await?; +let mut iter = result.rows::<(i32, i32, String)>()?; +while let Some((a, b, c)) = iter.next().transpose()? { + println!("a, b, c: {}, {}, {}", a, b, c); } ``` diff --git a/docs/source/SUMMARY.md b/docs/source/SUMMARY.md index 471e8efdad..1ec428b8de 100644 --- a/docs/source/SUMMARY.md +++ b/docs/source/SUMMARY.md @@ -7,6 +7,9 @@ - [Running Scylla using Docker](quickstart/scylla-docker.md) - [Connecting and running a simple query](quickstart/example.md) +- [Migration guides](migration-guides/migration-guides.md) + - [Adjusting deserialization code from 0.8 and older](migration-guides/post-0.8-deserialization.md) + - [Connecting to the cluster](connecting/connecting.md) - [Compression](connecting/compression.md) - [Authentication](connecting/authentication.md) diff --git a/docs/source/contents.rst b/docs/source/contents.rst index 0e0446baf7..5bc4a37c9e 100644 --- a/docs/source/contents.rst +++ b/docs/source/contents.rst @@ -13,6 +13,7 @@ retry-policy/retry-policy speculative-execution/speculative metrics/metrics + migration-guides/migration-guides logging/logging tracing/tracing schema/schema diff --git a/docs/source/data-types/blob.md b/docs/source/data-types/blob.md index c213da882c..c3e9d40377 100644 --- a/docs/source/data-types/blob.md +++ b/docs/source/data-types/blob.md @@ -17,10 +17,10 @@ session .await?; // Read blobs from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Vec,)>() { - let (blob_value,): (Vec,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Vec,)>()?; +while let Some((blob_value,)) = iter.next().transpose()? { + println!("{:?}", blob_value); } # Ok(()) # } diff --git a/docs/source/data-types/collections.md b/docs/source/data-types/collections.md index 43301d31d2..cc1c256158 100644 --- a/docs/source/data-types/collections.md +++ b/docs/source/data-types/collections.md @@ -17,10 +17,10 @@ session .await?; // Read a list of ints from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Vec,)>() { - let (list_value,): (Vec,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Vec,)>()?; +while let Some((list_value,)) = iter.next().transpose()? { + println!("{:?}", list_value); } # Ok(()) # } @@ -43,10 +43,10 @@ session .await?; // Read a set of ints from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Vec,)>() { - let (set_value,): (Vec,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Vec,)>()?; +while let Some((list_value,)) = iter.next().transpose()? { + println!("{:?}", list_value); } # Ok(()) # } @@ -67,10 +67,10 @@ session .await?; // Read a set of ints from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(HashSet,)>() { - let (set_value,): (HashSet,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(HashSet,)>()?; +while let Some((list_value,)) = iter.next().transpose()? { + println!("{:?}", list_value); } # Ok(()) # } @@ -91,10 +91,10 @@ session .await?; // Read a set of ints from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(BTreeSet,)>() { - let (set_value,): (BTreeSet,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(BTreeSet,)>()?; +while let Some((list_value,)) = iter.next().transpose()? { + println!("{:?}", list_value); } # Ok(()) # } @@ -120,10 +120,10 @@ session .await?; // Read a map from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(HashMap,)>() { - let (map_value,): (HashMap,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(HashMap,)>()?; +while let Some((map_value,)) = iter.next().transpose()? { + println!("{:?}", map_value); } # Ok(()) # } @@ -146,10 +146,10 @@ session .await?; // Read a map from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(BTreeMap,)>() { - let (map_value,): (BTreeMap,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(BTreeMap,)>()?; +while let Some((map_value,)) = iter.next().transpose()? { + println!("{:?}", map_value); } # Ok(()) # } diff --git a/docs/source/data-types/counter.md b/docs/source/data-types/counter.md index 0f31b6cba7..9379e45a2c 100644 --- a/docs/source/data-types/counter.md +++ b/docs/source/data-types/counter.md @@ -11,11 +11,11 @@ use scylla::IntoTypedRows; use scylla::frame::value::Counter; // Read counter from the table -if let Some(rows) = session.query("SELECT c FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Counter,)>() { - let (counter_value,): (Counter,) = row?; - let counter_int_value: i64 = counter_value.0; - } +let result = session.query("SELECT c FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Counter,)>()?; +while let Some((counter_value,)) = iter.next().transpose()? { + let counter_int_value: i64 = counter_value.0; + println!("{}", counter_int_value); } # Ok(()) # } diff --git a/docs/source/data-types/date.md b/docs/source/data-types/date.md index 6d3384c6af..c39a3f8641 100644 --- a/docs/source/data-types/date.md +++ b/docs/source/data-types/date.md @@ -1,12 +1,13 @@ # Date -For most use cases `Date` can be represented as +For most use cases `Date` can be represented as [`chrono::NaiveDate`](https://docs.rs/chrono/0.4.19/chrono/naive/struct.NaiveDate.html).\ `NaiveDate` supports dates from -262145-1-1 to 262143-12-31. For dates outside of this range you can use the raw `u32` representation. ### Using `chrono::NaiveDate`: + ```rust # extern crate scylla; # extern crate chrono; @@ -23,16 +24,17 @@ session .await?; // Read NaiveDate from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(NaiveDate,)>() { - let (date_value,): (NaiveDate,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(NaiveDate,)>()?; +while let Some((date_value,)) = iter.next().transpose()? { + println!("{:?}", date_value); } # Ok(()) # } ``` ### Using raw `u32` representation + Internally `Date` is represented as number of days since -5877641-06-23 i.e. 2^31 days before unix epoch. ```rust @@ -50,14 +52,11 @@ session .await?; // Read raw Date from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows { - let date_value: u32 = match row.columns[0] { - Some(CqlValue::Date(date_value)) => date_value, - _ => panic!("Should be a date!") - }; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Date,)>()?; +while let Some((date_value,)) = iter.next().transpose()? { + println!("{:?}", date_value); } # Ok(()) # } -``` \ No newline at end of file +``` diff --git a/docs/source/data-types/decimal.md b/docs/source/data-types/decimal.md index d3d7b45feb..02cad1db5f 100644 --- a/docs/source/data-types/decimal.md +++ b/docs/source/data-types/decimal.md @@ -18,10 +18,10 @@ session .await?; // Read a decimal from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(BigDecimal,)>() { - let (decimal_value,): (BigDecimal,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(BigDecimal,)>()?; +while let Some((decimal_value,)) = iter.next().transpose()? { + println!("{:?}", decimal_value); } # Ok(()) # } diff --git a/docs/source/data-types/duration.md b/docs/source/data-types/duration.md index 7526a478b3..a59abf3bb5 100644 --- a/docs/source/data-types/duration.md +++ b/docs/source/data-types/duration.md @@ -9,17 +9,17 @@ use scylla::IntoTypedRows; use scylla::frame::value::CqlDuration; -// Insert some ip address into the table +// Insert some duration into the table let to_insert: CqlDuration = CqlDuration { months: 1, days: 2, nanoseconds: 3 }; session .query("INSERT INTO keyspace.table (a) VALUES(?)", (to_insert,)) .await?; -// Read inet from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(CqlDuration,)>() { - let (cql_duration,): (CqlDuration,) = row?; - } +// Read duration from the table +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(CqlDuration,)>()?; +while let Some((duration_value,)) = iter.next().transpose()? { + println!("{:?}", duration_value); } # Ok(()) # } diff --git a/docs/source/data-types/inet.md b/docs/source/data-types/inet.md index c585aefc05..7b016ad86d 100644 --- a/docs/source/data-types/inet.md +++ b/docs/source/data-types/inet.md @@ -16,10 +16,10 @@ session .await?; // Read inet from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(IpAddr,)>() { - let (inet_value,): (IpAddr,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(IpAddr,)>()?; +while let Some((inet_value,)) = iter.next().transpose()? { + println!("{:?}", inet_value); } # Ok(()) # } diff --git a/docs/source/data-types/primitive.md b/docs/source/data-types/primitive.md index e521e5e6c7..0c1041ddbb 100644 --- a/docs/source/data-types/primitive.md +++ b/docs/source/data-types/primitive.md @@ -1,6 +1,7 @@ # Bool, Tinyint, Smallint, Int, Bigint, Float, Double ### Bool + `Bool` is represented as rust `bool` ```rust @@ -17,16 +18,17 @@ session .await?; // Read a bool from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(bool,)>() { - let (bool_value,): (bool,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(bool,)>()?; +while let Some((bool_value,)) = iter.next().transpose()? { + println!("{}", bool_value); } # Ok(()) # } ``` ### Tinyint + `Tinyint` is represented as rust `i8` ```rust @@ -43,16 +45,17 @@ session .await?; // Read a tinyint from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(i8,)>() { - let (tinyint_value,): (i8,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(i8,)>()?; +while let Some((tinyint_value,)) = iter.next().transpose()? { + println!("{:?}", tinyint_value); } # Ok(()) # } ``` ### Smallint + `Smallint` is represented as rust `i16` ```rust @@ -69,16 +72,17 @@ session .await?; // Read a smallint from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(i16,)>() { - let (smallint_value,): (i16,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(i16,)>()?; +while let Some((smallint_value,)) = iter.next().transpose()? { + println!("{}", smallint_value); } # Ok(()) # } ``` ### Int + `Int` is represented as rust `i32` ```rust @@ -95,16 +99,17 @@ session .await?; // Read an int from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(i32,)>() { - let (int_value,): (i32,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(i32,)>()?; +while let Some((int_value,)) = iter.next().transpose()? { + println!("{}", int_value); } # Ok(()) # } ``` ### Bigint + `Bigint` is represented as rust `i64` ```rust @@ -121,16 +126,17 @@ session .await?; // Read a bigint from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(i64,)>() { - let (bigint_value,): (i64,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(i64,)>()?; +while let Some((bigint_value,)) = iter.next().transpose()? { + println!("{:?}", bigint_value); } # Ok(()) # } ``` -### Float +### Float + `Float` is represented as rust `f32` ```rust @@ -147,16 +153,17 @@ session .await?; // Read a float from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(f32,)>() { - let (float_value,): (f32,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(f32,)>()?; +while let Some((float_value,)) = iter.next().transpose()? { + println!("{:?}", float_value); } # Ok(()) # } ``` ### Double + `Double` is represented as rust `f64` ```rust @@ -173,11 +180,11 @@ session .await?; // Read a double from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(f64,)>() { - let (double_value,): (f64,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(f64,)>()?; +while let Some((double_value,)) = iter.next().transpose()? { + println!("{:?}", double_value); } # Ok(()) # } -``` \ No newline at end of file +``` diff --git a/docs/source/data-types/text.md b/docs/source/data-types/text.md index 68479d233f..9001fb52e2 100644 --- a/docs/source/data-types/text.md +++ b/docs/source/data-types/text.md @@ -21,10 +21,10 @@ session .await?; // Read ascii/text/varchar from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(String,)>() { - let (text_value,): (String,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(String,)>()?; +while let Some((text_value,)) = iter.next().transpose()? { + println!("{}", text_value); } # Ok(()) # } diff --git a/docs/source/data-types/time.md b/docs/source/data-types/time.md index 6f46f9dae1..fa1aaba7d3 100644 --- a/docs/source/data-types/time.md +++ b/docs/source/data-types/time.md @@ -23,10 +23,10 @@ session .await?; // Read time from the table, no need for a wrapper here -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Duration,)>() { - let (time_value,): (Duration,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Duration,)>()?; +while let Some((time_value,)) = iter.next().transpose()? { + println!("{:?}", time_value); } # Ok(()) # } diff --git a/docs/source/data-types/timestamp.md b/docs/source/data-types/timestamp.md index d61aec2aec..d31d68e84d 100644 --- a/docs/source/data-types/timestamp.md +++ b/docs/source/data-types/timestamp.md @@ -23,10 +23,10 @@ session .await?; // Read timestamp from the table, no need for a wrapper here -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Duration,)>() { - let (timestamp_value,): (Duration,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Duration,)>()?; +while let Some((timestamp_value,)) = iter.next().transpose()? { + println!("{:?}", timestamp_value); } # Ok(()) # } diff --git a/docs/source/data-types/tuple.md b/docs/source/data-types/tuple.md index 74a41de947..32c586a91e 100644 --- a/docs/source/data-types/tuple.md +++ b/docs/source/data-types/tuple.md @@ -1,4 +1,5 @@ # Tuple + `Tuple` is represented as rust tuples of max 16 elements. ```rust @@ -15,14 +16,13 @@ session .await?; // Read a tuple of int and string from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<((i32, String),)>() { - let (tuple_value,): ((i32, String),) = row?; - - let int_value: i32 = tuple_value.0; - let string_value: String = tuple_value.1; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<((i32, String),)>()?; +while let Some((tuple_value,)) = iter.next().transpose()? { + let int_value: i32 = tuple_value.0; + let string_value: String = tuple_value.1; + println!("({}, {})", int_value, string_value); } # Ok(()) # } -``` \ No newline at end of file +``` diff --git a/docs/source/data-types/udt.md b/docs/source/data-types/udt.md index 85c401fae5..bbef21985c 100644 --- a/docs/source/data-types/udt.md +++ b/docs/source/data-types/udt.md @@ -14,12 +14,13 @@ To use this type in the driver create a matching struct and derive `IntoUserType # use std::error::Error; # async fn check_only_compiles(session: &Session) -> Result<(), Box> { use scylla::IntoTypedRows; -use scylla::macros::{FromUserType, IntoUserType}; +use scylla::macros::{IntoUserType, DeserializeCql}; use scylla::cql_to_rust::FromCqlVal; +use scylla::types::deserialize::value::DeserializeCql; // Define custom struct that matches User Defined Type created earlier // wrapping field in Option will gracefully handle null field values -#[derive(Debug, IntoUserType, FromUserType)] +#[derive(Debug, IntoUserType, DeserializeCql)] struct MyType { int_val: i32, text_val: Option, @@ -38,10 +39,10 @@ session .await?; // Read MyType from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(MyType,)>() { - let (my_type_value,): (MyType,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(MyType,)>()?; +while let Some((my_type_value,)) = iter.next().transpose()? { + println!("{:?}", my_type_value); } # Ok(()) # } diff --git a/docs/source/data-types/uuid.md b/docs/source/data-types/uuid.md index c3cfde2725..b8eb25b80d 100644 --- a/docs/source/data-types/uuid.md +++ b/docs/source/data-types/uuid.md @@ -18,10 +18,10 @@ session .await?; // Read uuid/timeuuid from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(Uuid,)>() { - let (uuid_value,): (Uuid,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(Uuid,)>()?; +while let Some((uuid_value,)) = iter.next().transpose()? { + println!("{:?}", uuid_value); } # Ok(()) # } diff --git a/docs/source/data-types/varint.md b/docs/source/data-types/varint.md index b90c9a5ccb..cca45215ed 100644 --- a/docs/source/data-types/varint.md +++ b/docs/source/data-types/varint.md @@ -18,10 +18,10 @@ session .await?; // Read a varint from the table -if let Some(rows) = session.query("SELECT a FROM keyspace.table", &[]).await?.rows { - for row in rows.into_typed::<(BigInt,)>() { - let (varint_value,): (BigInt,) = row?; - } +let result = session.query("SELECT a FROM keyspace.table", &[]).await?; +let mut iter = result.rows::<(BigInt,)>()?; +while let Some((varint_value,)) = iter.next().transpose()? { + println!("{:?}", varint_value); } # Ok(()) # } diff --git a/docs/source/index.md b/docs/source/index.md index c5e1191b1f..0ab28240ca 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -23,3 +23,4 @@ Although optimized for Scylla, the driver is also compatible with [Apache Cassan * [Logging](logging/logging.md) - Viewing and integrating logs produced by the driver * [Query tracing](tracing/tracing.md) - Tracing query execution * [Database schema](schema/schema.md) - Fetching and inspecting database schema +* [Migration guides](migration-guides/migration-guides.md) - How to update the code that used an older version of this driver diff --git a/docs/source/migration-guides/migration-guides.md b/docs/source/migration-guides/migration-guides.md new file mode 100644 index 0000000000..d257b461df --- /dev/null +++ b/docs/source/migration-guides/migration-guides.md @@ -0,0 +1,11 @@ +# Migration guides + +- [Migrating from 0.8 to the new deserialization framework](post-0.8-deserialization.md) + +```eval_rst +.. toctree:: + :hidden: + :glob: + + post-0.8-deserialization +``` \ No newline at end of file diff --git a/docs/source/migration-guides/post-0.8-deserialization.md b/docs/source/migration-guides/post-0.8-deserialization.md new file mode 100644 index 0000000000..63a7fc4be0 --- /dev/null +++ b/docs/source/migration-guides/post-0.8-deserialization.md @@ -0,0 +1,264 @@ +# Post-0.8 deserialization API migration guide + +After 0.8, a new deserialization API has been introduced. The new API improves type safety and performance of the old one, so it is highly recommended to switch to it. However, deserialization is an area of the API that users frequently interact with: deserialization traits appear in generic code and custom implementations have been written. In order to make migration easier, the driver still offers the old API, which - while opt-in - can be very easily switched to after version upgrade. Furthermore, a number of facilities have been introduced which help migrate the user code to the new API piece-by-piece. + +The old API and migration facilities will be removed in the next major release (2.0). + +## Introduction + +### Old traits + +The legacy API works by deserializing rows in the query response to a sequence of `Row`s. The `Row` is just a `Vec>`, where `CqlValue` is an enum that is able to represent any CQL value. + +The user can request this type-erased representation to be converted into something useful. There are two traits that power this: + +__`FromRow`__ + +```rust +# extern crate scylla; +# use scylla::frame::response::cql_to_rust::FromRowError; +# use scylla::frame::response::result::Row; +pub trait FromRow: Sized { + fn from_row(row: Row) -> Result; +} +``` + +__`FromCqlVal`__ + +```rust +# extern crate scylla; +# use scylla::frame::response::cql_to_rust::FromCqlValError; +// The `T` parameter is supposed to be either `CqlValue` or `Option` +pub trait FromCqlVal: Sized { + fn from_cql(cql_val: T) -> Result; +} +``` + +These traits are implemented for some common types: + +- `FromRow` is implemented for tuples up to 16 elements, +- `FromCqlVal` is implemented for a bunch of types, and each CQL type can be converted to one of them. + +While it's possible to implement those manually, the driver provides procedural macros for automatic derivation in some cases: + +- `FromRow` - implements `FromRow` for a struct. +- `FromUserType` - generated an implementation of `FromCqlVal` for the struct, trying to parse the CQL value as a UDT. + +Note: the macros above have a default behavior that is different than what `FromRow` and `FromUserTypes` do. + +### New traits + +The new API introduce two analogous traits that, instead of consuming pre-parsed `Vec>`, are given raw, serialized data with full information about its type. This leads to better performance and allows for better type safety. + +The new traits are: + +__`DeserializeRow<'frame>`__ + +```rust +# extern crate scylla; +# use scylla::types::deserialize::row::ColumnIterator; +# use scylla::frame::frame_errors::ParseError; +# use scylla::frame::response::result::ColumnSpec; +pub trait DeserializeRow<'frame> +where + Self: Sized, +{ + fn type_check(specs: &[ColumnSpec]) -> Result<(), ParseError>; + fn deserialize(row: ColumnIterator<'frame>) -> Result; +} +``` + +__`DeserializeCql<'frame>`__ + +```rust +# extern crate scylla; +# use scylla::types::deserialize::row::ColumnIterator; +# use scylla::types::deserialize::FrameSlice; +# use scylla::frame::frame_errors::ParseError; +# use scylla::frame::response::result::ColumnType; +pub trait DeserializeCql<'frame> +where + Self: Sized, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError>; + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result; +} +``` + +The above traits have been implemented for the same set of types as `FromRow` and `FromCqlVal`, respectively. Notably, `DeserializeRow` is implemented for `Row`, and `DeserializeCql` is implemented for `CqlValue`. + +There are also `DeserializeRow` and `DeserializeCql` derive macros, analogous to `FromRow` and `FromUserType`, respectively - but with slightly different defaults (explained later in this doc page). + +## Updating the code to use the new API + +Some of the core types have been updated to use the new traits. Updating the code to use the new API should be straightforward. + +### Basic queries + +Sending queries with the single page API should work similarly as before. The `Session::query`, `Session::execute` and `Session::batch` functions have the same interface as before, the only exception being that they return a new, updated `QueryResult`. + +Consuming rows from a result will require only minimal changes if you are using helper methods of the `QueryResult`. Now, there is no distinction between "typed" and "non-typed" methods; all methods that return rows need to have the type specified. For example, previously there used to be both `rows(self)` and `rows_typed(self)`, now there is only a single `rows>(&self)`. Another thing worth mentioning is that the returned iterator now _borrows_ from the `QueryResult` instead of consuming it. + +Note that the `QueryResult::rows` field is not available anymore. If you used to access it directly, you need to change your code to use the helper methods instead. + +Before: + +```rust +# extern crate scylla; +# use scylla::Legacy08Session; +# use std::error::Error; +# async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { +let iter = session + .query("SELECT name, age FROM my_keyspace.people", &[]) + .await? + .rows_typed::<(String, i32)>()?; +for row in iter { + let (name, age) = row?; + println!("{} has age {}", name, age); +} +# Ok(()) +# } +``` + +After: + +```rust +# extern crate scylla; +# use scylla::Session; +# use std::error::Error; +# async fn check_only_compiles(session: &Session) -> Result<(), Box> { +// 1. Note that the result must be assigned to a variable here, and only then +// an iterator created. +let result = session + .query("SELECT name, age FROM my_keyspace.people", &[]) + .await?; + +// 2. Note that `rows` is used here, not `rows_typed`. +for row in result.rows::<(String, i32)>()? { + let (name, age) = row?; + println!("{} has age {}", name, age); +} +# Ok(()) +# } +``` + +### Iterator queries + +The `Session::query_iter` and `Session::execute_iter` have been adjusted, too. They now return a `RawIterator` (notice it's "Raw" instead of "Row") - an intermediate object which needs to be converted into `TypedRowIterator` first before being actually iterated over. + +This particular example should work without any changes: + +```rust +# extern crate scylla; +# extern crate futures; +# use scylla::Session; +# use std::error::Error; +# use scylla::IntoTypedRows; +# use futures::stream::StreamExt; +# async fn check_only_compiles(session: &Session) -> Result<(), Box> { +let mut rows_stream = session + .query_iter("SELECT name, age FROM my_keyspace.people", &[]) + .await? + .into_typed::<(String, i32)>(); + +while let Some(next_row_res) = rows_stream.next().await { + let (a, b): (String, i32) = next_row_res?; + println!("a, b: {}, {}", a, b); +} +# Ok(()) +# } +``` + +### Procedural macros + +As mentioned in the Introduction section, the driver provides new procedural macros for the `DeserializeRow` and `DeserializeCql` traits that are meant to replace `FromRow` and `FromUserType`, respectively. The new macros are designed to be slightly more type-safe by matching column/UDT field names to rust field names dynamically. This is a different behavior to what the old macros used to do, but the new macros can be configured with `#[attributes]` to simulate the old behavior. + +__`FromRow` vs. `DeserializeRow`__ + +The impl generate by `FromRow` expects columns to be in the same order as the struct fields. The `FromRow` trait does not have information about column names, so it cannot match them with the struct field names. You can use `enforce_order` and `no_field_name_verification` attributes to achieve such behavior via `DeserializeRow` trait. + +__`FromUserType` vs. `DeserializeCql`__ + +The impl generated by `FromUserType` expects UDT fields to be in the same order as the struct fields. Field names should be the same both in the UDT and in the struct. You can use the `enforce_order` attribute to achieve such behavior via the `DeserializeCql` trait. + +### Adjusting custom impls of deserialization traits + +If you have a custom type with a hand-written `impl FromRow` or `impl FromCqlVal`, the best thing to do is to just write a new impl for `DeserializeRow` or `DeserializeCql` manually. Although it's technically possible to implement the new traits by using the existing implementation of the old ones, rolling out a new implementation will avoid performance problems related to the inefficient `CqlValue` representation. + +## Accessing the old API + +Most important types related to deserialization of the old API have been renamed and contain a `Legacy08` prefix in their names: + +- `Session` -> `Legacy08Session` +- `CachingSession` -> `Legacy08CachingSession` +- `RowIterator` -> `Legacy08RowIterator` +- `TypedRowIterator` -> `Legacy08TypedRowIterator` +- `QueryResult` -> `Legacy08QueryResult` + +If you intend to quickly migrate your application by using the old API, you can just import the legacy stuff and alias it as the new one, e.g.: + +```rust +# extern crate scylla; +use scylla::Legacy08Session as Session; +``` + +In order to create the `Legacy08Session` instead of the new `Session`, you need to use `SessionBuilder`'s `build_legacy()` method instead of `build()`: + +```rust +# extern crate scylla; +# use scylla::{Legacy08Session, SessionBuilder}; +# use std::error::Error; +# async fn check_only_compiles() -> Result<(), Box> { +let session: Legacy08Session = SessionBuilder::new() + .known_node("127.0.0.1") + .build_legacy() + .await?; +# Ok(()) +# } +``` + +## Mixing the old and the new API + +It is possible to use different APIs in different parts of the program. The `Session` allows to create a `Legacy08Session` object that has the old API but shares all resources with the session that has the new API (and vice versa - you can create a new API session from the old API session). + +```rust +# extern crate scylla; +# use scylla::{Legacy08Session, Session}; +# use std::error::Error; +# async fn check_only_compiles(new_api_session: &Session) -> Result<(), Box> { +// All of the session objects below will use the same resources: connections, +// metadata, current keyspace, etc. +let old_api_session: Legacy08Session = new_api_session.make_shared_session_with_legacy_api(); +let another_new_api_session: Session = old_api_session.make_shared_session_with_new_api(); +# Ok(()) +# } +``` + +In addition to that, it is possible to convert a `QueryResult` to `Legacy08QueryResult`: + +```rust +# extern crate scylla; +# use scylla::{QueryResult, Legacy08QueryResult}; +# use std::error::Error; +# async fn check_only_compiles(result: QueryResult) -> Result<(), Box> { +let result: QueryResult = result; +let legacy_result: Legacy08QueryResult = result.into_legacy_result()?; +# Ok(()) +# } +``` + +... and `RawIterator` into `Legacy08RowIterator`: + +```rust +# extern crate scylla; +# use scylla::transport::iterator::{RawIterator, Legacy08RowIterator}; +# use std::error::Error; +# async fn check_only_compiles(iter: RawIterator) -> Result<(), Box> { +let iter: RawIterator = iter; +let legacy_result: Legacy08RowIterator = iter.into_legacy(); +# Ok(()) +# } +``` diff --git a/docs/source/queries/paged.md b/docs/source/queries/paged.md index dab3672210..83d33199b5 100644 --- a/docs/source/queries/paged.md +++ b/docs/source/queries/paged.md @@ -113,7 +113,7 @@ use scylla::query::Query; let paged_query = Query::new("SELECT a, b, c FROM ks.t").with_page_size(6); let res1 = session.query(paged_query.clone(), &[]).await?; let res2 = session - .query_paged(paged_query.clone(), &[], res1.paging_state) + .query_paged(paged_query.clone(), &[], res1.paging_state()) .await?; # Ok(()) # } @@ -132,7 +132,7 @@ let paged_prepared = session .await?; let res1 = session.execute(&paged_prepared, &[]).await?; let res2 = session - .execute_paged(&paged_prepared, &[], res1.paging_state) + .execute_paged(&paged_prepared, &[], res1.paging_state()) .await?; # Ok(()) # } diff --git a/docs/source/queries/result.md b/docs/source/queries/result.md index 6350eab9ad..686bf83efe 100644 --- a/docs/source/queries/result.md +++ b/docs/source/queries/result.md @@ -2,62 +2,13 @@ `Session::query` and `Session::execute` return a `QueryResult` with rows represented as `Option>`. -### Basic representation -`Row` is a basic representation of a received row. It can be used by itself, but it's a bit awkward to use: -```rust -# extern crate scylla; -# use scylla::Session; -# use std::error::Error; -# async fn check_only_compiles(session: &Session) -> Result<(), Box> { -if let Some(rows) = session.query("SELECT a from ks.tab", &[]).await?.rows { - for row in rows { - let int_value: i32 = row.columns[0].as_ref().unwrap().as_int().unwrap(); - } -} -# Ok(()) -# } -``` - -### Parsing using `into_typed` -The driver provides a way to parse a row as a tuple of Rust types: -```rust -# extern crate scylla; -# use scylla::Session; -# use std::error::Error; -# async fn check_only_compiles(session: &Session) -> Result<(), Box> { -use scylla::IntoTypedRows; - -// Parse row as a single column containing an int value -if let Some(rows) = session.query("SELECT a from ks.tab", &[]).await?.rows { - for row in rows { - let (int_value,): (i32,) = row.into_typed::<(i32,)>()?; - } -} - -// rows.into_typed() converts a Vec of Rows to an iterator of parsing results -if let Some(rows) = session.query("SELECT a from ks.tab", &[]).await?.rows { - for row in rows.into_typed::<(i32,)>() { - let (int_value,): (i32,) = row?; - } -} - -// Parse row as two columns containing an int and text columns -if let Some(rows) = session.query("SELECT a, b from ks.tab", &[]).await?.rows { - for row in rows.into_typed::<(i32, String)>() { - let (int_value, text_value): (i32, String) = row?; - } -} -# Ok(()) -# } -``` - ## Parsing using convenience methods [`QueryResult`](https://docs.rs/scylla/latest/scylla/transport/query_result/struct.QueryResult.html) provides convenience methods for parsing rows. Here are a few of them: -* `rows_typed::()` - returns the rows parsed as the given type -* `maybe_first_row_typed::` - returns `Option` containing first row from the result -* `first_row_typed::` - same as `maybe_first_row`, but fails without the first row -* `single_row_typed::` - same as `first_row`, but fails when there is more than one row +* `rows::()` - returns the rows parsed as the given type +* `maybe_first_row::` - returns `Option` containing first row from the result +* `first_row::` - same as `maybe_first_row`, but fails without the first row +* `single_row::` - same as `first_row`, but fails when there is more than one row * `result_not_rows()` - ensures that query response was not `rows`, helps avoid bugs @@ -67,11 +18,10 @@ Here are a few of them: # use std::error::Error; # async fn check_only_compiles(session: &Session) -> Result<(), Box> { // Parse row as a single column containing an int value -let rows = session +let result = session .query("SELECT a from ks.tab", &[]) - .await? - .rows_typed::<(i32,)>()?; // Same as .rows()?.into_typed() -for row in rows { + .await?; +for row in result.rows::<(i32,)>()? { let (int_value,): (i32,) = row?; } @@ -79,7 +29,7 @@ for row in rows { let first_int_val: Option<(i32,)> = session .query("SELECT a from ks.tab", &[]) .await? - .maybe_first_row_typed::<(i32,)>()?; + .maybe_first_row::<(i32,)>()?; // no_rows fails when the response is rows session.query("INSERT INTO ks.tab (a) VALUES (0)", &[]).await?.result_not_rows()?; @@ -99,10 +49,9 @@ To properly handle `NULL` values parse column as an `Option<>`: use scylla::IntoTypedRows; // Parse row as two columns containing an int and text which might be null -if let Some(rows) = session.query("SELECT a, b from ks.tab", &[]).await?.rows { - for row in rows.into_typed::<(i32, Option)>() { - let (int_value, str_or_null): (i32, Option) = row?; - } +let result = session.query("SELECT a, b from ks.tab", &[]).await?; +for row in result.rows::<(i32, Option)>()? { + let (int_value, str_or_null): (i32, Option) = row?; } # Ok(()) # } @@ -113,7 +62,7 @@ It is possible to receive row as a struct with fields matching the columns.\ The struct must: * have the same number of fields as the number of queried columns * have field types matching the columns being received -* derive `FromRow` +* derive `DeserializeRow` Field names don't need to match column names. ```rust @@ -122,20 +71,19 @@ Field names don't need to match column names. # use std::error::Error; # async fn check_only_compiles(session: &Session) -> Result<(), Box> { use scylla::IntoTypedRows; -use scylla::macros::FromRow; -use scylla::frame::response::cql_to_rust::FromRow; +use scylla::macros::DeserializeRow; +use scylla::types::deserialize::row::DeserializeRow; -#[derive(FromRow)] +#[derive(DeserializeRow)] struct MyRow { age: i32, - name: Option + name: Option, } // Parse row as two columns containing an int and text which might be null -if let Some(rows) = session.query("SELECT a, b from ks.tab", &[]).await?.rows { - for row in rows.into_typed::() { - let my_row: MyRow = row?; - } +let result = session.query("SELECT a, b from ks.tab", &[]).await?; +for row in result.rows::()? { + let my_row: MyRow = row?; } # Ok(()) # } diff --git a/docs/source/queries/simple.md b/docs/source/queries/simple.md index 25190338dd..c45ac0b126 100644 --- a/docs/source/queries/simple.md +++ b/docs/source/queries/simple.md @@ -69,8 +69,9 @@ Here the first `?` will be filled with `2` and the second with `"Some text"`. See [Query values](values.md) for more information about sending values in queries ### Query result -`Session::query` returns `QueryResult` with rows represented as `Option>`.\ -Each row can be parsed as a tuple of rust types using `into_typed`: +`Session::query` returns `QueryResult`. +The result can then be operated on via helper methods which verify that the result is of appropriate type. +Here, we use the `rows` method to check that the response indeed contains rows with a single `int` column: ```rust # extern crate scylla; # use scylla::Session; @@ -79,13 +80,12 @@ Each row can be parsed as a tuple of rust types using `into_typed`: use scylla::IntoTypedRows; // Query rows from the table and print them -if let Some(rows) = session.query("SELECT a FROM ks.tab", &[]).await?.rows { - // Parse each row as a tuple containing single i32 - for row in rows.into_typed::<(i32,)>() { - let read_row: (i32,) = row?; - println!("Read a value from row: {}", read_row.0); - } +let result = session.query("SELECT a FROM ks.tab", &[]).await?; +let mut iter = result.rows::<(i32,)>()?; +while let Some(read_row) = iter.next().transpose()? { + println!("Read a value from row: {}", read_row.0); } + # Ok(()) # } ``` diff --git a/docs/source/quickstart/example.md b/docs/source/quickstart/example.md index 905f03e8cd..b5e0009063 100644 --- a/docs/source/quickstart/example.md +++ b/docs/source/quickstart/example.md @@ -43,12 +43,10 @@ async fn main() -> Result<(), Box> { .await?; // Query rows from the table and print them - if let Some(rows) = session.query("SELECT a FROM ks.extab", &[]).await?.rows { - // Parse each row as a tuple containing single i32 - for row in rows.into_typed::<(i32,)>() { - let read_row: (i32,) = row?; - println!("Read a value from row: {}", read_row.0); - } + let result = session.query("SELECT a FROM ks.extab", &[]).await?; + let mut iter = result.rows::<(i32,)>()?; + while let Some(read_row) = iter.next().transpose()? { + println!("Read a value from row: {}", read_row.0); } Ok(()) diff --git a/docs/source/tracing/basic.md b/docs/source/tracing/basic.md index 4ee5bc5737..648334bd02 100644 --- a/docs/source/tracing/basic.md +++ b/docs/source/tracing/basic.md @@ -20,7 +20,7 @@ let mut query: Query = Query::new("INSERT INTO ks.tab (a) VALUES(4)"); query.set_tracing(true); let res: QueryResult = session.query(query, &[]).await?; -let tracing_id: Option = res.tracing_id; +let tracing_id: Option = res.tracing_id(); if let Some(id) = tracing_id { // Query tracing info from system_traces.sessions and system_traces.events @@ -52,7 +52,7 @@ let mut prepared: PreparedStatement = session prepared.set_tracing(true); let res: QueryResult = session.execute(&prepared, &[]).await?; -let tracing_id: Option = res.tracing_id; +let tracing_id: Option = res.tracing_id(); if let Some(id) = tracing_id { // Query tracing info from system_traces.sessions and system_traces.events @@ -83,7 +83,7 @@ batch.append_statement("INSERT INTO ks.tab (a) VALUES(4)"); batch.set_tracing(true); let res: QueryResult = session.batch(&batch, ((),)).await?; -let tracing_id: Option = res.tracing_id; +let tracing_id: Option = res.tracing_id(); if let Some(id) = tracing_id { // Query tracing info from system_traces.sessions and system_traces.events diff --git a/docs/source/tracing/paged.md b/docs/source/tracing/paged.md index e69d4f3361..46e503fa22 100644 --- a/docs/source/tracing/paged.md +++ b/docs/source/tracing/paged.md @@ -13,7 +13,6 @@ If tracing is enabled the row iterator will contain a list of tracing ids for al # use std::error::Error; # async fn check_only_compiles(session: &Session) -> Result<(), Box> { use scylla::query::Query; -use scylla::transport::iterator::RowIterator; use scylla::tracing::TracingInfo; use futures::StreamExt; use uuid::Uuid; @@ -23,7 +22,10 @@ let mut query: Query = Query::new("INSERT INTO ks.tab (a) VALUES(4)"); query.set_tracing(true); // Create a paged query iterator and fetch pages -let mut row_iterator: RowIterator = session.query_iter(query, &[]).await?; +let mut row_iterator = session + .query_iter(query, &[]) + .await? + .into_typed::<(i32,)>(); while let Some(_row) = row_iterator.next().await { // Receive rows } @@ -49,7 +51,6 @@ for id in tracing_ids { # use std::error::Error; # async fn check_only_compiles(session: &Session) -> Result<(), Box> { use scylla::prepared_statement::PreparedStatement; -use scylla::transport::iterator::RowIterator; use scylla::tracing::TracingInfo; use futures::StreamExt; use uuid::Uuid; @@ -63,7 +64,10 @@ let mut prepared: PreparedStatement = session prepared.set_tracing(true); // Create a paged query iterator and fetch pages -let mut row_iterator: RowIterator = session.execute_iter(prepared, &[]).await?; +let mut row_iterator = session + .execute_iter(prepared, &[]) + .await? + .into_typed::<(i32,)>(); while let Some(_row) = row_iterator.next().await { // Receive rows } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 3a3d8f6f41..7844b833d7 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -11,6 +11,7 @@ openssl = "0.10.32" rustyline = "9" rustyline-derive = "0.6" scylla = {path = "../scylla", features = ["ssl", "cloud"]} +scylla-cql = {path = "../scylla-cql"} tokio = {version = "1.1.0", features = ["full"]} tracing = "0.1.25" tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } diff --git a/examples/allocations.rs b/examples/allocations.rs index 6cae728b73..2f9d728d72 100644 --- a/examples/allocations.rs +++ b/examples/allocations.rs @@ -1,5 +1,6 @@ use anyhow::Result; -use scylla::{statement::prepared_statement::PreparedStatement, Session, SessionBuilder}; +use scylla::transport::session::Session; +use scylla::{statement::prepared_statement::PreparedStatement, SessionBuilder}; use std::io::Write; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; diff --git a/examples/auth.rs b/examples/auth.rs index 61c563da57..f374ab086b 100644 --- a/examples/auth.rs +++ b/examples/auth.rs @@ -10,7 +10,7 @@ async fn main() -> Result<()> { let session = SessionBuilder::new() .known_node(uri) .user("cassandra", "cassandra") - .build() + .build_legacy() .await .unwrap(); diff --git a/examples/basic.rs b/examples/basic.rs index f5dcec9538..f2ac344591 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,6 +1,6 @@ use anyhow::Result; use scylla::macros::FromRow; -use scylla::transport::session::{IntoTypedRows, Session}; +use scylla::transport::session::Session; use scylla::SessionBuilder; use std::env; @@ -43,11 +43,10 @@ async fn main() -> Result<()> { .await?; // Rows can be parsed as tuples - if let Some(rows) = session.query("SELECT a, b, c FROM ks.t", &[]).await?.rows { - for row in rows.into_typed::<(i32, i32, String)>() { - let (a, b, c) = row?; - println!("a, b, c: {}, {}, {}", a, b, c); - } + let result = session.query("SELECT a, b, c FROM ks.t", &[]).await?; + let mut iter = result.rows::<(i32, i32, String)>()?; + while let Some((a, b, c)) = iter.next().transpose()? { + println!("a, b, c: {}, {}, {}", a, b, c); } // Or as custom structs that derive FromRow @@ -58,24 +57,10 @@ async fn main() -> Result<()> { _c: String, } - if let Some(rows) = session.query("SELECT a, b, c FROM ks.t", &[]).await?.rows { - for row_data in rows.into_typed::() { - let row_data = row_data?; - println!("row_data: {:?}", row_data); - } - } - - // Or simply as untyped rows - if let Some(rows) = session.query("SELECT a, b, c FROM ks.t", &[]).await?.rows { - for row in rows { - let a = row.columns[0].as_ref().unwrap().as_int().unwrap(); - let b = row.columns[1].as_ref().unwrap().as_int().unwrap(); - let c = row.columns[2].as_ref().unwrap().as_text().unwrap(); - println!("a, b, c: {}, {}, {}", a, b, c); - - // Alternatively each row can be parsed individually - // let (a2, b2, c2) = row.into_typed::<(i32, i32, String)>() ?; - } + let result = session.query("SELECT a, b, c FROM ks.t", &[]).await?; + let mut iter = result.rows::<(i32, i32, String)>()?; + while let Some(row_data) = iter.next().transpose()? { + println!("row_data: {:?}", row_data); } let metrics = session.get_metrics(); diff --git a/examples/clone.rs b/examples/clone.rs index 19dcc1c454..f926e47197 100644 --- a/examples/clone.rs +++ b/examples/clone.rs @@ -39,13 +39,12 @@ async fn main() -> Result<()> { .unwrap(); } - if let Some(rows) = sessions[42] + let num_rows = sessions[42] .query("SELECT a, b, c FROM ks.t", &[]) .await? - .rows - { - println!("Read {} rows", rows.len()); - } + .rows_num() + .unwrap(); + println!("Read {} rows", num_rows); Ok(()) } diff --git a/examples/cloud.rs b/examples/cloud.rs index f432d8a4f2..6edbf31b58 100644 --- a/examples/cloud.rs +++ b/examples/cloud.rs @@ -12,7 +12,7 @@ async fn main() -> Result<()> { .unwrap_or("examples/config_data.yaml".to_owned()); let session = CloudSessionBuilder::new(Path::new(&config_path)) .unwrap() - .build() + .build_legacy() .await .unwrap(); diff --git a/examples/compare-tokens.rs b/examples/compare-tokens.rs index 423c0d37bd..f837b72065 100644 --- a/examples/compare-tokens.rs +++ b/examples/compare-tokens.rs @@ -1,8 +1,9 @@ use anyhow::Result; use scylla::frame::value::ValueList; use scylla::transport::partitioner::{Murmur3Partitioner, Partitioner}; +use scylla::transport::session::Session; use scylla::transport::NodeAddr; -use scylla::{load_balancing, Session, SessionBuilder}; +use scylla::{load_balancing, SessionBuilder}; use std::env; #[tokio::main] @@ -47,18 +48,10 @@ async fn main() -> Result<()> { .collect::>() ); - let qt = session - .query(format!("SELECT token(pk) FROM ks.t where pk = {}", pk), &[]) + let (qt,) = session + .query("SELECT token(pk) FROM ks.t where pk = ?", (pk,)) .await? - .rows - .unwrap() - .get(0) - .expect("token query no rows!") - .columns[0] - .as_ref() - .expect("token query null value!") - .as_bigint() - .expect("token wrong type!"); + .single_row()?; assert_eq!(t, qt); println!("token for {}: {}", pk, t); } diff --git a/examples/cql-time-types.rs b/examples/cql-time-types.rs index a532901dc6..1f38aefa2e 100644 --- a/examples/cql-time-types.rs +++ b/examples/cql-time-types.rs @@ -3,9 +3,8 @@ use anyhow::Result; use chrono::{Duration, NaiveDate}; -use scylla::frame::response::result::CqlValue; use scylla::frame::value::{Date, Time, Timestamp}; -use scylla::transport::session::{IntoTypedRows, Session}; +use scylla::transport::session::Session; use scylla::SessionBuilder; use std::env; @@ -36,15 +35,10 @@ async fn main() -> Result<()> { .query("INSERT INTO ks.dates (d) VALUES (?)", (example_date,)) .await?; - if let Some(rows) = session.query("SELECT d from ks.dates", &[]).await?.rows { - for row in rows.into_typed::<(NaiveDate,)>() { - let (read_date,): (NaiveDate,) = match row { - Ok(read_date) => read_date, - Err(_) => continue, // We might read a date that does not fit in NaiveDate, skip it - }; - - println!("Read a date: {:?}", read_date); - } + let result = session.query("SELECT d from ks.dates", &[]).await?; + let mut iter = result.rows::<(NaiveDate,)>()?; + while let Some((read_date,)) = iter.next().transpose()? { + println!("Read a date: {:?}", read_date); } // Dates outside this range must be represented in the raw form - an u32 describing days since -5877641-06-23 @@ -53,15 +47,10 @@ async fn main() -> Result<()> { .query("INSERT INTO ks.dates (d) VALUES (?)", (example_big_date,)) .await?; - if let Some(rows) = session.query("SELECT d from ks.dates", &[]).await?.rows { - for row in rows { - let read_days: u32 = match row.columns[0] { - Some(CqlValue::Date(days)) => days, - _ => panic!("oh no"), - }; - - println!("Read a date as raw days: {}", read_days); - } + let result = session.query("SELECT d from ks.dates", &[]).await?; + let mut iter = result.rows::<(Date,)>()?; + while let Some((read_days,)) = iter.next().transpose()? { + println!("Read a date as raw days: {}", read_days.0); } // Time - nanoseconds since midnight in range 0..=86399999999999 @@ -79,12 +68,10 @@ async fn main() -> Result<()> { .query("INSERT INTO ks.times (t) VALUES (?)", (Time(example_time),)) .await?; - if let Some(rows) = session.query("SELECT t from ks.times", &[]).await?.rows { - for row in rows.into_typed::<(Duration,)>() { - let (read_time,): (Duration,) = row?; - - println!("Read a time: {:?}", read_time); - } + let result = session.query("SELECT t from ks.times", &[]).await?; + let mut iter = result.rows::<(Duration,)>()?; + while let Some((read_time,)) = iter.next().transpose()? { + println!("Read a time: {:?}", read_time); } // Timestamp - milliseconds since unix epoch - 1970-01-01 @@ -105,16 +92,10 @@ async fn main() -> Result<()> { ) .await?; - if let Some(rows) = session - .query("SELECT t from ks.timestamps", &[]) - .await? - .rows - { - for row in rows.into_typed::<(Duration,)>() { - let (read_time,): (Duration,) = row?; - - println!("Read a timestamp: {:?}", read_time); - } + let result = session.query("SELECT t from ks.timestamps", &[]).await?; + let mut iter = result.rows::<(Duration,)>()?; + while let Some((read_time,)) = iter.next().transpose()? { + println!("Read a timestamp: {:?}", read_time); } Ok(()) diff --git a/examples/cqlsh-rs.rs b/examples/cqlsh-rs.rs index 877b4af596..0161fd1b8e 100644 --- a/examples/cqlsh-rs.rs +++ b/examples/cqlsh-rs.rs @@ -3,8 +3,10 @@ use rustyline::completion::{Completer, Pair}; use rustyline::error::ReadlineError; use rustyline::{CompletionType, Config, Context, Editor}; use rustyline_derive::{Helper, Highlighter, Hinter, Validator}; +use scylla::transport::session::Session; use scylla::transport::Compression; -use scylla::{QueryResult, Session, SessionBuilder}; +use scylla::{QueryResult, SessionBuilder}; +use scylla_cql::frame::response::result::Row; use std::env; #[derive(Helper, Highlighter, Validator, Hinter)] @@ -174,11 +176,12 @@ impl Completer for CqlHelper { } fn print_result(result: &QueryResult) { - if result.rows.is_none() { + if !result.is_rows() { println!("OK"); return; } - for row in result.rows.as_ref().unwrap() { + for row in result.rows::().unwrap() { + let row = row.unwrap(); for column in &row.columns { print!("|"); print!( diff --git a/examples/custom_deserialization.rs b/examples/custom_deserialization.rs index 7514c6733c..65c6e4c099 100644 --- a/examples/custom_deserialization.rs +++ b/examples/custom_deserialization.rs @@ -2,7 +2,8 @@ use anyhow::Result; use scylla::cql_to_rust::{FromCqlVal, FromCqlValError}; use scylla::frame::response::result::CqlValue; use scylla::macros::impl_from_cql_value_from_method; -use scylla::{Session, SessionBuilder}; +use scylla::transport::session::Session; +use scylla::SessionBuilder; use std::env; #[tokio::main] @@ -40,6 +41,7 @@ async fn main() -> Result<()> { let (v,) = session .query("SELECT v FROM ks.t WHERE pk = 1", ()) .await? + .into_legacy_result()? .single_row_typed::<(MyType,)>()?; assert_eq!(v, MyType("asdf".to_owned())); @@ -64,6 +66,7 @@ async fn main() -> Result<()> { let (v,) = session .query("SELECT v FROM ks.t WHERE pk = 1", ()) .await? + .into_legacy_result()? .single_row_typed::<(MyOtherType,)>()?; assert_eq!(v, MyOtherType("asdf".to_owned())); diff --git a/examples/get_by_name.rs b/examples/get_by_name.rs index d02af56ced..830e710126 100644 --- a/examples/get_by_name.rs +++ b/examples/get_by_name.rs @@ -1,6 +1,7 @@ use anyhow::{anyhow, Result}; use scylla::transport::session::Session; use scylla::SessionBuilder; +use scylla_cql::frame::response::result::Row; use std::env; #[tokio::main] @@ -44,9 +45,14 @@ async fn main() -> Result<()> { let (value_idx, _) = query_result .get_column_spec("value") .ok_or_else(|| anyhow!("No value column found"))?; + let rows = query_result + .rows::() + .unwrap() + .collect::, _>>() + .unwrap(); println!("ck | value"); println!("---------------------"); - for row in query_result.rows.ok_or_else(|| anyhow!("no rows found"))? { + for row in rows { println!("{:?} | {:?}", row.columns[ck_idx], row.columns[value_idx]); } diff --git a/examples/query_history.rs b/examples/query_history.rs index 7ce4dd60c1..36229f4082 100644 --- a/examples/query_history.rs +++ b/examples/query_history.rs @@ -1,7 +1,6 @@ //! This example shows how to collect history of query execution. use anyhow::Result; -use futures::StreamExt; use scylla::history::{HistoryCollector, StructuredHistory}; use scylla::query::Query; use scylla::transport::session::Session; diff --git a/examples/schema_agreement.rs b/examples/schema_agreement.rs index 96dfb3f9ab..8d48839369 100644 --- a/examples/schema_agreement.rs +++ b/examples/schema_agreement.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use scylla::transport::session::{IntoTypedRows, Session}; +use scylla::transport::session::Session; use scylla::SessionBuilder; use std::env; use std::time::Duration; @@ -63,12 +63,12 @@ async fn main() -> Result<()> { .await?; // Rows can be parsed as tuples - if let Some(rows) = session.query("SELECT a, b, c FROM ks.t", &[]).await?.rows { - for row in rows.into_typed::<(i32, i32, String)>() { - let (a, b, c) = row?; - println!("a, b, c: {}, {}, {}", a, b, c); - } + let result = session.query("SELECT a, b, c FROM ks.t", &[]).await?; + let mut iter = result.rows::<(i32, i32, String)>()?; + while let Some((a, b, c)) = iter.next().transpose()? { + println!("a, b, c: {}, {}, {}", a, b, c); } + println!("Ok."); let schema_version = session.fetch_schema_version().await?; diff --git a/examples/select-paging.rs b/examples/select-paging.rs index ac6e2826aa..d03cede870 100644 --- a/examples/select-paging.rs +++ b/examples/select-paging.rs @@ -44,24 +44,24 @@ async fn main() -> Result<()> { let res1 = session.query(paged_query.clone(), &[]).await?; println!( "Paging state: {:#?} ({} rows)", - res1.paging_state, - res1.rows.unwrap().len() + res1.paging_state(), + res1.rows_num().unwrap(), ); let res2 = session - .query_paged(paged_query.clone(), &[], res1.paging_state) + .query_paged(paged_query.clone(), &[], res1.paging_state()) .await?; println!( "Paging state: {:#?} ({} rows)", - res2.paging_state, - res2.rows.unwrap().len() + res2.paging_state(), + res2.rows_num().unwrap(), ); let res3 = session - .query_paged(paged_query.clone(), &[], res2.paging_state) + .query_paged(paged_query.clone(), &[], res2.paging_state()) .await?; println!( "Paging state: {:#?} ({} rows)", - res3.paging_state, - res3.rows.unwrap().len() + res3.paging_state(), + res3.rows_num().unwrap(), ); let paged_prepared = session @@ -70,24 +70,24 @@ async fn main() -> Result<()> { let res4 = session.execute(&paged_prepared, &[]).await?; println!( "Paging state from the prepared statement execution: {:#?} ({} rows)", - res4.paging_state, - res4.rows.unwrap().len() + res4.paging_state(), + res4.rows_num().unwrap(), ); let res5 = session - .execute_paged(&paged_prepared, &[], res4.paging_state) + .execute_paged(&paged_prepared, &[], res4.paging_state()) .await?; println!( "Paging state from the second prepared statement execution: {:#?} ({} rows)", - res5.paging_state, - res5.rows.unwrap().len() + res5.paging_state(), + res5.rows_num().unwrap(), ); let res6 = session - .execute_paged(&paged_prepared, &[], res5.paging_state) + .execute_paged(&paged_prepared, &[], res5.paging_state()) .await?; println!( "Paging state from the third prepared statement execution: {:#?} ({} rows)", - res6.paging_state, - res6.rows.unwrap().len() + res6.paging_state(), + res6.rows_num().unwrap(), ); println!("Ok."); diff --git a/examples/tls.rs b/examples/tls.rs index c2410673ca..399be7cfbd 100644 --- a/examples/tls.rs +++ b/examples/tls.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use scylla::transport::session::{IntoTypedRows, Session}; +use scylla::transport::session::Session; use scylla::SessionBuilder; use std::env; use std::fs; @@ -80,12 +80,12 @@ async fn main() -> Result<()> { .await?; // Rows can be parsed as tuples - if let Some(rows) = session.query("SELECT a, b, c FROM ks.t", &[]).await?.rows { - for row in rows.into_typed::<(i32, i32, String)>() { - let (a, b, c) = row?; - println!("a, b, c: {}, {}, {}", a, b, c); - } + let result = session.query("SELECT a, b, c FROM ks.t", &[]).await?; + let mut iter = result.rows::<(i32, i32, String)>()?; + while let Some((a, b, c)) = iter.next().transpose()? { + println!("a, b, c: {}, {}, {}", a, b, c); } + println!("Ok."); Ok(()) diff --git a/examples/tower.rs b/examples/tower.rs index 1c3bb2112a..0701e1db86 100644 --- a/examples/tower.rs +++ b/examples/tower.rs @@ -1,3 +1,5 @@ +use scylla::transport::session::Session; +use scylla_cql::frame::response::result::Row; use std::env; use std::future::Future; use std::pin::Pin; @@ -7,7 +9,7 @@ use std::task::Poll; use tower::Service; struct SessionService { - session: Arc, + session: Arc, } // A trivial service implementation for sending parameterless simple string requests to Scylla. @@ -56,14 +58,14 @@ async fn main() -> anyhow::Result<()> { println!( "Tables:\n{}", - resp.rows()? - .into_iter() - .map(|r| format!( + resp.rows::()? + .map(|r| r.map(|r| format!( "\t{}.{}", print_text(&r.columns[0]), print_text(&r.columns[1]) - )) - .collect::>() + ))) + .collect::, _>>() + .unwrap() .join("\n") ); Ok(()) diff --git a/examples/tracing.rs b/examples/tracing.rs index eb8db39f39..55d1775aa9 100644 --- a/examples/tracing.rs +++ b/examples/tracing.rs @@ -2,13 +2,12 @@ // query() prepare() execute() batch() query_iter() and execute_iter() can be traced use anyhow::{anyhow, Result}; -use futures::StreamExt; use scylla::batch::Batch; use scylla::statement::{ prepared_statement::PreparedStatement, query::Query, Consistency, SerialConsistency, }; use scylla::tracing::{GetTracingConfig, TracingInfo}; -use scylla::transport::iterator::RowIterator; +use scylla::transport::iterator::RawIterator; use scylla::QueryResult; use scylla::{Session, SessionBuilder}; use std::env; @@ -41,7 +40,7 @@ async fn main() -> Result<()> { // QueryResult will contain a tracing_id which can be used to query tracing information let query_result: QueryResult = session.query(query.clone(), &[]).await?; let query_tracing_id: Uuid = query_result - .tracing_id + .tracing_id() .ok_or_else(|| anyhow!("Tracing id is None!"))?; // Get tracing information for this query and print it @@ -77,13 +76,13 @@ async fn main() -> Result<()> { prepared.set_tracing(true); let execute_result: QueryResult = session.execute(&prepared, &[]).await?; - println!("Execute tracing id: {:?}", execute_result.tracing_id); + println!("Execute tracing id: {:?}", execute_result.tracing_id()); // PAGED QUERY_ITER EXECUTE_ITER // It's also possible to trace paged queries like query_iter or execute_iter // After iterating through all rows iterator.get_tracing_ids() will give tracing ids // for all page queries - let mut row_iterator: RowIterator = session.query_iter(query, &[]).await?; + let mut row_iterator: RawIterator = session.query_iter(query, &[]).await?; while let Some(_row) = row_iterator.next().await { // Receive rows @@ -103,7 +102,7 @@ async fn main() -> Result<()> { // Run the batch and print its tracing_id let batch_result: QueryResult = session.batch(&batch, ((),)).await?; - println!("Batch tracing id: {:?}\n", batch_result.tracing_id); + println!("Batch tracing id: {:?}\n", batch_result.tracing_id()); // CUSTOM // GetTracingConfig allows to specify a custom settings for querying tracing info diff --git a/examples/user-defined-type.rs b/examples/user-defined-type.rs index bbbc0466ec..80db3027af 100644 --- a/examples/user-defined-type.rs +++ b/examples/user-defined-type.rs @@ -1,6 +1,6 @@ use anyhow::Result; -use scylla::macros::{FromUserType, IntoUserType}; -use scylla::{IntoTypedRows, Session, SessionBuilder}; +use scylla::macros::{DeserializeCql, IntoUserType}; +use scylla::{Session, SessionBuilder}; use std::env; #[tokio::main] @@ -29,7 +29,7 @@ async fn main() -> Result<()> { // Define custom struct that matches User Defined Type created earlier // wrapping field in Option will gracefully handle null field values - #[derive(Debug, IntoUserType, FromUserType)] + #[derive(Debug, IntoUserType, DeserializeCql)] struct MyType { int_val: i32, text_val: Option, @@ -46,11 +46,10 @@ async fn main() -> Result<()> { .await?; // And read like any normal value - if let Some(rows) = session.query("SELECT my FROM ks.udt_tab", &[]).await?.rows { - for row in rows.into_typed::<(MyType,)>() { - let (my_val,) = row?; - println!("{:?}", my_val) - } + let result = session.query("SELECT my FROM ks.udt_tab", &[]).await?; + let mut iter = result.rows::<(MyType,)>()?; + while let Some((my_val,)) = iter.next().transpose()? { + println!("{:?}", my_val); } println!("Ok."); diff --git a/examples/value_list.rs b/examples/value_list.rs index 8e64c6d5da..75805cc794 100644 --- a/examples/value_list.rs +++ b/examples/value_list.rs @@ -52,10 +52,14 @@ async fn main() { .await .unwrap(); - let q = session + let rows = session .query("SELECT * FROM ks.my_type", &[]) .await + .unwrap() + .rows::<(i32, String)>() + .unwrap() + .collect::, _>>() .unwrap(); - println!("Q: {:?}", q.rows); + println!("Q: {:?}", rows); } diff --git a/scylla-cql/Cargo.toml b/scylla-cql/Cargo.toml index ca19cfd4b7..b26fc96b0c 100644 --- a/scylla-cql/Cargo.toml +++ b/scylla-cql/Cargo.toml @@ -28,6 +28,7 @@ serde = { version = "1.0", optional = true } [dev-dependencies] criterion = "0.3" +uuid = { version = "1.0", features = ["v4"] } [[bench]] name = "benchmark" diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs index c8c4ec104d..8e0d14c542 100644 --- a/scylla-cql/src/frame/response/mod.rs +++ b/scylla-cql/src/frame/response/mod.rs @@ -41,8 +41,9 @@ impl Response { pub fn deserialize( features: &ProtocolFeatures, opcode: ResponseOpcode, - buf: &mut &[u8], + buf_bytes: bytes::Bytes, ) -> Result { + let buf = &mut &*buf_bytes; let response = match opcode { ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?), ResponseOpcode::Ready => Response::Ready, @@ -50,7 +51,7 @@ impl Response { Response::Authenticate(authenticate::Authenticate::deserialize(buf)?) } ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?), - ResponseOpcode::Result => Response::Result(result::deserialize(buf)?), + ResponseOpcode::Result => Response::Result(result::deserialize(buf_bytes)?), ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?), ResponseOpcode::AuthChallenge => { Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 80e4aa9bc4..15e668149f 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -3,6 +3,8 @@ use crate::frame::response::event::SchemaChangeEvent; use crate::frame::types::vint_decode; use crate::frame::value::{Counter, CqlDuration}; use crate::frame::{frame_errors::ParseError, types}; +use crate::types::deserialize::row::{ColumnIterator, DeserializeRow}; +use crate::types::deserialize::{FrameSlice, RowIterator, TypedRowIterator}; use bigdecimal::BigDecimal; use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, Bytes}; @@ -373,6 +375,146 @@ impl Row { } } +/// Rows response, in partially serialized form. +// TODO: We could provide ResultMetadata in a similar, lazily +// deserialized form - now it can be a source of allocations +#[derive(Debug, Default)] +pub struct RawRows { + metadata: ResultMetadata, + rows_count: usize, + raw_rows: Bytes, +} + +impl RawRows { + /// Returns the metadata associated with this response (paging state + /// and column specifications). + #[inline] + pub fn metadata(&self) -> &ResultMetadata { + &self.metadata + } + + /// Consumes the `RawRows` and returns metadata associated with the + /// response. + #[inline] + pub fn into_metadata(self) -> ResultMetadata { + self.metadata + } + + /// Returns the number of rows that these `RawRows` contain. + #[inline] + pub fn rows_count(&self) -> usize { + self.rows_count + } + + /// Returns the serialized size of the `RawRows`. + #[inline] + pub fn rows_size(&self) -> usize { + self.raw_rows.len() + } + + /// Creates a typed iterator over the rows that lazily deserializes + /// rows in the result. + /// + /// Returns Err if the schema of returned result doesn't match R. + #[inline] + pub fn rows_iter<'r, R: DeserializeRow<'r>>( + &'r self, + ) -> StdResult, ParseError> { + let slice = FrameSlice::new(&self.raw_rows); + let raw = RowIterator::new(self.rows_count, &self.metadata.col_specs, slice); + TypedRowIterator::new(raw) + } + + /// Converts the `RawRows` into `Rows` - a legacy, inefficient representation. + /// + /// Provided only to make migration to the new deserialization API + /// more convenient - this function will be deprecated and removed + /// in future releases. + pub fn into_legacy_rows(self) -> StdResult { + let rows = self.rows_iter::()?.collect::>()?; + Ok(Rows { + metadata: self.metadata, + rows_count: self.rows_count, + rows, + serialized_size: self.raw_rows.len(), + }) + } +} + +// Technically not an iterator because it returns items that borrow from it, +// and the std Iterator interface does not allow for that. +// TODO: Move to row.rs +/// A _lending_ iterator over serialized rows. +/// +/// This type is similar to `RowIterator`, but keeps ownership of the serialized +/// result. Because it returns `ColumnIterator`s that need to borrow from it, +/// it does not implement the `Iterator` trait (there is no type in the standard +/// library to represent this concept yet). +#[derive(Debug)] +pub struct RawRowsLendingIterator { + metadata: ResultMetadata, + remaining: usize, + at: usize, + raw_rows: Bytes, +} + +impl RawRowsLendingIterator { + /// Creates a new `RawRowsLendingIterator`, consuming given `RawRows`. + #[inline] + pub fn new(raw_rows: RawRows) -> Self { + Self { + metadata: raw_rows.metadata, + remaining: raw_rows.rows_count, + at: 0, + raw_rows: raw_rows.raw_rows, + } + } + + /// Returns a `ColumnIterator` that represents the next row. + /// + /// Note: the `ColumnIterator` borrows from the `RawRowsLendingIterator`. + /// The column iterator must be consumed before the rows iterator can + /// continue. + #[inline] + #[allow(clippy::should_implement_trait)] // https://github.com/rust-lang/rust-clippy/issues/5004 + pub fn next(&mut self) -> Option> { + self.remaining = self.remaining.checked_sub(1)?; + + let mut mem = &self.raw_rows[self.at..]; + + // Skip the row here, manually + for _ in 0..self.metadata.col_specs.len() { + if let Err(err) = types::read_bytes_opt(&mut mem) { + return Some(Err(err)); + } + } + + let slice = FrameSlice::new_subslice(&self.raw_rows[self.at..], &self.raw_rows); + let iter = ColumnIterator::new(&self.metadata.col_specs, slice); + self.at = self.raw_rows.len() - mem.len(); + Some(Ok(iter)) + } + + #[inline] + pub fn size_hint(&self) -> (usize, Option) { + (0, Some(self.remaining)) + } + + /// Returns the metadata associated with the response (paging state and + /// column specifications). + #[inline] + pub fn metadata(&self) -> &ResultMetadata { + &self.metadata + } + + /// Returns the remaining number of rows that this iterator is expected + /// to produce. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.remaining + } +} + #[derive(Debug)] pub struct Rows { pub metadata: ResultMetadata, @@ -385,7 +527,7 @@ pub struct Rows { #[derive(Debug)] pub enum Result { Void, - Rows(Rows), + Rows(RawRows), SetKeyspace(SetKeyspace), Prepared(Prepared), SchemaChange(SchemaChange), @@ -835,11 +977,10 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult StdResult { +fn deser_rows(buf_bytes: Bytes) -> StdResult { + let buf = &mut &*buf_bytes; let metadata = deser_result_metadata(buf)?; - let original_size = buf.len(); - // TODO: the protocol allows an optimization (which must be explicitly requested on query by // the driver) where the column metadata is not sent with the result. // Implement this optimization. We'll then need to take the column types by a parameter. @@ -848,24 +989,10 @@ fn deser_rows(buf: &mut &[u8]) -> StdResult { let rows_count: usize = types::read_int(buf)?.try_into()?; - let mut rows = Vec::with_capacity(rows_count); - for _ in 0..rows_count { - let mut columns = Vec::with_capacity(metadata.col_count); - for i in 0..metadata.col_count { - let v = if let Some(mut b) = types::read_bytes_opt(buf)? { - Some(deser_cql_value(&metadata.col_specs[i].typ, &mut b)?) - } else { - None - }; - columns.push(v); - } - rows.push(Row { columns }); - } - Ok(Rows { + Ok(RawRows { metadata, rows_count, - rows, - serialized_size: original_size - buf.len(), + raw_rows: buf_bytes.slice_ref(buf), }) } @@ -895,11 +1022,12 @@ fn deser_schema_change(buf: &mut &[u8]) -> StdResult { }) } -pub fn deserialize(buf: &mut &[u8]) -> StdResult { +pub fn deserialize(buf_bytes: Bytes) -> StdResult { + let buf = &mut &*buf_bytes; use self::Result::*; Ok(match types::read_int(buf)? { 0x0001 => Void, - 0x0002 => Rows(deser_rows(buf)?), + 0x0002 => Rows(deser_rows(buf_bytes.slice_ref(buf))?), 0x0003 => SetKeyspace(deser_set_keyspace(buf)?), 0x0004 => Prepared(deser_prepared(buf)?), 0x0005 => SchemaChange(deser_schema_change(buf)?), diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index c7daff52f0..572f34e0fb 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -227,11 +227,14 @@ pub fn write_bytes(v: &[u8], buf: &mut impl BufMut) -> Result<(), ParseError> { Ok(()) } -pub fn write_bytes_opt(v: Option<&Vec>, buf: &mut impl BufMut) -> Result<(), ParseError> { +pub fn write_bytes_opt( + v: Option>, + buf: &mut impl BufMut, +) -> Result<(), ParseError> { match v { Some(bytes) => { - write_int_length(bytes.len(), buf)?; - buf.put_slice(bytes); + write_int_length(bytes.as_ref().len(), buf)?; + buf.put_slice(bytes.as_ref()); } None => write_int(-1, buf), } diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 47b58b4f4e..fa18e9f131 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -3,6 +3,8 @@ pub mod frame; #[macro_use] pub mod macros; +pub mod types; + pub use crate::frame::response::cql_to_rust; pub use crate::frame::response::cql_to_rust::FromRow; @@ -10,6 +12,7 @@ pub use crate::frame::types::Consistency; #[doc(hidden)] pub mod _macro_internal { + pub use crate::frame; pub use crate::frame::response::cql_to_rust::{ FromCqlVal, FromCqlValError, FromRow, FromRowError, }; @@ -18,4 +21,5 @@ pub mod _macro_internal { SerializedResult, SerializedValues, Value, ValueList, ValueTooBig, }; pub use crate::macros::*; + pub use crate::types; } diff --git a/scylla-cql/src/macros.rs b/scylla-cql/src/macros.rs index 8d60312145..3e0f4250f3 100644 --- a/scylla-cql/src/macros.rs +++ b/scylla-cql/src/macros.rs @@ -13,6 +13,167 @@ pub use scylla_macros::IntoUserType; /// #[derive(ValueList)] allows to pass struct as a list of values for a query pub use scylla_macros::ValueList; +/// Derive macro for the `DeserializeCql` trait that generates an implementation +/// which deserializes a User Defined Type with the same layout as the Rust +/// struct. +/// +/// At the moment, only structs with named fields are supported. +/// +/// This macro properly supports structs with lifetimes, meaning that you can +/// deserialize UDTs with fields that borrow memory from the serialized response. +/// +/// # Example +/// +/// A UDT defined like this: +/// +/// ```notrust +/// CREATE TYPE ks.my_udt (a i32, b text, c blob); +/// ``` +/// +/// ...can be deserialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::DeserializeCql; +/// #[derive(DeserializeCql)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyUdt<'a> { +/// a: i32, +/// b: Option, +/// c: &'a [u8], +/// } +/// ``` +/// +/// # Attributes +/// +/// The macro supports a number of attributes that customize the generated +/// implementation. Many of the attributes were inspired by procedural macros +/// from `serde` and try to follow the same naming conventions. +/// +/// ## Struct attributes +/// +/// `#[scylla(crate = "crate_name")]` +/// +/// Specify a path to the `scylla` or `scylla-cql` crate to use from the +/// generated code. This attribute should be used if the crate or its API +/// is imported/re-exported under a different name. +/// +/// `#[scylla(enforce_order)]` +/// +/// By default, the generated deserialization code will be insensitive +/// to the UDT field order - when processing a field, it will look it up +/// in the Rust struct with the corresponding field and set it. However, +/// if the UDT field order is known to be the same both in the UDT +/// and the Rust struct, then the `enforce_order` annotation can be used +/// so that a more efficient implementation that does not perform lookups +/// is be generated. The UDT field names will still be checked during the +/// type check phase. +/// +/// #[(scylla(no_field_name_verification))] +/// +/// This attribute only works when used with `enforce_order`. +/// +/// If set, the generated implementation will not verify the UDF field names at +/// all. Because it only works with `enforce_order`, it will deserialize first +/// UDF field into the first struct field, second UDF field into the second +/// struct field and so on. It will still still verify that the UDF field types +/// and struct field types match. +/// +/// ## Field attributes +/// +/// `#[scylla(skip)]` +/// +/// The field will be completely ignored during deserialization and will +/// be initialized with `Default::default()`. +/// +/// `#[scylla(default_when_missing)]` +/// +/// If the UDT definition does not contain this field, it will be initialized +/// with `Default::default()`. __This attribute has no effect in `enforce_order` +/// mode.__ +/// +/// `#[scylla(rename = "field_name")` +/// +/// By default, the generated implementation will try to match the Rust field +/// to a UDT field with the same name. This attribute allows to match to a +/// UDT field with provided name. +pub use scylla_macros::DeserializeCql; + +/// Derive macro for the `DeserializeRow` trait that generates an implementation +/// which deserializes a row with a similar layout to the Rust struct. +/// +/// At the moment, only structs with named fields are supported. +/// +/// This macro properly supports structs with lifetimes, meaning that you can +/// deserialize columns that borrow memory from the serialized response. +/// +/// # Example +/// +/// A table defined like this: +/// +/// ```notrust +/// CREATE TABLE ks.my_table (a PRIMARY KEY, b text, c blob); +/// ``` +/// +/// ...can be deserialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::DeserializeRow; +/// #[derive(DeserializeRow)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyUdt<'a> { +/// a: i32, +/// b: Option, +/// c: &'a [u8], +/// } +/// ``` +/// +/// # Attributes +/// +/// The macro supports a number of attributes that customize the generated +/// implementation. Many of the attributes were inspired by procedural macros +/// from `serde` and try to follow the same naming conventions. +/// +/// ## Struct attributes +/// +/// `#[scylla(crate = "crate_name")]` +/// +/// Specify a path to the `scylla` or `scylla-cql` crate to use from the +/// generated code. This attribute should be used if the crate or its API +/// is imported/re-exported under a different name. +/// +/// `#[scylla(enforce_order)]` +/// +/// By default, the generated deserialization code will be insensitive +/// to the column order - when processing a column, the corresponding Rust field +/// will be looked up and the column will be deserialized based on its type. +/// However, if the column order and the Rust field order is known to be the +/// same, then the `enforce_order` annotation can be used so that a more +/// efficient implementation that does not perform lookups is be generated. +/// The generated code will still check that the column and field names match. +/// +/// #[(scylla(no_field_name_verification))] +/// +/// This attribute only works when used with `enforce_order`. +/// +/// If set, the generated implementation will not verify the column names at +/// all. Because it only works with `enforce_order`, it will deserialize first +/// column into the first field, second column into the second field and so on. +/// It will still still verify that the column types and field types match. +/// +/// ## Field attributes +/// +/// `#[scylla(skip)]` +/// +/// The field will be completely ignored during deserialization and will +/// be initialized with `Default::default()`. +/// +/// `#[scylla(rename = "field_name")` +/// +/// By default, the generated implementation will try to match the Rust field +/// to a column with the same name. This attribute allows to match to a column +/// with provided name. +pub use scylla_macros::DeserializeRow; + // Reexports for derive(IntoUserType) pub use bytes::{BufMut, Bytes, BytesMut}; diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs new file mode 100644 index 0000000000..2e53f6b28e --- /dev/null +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -0,0 +1,515 @@ +//! Framework for deserialization of data returned by database queries. +//! +//! Deserialization is based on two traits: +//! +//! - A type that implements `DeserializeCql<'frame>` can be deserialized +//! from a single _CQL value_ - i.e. an element of a row in the query result, +//! - A type that implements `DeserializeRow<'frame>` can be deserialized +//! from a single _row_ of a query result. +//! +//! Those traits are quite similar to each other, both in the idea behind them +//! and the interface that they expose. +//! +//! # `type_check` and `deserialize` +//! +//! The deserialization process is divided into two parts: type checking and +//! actual deserialization, represented by `DeserializeCql`/`DeserializeRow`'s +//! methods called `type_check` and `deserialize`. +//! +//! The `deserialize` method can assume that `type_check` was called before, so +//! it doesn't have to verify the type again. This can be a performance gain +//! when deserializing query results with multiple rows: as each row in a result +//! has the same type, it is only necessary to call `type_check` once for the +//! whole result and then `deserialize` for each row. +//! +//! Note that `deserialize` is not an `unsafe` method - although you can be +//! sure that the driver will call `type_check` before `deserialize`, you +//! shouldn't do unsafe things based on this assumption. +//! +//! # Data ownership +//! +//! Some CQL types can be easily consumed while still partially serialized. +//! For example, types like `blob` or `text` can be just represented with +//! `&[u8]` and `&str` that just point to a part of the serialized response. +//! This is more efficient than using `Vec` or `String` because it avoids +//! an allocation and a copy, however it is less convenient because those types +//! are bound with a lifetime. +//! +//! The framework supports types that refer to the serialized response's memory +//! in three different ways: +//! +//! ## Owned types +//! +//! Some types don't borrow anything and fully own their data, e.g. `i32` or +//! `String`. They aren't constrained by any lifetime and should implement +//! the respective trait for _all_ lifetimes, i.e.: +//! +//! ```rust +//! # use scylla_cql::types::deserialize::{value::DeserializeCql, FrameSlice}; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::frame::frame_errors::ParseError; +//! struct MyVec(Vec); +//! impl<'frame> DeserializeCql<'frame> for MyVec { +//! fn type_check(typ: &ColumnType) -> Result<(), ParseError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(ParseError::BadIncomingData("Expected bytes".to_string())) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! ParseError::BadIncomingData("Expected non-null value".to_string()) +//! }) +//! .map(|v| Self(v.as_slice().to_vec())) +//! } +//! } +//! ``` +//! +//! ## Borrowing types +//! +//! Some types do not fully contain their data but rather will point to some +//! bytes in the serialized response, e.g. `&str` or `&[u8]`. Those types will +//! usually contain a lifetime in their definition. In order to properly +//! implement `DeserializeCql` or `DeserializeRow` for such a type, the `impl` +//! should still have a generic lifetime parameter, but the lifetimes from the +//! type definition should be constrained with the generic lifetime parameter. +//! For example: +//! +//! ```rust +//! # use scylla_cql::types::deserialize::{value::DeserializeCql, FrameSlice}; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::frame::frame_errors::ParseError; +//! struct MySlice<'a>(&'a [u8]); +//! impl<'a, 'frame> DeserializeCql<'frame> for MySlice<'a> +//! where +//! 'frame: 'a, +//! { +//! fn type_check(typ: &ColumnType) -> Result<(), ParseError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(ParseError::BadIncomingData("Expected bytes".to_string())) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! ParseError::BadIncomingData("Expected non-null value".to_string()) +//! }) +//! .map(|v| Self(v.as_slice())) +//! } +//! } +//! ``` +//! +//! ## Reference-counted types (`DeserializeCql` only) +//! +//! Internally, the driver uses the `bytes::Bytes` type to keep the contents +//! of the serialized response. It supports creating derived `Bytes` objects +//! which point to a subslice but keep the whole, original `Bytes` object alive. +//! +//! During deserialization, a type can obtain a `Bytes` subslice that points +//! to the serialized value. This approach combines advantages of the previous +//! two approaches - creating a derived `Bytes` object can be cheaper than +//! allocation and a copy (it supports `Arc`-like semantics) and the `Bytes` +//! type is not constrained by a lifetime. However, you should be aware that +//! the subslice will keep the whole `Bytes` object that holds the frame alive. +//! It is not recommended to use this approach for long-living objects because +//! it can introduce space leaks. +//! +//! Example: +//! +//! ```rust +//! # use scylla_cql::types::deserialize::{value::DeserializeCql, FrameSlice}; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use bytes::Bytes; +//! struct MyBytes(Bytes); +//! impl<'frame> DeserializeCql<'frame> for MyBytes { +//! fn type_check(typ: &ColumnType) -> Result<(), ParseError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(ParseError::BadIncomingData("Expected bytes".to_string())) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! ParseError::BadIncomingData("Expected non-null value".to_string()) +//! }) +//! .map(|v| Self(v.to_bytes())) +//! } +//! } +//! ``` + +pub mod row; +pub mod value; + +use std::marker::PhantomData; + +use bytes::Bytes; + +use crate::frame::frame_errors::ParseError; +use crate::frame::response::result::ColumnSpec; +use crate::frame::types; + +use self::row::{ColumnIterator, DeserializeRow}; + +/// A reference to a part of the frame. +#[derive(Clone, Copy, Debug)] +pub struct FrameSlice<'frame> { + // The actual subslice represented by this FrameSlice. + frame_subslice: &'frame [u8], + + // Reference to the original Bytes object that this FrameSlice is derived + // from. It is used to convert the `mem` slice into a fully blown Bytes + // object via Bytes::slice_ref method. + original_frame: &'frame Bytes, +} + +static EMPTY_BYTES: Bytes = Bytes::new(); + +impl<'frame> FrameSlice<'frame> { + /// Creates a new FrameSlice from a reference of a Bytes object. + /// + /// This method is exposed to allow writing deserialization tests + /// for custom types. + #[inline] + pub fn new(frame: &'frame Bytes) -> Self { + Self { + frame_subslice: frame, + original_frame: frame, + } + } + + /// Creates a new FrameSlice that refers to a subslice of a given Bytes object. + #[inline] + pub fn new_subslice(mem: &'frame [u8], frame: &'frame Bytes) -> Self { + Self { + frame_subslice: mem, + original_frame: frame, + } + } + + /// Creates an empty FrameSlice. + #[inline] + pub fn new_empty() -> Self { + Self { + frame_subslice: &EMPTY_BYTES, + original_frame: &EMPTY_BYTES, + } + } + + /// Returns a reference to the slice. + #[inline] + pub fn as_slice(&self) -> &'frame [u8] { + self.frame_subslice + } + + /// Returns `true` if the slice has length of 0. + #[inline] + pub fn is_empty(&self) -> bool { + self.frame_subslice.is_empty() + } + + /// Returns a reference to the Bytes object which encompasses the slice. + /// + /// The Bytes object will usually be larger than the slice returned by + /// [FrameSlice::as_slice]. If you wish to obtain a new Bytes object that + /// points only to the subslice represented by the FrameSlice object, + /// see [FrameSlice::to_bytes]. + #[inline] + pub fn as_bytes_ref(&self) -> &'frame Bytes { + self.original_frame + } + + /// Returns a new Bytes object which is a subslice of the original slice + /// object. + #[inline] + pub fn to_bytes(&self) -> Bytes { + self.original_frame.slice_ref(self.frame_subslice) + } + + /// Reads and consumes a `[bytes]` item from the beginning of the frame. + /// + /// If the operation fails then the slice remains unchanged. + #[inline] + fn read_cql_bytes(&mut self) -> Result>, ParseError> { + match types::read_bytes_opt(&mut self.frame_subslice) { + Ok(Some(slice)) => Ok(Some(Self::new_subslice(slice, self.original_frame))), + Ok(None) => Ok(None), + Err(err) => Err(err), + } + } +} + +/// Iterates over the whole result, returning rows. +pub struct RowIterator<'frame> { + specs: &'frame [ColumnSpec], + remaining: usize, + slice: FrameSlice<'frame>, +} + +impl<'frame> RowIterator<'frame> { + /// Creates a new iterator over rows from a serialized response. + /// + /// - `remaining` - number of the remaining rows in the serialized response, + /// - `specs` - information about columns of the serialized response, + /// - `slice` - a `FrameSlice` that points to the serialized rows data. + #[inline] + pub fn new(remaining: usize, specs: &'frame [ColumnSpec], slice: FrameSlice<'frame>) -> Self { + Self { + specs, + remaining, + slice, + } + } + + /// Returns information about the columns of rows that are iterated over. + #[inline] + pub fn specs(&self) -> &'frame [ColumnSpec] { + self.specs + } + + /// Returns the remaining number of rows that this iterator is supposed + /// to return. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.remaining + } +} + +impl<'frame> Iterator for RowIterator<'frame> { + type Item = Result, ParseError>; + + #[inline] + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + + let iter = ColumnIterator::new(self.specs, self.slice); + + // Skip the row here, manually + for _ in 0..self.specs.len() { + if let Err(err) = self.slice.read_cql_bytes() { + return Some(Err(err)); + } + } + + Some(Ok(iter)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.remaining)) + } +} + +/// A typed version of `RowIterator` which deserializes the rows before +/// returning them. +pub struct TypedRowIterator<'frame, R> { + inner: RowIterator<'frame>, + _phantom: PhantomData, +} + +impl<'frame, R> TypedRowIterator<'frame, R> +where + R: DeserializeRow<'frame>, +{ + /// Creates a new `TypedRowIterator` from given `RowIterator`. + /// + /// Calls `R::type_check` and fails if the type check fails. + #[inline] + pub fn new(raw: RowIterator<'frame>) -> Result { + R::type_check(raw.specs())?; + Ok(Self { + inner: raw, + _phantom: PhantomData, + }) + } + + /// Returns information about the columns of rows that are iterated over. + #[inline] + pub fn specs(&self) -> &'frame [ColumnSpec] { + self.inner.specs() + } + + /// Returns the remaining number of rows that this iterator is supposed + /// to return. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.inner.rows_remaining() + } +} + +impl<'frame, R> Iterator for TypedRowIterator<'frame, R> +where + R: DeserializeRow<'frame>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + let raw = match self.inner.next() { + Some(Ok(raw)) => raw, + Some(Err(err)) => return Some(Err(err)), + None => return None, + }; + + Some(R::deserialize(raw)) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[cfg(test)] +mod tests { + use crate::frame::{ + response::result::{ColumnType, TableSpec}, + types, + }; + + use super::*; + + use bytes::{Bytes, BytesMut}; + + static CELL1: &[u8] = &[1, 2, 3]; + static CELL2: &[u8] = &[4, 5, 6, 7]; + + pub(super) fn serialize_cells( + cells: impl IntoIterator>>, + ) -> Bytes { + let mut bytes = BytesMut::new(); + for cell in cells { + types::write_bytes_opt(cell, &mut bytes).unwrap(); + } + bytes.freeze() + } + + fn spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + name: name.to_owned(), + typ, + table_spec: TableSpec { + ks_name: "ks".to_owned(), + table_name: "tbl".to_owned(), + }, + } + } + + #[test] + fn test_cql_bytes_consumption() { + let frame = serialize_cells([Some(CELL1), None, Some(CELL2)]); + let mut slice = FrameSlice::new(&frame); + assert!(!slice.is_empty()); + + assert_eq!( + slice.read_cql_bytes().unwrap().map(|s| s.as_slice()), + Some(CELL1) + ); + assert!(!slice.is_empty()); + assert!(slice.read_cql_bytes().unwrap().is_none()); + assert!(!slice.is_empty()); + assert_eq!( + slice.read_cql_bytes().unwrap().map(|s| s.as_slice()), + Some(CELL2) + ); + assert!(slice.is_empty()); + slice.read_cql_bytes().unwrap_err(); + assert!(slice.is_empty()); + } + + #[test] + fn test_cql_bytes_owned() { + let frame = serialize_cells([Some(CELL1), Some(CELL2)]); + let mut slice = FrameSlice::new(&frame); + + let subslice1 = slice.read_cql_bytes().unwrap().unwrap(); + let subslice2 = slice.read_cql_bytes().unwrap().unwrap(); + + assert_eq!(subslice1.as_slice(), CELL1); + assert_eq!(subslice2.as_slice(), CELL2); + + assert_eq!( + subslice1.as_bytes_ref() as *const Bytes, + &frame as *const Bytes + ); + assert_eq!( + subslice2.as_bytes_ref() as *const Bytes, + &frame as *const Bytes + ); + + let subslice1_bytes = subslice1.to_bytes(); + let subslice2_bytes = subslice2.to_bytes(); + + assert_eq!(subslice1.as_slice(), subslice1_bytes.as_ref()); + assert_eq!(subslice2.as_slice(), subslice2_bytes.as_ref()); + } + + #[test] + fn test_row_iterator_basic_parse() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2), Some(CELL2), Some(CELL1)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let mut iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + + let mut row1 = iter.next().unwrap().unwrap(); + let c11 = row1.next().unwrap().unwrap(); + assert_eq!(c11.slice.unwrap().as_slice(), CELL1); + let c12 = row1.next().unwrap().unwrap(); + assert_eq!(c12.slice.unwrap().as_slice(), CELL2); + assert!(row1.next().is_none()); + + let mut row2 = iter.next().unwrap().unwrap(); + let c21 = row2.next().unwrap().unwrap(); + assert_eq!(c21.slice.unwrap().as_slice(), CELL2); + let c22 = row2.next().unwrap().unwrap(); + assert_eq!(c22.slice.unwrap().as_slice(), CELL1); + assert!(row2.next().is_none()); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_row_iterator_too_few_rows() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let mut iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + + iter.next().unwrap().unwrap(); + assert!(iter.next().unwrap().is_err()); + } + + #[test] + fn test_typed_row_iterator_basic_parse() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2), Some(CELL2), Some(CELL1)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + let mut iter = TypedRowIterator::<'_, (&[u8], Vec)>::new(iter).unwrap(); + + let (c11, c12) = iter.next().unwrap().unwrap(); + assert_eq!(c11, CELL1); + assert_eq!(c12, CELL2); + + let (c21, c22) = iter.next().unwrap().unwrap(); + assert_eq!(c21, CELL2); + assert_eq!(c22, CELL1); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_typed_row_iterator_wrong_type() { + let raw_data = Bytes::new(); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let iter = RowIterator::new(0, &specs, FrameSlice::new(&raw_data)); + assert!(TypedRowIterator::<'_, (i32, i64)>::new(iter).is_err()); + } +} diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs new file mode 100644 index 0000000000..c8f0a8df10 --- /dev/null +++ b/scylla-cql/src/types/deserialize/row.rs @@ -0,0 +1,414 @@ +//! Provides types for dealing with row deserialization. + +use super::value::DeserializeCql; +use super::FrameSlice; +use crate::frame::frame_errors::ParseError; +use crate::frame::response::result::CqlValue; +use crate::frame::response::result::{ColumnSpec, Row}; + +/// Represents a raw, unparsed column value. +#[non_exhaustive] +pub struct RawColumn<'frame> { + pub spec: &'frame ColumnSpec, + pub slice: Option>, +} + +/// Iterates over columns of a single row. +#[derive(Clone, Debug)] +pub struct ColumnIterator<'frame> { + specs: std::slice::Iter<'frame, ColumnSpec>, + slice: FrameSlice<'frame>, +} + +impl<'frame> ColumnIterator<'frame> { + /// Creates a new iterator over a single row. + /// + /// - `specs` - information about columns of the serialized response, + /// - `slice` - a `FrameSlice` which points to the serialized row. + #[inline] + pub fn new(specs: &'frame [ColumnSpec], slice: FrameSlice<'frame>) -> Self { + Self { + specs: specs.iter(), + slice, + } + } + + /// Returns the remaining number of rows that this iterator is expected + /// to return. + #[inline] + pub fn columns_remaining(&self) -> usize { + self.specs.len() + } +} + +impl<'frame> Iterator for ColumnIterator<'frame> { + type Item = Result, ParseError>; + + #[inline] + fn next(&mut self) -> Option { + let spec = self.specs.next()?; + Some( + self.slice + .read_cql_bytes() + .map(|slice| RawColumn { spec, slice }), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.specs.len())) + } +} + +/// A type that can be deserialized from a row that was returned from a query. +/// +/// For tips on how to write a custom implementation of this trait, see the +/// documentation of the parent module. +/// +/// The crate also provides a derive macro which allows to automatically +/// implement the trait for a custom type. For more details on what the macro +/// is capable of, see its documentation. +pub trait DeserializeRow<'frame> +where + Self: Sized, +{ + /// Checks that the schema of the result matches what this type expects. + /// + /// This function can check whether column types and names match the + /// expectations. + fn type_check(specs: &[ColumnSpec]) -> Result<(), ParseError>; + + /// Deserializes a row from given column iterator. + /// + /// This function can assume that the driver called `type_check` to verify + /// the row's type. Note that `deserialize` is not an unsafe function, + /// so it should not use the assumption about `type_check` being called + /// as an excuse to run `unsafe` code. + fn deserialize(row: ColumnIterator<'frame>) -> Result; +} + +impl<'frame> DeserializeRow<'frame> for Row { + #[inline] + fn type_check(_specs: &[ColumnSpec]) -> Result<(), ParseError> { + // CqlValues accept all types, no type checking needed + Ok(()) + } + + #[inline] + fn deserialize(mut row: ColumnIterator<'frame>) -> Result { + let mut columns = Vec::with_capacity(row.size_hint().0); + while let Some(column) = row.next().transpose()? { + columns.push(>::deserialize( + &column.spec.typ, + column.slice, + )?); + } + Ok(Self { columns }) + } +} + +impl<'frame> DeserializeRow<'frame> for ColumnIterator<'frame> { + #[inline] + fn type_check(_specs: &[ColumnSpec]) -> Result<(), ParseError> { + Ok(()) + } + + #[inline] + fn deserialize(row: ColumnIterator<'frame>) -> Result { + Ok(row) + } +} + +macro_rules! impl_tuple { + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl<'frame, $($Ti),*> DeserializeRow<'frame> for ($($Ti,)*) + where + $($Ti: DeserializeCql<'frame>),* + { + fn type_check(specs: &[ColumnSpec]) -> Result<(), ParseError> { + if let [$($idf),*] = &specs { + $( + <$Ti as DeserializeCql<'frame>>::type_check(&$idf.typ)?; + )* + return Ok(()); + } + const TUPLE_LEN: usize = [0, $($idx),*].len() - 1; + return Err(ParseError::BadIncomingData(format!( + "Expected {} columns, but got {:?}", + TUPLE_LEN, specs.len(), + ))); + } + + fn deserialize(mut row: ColumnIterator<'frame>) -> Result { + const TUPLE_LEN: usize = [0, $($idx),*].len() - 1; + let ret = ( + $({ + let column = row.next().ok_or_else(|| ParseError::BadIncomingData( + format!("Expected {} values, got {}", TUPLE_LEN, $idx) + ))??; + <$Ti as DeserializeCql<'frame>>::deserialize(&column.spec.typ, column.slice)? + },)* + ); + if row.next().is_some() { + return Err(ParseError::BadIncomingData( + format!("Expected {} values, but got more", TUPLE_LEN) + )); + } + Ok(ret) + } + } + } +} + +macro_rules! impl_tuple_multiple { + (;;) => { + impl_tuple!(;;); + }; + ($TN:ident $(,$Ti:ident)*; $idx_n:literal $(,$idx:literal)*; $idf_n:ident $(,$idf:ident)*) => { + impl_tuple_multiple!($($Ti),*; $($idx),*; $($idf),*); + impl_tuple!($TN $(,$Ti)*; $idx_n $(,$idx)*; $idf_n $(,$idf)*); + } +} + +impl_tuple_multiple!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 +); + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use scylla_macros::DeserializeRow; + + use crate::frame::frame_errors::ParseError; + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + use crate::types::deserialize::FrameSlice; + + use super::super::tests::serialize_cells; + use super::{ColumnIterator, DeserializeRow}; + + #[test] + fn test_tuple_deserialization() { + // Empty tuple + deserialize::<()>(&[], &Bytes::new()).unwrap(); + + // 1-elem tuple + let (a,) = deserialize::<(i32,)>( + &[spec("i", ColumnType::Int)], + &serialize_cells([val_int(123)]), + ) + .unwrap(); + assert_eq!(a, 123); + + // 3-elem tuple + let (a, b, c) = deserialize::<(i32, i32, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap(); + assert_eq!((a, b, c), (123, 456, 789)); + + // Make sure that column type mismatch is detected + deserialize::<(i32, String, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap_err(); + + // Make sure that borrowing types compile and work correctly + let specs = &[spec("s", ColumnType::Text)]; + let byts = serialize_cells([val_str("abc")]); + let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); + assert_eq!(s, "abc"); + } + + #[test] + fn test_deserialization_as_column_iterator() { + let col_specs = [ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Text), + spec("i3", ColumnType::Counter), + ]; + let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); + let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); + + let col1 = iter.next().unwrap().unwrap(); + assert_eq!(col1.spec.name, "i1"); + assert_eq!(col1.spec.typ, ColumnType::Int); + assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); + + let col2 = iter.next().unwrap().unwrap(); + assert_eq!(col2.spec.name, "i2"); + assert_eq!(col2.spec.typ, ColumnType::Text); + assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); + + let col3 = iter.next().unwrap().unwrap(); + assert_eq!(col3.spec.name, "i3"); + assert_eq!(col3.spec.typ, ColumnType::Counter); + assert!(col3.slice.is_none()); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_struct_deserialization_loose_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Original order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Different order of columns - should still work + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + let byts = serialize_cells([val_int(123), val_str("abc")]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); + } + + #[test] + fn test_struct_deserialization_strict_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Wrong order of columns + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); + } + + #[test] + fn test_struct_deserialization_no_name_check() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, no_field_name_verification)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Correct order of columns, but different names - should still succeed + let specs = &[spec("z", ColumnType::Text), spec("x", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + } + + fn val_int(i: i32) -> Option> { + Some(i.to_be_bytes().to_vec()) + } + + fn val_str(s: &str) -> Option> { + Some(s.as_bytes().to_vec()) + } + + fn spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + name: name.to_owned(), + typ, + table_spec: TableSpec { + ks_name: "ks".to_owned(), + table_name: "tbl".to_owned(), + }, + } + } + + fn deserialize<'frame, R>( + specs: &'frame [ColumnSpec], + byts: &'frame Bytes, + ) -> Result + where + R: DeserializeRow<'frame>, + { + >::type_check(specs)?; + let slice = FrameSlice::new(byts); + let iter = ColumnIterator::new(specs, slice); + >::deserialize(iter) + } +} diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs new file mode 100644 index 0000000000..5938bad9a2 --- /dev/null +++ b/scylla-cql/src/types/deserialize/value.rs @@ -0,0 +1,1814 @@ +//! Provides types for dealing with CQL value deserialization. + +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::hash::{BuildHasher, Hash}; +use std::net::IpAddr; + +use bytes::Bytes; +use chrono::{DateTime, Duration, NaiveDate, TimeZone, Utc}; +use uuid::Uuid; + +use crate::frame::frame_errors::ParseError; +use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; +use crate::frame::types; +use crate::frame::value::{Counter, CqlDuration, Date, Time, Timestamp}; + +use super::FrameSlice; + +/// A type that can be deserialized from a column value inside a row that was +/// returned from a query. +/// +/// For tips on how to write a custom implementation of this trait, see the +/// documentation of the parent module. +/// +/// The crate also provides a derive macro which allows to automatically +/// implement the trait for a custom type. For more details on what the macro +/// is capable of, see its documentation. +pub trait DeserializeCql<'frame> +where + Self: Sized, +{ + /// Checks that the column type matches what this type expects. + fn type_check(typ: &ColumnType) -> Result<(), ParseError>; + + /// Deserialize a column value from given serialized representation. + /// + /// This function can assume that the driver called `type_check` to verify + /// the column's type. Note that `deserialize` is not an unsafe function, + /// so it should not use the assumption about `type_check` being called + /// as an excuse to run `unsafe` code. + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result; +} + +impl<'frame> DeserializeCql<'frame> for CqlValue { + fn type_check(_typ: &ColumnType) -> Result<(), ParseError> { + // CqlValue accepts all possible CQL types + Ok(()) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut val = ensure_not_null(v)?; + let cql = deser_cql_value(typ, &mut val)?; + Ok(cql) + } +} + +impl<'frame, T> DeserializeCql<'frame> for Option +where + T: DeserializeCql<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + T::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + v.map(|_| T::deserialize(typ, v)).transpose() + } +} + +macro_rules! impl_strict_type { + ($cql_name:literal, $t:ty, $cql_type:pat, $conv:expr $(, $l:lifetime)?) => { + impl<$($l,)? 'frame> DeserializeCql<'frame> for $t + where + $('frame: $l)? + { + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + // TODO: Format the CQL type names in the same notation + // that ScyllaDB/Casssandra uses internally and include them + // in such form in the error message + match typ { + $cql_type => Ok(()), + _ => Err(ParseError::BadIncomingData(format!( + "Expected {}, got {:?}", + $cql_name, typ, + ))), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + $conv(typ, v) + } + } + }; +} + +// fixed numeric types + +macro_rules! impl_fixed_numeric_type { + ($cql_name:literal, $t:ty, $col_type:pat) => { + impl_strict_type!( + $cql_name, + $t, + $col_type, + |_typ: &'frame ColumnType, v: Option>| { + const SIZE: usize = std::mem::size_of::<$t>(); + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::($cql_name, val)?; + Ok(<$t>::from_be_bytes(arr)) + } + ); + }; +} + +impl_strict_type!( + "boolean", + bool, + ColumnType::Boolean, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<1>("boolean", val)?; + Ok(arr[0] != 0x00) + } +); + +impl_fixed_numeric_type!("tinyint", i8, ColumnType::TinyInt); +impl_fixed_numeric_type!("smallint", i16, ColumnType::SmallInt); +impl_fixed_numeric_type!("int", i32, ColumnType::Int); +impl_fixed_numeric_type!( + "bigint or counter", + i64, + ColumnType::BigInt | ColumnType::Counter +); +impl_fixed_numeric_type!("float", f32, ColumnType::Float); +impl_fixed_numeric_type!("double", f64, ColumnType::Double); + +// other numeric types + +impl_strict_type!( + "varint", + num_bigint::BigInt, + ColumnType::Varint, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + Ok(num_bigint::BigInt::from_signed_bytes_be(val)) + } +); + +impl_strict_type!( + "decimal", + bigdecimal::BigDecimal, + ColumnType::Decimal, + |_typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null(v)?; + let scale = types::read_int(&mut val)? as i64; + let int_value = num_bigint::BigInt::from_signed_bytes_be(val); + Ok(bigdecimal::BigDecimal::from((int_value, scale))) + } +); + +// blob + +impl_strict_type!( + "blob", + &'a [u8], + ColumnType::Blob, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + Ok(val) + }, + 'a +); +impl_strict_type!( + "blob", + Vec, + ColumnType::Blob, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + Ok(val.to_vec()) + } +); +impl_strict_type!( + "blob", + Bytes, + ColumnType::Blob, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_owned(v)?; + Ok(val) + } +); + +// string + +macro_rules! impl_string_type { + ($t:ty, $conv:expr $(, $l:lifetime)?) => { + impl_strict_type!( + "ascii or text", + $t, + ColumnType::Ascii | ColumnType::Text, + $conv + $(, $l)? + ); + }; +} + +impl_string_type!( + &'a str, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + check_ascii(typ, val)?; + Ok(std::str::from_utf8(val)?) + }, + 'a +); +impl_string_type!( + String, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + check_ascii(typ, val)?; + Ok(std::str::from_utf8(val)?.to_string()) + } +); + +// TODO: Deserialization for string::String + +fn check_ascii(typ: &ColumnType, s: &[u8]) -> Result<(), ParseError> { + if matches!(typ, ColumnType::Ascii) && !s.is_ascii() { + return Err(ParseError::BadIncomingData( + "Expected a valid ASCII string".to_string(), + )); + } + Ok(()) +} + +// counter + +impl_strict_type!( + "counter", + Counter, + ColumnType::Counter, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<8>("counter", val)?; + let counter = i64::from_be_bytes(arr); + Ok(Counter(counter)) + } +); + +// date and time types + +impl_strict_type!( + "date", + Date, + ColumnType::Date, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<4>("date", val)?; + let days = u32::from_be_bytes(arr); + Ok(Date(days)) + } +); + +impl_strict_type!( + "date", + NaiveDate, + ColumnType::Date, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<4>("date", val)?; + let days = u32::from_be_bytes(arr); + let days_since_epoch = chrono::Duration::days(days as i64 - (1i64 << 31)); + NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .checked_add_signed(days_since_epoch) + .ok_or_else(|| { + ParseError::BadIncomingData( + "Value is out of representable range for NaiveDate".to_string(), + ) + }) + } +); + +impl_strict_type!( + "duration", + CqlDuration, + ColumnType::Duration, + |_typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null(v)?; + let months = i32::try_from(types::vint_decode(&mut val)?)?; + let days = i32::try_from(types::vint_decode(&mut val)?)?; + let nanoseconds = types::vint_decode(&mut val)?; + + Ok(CqlDuration { + months, + days, + nanoseconds, + }) + } +); + +impl_strict_type!( + "time or timestamp", + Duration, + ColumnType::Time | ColumnType::Timestamp, + |typ: &'frame ColumnType, v: Option>| { + // Delegate parsing to time/timestamp impls + match typ { + ColumnType::Time => Time::deserialize(typ, v).map(|t| t.0), + ColumnType::Timestamp => Timestamp::deserialize(typ, v).map(|t| t.0), + _ => Err(ParseError::BadIncomingData(format!( + "Invalid type: expected time or timestamp, got {:?}", + typ, + ))), + } + } +); + +impl_strict_type!( + "timestamp", + Timestamp, + ColumnType::Timestamp, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<8>("timestamp", val)?; + let duration = chrono::Duration::milliseconds(i64::from_be_bytes(arr)); + Ok(Timestamp(duration)) + } +); + +impl_strict_type!( + "timestamp", + DateTime, + ColumnType::Timestamp, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<8>("timestamp", val)?; + let millis = i64::from_be_bytes(arr); + match Utc.timestamp_millis_opt(millis) { + chrono::LocalResult::Single(datetime) => Ok(datetime), + _ => Err(ParseError::BadIncomingData(format!( + "Timestamp {} is out of the representable range for DateTime", + millis + ))), + } + } +); + +impl_strict_type!( + "time", + Time, + ColumnType::Time, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<8>("date", val)?; + let nanoseconds = i64::from_be_bytes(arr); + + // Valid values are in the range 0 to 86399999999999 + if !(0..=86399999999999).contains(&nanoseconds) { + return Err(ParseError::BadIncomingData(format!( + "Invalid time value only 0 to 86399999999999 allowed: {}.", + nanoseconds, + ))); + } + + Ok(Time(chrono::Duration::nanoseconds(nanoseconds))) + } +); + +// inet + +impl_strict_type!( + "inet", + IpAddr, + ColumnType::Inet, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + if let Ok(ipv4) = <[u8; 4]>::try_from(val) { + Ok(IpAddr::from(ipv4)) + } else if let Ok(ipv16) = <[u8; 16]>::try_from(val) { + Ok(IpAddr::from(ipv16)) + } else { + Err(ParseError::BadIncomingData(format!( + "Invalid inet bytes length: {}", + val.len(), + ))) + } + } +); + +// uuid +// TODO: Consider having separate types for timeuuid and uuid + +impl_strict_type!( + "timeuuid or uuid", + Uuid, + ColumnType::Uuid | ColumnType::Timeuuid, + |_typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null(v)?; + let arr = ensure_exact_length::<16>("timeuuid or uuid", val)?; + let i = u128::from_be_bytes(arr); + Ok(uuid::Uuid::from_u128(i)) + } +); + +/// A value that may be empty or not. +/// +/// In CQL, some types can have a special value of "empty", represented as +/// a serialized value of length 0. An example of this are integral types: +/// the "int" type can actually hold 2^32 + 1 possible values because of this +/// quirk. Note that this is distinct from being NULL. +/// +/// `MaybeEmpty` was introduced to help support this quirk for Rust types +/// which can't represent the empty, additional value. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub enum MaybeEmpty { + Empty, + Value(T), +} + +impl<'frame, T> DeserializeCql<'frame> for MaybeEmpty +where + T: DeserializeCql<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + >::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let val = ensure_not_null(v)?; + if val.is_empty() { + Ok(MaybeEmpty::Empty) + } else { + let v = >::deserialize(typ, v)?; + Ok(MaybeEmpty::Value(v)) + } + } +} + +// secrecy +#[cfg(feature = "secret")] +impl<'frame, T> DeserializeCql<'frame> for secrecy::Secret +where + T: DeserializeCql<'frame> + secrecy::Zeroize, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + >::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + >::deserialize(typ, v).map(secrecy::Secret::new) + } +} + +// collections + +// lists and sets + +/// An iterator over either a CQL set or list. +pub struct SequenceIterator<'frame, T> { + elem_typ: &'frame ColumnType, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data: std::marker::PhantomData, +} + +impl<'frame, T> SequenceIterator<'frame, T> { + pub fn new(elem_typ: &'frame ColumnType, count: usize, slice: FrameSlice<'frame>) -> Self { + Self { + elem_typ, + raw_iter: FixedLengthBytesSequenceIterator::new(count, slice), + phantom_data: std::marker::PhantomData, + } + } +} + +impl<'frame, T> DeserializeCql<'frame> for SequenceIterator<'frame, T> +where + T: DeserializeCql<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + match typ { + ColumnType::List(el_t) | ColumnType::Set(el_t) => { + >::type_check(el_t) + } + _ => Err(ParseError::BadIncomingData(format!( + "Expected list or set, got {:?}", + typ, + ))), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let v = ensure_not_null_slice(v)?; + let mut mem = v.as_slice(); + let count = types::read_int_length(&mut mem)?; + let elem_typ = match typ { + ColumnType::List(elem_typ) | ColumnType::Set(elem_typ) => elem_typ, + _ => { + return Err(ParseError::BadIncomingData(format!( + "Expected list or set, got {:?}", + typ, + ))) + } + }; + Ok(Self::new( + elem_typ, + count, + FrameSlice::new_subslice(mem, v.as_bytes_ref()), + )) + } +} + +impl<'frame, T> Iterator for SequenceIterator<'frame, T> +where + T: DeserializeCql<'frame>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + let raw = self.raw_iter.next()?; + Some(raw.and_then(|raw| T::deserialize(self.elem_typ, raw))) + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +impl<'frame, T> DeserializeCql<'frame> for Vec +where + T: DeserializeCql<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + SequenceIterator::<'frame, T>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + SequenceIterator::<'frame, T>::deserialize(typ, v)?.collect() + } +} + +impl<'frame, T> DeserializeCql<'frame> for BTreeSet +where + T: DeserializeCql<'frame> + Ord, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + SequenceIterator::<'frame, T>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + SequenceIterator::<'frame, T>::deserialize(typ, v)?.collect() + } +} + +impl<'frame, T, S> DeserializeCql<'frame> for HashSet +where + T: DeserializeCql<'frame> + Eq + Hash, + S: BuildHasher + Default + 'frame, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + SequenceIterator::<'frame, T>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + SequenceIterator::<'frame, T>::deserialize(typ, v)?.collect() + } +} + +/// An iterator over a CQL map. +pub struct MapIterator<'frame, K, V> { + k_typ: &'frame ColumnType, + v_typ: &'frame ColumnType, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, +} + +impl<'frame, K, V> DeserializeCql<'frame> for MapIterator<'frame, K, V> +where + K: DeserializeCql<'frame>, + V: DeserializeCql<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + match typ { + ColumnType::Map(k_t, v_t) => { + >::type_check(k_t)?; + >::type_check(v_t)?; + Ok(()) + } + _ => Err(ParseError::BadIncomingData(format!( + "Expected map, got {:?}", + typ, + ))), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let v = ensure_not_null_slice(v)?; + let mut mem = v.as_slice(); + let count = types::read_int_length(&mut mem)?; + let (k_typ, v_typ) = match typ { + ColumnType::Map(k_t, v_t) => (k_t, v_t), + _ => { + return Err(ParseError::BadIncomingData(format!( + "Expected map, got {:?}", + typ, + ))) + } + }; + Ok(Self { + k_typ, + v_typ, + raw_iter: FixedLengthBytesSequenceIterator::new( + 2 * count, + FrameSlice::new_subslice(mem, v.as_bytes_ref()), + ), + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, + }) + } +} + +impl<'frame, K, V> Iterator for MapIterator<'frame, K, V> +where + K: DeserializeCql<'frame>, + V: DeserializeCql<'frame>, +{ + type Item = Result<(K, V), ParseError>; + + fn next(&mut self) -> Option { + let raw_k = match self.raw_iter.next() { + Some(Ok(raw_k)) => raw_k, + Some(Err(err)) => return Some(Err(err)), + None => return None, + }; + let raw_v = match self.raw_iter.next() { + Some(Ok(raw_v)) => raw_v, + Some(Err(err)) => return Some(Err(err)), + None => return None, + }; + let do_next = || -> Result<(K, V), ParseError> { + let k = K::deserialize(self.k_typ, raw_k)?; + let v = V::deserialize(self.v_typ, raw_v)?; + Ok((k, v)) + }; + do_next().map(Some).transpose() + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +impl<'frame, K, V> DeserializeCql<'frame> for BTreeMap +where + K: DeserializeCql<'frame> + Ord, + V: DeserializeCql<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + MapIterator::<'frame, K, V>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + MapIterator::<'frame, K, V>::deserialize(typ, v)?.collect() + } +} + +impl<'frame, K, V, S> DeserializeCql<'frame> for HashMap +where + K: DeserializeCql<'frame> + Eq + Hash, + V: DeserializeCql<'frame>, + S: BuildHasher + Default + 'frame, +{ + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + MapIterator::<'frame, K, V>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + MapIterator::<'frame, K, V>::deserialize(typ, v)?.collect() + } +} + +// tuples + +// Implements tuple deserialization. +// The generated impl expects that the serialized data will contain at least +// the given amount of values. +// TODO: Include information about the id of the column that failed to parse +macro_rules! impl_tuple { + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl<'frame, $($Ti),*> DeserializeCql<'frame> for ($($Ti,)*) + where + $($Ti: DeserializeCql<'frame>),* + { + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + const TUPLE_LEN: usize = [0, $($idx),*].len() - 1; + let [$($idf),*] = ensure_tuple_type::(typ)?; + $( + <$Ti>::type_check($idf)?; + )* + Ok(()) + } + + fn deserialize(typ: &'frame ColumnType, v: Option>) -> Result { + const TUPLE_LEN: usize = [0, $($idx),*].len() - 1; + let [$($idf),*] = ensure_tuple_type::(typ)?; + + // Ignore the warning for the zero-sized tuple + #[allow(unused)] + let mut v = ensure_not_null_slice(v)?; + let ret = ( + $( + <$Ti>::deserialize($idf, v.read_cql_bytes()?)?, + )* + ); + Ok(ret) + } + } + } +} + +macro_rules! impl_tuple_multiple { + (;;) => { + impl_tuple!(;;); + }; + ($TN:ident $(,$Ti:ident)*; $idx_n:literal $(,$idx:literal)*; $idf_n:ident $(,$idf:ident)*) => { + impl_tuple_multiple!($($Ti),*; $($idx),*; $($idf),*); + impl_tuple!($TN $(,$Ti)*; $idx_n $(,$idx)*; $idf_n $(,$idf)*); + } +} + +impl_tuple_multiple!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 +); + +// udts + +/// An iterator over fields of a User Defined Type. +/// +/// # Note +/// +/// A serialized UDT will generally have one value for each field, but it is +/// allowed to have fewer. This iterator differentiates null values +/// from non-existent values in the following way: +/// +/// - `None` - missing from the serialized form +/// - `Some(None)` - present, but null +/// - `Some(Some(...))` - non-null, present value +pub struct UdtIterator<'frame> { + fields: &'frame [(String, ColumnType)], + raw_iter: BytesSequenceIterator<'frame>, +} + +impl<'frame> UdtIterator<'frame> { + #[inline] + pub fn new(fields: &'frame [(String, ColumnType)], slice: FrameSlice<'frame>) -> Self { + Self { + fields, + raw_iter: BytesSequenceIterator::new(slice), + } + } + + #[inline] + pub fn fields(&self) -> &'frame [(String, ColumnType)] { + self.fields + } +} + +impl<'frame> DeserializeCql<'frame> for UdtIterator<'frame> { + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + match typ { + ColumnType::UserDefinedType { .. } => Ok(()), + _ => Err(ParseError::BadIncomingData(format!( + "Expected a user defined type, got {:?}", + typ, + ))), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let v = ensure_not_null_slice(v)?; + let mem = v.as_slice(); + let fields = match typ { + ColumnType::UserDefinedType { field_types, .. } => field_types.as_ref(), + _ => { + return Err(ParseError::BadIncomingData(format!( + "Expected a user defined type, got {:?}", + typ, + ))) + } + }; + Ok(Self::new( + fields, + FrameSlice::new_subslice(mem, v.as_bytes_ref()), + )) + } +} + +impl<'frame> Iterator for UdtIterator<'frame> { + type Item = Result< + ( + &'frame (String, ColumnType), + Option>>, + ), + ParseError, + >; + + fn next(&mut self) -> Option { + // TODO: Should we fail when there are too many fields? + let (head, fields) = self.fields.split_first()?; + self.fields = fields; + let raw = match self.raw_iter.next() { + // The field is there and it was parsed correctly + Some(Ok(raw)) => Some(raw), + + // There were some bytes but they didn't parse as correct field value + Some(Err(err)) => return Some(Err(err)), + + // The field is just missing from the serialized form + None => None, + }; + Some(Ok((head, raw))) + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +// Utilities + +fn ensure_not_null(v: Option) -> Result<&[u8], ParseError> { + match v { + Some(v) => Ok(v.as_slice()), + None => Err(ParseError::BadIncomingData( + "Expected a non-null value".to_string(), + )), + } +} + +fn ensure_not_null_owned(v: Option) -> Result { + match v { + Some(v) => Ok(v.to_bytes()), + None => Err(ParseError::BadIncomingData( + "Expected a non-null value".to_string(), + )), + } +} + +fn ensure_not_null_slice(v: Option) -> Result { + match v { + Some(v) => Ok(v), + None => Err(ParseError::BadIncomingData( + "Expected a non-null value".to_string(), + )), + } +} + +fn ensure_exact_length( + cql_name: &str, + v: &[u8], +) -> Result<[u8; SIZE], ParseError> { + v.try_into().map_err(|_| { + ParseError::BadIncomingData(format!( + "The type {} requires {} bytes, but got {}", + cql_name, + SIZE, + v.len(), + )) + }) +} + +fn ensure_tuple_type( + typ: &ColumnType, +) -> Result<&[ColumnType; SIZE], ParseError> { + let fail = || { + ParseError::BadIncomingData(format!( + "Expected tuple of size {}, but got {:?}", + SIZE, typ, + )) + }; + if let ColumnType::Tuple(typs_v) = typ { + typs_v.as_slice().try_into().map_err(|_| fail()) + } else { + Err(fail()) + } +} + +// Helper iterators + +/// Iterates over a sequence of `[bytes]` items from a frame subslice. +/// +/// The `[bytes]` items are parsed until the end of subslice is reached. +#[derive(Clone, Copy, Debug)] +pub struct BytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, +} + +impl<'frame> BytesSequenceIterator<'frame> { + #[inline] + fn new(slice: FrameSlice<'frame>) -> Self { + Self { slice } + } +} + +impl<'frame> From> for BytesSequenceIterator<'frame> { + #[inline] + fn from(slice: FrameSlice<'frame>) -> Self { + Self::new(slice) + } +} + +impl<'frame> Iterator for BytesSequenceIterator<'frame> { + type Item = Result>, ParseError>; + + fn next(&mut self) -> Option { + if self.slice.as_slice().is_empty() { + None + } else { + Some(self.slice.read_cql_bytes()) + } + } +} + +/// Iterates over a sequence of `[bytes]` items from a frame subslice, expecting +/// a particular number of items. +/// +/// The iterator does not consider it to be an error if there are some bytes +/// remaining in the slice after parsing requested amount of items. +#[derive(Clone, Copy, Debug)] +pub struct FixedLengthBytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, + remaining: usize, +} + +impl<'frame> FixedLengthBytesSequenceIterator<'frame> { + pub fn new(count: usize, slice: FrameSlice<'frame>) -> Self { + Self { + slice, + remaining: count, + } + } +} + +impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { + type Item = Result>, ParseError>; + + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + Some(self.slice.read_cql_bytes()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; + use std::fmt::Debug; + use std::net::{IpAddr, Ipv6Addr}; + + use bigdecimal::BigDecimal; + use bytes::{BufMut, Bytes, BytesMut}; + use chrono::{DateTime, Duration, NaiveDate, Utc}; + use num_bigint::BigInt; + use scylla_macros::{DeserializeCql, FromUserType}; + use uuid::Uuid; + + use crate::frame::response::cql_to_rust::FromCqlVal; + use crate::frame::response::result::{ColumnType, CqlValue}; + use crate::frame::types; + use crate::frame::value::{CqlDuration, Value}; + use crate::frame::value::{Date, Time, Timestamp}; + use crate::frame::{ + frame_errors::ParseError, response::result::deser_cql_value, value::Counter, + }; + use crate::types::deserialize::value::MaybeEmpty; + use crate::types::deserialize::FrameSlice; + + use super::{DeserializeCql, MapIterator, SequenceIterator}; + + #[test] + fn test_deserialize_bytes() { + const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; + + let bytes = make_cell(ORIGINAL_BYTES); + + let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); + + assert_eq!(decoded_slice, ORIGINAL_BYTES); + assert_eq!(decoded_vec, ORIGINAL_BYTES); + assert_eq!(decoded_bytes, ORIGINAL_BYTES); + } + + #[test] + fn test_deserialize_ascii() { + const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; + + let ascii = make_cell(ASCII_TEXT.as_bytes()); + + let decoded_ascii_str = deserialize::<&str>(&ColumnType::Ascii, &ascii).unwrap(); + let decoded_ascii_string = deserialize::(&ColumnType::Ascii, &ascii).unwrap(); + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &ascii).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &ascii).unwrap(); + + assert_eq!(decoded_ascii_str, ASCII_TEXT); + assert_eq!(decoded_ascii_string, ASCII_TEXT); + assert_eq!(decoded_text_str, ASCII_TEXT); + assert_eq!(decoded_text_string, ASCII_TEXT); + } + + #[test] + fn test_deserialize_text() { + const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; + + let unicode = make_cell(UNICODE_TEXT.as_bytes()); + + // Should fail because it's not an ASCII string + deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); + deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); + + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); + assert_eq!(decoded_text_str, UNICODE_TEXT); + assert_eq!(decoded_text_string, UNICODE_TEXT); + } + + #[test] + fn test_integral() { + let tinyint = make_cell(&[0x01]); + let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); + assert_eq!(decoded_tinyint, 0x01); + + let smallint = make_cell(&[0x01, 0x02]); + let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); + assert_eq!(decoded_smallint, 0x0102); + + let int = make_cell(&[0x01, 0x02, 0x03, 0x04]); + let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, 0x01020304); + + let bigint = make_cell(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); + assert_eq!(decoded_bigint, 0x0102030405060708); + } + + #[test] + fn test_floating_point() { + let float = make_cell(&[63, 0, 0, 0]); + let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); + assert_eq!(decoded_float, 0.5); + + let double = make_cell(&[64, 0, 0, 0, 0, 0, 0, 0]); + let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); + assert_eq!(decoded_double, 2.0); + } + + #[test] + fn test_list_and_set() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_cell(&mut collection_contents, "quick".as_bytes()); + append_cell(&mut collection_contents, "brown".as_bytes()); + append_cell(&mut collection_contents, "fox".as_bytes()); + + let collection = make_cell(&collection_contents); + + let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); + let set_typ = ColumnType::List(Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some("quick")); + assert_eq!(iter.next().transpose().unwrap(), Some("brown")); + assert_eq!(iter.next().transpose().unwrap(), Some("fox")); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_vec_str = vec!["quick", "brown", "fox"]; + let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; + + // list + let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); + let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(decoded_vec_str, expected_vec_str); + assert_eq!(decoded_vec_string, expected_vec_string); + + // hash set + let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_hash_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_hash_string, + expected_vec_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_btree_string, + expected_vec_string.into_iter().collect(), + ); + } + + #[test] + fn test_map() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_cell(&mut collection_contents, &1i32.to_be_bytes()); + append_cell(&mut collection_contents, "quick".as_bytes()); + append_cell(&mut collection_contents, &2i32.to_be_bytes()); + append_cell(&mut collection_contents, "brown".as_bytes()); + append_cell(&mut collection_contents, &3i32.to_be_bytes()); + append_cell(&mut collection_contents, "fox".as_bytes()); + + let collection = make_cell(&collection_contents); + + let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); + assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); + assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; + let expected_string = vec![ + (1, "quick".to_string()), + (2, "brown".to_string()), + (3, "fox".to_string()), + ]; + + // hash set + let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); + assert_eq!( + decoded_hash_string, + expected_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_str.clone().into_iter().collect(), + ); + assert_eq!(decoded_btree_string, expected_string.into_iter().collect(),); + } + + #[test] + fn test_tuples() { + let mut tuple_contents = BytesMut::new(); + append_cell(&mut tuple_contents, &42i32.to_be_bytes()); + append_cell(&mut tuple_contents, "foo".as_bytes()); + append_null(&mut tuple_contents); + + let tuple = make_cell(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); + + let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); + assert_eq!(tup, (42, "foo", None)); + } + + #[test] + fn test_maybe_empty() { + let empty = make_cell(&[]); + let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); + assert_eq!(decoded_empty, MaybeEmpty::Empty); + + let non_empty = make_cell(&[0x01]); + let decoded_non_empty = + deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); + assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); + } + + #[test] + fn test_udt_loose_ordering() { + #[derive(DeserializeCql, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(default_when_missing)] + b: Option, + } + + // UDF fields in correct same order + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + append_cell(&mut udt_contents, &42i32.to_be_bytes()); + let udt = make_cell(&udt_contents); + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("a".to_owned(), ColumnType::Text), + ("b".to_owned(), ColumnType::Int), + ], + }; + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + + // The last UDT field is missing in serialized form - it should treat + // as if there were null at the end + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + let udt = make_cell(&udt_contents); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + } + ); + + // UDF fields switched - should still work + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, &42i32.to_be_bytes()); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + let udt = make_cell(&udt_contents); + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("b".to_owned(), ColumnType::Int), + ("a".to_owned(), ColumnType::Text), + ], + }; + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + + // Only field 'a' is present + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + let udt = make_cell(&udt_contents); + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![("a".to_owned(), ColumnType::Text)], + }; + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + } + ); + + // Wrong column type + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![("a".to_owned(), ColumnType::Int)], + }; + Udt::type_check(&typ).unwrap_err(); + + // Missing required column + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![("b".to_owned(), ColumnType::Int)], + }; + Udt::type_check(&typ).unwrap_err(); + } + + #[test] + fn test_udt_strict_ordering() { + #[derive(DeserializeCql, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + // UDF fields in correct same order + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + append_cell(&mut udt_contents, &42i32.to_be_bytes()); + let udt = make_cell(&udt_contents); + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("a".to_owned(), ColumnType::Text), + ("b".to_owned(), ColumnType::Int), + ], + }; + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + + // The last UDF field is missing in serialized form - it should treat + // as if there were null at the end + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + let udt = make_cell(&udt_contents); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + } + ); + + // UDF fields switched - will not work + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("b".to_owned(), ColumnType::Int), + ("a".to_owned(), ColumnType::Text), + ], + }; + Udt::type_check(&typ).unwrap_err(); + + // Wrong column type + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("a".to_owned(), ColumnType::Int), + ("b".to_owned(), ColumnType::Int), + ], + }; + Udt::type_check(&typ).unwrap_err(); + + // Missing required column + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![("b".to_owned(), ColumnType::Int)], + }; + Udt::type_check(&typ).unwrap_err(); + } + + #[test] + fn test_udt_no_name_check() { + #[derive(DeserializeCql, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, no_field_name_verification)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + // UDF fields in correct same order + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + append_cell(&mut udt_contents, &42i32.to_be_bytes()); + let udt = make_cell(&udt_contents); + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("a".to_owned(), ColumnType::Text), + ("b".to_owned(), ColumnType::Int), + ], + }; + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + + // Correct order of UDF fields, but different names - should still succeed + let mut udt_contents = BytesMut::new(); + append_cell(&mut udt_contents, "The quick brown fox".as_bytes()); + append_cell(&mut udt_contents, &42i32.to_be_bytes()); + let udt = make_cell(&udt_contents); + let typ = ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: vec![ + ("k".to_owned(), ColumnType::Text), + ("l".to_owned(), ColumnType::Int), + ], + }; + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + #[test] + fn test_custom_type_parser() { + #[derive(Default, Debug, PartialEq, Eq)] + struct SwappedPair(B, A); + impl<'frame, A, B> DeserializeCql<'frame> for SwappedPair + where + A: DeserializeCql<'frame>, + B: DeserializeCql<'frame>, + { + fn type_check(typ: &ColumnType) -> Result<(), ParseError> { + <(B, A) as DeserializeCql<'frame>>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + <(B, A) as DeserializeCql<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) + } + } + + let mut tuple_contents = BytesMut::new(); + append_cell(&mut tuple_contents, "foo".as_bytes()); + append_cell(&mut tuple_contents, &42i32.to_be_bytes()); + let tuple = make_cell(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); + + let tup = deserialize::>(&typ, &tuple).unwrap(); + assert_eq!(tup, SwappedPair("foo", 42)); + } + + #[test] + fn test_from_cql_value_compatibility() { + // This test should have a sub-case for each type + // that implements FromCqlValue + + // fixed size integers + for i in 0..7 { + let v: i8 = 1 << i; + compat_check::(&ColumnType::TinyInt, make_cell(&v.to_be_bytes())); + compat_check::(&ColumnType::TinyInt, make_cell(&(-v).to_be_bytes())); + } + for i in 0..15 { + let v: i16 = 1 << i; + compat_check::(&ColumnType::SmallInt, make_cell(&v.to_be_bytes())); + compat_check::(&ColumnType::SmallInt, make_cell(&(-v).to_be_bytes())); + } + for i in 0..31 { + let v: i32 = 1 << i; + compat_check::(&ColumnType::Int, make_cell(&v.to_be_bytes())); + compat_check::(&ColumnType::Int, make_cell(&(-v).to_be_bytes())); + } + for i in 0..63 { + let v: i64 = 1 << i; + compat_check::(&ColumnType::BigInt, make_cell(&v.to_be_bytes())); + compat_check::(&ColumnType::BigInt, make_cell(&(-v).to_be_bytes())); + } + + // counters + for i in 0..63 { + let v: i64 = 1 << i; + compat_check::(&ColumnType::Counter, make_cell(&v.to_be_bytes())); + } + + // bool + compat_check::(&ColumnType::Boolean, make_cell(&[0])); + compat_check::(&ColumnType::Boolean, make_cell(&[1])); + + // fixed size floating point types + compat_check::(&ColumnType::Float, make_cell(&123f32.to_be_bytes())); + compat_check::(&ColumnType::Float, make_cell(&(-123f32).to_be_bytes())); + compat_check::(&ColumnType::Double, make_cell(&123f64.to_be_bytes())); + compat_check::(&ColumnType::Double, make_cell(&(-123f64).to_be_bytes())); + + const PI_STR: &[u8] = b"3.1415926535897932384626433832795028841971693993751058209749445923"; + + // big integers + let num1 = PI_STR[2..].to_vec(); + let num2 = vec![b'-'] + .into_iter() + .chain(PI_STR[2..].iter().copied()) + .collect::>(); + let num3 = b"0".to_vec(); + + let num1 = BigInt::parse_bytes(&num1, 10).unwrap(); + let num2 = BigInt::parse_bytes(&num2, 10).unwrap(); + let num3 = BigInt::parse_bytes(&num3, 10).unwrap(); + compat_check::(&ColumnType::Varint, serialize_cell(&num1)); + compat_check::(&ColumnType::Varint, serialize_cell(&num2)); + compat_check::(&ColumnType::Varint, serialize_cell(&num3)); + + // big decimals + let num1 = PI_STR.to_vec(); + let num2 = vec![b'-'] + .into_iter() + .chain(PI_STR.iter().copied()) + .collect::>(); + let num3 = b"0.0".to_vec(); + + let num1 = BigDecimal::parse_bytes(&num1, 10).unwrap(); + let num2 = BigDecimal::parse_bytes(&num2, 10).unwrap(); + let num3 = BigDecimal::parse_bytes(&num3, 10).unwrap(); + compat_check::(&ColumnType::Decimal, serialize_cell(&num1)); + compat_check::(&ColumnType::Decimal, serialize_cell(&num2)); + compat_check::(&ColumnType::Decimal, serialize_cell(&num3)); + + // date and time + let date1 = (2u32.pow(31)).to_be_bytes(); + let date2 = (2u32.pow(31) - 30).to_be_bytes(); + let date3 = (2u32.pow(31) + 30).to_be_bytes(); + compat_check::(&ColumnType::Date, make_cell(&date1)); + compat_check::(&ColumnType::Date, make_cell(&date2)); + compat_check::(&ColumnType::Date, make_cell(&date3)); + + compat_check::(&ColumnType::Date, make_cell(&date1)); + compat_check::(&ColumnType::Date, make_cell(&date2)); + compat_check::(&ColumnType::Date, make_cell(&date3)); + + let timestamp1 = Duration::milliseconds(123); + let timestamp2 = Duration::seconds(123); + let timestamp3 = Duration::hours(18); + // Duration type is relevant for both `time` and `timestamp` CQL types + compat_check::(&ColumnType::Time, serialize_cell(&Time(timestamp1))); + compat_check::(&ColumnType::Time, serialize_cell(&Time(timestamp2))); + compat_check::(&ColumnType::Time, serialize_cell(&Time(timestamp3))); + compat_check::( + &ColumnType::Timestamp, + serialize_cell(&Timestamp(timestamp1)), + ); + compat_check::( + &ColumnType::Timestamp, + serialize_cell(&Timestamp(timestamp2)), + ); + compat_check::( + &ColumnType::Timestamp, + serialize_cell(&Timestamp(timestamp3)), + ); + + compat_check::>( + &ColumnType::Timestamp, + serialize_cell(&Timestamp(timestamp1)), + ); + compat_check::>( + &ColumnType::Timestamp, + serialize_cell(&Timestamp(timestamp2)), + ); + compat_check::>( + &ColumnType::Timestamp, + serialize_cell(&Timestamp(timestamp3)), + ); + + // duration + let duration1 = CqlDuration { + days: 123, + months: 456, + nanoseconds: 789, + }; + let duration2 = CqlDuration { + days: 987, + months: 654, + nanoseconds: 321, + }; + compat_check::(&ColumnType::Duration, serialize_cell(&duration1)); + compat_check::(&ColumnType::Duration, serialize_cell(&duration2)); + + // text types + for typ in &[ColumnType::Ascii, ColumnType::Text] { + compat_check::(typ, make_cell("".as_bytes())); + compat_check::(typ, make_cell("foo".as_bytes())); + compat_check::(typ, make_cell("superfragilisticexpialidocious".as_bytes())); + } + + // blob + compat_check::>(&ColumnType::Blob, make_cell(&[])); + compat_check::>(&ColumnType::Blob, make_cell(&[1, 9, 2, 8, 3, 7, 4, 6, 5])); + + let ipv4 = IpAddr::from([127u8, 0, 0, 1]); + let ipv6: IpAddr = Ipv6Addr::LOCALHOST.into(); + compat_check::(&ColumnType::Inet, make_ip_address(ipv4)); + compat_check::(&ColumnType::Inet, make_ip_address(ipv6)); + + // uuid and timeuuid + // new_v4 generates random UUIDs, so these are different cases + let uuid1 = Uuid::new_v4(); + let uuid2 = Uuid::new_v4(); + let uuid3 = Uuid::new_v4(); + compat_check::(&ColumnType::Uuid, serialize_cell(&uuid1)); + compat_check::(&ColumnType::Uuid, serialize_cell(&uuid2)); + compat_check::(&ColumnType::Uuid, serialize_cell(&uuid3)); + compat_check::(&ColumnType::Timeuuid, serialize_cell(&uuid1)); + compat_check::(&ColumnType::Timeuuid, serialize_cell(&uuid2)); + compat_check::(&ColumnType::Timeuuid, serialize_cell(&uuid3)); + + // nulls, represented via option + compat_check::>(&ColumnType::Int, serialize_cell(&123i32)); + compat_check::>(&ColumnType::Int, make_null()); + + // empty values + // ...are implemented via MaybeEmpty and are handled in other tests + + // collections + let mut list = BytesMut::new(); + list.put_i32(3); + append_cell(&mut list, &123i32.to_be_bytes()); + append_cell(&mut list, &456i32.to_be_bytes()); + append_cell(&mut list, &789i32.to_be_bytes()); + let list = make_cell(&list); + let list_type = ColumnType::List(Box::new(ColumnType::Int)); + compat_check::>(&list_type, list.clone()); + compat_check::>(&list_type, list.clone()); + compat_check::>(&list_type, list); + + let mut map = BytesMut::new(); + map.put_i32(3); + append_cell(&mut map, &123i32.to_be_bytes()); + append_cell(&mut map, "quick".as_bytes()); + append_cell(&mut map, &456i32.to_be_bytes()); + append_cell(&mut map, "brown".as_bytes()); + append_cell(&mut map, &789i32.to_be_bytes()); + append_cell(&mut map, "fox".as_bytes()); + let map = make_cell(&map); + let map_type = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Text)); + compat_check::>(&map_type, map.clone()); + compat_check::>(&map_type, map); + + // Tuples + let tup_type = ColumnType::Tuple(vec![ColumnType::Text, ColumnType::Int, ColumnType::Uuid]); + let mut tup = BytesMut::new(); + append_cell(&mut tup, "quick brown fox".as_bytes()); + append_cell(&mut tup, &123i32.to_be_bytes()); + append_cell(&mut tup, &uuid1.to_u128_le().to_be_bytes()); + let tup = make_cell(&tup); + compat_check::<(String, i32, Uuid)>(&tup_type, tup); + + // UDTs + #[derive(DeserializeCql, FromUserType, Debug, PartialEq, Eq)] + #[scylla_crate = "crate"] + #[scylla(crate = "crate")] + struct Udt { + a: String, + b: Option, + } + + let udt_type = ColumnType::UserDefinedType { + type_name: "udt".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("a".to_string(), ColumnType::Text), + ("b".to_string(), ColumnType::Int), + ], + }; + + let mut udt = BytesMut::new(); + append_cell(&mut udt, "quick brown fox".as_bytes()); + append_cell(&mut udt, &123i32.to_be_bytes()); + let udt = make_cell(&udt); + compat_check::(&udt_type, udt); + + // One column missing + let mut udt = BytesMut::new(); + append_cell(&mut udt, "quick brown fox".as_bytes()); + let udt = make_cell(&udt); + compat_check::(&udt_type, udt); + } + + // Checks that both new and old serialization framework + // produces the same results in this case + fn compat_check(typ: &ColumnType, raw: Bytes) + where + T: for<'f> DeserializeCql<'f>, + T: FromCqlVal>, + T: Debug + PartialEq, + { + let mut slice = raw.as_ref(); + let mut cell = types::read_bytes_opt(&mut slice).unwrap(); + let old = T::from_cql( + cell.as_mut() + .map(|c| deser_cql_value(typ, c)) + .transpose() + .unwrap(), + ) + .unwrap(); + let new = deserialize::(typ, &raw).unwrap(); + assert_eq!(old, new); + } + + fn deserialize<'frame, T>(typ: &'frame ColumnType, byts: &'frame Bytes) -> Result + where + T: DeserializeCql<'frame>, + { + >::type_check(typ)?; + let mut buf = byts.as_ref(); + let cell = types::read_bytes_opt(&mut buf)?; + let value = cell.map(|cell| FrameSlice::new_subslice(cell, byts)); + >::deserialize(typ, value) + } + + fn make_cell(cell: &[u8]) -> Bytes { + let mut b = BytesMut::new(); + append_cell(&mut b, cell); + b.freeze() + } + + fn make_null() -> Bytes { + let mut b = BytesMut::new(); + append_null(&mut b); + b.freeze() + } + + fn serialize_cell(value: &impl Value) -> Bytes { + let mut v = Vec::new(); + value.serialize(&mut v).unwrap(); + v.into() + } + + fn make_ip_address(ip: IpAddr) -> Bytes { + match ip { + IpAddr::V4(v4) => make_cell(&v4.octets()), + IpAddr::V6(v6) => make_cell(&v6.octets()), + } + } + + fn append_cell(b: &mut impl BufMut, cell: &[u8]) { + b.put_i32(cell.len() as i32); + b.put_slice(cell); + } + + fn append_null(b: &mut impl BufMut) { + b.put_i32(-1); + } +} diff --git a/scylla-cql/src/types/mod.rs b/scylla-cql/src/types/mod.rs new file mode 100644 index 0000000000..6339d95e71 --- /dev/null +++ b/scylla-cql/src/types/mod.rs @@ -0,0 +1 @@ +pub mod deserialize; diff --git a/scylla-macros/Cargo.toml b/scylla-macros/Cargo.toml index d428d24c71..1bc2a1ce79 100644 --- a/scylla-macros/Cargo.toml +++ b/scylla-macros/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0" proc-macro = true [dependencies] +darling = "0.14" syn = "1.0" quote = "1.0" -proc-macro2 = "1.0" \ No newline at end of file +proc-macro2 = "1.0" diff --git a/scylla-macros/src/deserialize/cql.rs b/scylla-macros/src/deserialize/cql.rs new file mode 100644 index 0000000000..864ab0ac4a --- /dev/null +++ b/scylla-macros/src/deserialize/cql.rs @@ -0,0 +1,566 @@ +use darling::{FromAttributes, FromField}; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::{ext::IdentExt, parse_quote}; + +use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct StructAttrs { + #[darling(rename = "crate")] + crate_path: Option, + + // If true, then the type checking code will require the order of the fields + // to be the same in both the Rust struct and the UDT. This allows the + // deserialization to be slightly faster because looking struct fields up + // by name can be avoided, though it is less convenient. + #[darling(default)] + enforce_order: bool, + + // If true, then the type checking code won't verify the UDT field names. + // UDT fields will be matched to struct fields based solely on the order. + // + // This annotation only works if `enforce_order` is specified. + #[darling(default)] + no_field_name_verification: bool, +} + +impl DeserializeCommonStructAttrs for StructAttrs { + fn crate_path(&self) -> syn::Path { + match &self.crate_path { + Some(path) => parse_quote!(#path::_macro_internal), + None => parse_quote!(scylla::_macro_internal), + } + } +} + +#[derive(FromField)] +#[darling(attributes(scylla))] +struct Field { + // If true, then the field is not parsed at all, but it is initialized + // with Default::default() instead. All other attributes are ignored. + #[darling(default)] + skip: bool, + + // If true, then - if this field is missing from the UDT fields - it will + // be initialized to Default::default(). + // Not supported in enforce_order mode. + #[darling(default)] + default_when_missing: bool, + + // If set, then deserializes from the UDT field with this particular name + // instead of the Rust field name. + #[darling(default)] + rename: Option, + + ident: Option, + ty: syn::Type, +} + +impl DeserializeCommonFieldAttrs for Field { + fn needs_default(&self) -> bool { + self.skip || self.default_when_missing + } + + fn deserialize_target(&self) -> syn::Type { + self.ty.clone() + } +} + +// derive(DeserializeUserType) for the new API +pub fn deserialize_user_type_derive( + tokens_input: TokenStream, +) -> Result { + let input = syn::parse(tokens_input)?; + + let implemented_trait: syn::Path = parse_quote!(types::deserialize::value::DeserializeCql); + let constraining_trait = implemented_trait.clone(); + let s = StructDesc::new(&input, "DeserializeCql", constraining_trait)?; + + let items = vec![ + s.generate_type_check_method().into(), + s.generate_deserialize_method().into(), + ]; + + Ok(s.generate_impl(implemented_trait, items)) +} + +impl Field { + // Returns whether this field is mandatory for deserialization. + fn is_required(&self) -> bool { + !self.skip && !self.default_when_missing + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + let field_name = match self.rename.as_ref() { + Some(rename) => rename.to_owned(), + None => self.ident.as_ref().unwrap().unraw().to_string(), + }; + syn::LitStr::new(&field_name, Span::call_site()) + } +} + +type StructDesc = super::StructDescForDeserialize; + +impl StructDesc { + // Generates an expression which extracts the UDT fields or returns an error + fn generate_extract_fields_from_type(&self, typ_expr: syn::Expr) -> syn::Expr { + let crate_path = &self.struct_attrs().crate_path(); + parse_quote!( + match #typ_expr { + #crate_path::frame::response::result::ColumnType::UserDefinedType { field_types, .. } => field_types, + _ => return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + "Wrong type, expected an UDT".to_string(), + ), + ), + } + ) + } + + fn generate_type_check_method(&self) -> syn::ImplItemMethod { + if self.attrs.enforce_order { + TypeCheckAssumeOrderGenerator(self).generate() + } else { + TypeCheckUnorderedGenerator(self).generate() + } + } + + fn generate_deserialize_method(&self) -> syn::ImplItemMethod { + if self.attrs.enforce_order { + DeserializeAssumeOrderGenerator(self).generate() + } else { + DeserializeUnorderedGenerator(self).generate() + } + } +} + +struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { + fn generate_name_verification( + &self, + id: usize, + field: &Field, + udt_field_name: &syn::Ident, + ) -> Option { + if self.0.attrs.no_field_name_verification { + return None; + } + + let crate_path = &self.0.struct_attrs().crate_path; + let field_name = field.cql_name_literal(); + + Some(parse_quote!( + if #udt_field_name != #field_name { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!( + "Field #{} has wrong name, expected {} but got {}", + #id, + #udt_field_name, + #field_name + ) + ) + ); + } + )) + } + + // Generates the type_check method for when ensure_order == true. + fn generate(&self) -> syn::ImplItemMethod { + // The generated method will: + // - Check that every required field appears on the list in the same order as struct fields + // - Every type on the list is correct + + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let required_fields: Vec<_> = self.0.fields().iter().filter(|f| f.is_required()).collect(); + let extract_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + let field_deserializers = required_fields.iter().map(|f| f.deserialize_target()); + let required_field_count = required_fields.len(); + let field_count_lit = + syn::LitInt::new(&required_field_count.to_string(), Span::call_site()); + let numbers = 0usize..; + + let name_verifications: Vec<_> = required_fields + .iter() + .enumerate() + .map(|(id, field)| self.generate_name_verification(id, field, &parse_quote!(name))) + .collect(); + + parse_quote!( + fn type_check( + typ: &#crate_path::frame::response::result::ColumnType, + ) -> ::std::result::Result<(), #crate_path::frame::frame_errors::ParseError> { + // Extract information about the field types from the UDT + // type definition. + let fields = #extract_fields_expr; + + // Verify that the field count is correct + if fields.len() != #field_count_lit { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!( + "Wrong number of fields in a UDT, expected {} but got {}", + #field_count_lit, + fields.len(), + ) + ) + ) + } + + #( + let (name, typ) = &fields[#numbers]; + + // Verify the name (unless `no_field_name_verification` is specified) + #name_verifications + + // Verify the type + // TODO: Provide better context about which field this error is about + <#field_deserializers as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::type_check(typ)?; + )* + + // All is good! + Ok(()) + } + ) + } +} + +struct DeserializeAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeAssumeOrderGenerator<'sd> { + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let crate_path = self.0.struct_attrs().crate_path(); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + let constraint_lifetime = self.0.constraint_lifetime(); + parse_quote!( + { + let res = iter.next().ok_or_else(|| { + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Missing field: {}", #cql_name_literal) + ) + })?; + let ((_, typ), value) = res?; + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); + <#deserializer as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::deserialize(typ, value)? + } + ) + } + + fn generate(&self) -> syn::ImplItemMethod { + // We can assume that type_check was called. + + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + #[allow(unused_mut)] + let mut iterator_type = + quote!(#crate_path::types::deserialize::value::UdtIterator<#constraint_lifetime>); + + parse_quote! { + fn deserialize( + typ: &#constraint_lifetime #crate_path::frame::response::result::ColumnType, + v: ::std::option::Option<#crate_path::types::deserialize::FrameSlice<#constraint_lifetime>>, + ) -> ::std::result::Result { + // Create an iterator over the fields of the UDT. + let mut iter = <#iterator_type as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::deserialize(typ, v)?; + + Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} + +struct TypeCheckUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckUnorderedGenerator<'sd> { + // An identifier for a bool variable that represents whether given + // field was already visited during type check + fn visited_flag_variable(field: &Field) -> syn::Ident { + quote::format_ident!("visited_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates a declaration of a "visited" flag for the purpose of type check. + // We generate it even if the flag is not required in order to protect + // from fields appearing more than once + fn generate_visited_flag_decl(field: &Field) -> Option { + if field.skip { + return None; + } + + let visited_flag = Self::visited_flag_variable(field); + Some(parse_quote!(let mut #visited_flag = false;)) + } + + // Generates code that, given variable `typ`, type-checks given field + fn generate_type_check(&self, field: &Field) -> Option { + if field.skip { + return None; + } + + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let visited_flag = Self::visited_flag_variable(field); + let typ = field.deserialize_target(); + let cql_name_literal = field.cql_name_literal(); + let decrement_if_required = if field.is_required() { + quote!(remaining_required_fields -= 1;) + } else { + quote!() + }; + Some(parse_quote!( + { + if !#visited_flag { + <#typ as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::type_check(typ)?; + #visited_flag = true; + #decrement_if_required + } else { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Field {} occurs more than once in serialized data", #cql_name_literal), + ), + ) + } + } + )) + } + + // Generates code that appends the flag name if it is missing. + // The generated code is used to construct a nice error message. + fn generate_append_name(field: &Field) -> Option { + if field.is_required() { + let visited_flag = Self::visited_flag_variable(field); + let cql_name_literal = field.cql_name_literal(); + Some(parse_quote!( + { + if !#visited_flag { + missing_fields.push(#cql_name_literal); + } + } + )) + } else { + None + } + } + + // Generates the type_check method for when ensure_order == false. + fn generate(&self) -> syn::ImplItemMethod { + // The generated method will: + // - Check that every required field appears on the list exactly once, in any order + // - Every type on the list is correct + + let crate_path = &self.0.struct_attrs().crate_path(); + let fields = self.0.fields(); + let visited_field_declarations = fields.iter().flat_map(Self::generate_visited_flag_decl); + let type_check_blocks = fields.iter().flat_map(|f| self.generate_type_check(f)); + let append_name_blocks = fields.iter().flat_map(Self::generate_append_name); + let field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + let required_field_count = fields.iter().filter(|f| f.is_required()).count(); + let field_count_lit = + syn::LitInt::new(&required_field_count.to_string(), Span::call_site()); + let extract_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + + parse_quote! { + fn type_check( + typ: &#crate_path::frame::response::result::ColumnType, + ) -> ::std::result::Result<(), #crate_path::frame::frame_errors::ParseError> { + // Extract information about the field types from the UDT + // type definition. + let fields = #extract_fields_expr; + + // Counts down how many required fields are remaining + let mut remaining_required_fields: ::std::primitive::usize = #field_count_lit; + + // For each required field, generate a "visited" boolean flag + #(#visited_field_declarations)* + + for (name, typ) in fields { + // Pattern match on the name and verify that the type is correct. + match name.as_str() { + #(#field_names => #type_check_blocks,)* + unknown => { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Unknown field: {}", unknown), + ), + ) + } + } + } + + if remaining_required_fields > 0 { + // If there are some missing required fields, generate an error + // which contains missing field names + let mut missing_fields = ::std::vec::Vec::<&'static str>::with_capacity(remaining_required_fields); + #(#append_name_blocks)* + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Missing fields: {:?}", missing_fields), + ), + ) + } + + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeUnorderedGenerator<'sd> { + // An identifier for a variable that is meant to store the parsed variable + // before being ultimately moved to the struct on deserialize + fn deserialize_field_variable(field: &Field) -> syn::Ident { + quote::format_ident!("f_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates an expression which produces a value ready to be put into a field + // of the target structure + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let crate_path = self.0.struct_attrs().crate_path(); + let deserialize_field = Self::deserialize_field_variable(field); + if field.default_when_missing { + // Generate Default::default if the field was missing + parse_quote!(#deserialize_field.unwrap_or_default()) + } else { + let cql_name_literal = field.cql_name_literal(); + parse_quote!(#deserialize_field.ok_or_else(|| { + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Missing field: {}", #cql_name_literal) + ) + })?) + } + } + + // Generated code that performs deserialization when the raw field + // is being processed + fn generate_deserialization(&self, field: &Field) -> Option { + if field.skip { + return None; + } + + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + Some(parse_quote!( + { + if #deserialize_field.is_some() { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Field {} occurs more than once in serialized data", #cql_name_literal), + ), + ); + } else { + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); + #deserialize_field = ::std::option::Option::Some( + <#deserializer as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::deserialize(typ, value)? + ); + } + } + )) + } + + // Generate a declaration of a variable that temporarily keeps + // the deserialized value + fn generate_deserialize_field_decl(field: &Field) -> Option { + if field.skip { + return None; + } + let deserialize_field = Self::deserialize_field_variable(field); + Some(parse_quote!(let mut #deserialize_field = ::std::option::Option::None;)) + } + + fn generate(&self) -> syn::ImplItemMethod { + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let deserialize_field_decls = fields.iter().map(Self::generate_deserialize_field_decl); + let deserialize_blocks = fields.iter().flat_map(|f| self.generate_deserialization(f)); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + let iterator_type = + quote!(#crate_path::types::deserialize::value::UdtIterator<#constraint_lifetime>); + + // TODO: Allow collecting unrecognized fields into some special field + + parse_quote! { + fn deserialize( + typ: &#constraint_lifetime #crate_path::frame::response::result::ColumnType, + v: ::std::option::Option<#crate_path::types::deserialize::FrameSlice<#constraint_lifetime>>, + ) -> ::std::result::Result { + // Create an iterator over the fields of the UDT. + let iter = <#iterator_type as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::deserialize(typ, v)?; + + // Generate fields that will serve as temporary storage + // for the fields' values. Those are of type Option. + #(#deserialize_field_decls)* + + for item in iter { + let ((name, typ), value) = item?; + // Pattern match on the field name and deserialize. + match name.as_str() { + #(#field_names => #deserialize_blocks,)* + unknown => return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + format!("Unknown field: {}", unknown), + ) + ) + } + } + + // Create the final struct. The finalizer expressions convert + // the temporary storage fields to the final field values. + // For example, if a field is missing but marked as + // `default_when_null` it will create a default value, otherwise + // it will report an error. + Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} diff --git a/scylla-macros/src/deserialize/mod.rs b/scylla-macros/src/deserialize/mod.rs new file mode 100644 index 0000000000..4082c51eb9 --- /dev/null +++ b/scylla-macros/src/deserialize/mod.rs @@ -0,0 +1,167 @@ +use darling::{FromAttributes, FromField}; +use proc_macro2::Span; +use quote::quote; +use syn::parse_quote; + +pub(crate) mod cql; +pub(crate) mod row; + +/// Common attributes that all deserialize impls should understand. +trait DeserializeCommonStructAttrs { + /// The path to either `scylla` or `scylla_cql` crate + fn crate_path(&self) -> syn::Path; +} + +/// Provides access to attributes that are common to DeserializeCql +/// and DeserializeRow traits. +trait DeserializeCommonFieldAttrs { + /// Does the type of this field need Default to be implemented? + fn needs_default(&self) -> bool; + + /// The type of the field, i.e. what this field deserializes to. + fn deserialize_target(&self) -> syn::Type; +} + +/// A structure helpful in implementing DeserializeCql and DeserializeRow. +/// +/// It implements some common logic for both traits: +/// - Generates a unique lifetime that binds all other lifetimes in both structs, +/// - Adds appropriate trait bounds (DeserializeCql + Default) +struct StructDescForDeserialize { + name: syn::Ident, + attrs: Attrs, + fields: Vec, + constraint_trait: syn::Path, + constraint_lifetime: syn::Lifetime, + + generics: syn::Generics, +} + +impl StructDescForDeserialize +where + Attrs: FromAttributes + DeserializeCommonStructAttrs, + Field: FromField + DeserializeCommonFieldAttrs, +{ + fn new( + input: &syn::DeriveInput, + trait_name: &str, + constraint_trait: syn::Path, + ) -> Result { + let attrs = Attrs::from_attributes(&input.attrs)?; + + // TODO: Handle errors from parse_name_fields + let fields = crate::parser::parse_named_fields(input, trait_name) + .named + .iter() + .map(Field::from_field) + .collect::>()?; + + let constraint_lifetime = generate_unique_lifetime_for_impl(&input.generics); + + Ok(Self { + name: input.ident.clone(), + attrs, + fields, + constraint_trait, + constraint_lifetime, + generics: input.generics.clone(), + }) + } + + fn struct_attrs(&self) -> &Attrs { + &self.attrs + } + + fn constraint_lifetime(&self) -> &syn::Lifetime { + &self.constraint_lifetime + } + + fn fields(&self) -> &[Field] { + &self.fields + } + + fn generate_impl(&self, trait_: syn::Path, items: Vec) -> syn::ItemImpl { + let constraint_lifetime = &self.constraint_lifetime; + let (_, ty_generics, _) = self.generics.split_for_impl(); + let impl_generics = &self.generics.params; + + let scylla_crate = self.attrs.crate_path(); + let struct_name = &self.name; + let mut predicates = Vec::new(); + predicates.extend(generate_lifetime_constraints_for_impl( + &self.generics, + self.constraint_trait.clone(), + self.constraint_lifetime.clone(), + )); + predicates.extend(generate_default_constraints(&self.fields)); + let trait_ = quote!(#scylla_crate::#trait_); + + parse_quote! { + impl<#constraint_lifetime, #impl_generics> #trait_<#constraint_lifetime> for #struct_name #ty_generics + where #(#predicates),* + { + #(#items)* + } + } + } +} + +/// Generates T: Default constraints for those fields that need it. +fn generate_default_constraints(fields: &[Field]) -> Vec +where + Field: DeserializeCommonFieldAttrs, +{ + fields + .iter() + .filter(|f| f.needs_default()) + .map(|f| { + let t = f.deserialize_target(); + parse_quote!(#t: std::default::Default) + }) + .collect() +} + +/// Helps introduce a lifetime to an `impl` definition that constrains +/// other lifetimes and types. +/// +/// The original use case is DeriveCql and DeriveRow. Both of those traits +/// are parametrized with a lifetime. If T: DeriveCql<'a> then this means +/// that you can deserialize T as some CQL value from bytes that have +/// lifetime 'a, similarly for DeriveRow. In impls for those traits, +/// an additional lifetime must be introduced and properly constrained. +fn generate_lifetime_constraints_for_impl( + generics: &syn::Generics, + trait_full_name: syn::Path, + constraint_lifetime: syn::Lifetime, +) -> Vec { + let mut predicates = Vec::new(); + + // Constrain the new lifetime with the existing lifetime parameters + // 'lifetime: 'a + 'b + 'c ... + let lifetimes: Vec<_> = generics.lifetimes().map(|l| l.lifetime.clone()).collect(); + if !lifetimes.is_empty() { + predicates.push(parse_quote!(#constraint_lifetime: #(#lifetimes)+*)); + } + + // For each type parameter T, constrain it like this: + // T: DeriveCql<'lifetime>, + for t in generics.type_params() { + let t_ident = &t.ident; + predicates.push(parse_quote!(#t_ident: #trait_full_name<#constraint_lifetime>)); + } + + predicates +} + +/// Generates a new lifetime parameter, with a different name to any of the +/// existing generic lifetimes. +fn generate_unique_lifetime_for_impl(generics: &syn::Generics) -> syn::Lifetime { + let mut constraint_lifetime_name = "'lifetime".to_string(); + while generics + .lifetimes() + .any(|l| l.lifetime.to_string() == constraint_lifetime_name) + { + constraint_lifetime_name += "e"; + } + syn::Lifetime::new(&constraint_lifetime_name, Span::call_site()) +} diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs new file mode 100644 index 0000000000..da3e991fe4 --- /dev/null +++ b/scylla-macros/src/deserialize/row.rs @@ -0,0 +1,500 @@ +use darling::{FromAttributes, FromField}; +use proc_macro2::Span; +use quote::quote; +use syn::ext::IdentExt; +use syn::parse_quote; + +use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct StructAttrs { + #[darling(rename = "crate")] + crate_path: Option, + + // If true, then the type checking code will require the order of the fields + // to be the same in both the Rust struct and the columns. This allows the + // deserialization to be slightly faster because looking struct fields up + // by name can be avoided, though it is less convenient. + #[darling(default)] + enforce_order: bool, + + // If true, then the type checking code won't verify the column names. + // Columns will be matched to struct fields based solely on the order. + // + // This annotation only works if `enforce_order` is specified. + #[darling(default)] + no_field_name_verification: bool, +} + +impl DeserializeCommonStructAttrs for StructAttrs { + fn crate_path(&self) -> syn::Path { + match &self.crate_path { + Some(path) => parse_quote!(#path::_macro_internal), + None => parse_quote!(scylla::_macro_internal), + } + } +} + +#[derive(FromField)] +#[darling(attributes(scylla))] +struct Field { + // If true, then the field is not parsed at all, but it is initialized + // with Default::default() instead. All other attributes are ignored. + #[darling(default)] + skip: bool, + + // If set, then deserialization will look for the UDT field of given name + // and deserialize to this Rust field, instead of just using the Rust + // field name. + #[darling(default)] + rename: Option, + + ident: Option, + ty: syn::Type, +} + +impl DeserializeCommonFieldAttrs for Field { + fn needs_default(&self) -> bool { + self.skip + } + + fn deserialize_target(&self) -> syn::Type { + self.ty.clone() + } +} + +// derive(DeserializeRow) for the new DeserializeRow trait +pub fn deserialize_row_derive( + tokens_input: proc_macro::TokenStream, +) -> Result { + let input = syn::parse(tokens_input)?; + + let implemented_trait = parse_quote!(types::deserialize::row::DeserializeRow); + let constraining_trait = parse_quote!(types::deserialize::value::DeserializeCql); + let s = StructDesc::new(&input, "DeserializeRow", constraining_trait)?; + + let items = vec![ + s.generate_type_check_method().into(), + s.generate_deserialize_method().into(), + ]; + + Ok(s.generate_impl(implemented_trait, items)) +} + +impl Field { + // Returns whether this field is mandatory for deserialization. + fn is_required(&self) -> bool { + !self.skip + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + let field_name = match self.rename.as_ref() { + Some(rename) => rename.to_owned(), + None => self.ident.as_ref().unwrap().unraw().to_string(), + }; + syn::LitStr::new(&field_name, Span::call_site()) + } +} + +type StructDesc = super::StructDescForDeserialize; + +impl StructDesc { + fn generate_type_check_method(&self) -> syn::ImplItemMethod { + if self.attrs.enforce_order { + TypeCheckAssumeOrderGenerator(self).generate() + } else { + TypeCheckUnorderedGenerator(self).generate() + } + } + + fn generate_deserialize_method(&self) -> syn::ImplItemMethod { + if self.attrs.enforce_order { + DeserializeAssumeOrderGenerator(self).generate() + } else { + DeserializeUnorderedGenerator(self).generate() + } + } +} + +struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { + fn generate_name_verification( + &self, + id: usize, + field: &Field, + column_spec: &syn::Ident, + ) -> Option { + if self.0.attrs.no_field_name_verification { + return None; + } + + let crate_path = &self.0.struct_attrs().crate_path; + let field_name = field.cql_name_literal(); + + Some(parse_quote!( + if #column_spec.name != #field_name { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!( + "Column #{} has wrong name, expected {} but got {}", + #id, + #column_spec.name, + #field_name + ) + ) + ); + } + )) + } + + fn generate(&self) -> syn::ImplItemMethod { + // TODO: Better error messages here, add more context to which field + // failed to be parsed + + // The generated method will check that the order and the types + // of the columns correspond fields' names/types. + + let crate_path = &self.0.struct_attrs().crate_path; + let constraint_lifetime = self.0.constraint_lifetime(); + + let required_fields: Vec<_> = self.0.fields().iter().filter(|f| f.is_required()).collect(); + let field_idents: Vec<_> = (0..required_fields.len()) + .map(|i| quote::format_ident!("f_{}", i)) + .collect(); + let name_verifications: Vec<_> = required_fields + .iter() + .zip(field_idents.iter()) + .enumerate() + .map(|(id, (field, fidents))| self.generate_name_verification(id, field, fidents)) + .collect(); + + let field_deserializers = required_fields.iter().map(|f| f.deserialize_target()); + + let field_count = required_fields.len(); + + parse_quote! { + fn type_check( + specs: &[#crate_path::frame::response::result::ColumnSpec], + ) -> ::std::result::Result<(), #crate_path::frame::frame_errors::ParseError> { + match specs { + [#(#field_idents),*] => { + #( + // Verify the name (unless `no_field_name_verification' is specified) + #name_verifications + + // Verify the type + // TODO: Provide better context about which field this error is about + <#field_deserializers as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::type_check(&#field_idents.typ)?; + )* + ::std::result::Result::Ok(()) + }, + _ => ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + format!( + "Wrong number of columns, expected {} but got {}", + #field_count, + specs.len(), + ) + ), + ), + } + } + } + } +} + +struct DeserializeAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeAssumeOrderGenerator<'sd> { + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let crate_path = self.0.struct_attrs().crate_path(); + let deserializer = field.deserialize_target(); + let constraint_lifetime = self.0.constraint_lifetime(); + parse_quote!( + { + let col = row.next().ok_or_else(|| { + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + "Not enough fields".to_owned(), + ) + })??; + <#deserializer as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice)? + } + ) + } + + fn generate(&self) -> syn::ImplItemMethod { + let crate_path = &self.0.struct_attrs().crate_path; + let constraint_lifetime = self.0.constraint_lifetime(); + + let fields = self.0.fields(); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + parse_quote! { + fn deserialize( + #[allow(unused_mut)] + mut row: #crate_path::types::deserialize::row::ColumnIterator<#constraint_lifetime>, + ) -> ::std::result::Result { + ::std::result::Result::Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} + +struct TypeCheckUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckUnorderedGenerator<'sd> { + // An identifier for a bool variable that represents whether given + // field was already visited during type check + fn visited_flag_variable(field: &Field) -> syn::Ident { + quote::format_ident!("visited_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates a declaration of a "visited" flag for the purpose of type check. + // We generate it even if the flag is not required in order to protect + // from fields appearing more than once + fn generate_visited_flag_decl(field: &Field) -> Option { + if field.skip { + return None; + } + + let visited_flag = Self::visited_flag_variable(field); + Some(parse_quote!(let mut #visited_flag = false;)) + } + + // Generates code that, given variable `typ`, type-checks given field + fn generate_type_check(&self, field: &Field) -> Option { + if field.skip { + return None; + } + + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let visited_flag = Self::visited_flag_variable(field); + let typ = field.deserialize_target(); + let cql_name_literal = field.cql_name_literal(); + let decrement_if_required = if field.is_required() { + quote!(remaining_required_fields -= 1;) + } else { + quote!() + }; + Some(parse_quote!( + { + if !#visited_flag { + <#typ as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::type_check(&spec.typ)?; + #visited_flag = true; + #decrement_if_required + } else { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Column {} occurs more than once in serialized data", #cql_name_literal), + ), + ) + } + } + )) + } + + // Generates code that appends the flag name if it is missing. + // The generated code is used to construct a nice error message. + fn generate_append_name(field: &Field) -> Option { + if field.is_required() { + let visited_flag = Self::visited_flag_variable(field); + let cql_name_literal = field.cql_name_literal(); + Some(parse_quote!( + { + if !#visited_flag { + missing_fields.push(#cql_name_literal); + } + } + )) + } else { + None + } + } + + fn generate(&self) -> syn::ImplItemMethod { + let crate_path = self.0.struct_attrs().crate_path(); + + let fields = self.0.fields(); + let visited_field_declarations = fields.iter().flat_map(Self::generate_visited_flag_decl); + let type_check_blocks = fields.iter().flat_map(|f| self.generate_type_check(f)); + let append_name_blocks = fields.iter().flat_map(Self::generate_append_name); + let field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + let field_count_lit = fields.iter().filter(|f| f.is_required()).count(); + + parse_quote! { + fn type_check( + specs: &[#crate_path::frame::response::result::ColumnSpec], + ) -> ::std::result::Result<(), #crate_path::frame::frame_errors::ParseError> { + // Counts down how many required fields are remaining + let mut remaining_required_fields: ::std::primitive::usize = #field_count_lit; + + // For each required field, generate a "visited" boolean flag + #(#visited_field_declarations)* + + for spec in specs { + // Pattern match on the name and verify that the type is correct. + match spec.name.as_str() { + #(#field_names => #type_check_blocks,)* + unknown => { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Unknown field: {}", unknown), + ), + ) + } + } + } + + if remaining_required_fields > 0 { + // If there are some missing required fields, generate an error + // which contains missing field names + let mut missing_fields = ::std::vec::Vec::<&'static str>::with_capacity(remaining_required_fields); + #(#append_name_blocks)* + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Missing fields: {:?}", missing_fields), + ), + ) + } + + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeUnorderedGenerator<'sd> { + // An identifier for a variable that is meant to store the parsed variable + // before being ultimately moved to the struct on deserialize + fn deserialize_field_variable(field: &Field) -> syn::Ident { + quote::format_ident!("f_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates an expression which produces a value ready to be put into a field + // of the target structure + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let crate_path = self.0.struct_attrs().crate_path(); + let deserialize_field = Self::deserialize_field_variable(field); + { + let cql_name_literal = field.cql_name_literal(); + parse_quote!(#deserialize_field.ok_or_else(|| { + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Missing field: {}", #cql_name_literal) + ) + })?) + } + } + + // Generated code that performs deserialization when the raw field + // is being processed + fn generate_deserialization(&self, field: &Field) -> Option { + if field.skip { + return None; + } + + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + Some(parse_quote!( + { + if #deserialize_field.is_some() { + return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + ::std::format!("Field {} occurs more than once in serialized data", #cql_name_literal), + ), + ); + } else { + #deserialize_field = ::std::option::Option::Some( + <#deserializer as #crate_path::types::deserialize::value::DeserializeCql<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice)? + ); + } + } + )) + } + + // Generate a declaration of a variable that temporarily keeps + // the deserialized value + fn generate_deserialize_field_decl(field: &Field) -> Option { + if field.skip { + return None; + } + let deserialize_field = Self::deserialize_field_variable(field); + Some(parse_quote!(let mut #deserialize_field = ::std::option::Option::None;)) + } + + fn generate(&self) -> syn::ImplItemMethod { + let crate_path = self.0.struct_attrs().crate_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let deserialize_field_decls = fields.iter().map(Self::generate_deserialize_field_decl); + let deserialize_blocks = fields.iter().flat_map(|f| self.generate_deserialization(f)); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + // TODO: Allow collecting unrecognized fields into some special field + + parse_quote! { + fn deserialize( + #[allow(unused_mut)] + mut row: #crate_path::types::deserialize::row::ColumnIterator<#constraint_lifetime>, + ) -> ::std::result::Result { + + // Generate fields that will serve as temporary storage + // for the fields' values. Those are of type Option. + #(#deserialize_field_decls)* + + for col in row { + let col = col?; + // Pattern match on the field name and deserialize. + match col.spec.name.as_str() { + #(#field_names => #deserialize_blocks,)* + unknown => return ::std::result::Result::Err( + #crate_path::frame::frame_errors::ParseError::BadIncomingData( + format!("Unknown column: {}", unknown), + ) + ) + } + } + + // Create the final struct. The finalizer expressions convert + // the temporary storage fields to the final field values. + // For example, if a field is missing but marked as + // `default_when_null` it will create a default value, otherwise + // it will report an error. + Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index f5ad28a26d..0caef39203 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -1,11 +1,30 @@ +use darling::ToTokens; use proc_macro::TokenStream; +mod deserialize; + mod from_row; mod from_user_type; mod into_user_type; mod parser; mod value_list; +#[proc_macro_derive(DeserializeRow, attributes(scylla, scylla_crate))] +pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { + match deserialize::row::deserialize_row_derive(tokens_input) { + Ok(tokens) => tokens.into_token_stream().into(), + Err(err) => err.into_compile_error().into(), + } +} + +#[proc_macro_derive(DeserializeCql, attributes(scylla))] +pub fn deserialize_cql_derive(tokens_input: TokenStream) -> TokenStream { + match deserialize::cql::deserialize_user_type_derive(tokens_input) { + Ok(tokens) => tokens.into_token_stream().into(), + Err(err) => err.into_compile_error().into(), + } +} + /// #[derive(FromRow)] derives FromRow for struct /// Works only on simple structs without generics etc #[proc_macro_derive(FromRow, attributes(scylla_crate))] diff --git a/scylla-macros/src/parser.rs b/scylla-macros/src/parser.rs index 0da54370c8..efb4581d2b 100644 --- a/scylla-macros/src/parser.rs +++ b/scylla-macros/src/parser.rs @@ -1,5 +1,4 @@ -use syn::{Data, DeriveInput, Fields, FieldsNamed}; -use syn::{Lit, Meta}; +use syn::{Data, DeriveInput, Fields, FieldsNamed, Lit, Meta}; /// Parses the tokens_input to a DeriveInput and returns the struct name from which it derives and /// the named fields @@ -19,8 +18,8 @@ pub(crate) fn parse_named_fields<'a>( } } -pub(crate) fn get_path(input: &DeriveInput) -> Result { - let mut this_path: Option = None; +pub(crate) fn get_path(input: &DeriveInput) -> Result { + let mut this_path: Option = None; for attr in input.attrs.iter() { if !attr.path.is_ident("scylla_crate") { continue; @@ -30,7 +29,7 @@ pub(crate) fn get_path(input: &DeriveInput) -> Result().unwrap(); if this_path.is_none() { - this_path = Some(quote::quote!(#path_val::_macro_internal)); + this_path = Some(syn::parse_quote!(#path_val::_macro_internal)); } else { return Err(syn::Error::new_spanned( &meta_name_value.lit, @@ -55,5 +54,5 @@ pub(crate) fn get_path(input: &DeriveInput) -> Result Result<(), Box> { -//! let session: Session = SessionBuilder::new() +//! let session: Legacy08Session = SessionBuilder::new() //! .known_node("127.0.0.1:9042") //! .known_node("1.2.3.4:9876") -//! .build() +//! .build_legacy() //! .await?; //! //! Ok(()) @@ -50,9 +50,9 @@ //! //! The easiest way to specify bound values in a query is using a tuple: //! ```rust -//! # use scylla::Session; +//! # use scylla::Legacy08Session; //! # use std::error::Error; -//! # async fn check_only_compiles(session: &Session) -> Result<(), Box> { +//! # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { //! // Insert an int and text into the table //! session //! .query( @@ -69,9 +69,9 @@ //! The easiest way to read rows returned by a query is to cast each row to a tuple of values: //! //! ```rust -//! # use scylla::Session; +//! # use scylla::Legacy08Session; //! # use std::error::Error; -//! # async fn check_only_compiles(session: &Session) -> Result<(), Box> { +//! # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { //! use scylla::IntoTypedRows; //! //! // Read rows containing an int and text @@ -100,6 +100,7 @@ pub mod _macro_internal { pub use scylla_cql::frame; pub use scylla_cql::macros::{self, *}; +pub use scylla_cql::types; pub mod authentication; #[cfg(feature = "cloud")] @@ -125,10 +126,11 @@ pub use statement::query; pub use frame::response::cql_to_rust; pub use frame::response::cql_to_rust::FromRow; -pub use transport::caching_session::CachingSession; +pub use transport::caching_session::{CachingSession, Legacy08CachingSession}; pub use transport::execution_profile::ExecutionProfile; +pub use transport::legacy_query_result::Legacy08QueryResult; pub use transport::query_result::QueryResult; -pub use transport::session::{IntoTypedRows, Session, SessionConfig}; +pub use transport::session::{IntoTypedRows, Legacy08Session, Session, SessionConfig}; pub use transport::session_builder::SessionBuilder; #[cfg(feature = "cloud")] diff --git a/scylla/src/tracing.rs b/scylla/src/tracing.rs index dbdff59963..8d323d95d8 100644 --- a/scylla/src/tracing.rs +++ b/scylla/src/tracing.rs @@ -1,17 +1,16 @@ use crate::statement::Consistency; use itertools::Itertools; +use scylla_macros::DeserializeRow; use std::collections::HashMap; use std::net::IpAddr; use std::num::NonZeroU32; use std::time::Duration; use uuid::Uuid; -use crate::cql_to_rust::{FromRow, FromRowError}; -use crate::frame::response::result::Row; - /// Tracing info retrieved from `system_traces.sessions` /// with all events from `system_traces.events` -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, DeserializeRow, Clone, PartialEq, Eq)] +#[scylla(crate = "crate")] pub struct TracingInfo { pub client: Option, pub command: Option, @@ -22,11 +21,13 @@ pub struct TracingInfo { /// started_at is a timestamp - time since unix epoch pub started_at: Option, + #[scylla(skip)] pub events: Vec, } /// A single event happening during a traced query -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, DeserializeRow, Clone, PartialEq, Eq)] +#[scylla(crate = "crate")] pub struct TracingEvent { pub event_id: Uuid, pub activity: Option, @@ -79,51 +80,3 @@ pub(crate) const TRACES_SESSION_QUERY_STR: &str = pub(crate) const TRACES_EVENTS_QUERY_STR: &str = "SELECT event_id, activity, source, source_elapsed, thread \ FROM system_traces.events WHERE session_id = ?"; - -// Converts a row received by performing TRACES_SESSION_QUERY_STR to TracingInfo -impl FromRow for TracingInfo { - fn from_row(row: Row) -> Result { - let (client, command, coordinator, duration, parameters, request, started_at) = - <( - Option, - Option, - Option, - Option, - Option>, - Option, - Option, - )>::from_row(row)?; - - Ok(TracingInfo { - client, - command, - coordinator, - duration, - parameters, - request, - started_at, - events: Vec::new(), - }) - } -} - -// Converts a row received by performing TRACES_SESSION_QUERY_STR to TracingInfo -impl FromRow for TracingEvent { - fn from_row(row: Row) -> Result { - let (event_id, activity, source, source_elapsed, thread) = <( - Uuid, - Option, - Option, - Option, - Option, - )>::from_row(row)?; - - Ok(TracingEvent { - event_id, - activity, - source, - source_elapsed, - thread, - }) - } -} diff --git a/scylla/src/transport/authenticate_test.rs b/scylla/src/transport/authenticate_test.rs index 2e8f32e542..38bcc7c059 100644 --- a/scylla/src/transport/authenticate_test.rs +++ b/scylla/src/transport/authenticate_test.rs @@ -14,7 +14,7 @@ async fn authenticate_superuser() { let session = crate::SessionBuilder::new() .known_node(uri) .user("cassandra", "cassandra") - .build() + .build_legacy() .await .unwrap(); let ks = unique_keyspace_name(); @@ -69,7 +69,7 @@ async fn custom_authentication() { let session = crate::SessionBuilder::new() .known_node(uri) .authenticator_provider(Arc::new(CustomAuthenticatorProvider)) - .build() + .build_legacy() .await .unwrap(); let ks = unique_keyspace_name(); diff --git a/scylla/src/transport/caching_session.rs b/scylla/src/transport/caching_session.rs index cc0a1dedbd..39a9ea75c2 100644 --- a/scylla/src/transport/caching_session.rs +++ b/scylla/src/transport/caching_session.rs @@ -3,9 +3,9 @@ use crate::frame::value::{BatchValues, ValueList}; use crate::prepared_statement::PreparedStatement; use crate::query::Query; use crate::transport::errors::QueryError; -use crate::transport::iterator::RowIterator; +use crate::transport::iterator::Legacy08RowIterator; use crate::transport::partitioner::PartitionerName; -use crate::{QueryResult, Session}; +use crate::{Legacy08QueryResult, QueryResult}; use bytes::Bytes; use dashmap::DashMap; use futures::future::try_join_all; @@ -13,6 +13,11 @@ use scylla_cql::frame::response::result::PreparedMetadata; use std::collections::hash_map::RandomState; use std::hash::BuildHasher; +use super::iterator::RawIterator; +use super::session::{ + CurrentDeserializationApi, DeserializationApiKind, GenericSession, Legacy08DeserializationApi, +}; + /// Contains just the parts of a prepared statement that were returned /// from the database. All remaining parts (query string, page size, /// consistency, etc.) are taken from the Query passed @@ -27,11 +32,12 @@ struct RawPreparedStatementData { /// Provides auto caching while executing queries #[derive(Debug)] -pub struct CachingSession +pub struct GenericCachingSession where S: Clone + BuildHasher, + DeserializationApi: DeserializationApiKind, { - session: Session, + session: GenericSession, /// The prepared statement cache size /// If a prepared statement is added while the limit is reached, the oldest prepared statement /// is removed from the cache @@ -39,11 +45,16 @@ where cache: DashMap, } -impl CachingSession +pub type CachingSession = GenericCachingSession; +pub type Legacy08CachingSession = + GenericCachingSession; + +impl GenericCachingSession where S: Default + BuildHasher + Clone, + DeserApi: DeserializationApiKind, { - pub fn from(session: Session, cache_size: usize) -> Self { + pub fn from(session: GenericSession, cache_size: usize) -> Self { Self { session, max_capacity: cache_size, @@ -52,20 +63,26 @@ where } } -impl CachingSession +impl GenericCachingSession where S: BuildHasher + Clone, + DeserApi: DeserializationApiKind, { /// Builds a [`CachingSession`] from a [`Session`], a cache size, and a [`BuildHasher`]., /// using a customer hasher. - pub fn with_hasher(session: Session, cache_size: usize, hasher: S) -> Self { + pub fn with_hasher(session: GenericSession, cache_size: usize, hasher: S) -> Self { Self { session, max_capacity: cache_size, cache: DashMap::with_hasher(hasher), } } +} +impl GenericCachingSession +where + S: BuildHasher + Clone, +{ /// Does the same thing as [`Session::execute`] but uses the prepared statement cache pub async fn execute( &self, @@ -83,7 +100,7 @@ where &self, query: impl Into, values: impl ValueList, - ) -> Result { + ) -> Result { let query = query.into(); let prepared = self.add_prepared_statement_owned(query).await?; let values = values.serialized()?; @@ -125,7 +142,78 @@ where self.session.batch(&prepared_batch, &values).await } } +} +impl GenericCachingSession +where + S: BuildHasher + Clone, +{ + /// Does the same thing as [`Session::execute`] but uses the prepared statement cache + pub async fn execute( + &self, + query: impl Into, + values: impl ValueList, + ) -> Result { + let query = query.into(); + let prepared = self.add_prepared_statement_owned(query).await?; + let values = values.serialized()?; + self.session.execute(&prepared, values.clone()).await + } + + /// Does the same thing as [`Session::execute_iter`] but uses the prepared statement cache + pub async fn execute_iter( + &self, + query: impl Into, + values: impl ValueList, + ) -> Result { + let query = query.into(); + let prepared = self.add_prepared_statement_owned(query).await?; + let values = values.serialized()?; + self.session.execute_iter(prepared, values.clone()).await + } + + /// Does the same thing as [`Session::execute_paged`] but uses the prepared statement cache + pub async fn execute_paged( + &self, + query: impl Into, + values: impl ValueList, + paging_state: Option, + ) -> Result { + let query = query.into(); + let prepared = self.add_prepared_statement_owned(query).await?; + let values = values.serialized()?; + self.session + .execute_paged(&prepared, values.clone(), paging_state.clone()) + .await + } + + /// Does the same thing as [`Session::batch`] but uses the prepared statement cache\ + /// Prepares batch using CachingSession::prepare_batch if needed and then executes it + pub async fn batch( + &self, + batch: &Batch, + values: impl BatchValues, + ) -> Result { + let all_prepared: bool = batch + .statements + .iter() + .all(|stmt| matches!(stmt, BatchStatement::PreparedStatement(_))); + + if all_prepared { + self.session.batch(batch, &values).await + } else { + let prepared_batch: Batch = self.prepare_batch(batch).await?; + + self.session.batch(&prepared_batch, &values).await + } + } +} + +impl GenericCachingSession +where + S: BuildHasher + Clone, + DeserApi: DeserializationApiKind, +{ /// Prepares all statements within the batch and returns a new batch where every /// statement is prepared. /// Uses the prepared statements cache. @@ -211,7 +299,7 @@ where self.max_capacity } - pub fn get_session(&self) -> &Session { + pub fn get_session(&self) -> &GenericSession { &self.session } } @@ -221,13 +309,15 @@ mod tests { use crate::query::Query; use crate::test_utils::create_new_session_builder; use crate::transport::partitioner::PartitionerName; + use crate::transport::session::Session; use crate::utils::test_utils::unique_keyspace_name; use crate::{ batch::{Batch, BatchStatement}, prepared_statement::PreparedStatement, - CachingSession, Session, + CachingSession, }; use futures::TryStreamExt; + use scylla_cql::frame::response::result::Row; use std::collections::BTreeSet; async fn new_for_test() -> Session { @@ -323,7 +413,7 @@ mod tests { .unwrap(); assert_eq!(1, session.cache.len()); - assert_eq!(1, result.rows.unwrap().len()); + assert_eq!(1, result.rows_num().unwrap()); let result = session .execute("select * from test_table", &[]) @@ -331,7 +421,7 @@ mod tests { .unwrap(); assert_eq!(1, session.cache.len()); - assert_eq!(1, result.rows.unwrap().len()); + assert_eq!(1, result.rows_num().unwrap()); } /// Checks that caching works with execute_iter @@ -344,7 +434,8 @@ mod tests { let iter = session .execute_iter("select * from test_table", &[]) .await - .unwrap(); + .unwrap() + .into_typed::(); let rows = iter.try_collect::>().await.unwrap().len(); @@ -365,7 +456,7 @@ mod tests { .unwrap(); assert_eq!(1, session.cache.len()); - assert_eq!(1, result.rows.unwrap().len()); + assert_eq!(1, result.rows_num().unwrap()); } async fn assert_test_batch_table_rows_contain( @@ -376,7 +467,7 @@ mod tests { .execute("SELECT a, b FROM test_batch_table", ()) .await .unwrap() - .rows_typed::<(i32, i32)>() + .rows::<(i32, i32)>() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -582,7 +673,8 @@ mod tests { .execute("SELECT b, WRITETIME(b) FROM tbl", ()) .await .unwrap() - .rows_typed_or_empty::<(i32, i64)>() + .rows::<(i32, i64)>() + .unwrap() .collect::, _>>() .unwrap(); diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 6b2a63662d..de3867467c 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -40,7 +40,16 @@ use super::topology::Strategy; /// Cluster manages up to date information and connections to database nodes. /// All data can be accessed by cloning Arc in the `data` field -pub struct Cluster { +// +// NOTE: This structure was intentionally made cloneable. The reason for this +// is to make it possible to use two different Session APIs in the same program +// that share the same session resources. +// +// It is safe to do because the Cluster struct is just a facade for the real, +// "semantic" Cluster object. Cloned instance of this struct will use the same +// ClusterData and worker and will observe the same state. +#[derive(Clone)] +pub(crate) struct Cluster { // `ArcSwap` is wrapped in `Arc` to support sharing cluster data // between `Cluster` and `ClusterWorker` data: Arc>, @@ -48,12 +57,12 @@ pub struct Cluster { refresh_channel: tokio::sync::mpsc::Sender, use_keyspace_channel: tokio::sync::mpsc::Sender, - _worker_handle: RemoteHandle<()>, + _worker_handle: Arc>, } /// Enables printing [Cluster] struct in a neat way, by skipping the rather useless /// print of channels state and printing [ClusterData] neatly. -pub struct ClusterNeatDebug<'a>(pub &'a Cluster); +pub(crate) struct ClusterNeatDebug<'a>(pub &'a Cluster); impl<'a> std::fmt::Debug for ClusterNeatDebug<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let cluster = self.0; @@ -193,7 +202,7 @@ impl Cluster { data: cluster_data, refresh_channel: refresh_sender, use_keyspace_channel: use_keyspace_sender, - _worker_handle: worker_handle, + _worker_handle: Arc::new(worker_handle), }; Ok(result) diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index 223301e202..cd280d0ac7 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -34,7 +34,8 @@ use std::{ }; use super::errors::{BadKeyspaceName, DbError, QueryError}; -use super::iterator::RowIterator; +use super::iterator::RawIterator; +use super::query_result::{QueryResult, SingleRowError}; use super::session::AddressTranslator; use super::topology::{PeerEndpoint, UntranslatedEndpoint, UntranslatedPeer}; use super::NodeAddr; @@ -55,13 +56,8 @@ use crate::query::Query; use crate::routing::ShardInfo; use crate::statement::prepared_statement::PreparedStatement; use crate::statement::Consistency; -use crate::transport::session::IntoTypedRows; use crate::transport::Compression; -// Existing code imports scylla::transport::connection::QueryResult because it used to be located in this file. -// Reexport QueryResult to avoid breaking the existing code. -pub use crate::QueryResult; - // Queries for schema agreement const LOCAL_VERSION: &str = "SELECT schema_version FROM system.local WHERE key='local'"; @@ -194,14 +190,9 @@ impl NonErrorQueryResponse { } pub fn into_query_result(self) -> Result { - let (rows, paging_state, col_specs, serialized_size) = match self.response { - NonErrorResponse::Result(result::Result::Rows(rs)) => ( - Some(rs.rows), - rs.metadata.paging_state, - rs.metadata.col_specs, - rs.serialized_size, - ), - NonErrorResponse::Result(_) => (None, None, vec![], 0), + let raw_rows = match self.response { + NonErrorResponse::Result(result::Result::Rows(rs)) => Some(rs), + NonErrorResponse::Result(_) => None, _ => { return Err(QueryError::ProtocolError( "Unexpected server response, expected Result or Error", @@ -209,14 +200,7 @@ impl NonErrorQueryResponse { } }; - Ok(QueryResult { - rows, - warnings: self.warnings, - tracing_id: self.tracing_id, - paging_state, - col_specs, - serialized_size, - }) + Ok(QueryResult::new(raw_rows, self.tracing_id, self.warnings)) } } #[cfg(feature = "ssl")] @@ -594,7 +578,7 @@ impl Connection { self: Arc, query: Query, values: impl ValueList, - ) -> Result { + ) -> Result { let serialized_values = values.serialized()?.into_owned(); let consistency = query @@ -602,7 +586,7 @@ impl Connection { .determine_consistency(self.config.default_consistency); let serial_consistency = query.config.serial_consistency.flatten(); - RowIterator::new_for_connection_query_iter( + RawIterator::new_for_connection_query_iter( query, self, serialized_values, @@ -744,12 +728,18 @@ impl Connection { let (version_id,): (Uuid,) = self .query_single_page(LOCAL_VERSION, &[]) .await? - .rows - .ok_or(QueryError::ProtocolError("Version query returned not rows"))? - .into_typed::<(Uuid,)>() - .next() - .ok_or(QueryError::ProtocolError("Admin table returned empty rows"))? - .map_err(|_| QueryError::ProtocolError("Row is not uuid type as it should be"))?; + .single_row::<(Uuid,)>() + .map_err(|err| match err { + SingleRowError::NotRowsResponse => { + QueryError::ProtocolError("Version query returned not rows") + } + SingleRowError::UnexpectedRowCount(_) => { + QueryError::ProtocolError("system.local query returned a wrong number of rows") + } + SingleRowError::TypeCheckFailed(_) => { + QueryError::ProtocolError("Row is not uuid type as it should be") + } + })?; Ok(version_id) } @@ -832,8 +822,7 @@ impl Connection { ); } - let response = - Response::deserialize(features, task_response.opcode, &mut &*body_with_ext.body)?; + let response = Response::deserialize(features, task_response.opcode, body_with_ext.body)?; Ok(QueryResponse { response, @@ -1588,6 +1577,7 @@ mod tests { use scylla_cql::frame::protocol_features::{ LWT_OPTIMIZATION_META_BIT_MASK_KEY, SCYLLA_LWT_ADD_METADATA_MARK_EXTENSION, }; + use scylla_cql::frame::response::result::Row; use scylla_cql::frame::types; use scylla_proxy::{ Condition, Node, Proxy, Reaction, RequestFrame, RequestOpcode, RequestReaction, @@ -1652,7 +1642,7 @@ mod tests { // Preparation phase let session = SessionBuilder::new() .known_node_addr(addr) - .build() + .build_legacy() .await .unwrap(); session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'SimpleStrategy', 'replication_factor' : 1}}", ks.clone()), &[]).await.unwrap(); @@ -1682,6 +1672,7 @@ mod tests { .query_iter(select_query.clone(), &[]) .await .unwrap() + .into_typed::() .try_collect::>() .await .unwrap(); @@ -1715,6 +1706,7 @@ mod tests { .query_iter(insert_query, (0,)) .await .unwrap() + .into_typed::() .try_collect::>() .await .unwrap(); diff --git a/scylla/src/transport/cql_collections_test.rs b/scylla/src/transport/cql_collections_test.rs index 8c998f62fd..ec0263a735 100644 --- a/scylla/src/transport/cql_collections_test.rs +++ b/scylla/src/transport/cql_collections_test.rs @@ -1,8 +1,10 @@ -use crate::cql_to_rust::FromCqlVal; +use crate::transport::session::Session; +use scylla_cql::types::deserialize::value::DeserializeCql; + +use crate::frame::response::result::CqlValue; use crate::frame::value::Value; use crate::test_utils::create_new_session_builder; use crate::utils::test_utils::unique_keyspace_name; -use crate::{frame::response::result::CqlValue, IntoTypedRows, Session}; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; async fn connect() -> Session { @@ -34,7 +36,7 @@ async fn insert_and_select( expected: &SelectT, ) where InsertT: Value, - SelectT: FromCqlVal> + PartialEq + std::fmt::Debug, + SelectT: for<'r> DeserializeCql<'r> + PartialEq + std::fmt::Debug, { session .query( @@ -48,11 +50,7 @@ async fn insert_and_select( .query(format!("SELECT val FROM {} WHERE p = 0", table_name), ()) .await .unwrap() - .rows - .unwrap() - .into_typed::<(SelectT,)>() - .next() - .unwrap() + .single_row::<(SelectT,)>() .unwrap() .0; diff --git a/scylla/src/transport/cql_types_test.rs b/scylla/src/transport/cql_types_test.rs index 999476ecee..9f39900f65 100644 --- a/scylla/src/transport/cql_types_test.rs +++ b/scylla/src/transport/cql_types_test.rs @@ -1,17 +1,16 @@ use crate as scylla; -use crate::cql_to_rust::FromCqlVal; use crate::frame::response::result::CqlValue; use crate::frame::value::Counter; use crate::frame::value::Value; use crate::frame::value::{Date, Time, Timestamp}; use crate::macros::{FromUserType, IntoUserType}; use crate::test_utils::create_new_session_builder; -use crate::transport::session::IntoTypedRows; use crate::transport::session::Session; use crate::utils::test_utils::unique_keyspace_name; use bigdecimal::BigDecimal; use chrono::{Duration, NaiveDate}; use num_bigint::BigInt; +use scylla_cql::types::deserialize::value::DeserializeCql; use std::cmp::PartialEq; use std::fmt::Debug; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -67,7 +66,7 @@ async fn init_test(table_name: &str, type_name: &str) -> Session { // Expected values and bound values are computed using T::from_str async fn run_tests(tests: &[&str], type_name: &str) where - T: Value + FromCqlVal + FromStr + Debug + Clone + PartialEq, + T: Value + for<'r> DeserializeCql<'r> + FromStr + Debug + Clone + PartialEq, { let session: Session = init_test(type_name, type_name).await; session.await_schema_agreement().await.unwrap(); @@ -92,9 +91,8 @@ where .query(select_values, &[]) .await .unwrap() - .rows + .rows::<(T,)>() .unwrap() - .into_typed::<(T,)>() .map(Result::unwrap) .map(|row| row.0) .collect::>(); @@ -182,9 +180,8 @@ async fn test_counter() { .query(select_values, (i as i32,)) .await .unwrap() - .rows + .rows::<(Counter,)>() .unwrap() - .into_typed::<(Counter,)>() .map(Result::unwrap) .map(|row| row.0) .collect::>(); @@ -262,9 +259,8 @@ async fn test_naive_date() { .query("SELECT val from naive_date", &[]) .await .unwrap() - .rows + .rows::<(NaiveDate,)>() .unwrap() - .into_typed::<(NaiveDate,)>() .next() .unwrap() .ok() @@ -286,11 +282,7 @@ async fn test_naive_date() { .query("SELECT val from naive_date", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(NaiveDate,)>() - .next() - .unwrap() + .single_row::<(NaiveDate,)>() .unwrap(); assert_eq!(read_date, *naive_date); } @@ -341,18 +333,11 @@ async fn test_date() { .await .unwrap(); - let read_date: Date = session + let (read_date,): (Date,) = session .query("SELECT val from date_tests", &[]) .await .unwrap() - .rows - .unwrap()[0] - .columns[0] - .as_ref() - .map(|cql_val| match cql_val { - CqlValue::Date(days) => Date(*days), - _ => panic!(), - }) + .single_row::<(Date,)>() .unwrap(); assert_eq!(read_date, *date); @@ -394,11 +379,7 @@ async fn test_time() { .query("SELECT val from time_tests", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Duration,)>() - .next() - .unwrap() + .single_row::<(Duration,)>() .unwrap(); assert_eq!(read_time, *time_duration); @@ -416,11 +397,7 @@ async fn test_time() { .query("SELECT val from time_tests", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Duration,)>() - .next() - .unwrap() + .single_row::<(Duration,)>() .unwrap(); assert_eq!(read_time, *time_duration); @@ -498,11 +475,7 @@ async fn test_timestamp() { .query("SELECT val from timestamp_tests", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Duration,)>() - .next() - .unwrap() + .single_row::<(Duration,)>() .unwrap(); assert_eq!(read_timestamp, *timestamp_duration); @@ -520,11 +493,7 @@ async fn test_timestamp() { .query("SELECT val from timestamp_tests", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Duration,)>() - .next() - .unwrap() + .single_row::<(Duration,)>() .unwrap(); assert_eq!(read_timestamp, *timestamp_duration); @@ -574,11 +543,7 @@ async fn test_timeuuid() { .query("SELECT val from timeuuid_tests", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Uuid,)>() - .next() - .unwrap() + .single_row::<(Uuid,)>() .unwrap(); assert_eq!(read_timeuuid.as_bytes(), timeuuid_bytes); @@ -597,11 +562,7 @@ async fn test_timeuuid() { .query("SELECT val from timeuuid_tests", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Uuid,)>() - .next() - .unwrap() + .single_row::<(Uuid,)>() .unwrap(); assert_eq!(read_timeuuid.as_bytes(), timeuuid_bytes); @@ -666,11 +627,7 @@ async fn test_inet() { .query("SELECT val from inet_tests WHERE id = 0", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(IpAddr,)>() - .next() - .unwrap() + .single_row::<(IpAddr,)>() .unwrap(); assert_eq!(read_inet, *inet); @@ -685,11 +642,7 @@ async fn test_inet() { .query("SELECT val from inet_tests WHERE id = 0", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(IpAddr,)>() - .next() - .unwrap() + .single_row::<(IpAddr,)>() .unwrap(); assert_eq!(read_inet, *inet); @@ -739,11 +692,7 @@ async fn test_blob() { .query("SELECT val from blob_tests WHERE id = 0", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Vec,)>() - .next() - .unwrap() + .single_row::<(Vec,)>() .unwrap(); assert_eq!(read_blob, *blob); @@ -758,11 +707,7 @@ async fn test_blob() { .query("SELECT val from blob_tests WHERE id = 0", &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(Vec,)>() - .next() - .unwrap() + .single_row::<(Vec,)>() .unwrap(); assert_eq!(read_blob, *blob); @@ -848,11 +793,9 @@ async fn test_udt_after_schema_update() { .query(format!("SELECT val from {} WHERE id = 0", table_name), &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(UdtV1,)>() - .next() + .into_legacy_result() .unwrap() + .single_row_typed::<(UdtV1,)>() .unwrap(); assert_eq!(read_udt, v1); @@ -869,11 +812,9 @@ async fn test_udt_after_schema_update() { .query(format!("SELECT val from {} WHERE id = 0", table_name), &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(UdtV1,)>() - .next() + .into_legacy_result() .unwrap() + .single_row_typed::<(UdtV1,)>() .unwrap(); assert_eq!(read_udt, v1); @@ -894,11 +835,9 @@ async fn test_udt_after_schema_update() { .query(format!("SELECT val from {} WHERE id = 0", table_name), &[]) .await .unwrap() - .rows - .unwrap() - .into_typed::<(UdtV2,)>() - .next() + .into_legacy_result() .unwrap() + .single_row_typed::<(UdtV2,)>() .unwrap(); assert_eq!( @@ -927,7 +866,7 @@ async fn test_empty() { .query("SELECT val FROM empty_tests WHERE id = 0", ()) .await .unwrap() - .first_row_typed::<(CqlValue,)>() + .first_row::<(CqlValue,)>() .unwrap(); assert_eq!(empty, CqlValue::Empty); @@ -944,7 +883,7 @@ async fn test_empty() { .query("SELECT val FROM empty_tests WHERE id = 1", ()) .await .unwrap() - .first_row_typed::<(CqlValue,)>() + .first_row::<(CqlValue,)>() .unwrap(); assert_eq!(empty, CqlValue::Empty); diff --git a/scylla/src/transport/cql_value_test.rs b/scylla/src/transport/cql_value_test.rs index 75c736644c..7009cf6c59 100644 --- a/scylla/src/transport/cql_value_test.rs +++ b/scylla/src/transport/cql_value_test.rs @@ -1,4 +1,5 @@ -use crate::frame::{response::result::CqlValue, value::CqlDuration}; +use crate::frame::response::result::{CqlValue, Row}; +use crate::frame::value::CqlDuration; use crate::test_utils::create_new_session_builder; use crate::utils::test_utils::unique_keyspace_name; @@ -57,7 +58,9 @@ async fn test_cqlvalue_udt() { .query("SELECT my FROM cqlvalue_udt_test", &[]) .await .unwrap() - .rows + .rows::() + .unwrap() + .collect::, _>>() .unwrap(); assert_eq!(rows.len(), 1); @@ -111,7 +114,9 @@ async fn test_cqlvalue_duration() { ) .await .unwrap() - .rows + .rows::() + .unwrap() + .collect::, _>>() .unwrap(); assert_eq!(rows.len(), 4); diff --git a/scylla/src/transport/execution_profile.rs b/scylla/src/transport/execution_profile.rs index 8fe82f3e36..157de3e60d 100644 --- a/scylla/src/transport/execution_profile.rs +++ b/scylla/src/transport/execution_profile.rs @@ -16,7 +16,7 @@ //! # extern crate scylla; //! # use std::error::Error; //! # async fn check_only_compiles() -> Result<(), Box> { -//! use scylla::{Session, SessionBuilder}; +//! use scylla::{Legacy08Session, SessionBuilder}; //! use scylla::statement::Consistency; //! use scylla::transport::ExecutionProfile; //! @@ -27,10 +27,10 @@ //! //! let handle = profile.into_handle(); //! -//! let session: Session = SessionBuilder::new() +//! let session: Legacy08Session = SessionBuilder::new() //! .known_node("127.0.0.1:9042") //! .default_execution_profile_handle(handle) -//! .build() +//! .build_legacy() //! .await?; //! # Ok(()) //! # } @@ -109,7 +109,7 @@ //! # extern crate scylla; //! # use std::error::Error; //! # async fn check_only_compiles() -> Result<(), Box> { -//! use scylla::{Session, SessionBuilder}; +//! use scylla::{Legacy08Session, SessionBuilder}; //! use scylla::query::Query; //! use scylla::statement::Consistency; //! use scylla::transport::ExecutionProfile; @@ -125,10 +125,10 @@ //! let mut handle1 = profile1.clone().into_handle(); //! let mut handle2 = profile2.clone().into_handle(); //! -//! let session: Session = SessionBuilder::new() +//! let session: Legacy08Session = SessionBuilder::new() //! .known_node("127.0.0.1:9042") //! .default_execution_profile_handle(handle1.clone()) -//! .build() +//! .build_legacy() //! .await?; //! //! let mut query1 = Query::from("SELECT * FROM ks.table"); diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index e6fd344dc7..f0cdbfd12c 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -1,7 +1,6 @@ //! Iterators over rows returned by paged queries use std::future::Future; -use std::mem; use std::net::SocketAddr; use std::ops::ControlFlow; use std::pin::Pin; @@ -10,8 +9,10 @@ use std::task::{Context, Poll}; use bytes::Bytes; use futures::Stream; +use scylla_cql::frame::response::result::{RawRows, RawRowsLendingIterator}; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::frame::types::SerialConsistency; +use scylla_cql::types::deserialize::row::{ColumnIterator, DeserializeRow}; use std::result::Result; use thiserror::Error; use tokio::sync::mpsc; @@ -26,7 +27,7 @@ use crate::frame::types::LegacyConsistency; use crate::frame::{ response::{ result, - result::{ColumnSpec, Row, Rows}, + result::{ColumnSpec, Row}, }, value::SerializedValues, }; @@ -43,6 +44,20 @@ use crate::transport::{Node, NodeRef}; use tracing::{trace, trace_span, warn, Instrument}; use uuid::Uuid; +// Like std::task::ready!, but handles the whole stack of Poll>>. +// If it matches Poll::Ready(Some(Ok(_))), then it returns the innermost value, +// otherwise it returns from the surrounding function. +macro_rules! ready_some_ok { + ($e:expr) => { + match $e { + Poll::Ready(Some(Ok(x))) => x, + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + }; +} + // #424 // // Both `Query` and `PreparedStatement` have page size set to `None` as default, @@ -59,17 +74,25 @@ use uuid::Uuid; // value at the beginning of `query_iter` and `execute_iter`. const DEFAULT_ITER_PAGE_SIZE: i32 = 5000; -/// Iterator over rows returned by paged queries\ -/// Allows to easily access rows without worrying about handling multiple pages -pub struct RowIterator { - current_row_idx: usize, - current_page: Rows, +/// An intermediate object that allows to construct an iterator over a query +/// that is asynchronously paged in the background. +/// +/// Before the results can be processed, the RawIterator needs to be cast +/// into a typed iterator. +/// +/// TODO: How? +/// +/// A pre-0.8.0 interface is also available: +/// +/// TODO +pub struct RawIterator { + current_page: RawRowsLendingIterator, page_receiver: mpsc::Receiver>, tracing_ids: Vec, } struct ReceivedPage { - pub rows: Rows, + pub rows: RawRows, pub tracing_id: Option, } @@ -83,60 +106,88 @@ pub(crate) struct PreparedIteratorConfig { pub metrics: Arc, } -/// Fetching pages is asynchronous so `RowIterator` does not implement the `Iterator` trait.\ -/// Instead it uses the asynchronous `Stream` trait -impl Stream for RowIterator { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut s = self.as_mut(); +/// RawIterator is not an iterator or a stream! However, it implements +/// a `next()` method which returns a ColumnIterator<'r>. The ColumnIterator +/// borrows from the RawIterator, and the futures::Stream trait does not allow +/// for such a pattern. Lending streams are not a thing yet. +impl RawIterator { + /// Returns the next item from the stream. + /// + /// This is not a part of the Stream interface because the returned iterator + /// borrows from self. + /// + /// This is cancel-safe. + pub async fn next(&mut self) -> Option> { + let res = std::future::poll_fn(|cx| Pin::new(&mut *self).poll_fill_page(cx)).await; + match res { + Some(Ok(())) => {} + Some(Err(err)) => return Some(Err(err)), + None => return None, + } - if s.is_current_page_exhausted() { - match Pin::new(&mut s.page_receiver).poll_recv(cx) { - Poll::Ready(Some(Ok(received_page))) => { - s.current_page = received_page.rows; - s.current_row_idx = 0; + // We are guaranteed here to have a non-empty page, so unwrap + Some(self.current_page.next().unwrap().map_err(|e| e.into())) + } - if let Some(tracing_id) = received_page.tracing_id { - s.tracing_ids.push(tracing_id); - } - } - Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } + /// Tries to acquire a non-empty page, if current page is exhausted. + fn poll_fill_page<'r>( + mut self: Pin<&'r mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if !self.is_current_page_exhausted() { + return Poll::Ready(Some(Ok(()))); + } + ready_some_ok!(self.as_mut().poll_next_page(cx)); + if self.is_current_page_exhausted() { + // Try again later + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Some(Ok(()))) } + } + + /// Makes an attempt to acquire the next page (which may be empty). + /// + /// On success, returns Some(Ok()). + /// On failure, returns Some(Err()). + /// If there are no more pages, returns None. + fn poll_next_page<'r>( + mut self: Pin<&'r mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let mut s = self.as_mut(); - let idx = s.current_row_idx; - if idx < s.current_page.rows.len() { - let row = mem::take(&mut s.current_page.rows[idx]); - s.current_row_idx += 1; - return Poll::Ready(Some(Ok(row))); + let received_page = ready_some_ok!(Pin::new(&mut s.page_receiver).poll_recv(cx)); + s.current_page = RawRowsLendingIterator::new(received_page.rows); + + if let Some(tracing_id) = received_page.tracing_id { + s.tracing_ids.push(tracing_id); } - // We probably got a zero-sized page - // Yield, but tell that we are ready - cx.waker().wake_by_ref(); - Poll::Pending + Poll::Ready(Some(Ok(()))) } -} -impl RowIterator { - /// Converts this iterator into an iterator over rows parsed as given type - pub fn into_typed(self) -> TypedRowIterator { + pub fn into_typed(self) -> TypedRowIterator { TypedRowIterator { - row_iterator: self, - phantom_data: Default::default(), + raw_iterator: self, + _phantom: Default::default(), } } + /// Converts this iterator into an iterator over rows parsed as given type, + /// using the legacy deserialization framework. + pub fn into_legacy(self) -> Legacy08RowIterator { + Legacy08RowIterator { raw_iterator: self } + } + pub(crate) async fn new_for_query( mut query: Query, values: SerializedValues, execution_profile: Arc, cluster_data: Arc, metrics: Arc, - ) -> Result { + ) -> Result { if query.get_page_size().is_none() { query.set_page_size(DEFAULT_ITER_PAGE_SIZE); } @@ -210,7 +261,7 @@ impl RowIterator { pub(crate) async fn new_for_prepared_statement( mut config: PreparedIteratorConfig, - ) -> Result { + ) -> Result { if config.prepared.get_page_size().is_none() { config.prepared.set_page_size(DEFAULT_ITER_PAGE_SIZE); } @@ -328,7 +379,7 @@ impl RowIterator { values: SerializedValues, consistency: Consistency, serial_consistency: Option, - ) -> Result { + ) -> Result { if query.get_page_size().is_none() { query.set_page_size(DEFAULT_ITER_PAGE_SIZE); } @@ -356,7 +407,7 @@ impl RowIterator { async fn new_from_worker_future( worker_task: impl Future + Send + 'static, mut receiver: mpsc::Receiver>, - ) -> Result { + ) -> Result { tokio::task::spawn(worker_task.with_current_subscriber()); // This unwrap is safe because: @@ -366,9 +417,8 @@ impl RowIterator { // cancelled let pages_received = receiver.recv().await.unwrap()?; - Ok(RowIterator { - current_row_idx: 0, - current_page: pages_received.rows, + Ok(Self { + current_page: RawRowsLendingIterator::new(pages_received.rows), page_receiver: receiver, tracing_ids: if let Some(tracing_id) = pages_received.tracing_id { vec![tracing_id] @@ -385,18 +435,19 @@ impl RowIterator { /// Returns specification of row columns pub fn get_column_specs(&self) -> &[ColumnSpec] { - &self.current_page.metadata.col_specs + &self.current_page.metadata().col_specs } fn is_current_page_exhausted(&self) -> bool { - self.current_row_idx >= self.current_page.rows.len() + self.current_page.rows_remaining() == 0 } } // A separate module is used here so that the parent module cannot construct // SendAttemptedProof directly. mod checked_channel_sender { - use scylla_cql::{errors::QueryError, frame::response::result::Rows}; + use scylla_cql::errors::QueryError; + use scylla_cql::frame::response::result::RawRows; use std::marker::PhantomData; use tokio::sync::mpsc; use uuid::Uuid; @@ -437,12 +488,7 @@ mod checked_channel_sender { Result<(), mpsc::error::SendError>, ) { let empty_page = ReceivedPage { - rows: Rows { - metadata: Default::default(), - rows_count: 0, - rows: Vec::new(), - serialized_size: 0, - }, + rows: RawRows::default(), tracing_id, }; self.send(Ok(empty_page)).await @@ -651,7 +697,7 @@ where match query_response { Ok(NonErrorQueryResponse { - response: NonErrorResponse::Result(result::Result::Rows(mut rows)), + response: NonErrorResponse::Result(result::Result::Rows(rows)), tracing_id, .. }) => { @@ -662,7 +708,7 @@ where .load_balancing_policy .on_query_success(&self.statement_info, elapsed, node); - self.paging_state = rows.metadata.paging_state.take(); + self.paging_state = rows.metadata().paging_state.clone(); request_span.record_rows_fields(&rows); @@ -826,8 +872,8 @@ where let result = (self.fetcher)(paging_state).await?; let response = result.into_non_error_query_response()?; match response.response { - NonErrorResponse::Result(result::Result::Rows(mut rows)) => { - paging_state = rows.metadata.paging_state.take(); + NonErrorResponse::Result(result::Result::Rows(rows)) => { + paging_state = rows.metadata().paging_state.clone(); let (proof, send_result) = self .sender .send(Ok(ReceivedPage { @@ -857,15 +903,101 @@ where } } +pub struct TypedRowIterator { + raw_iterator: RawIterator, + _phantom: std::marker::PhantomData, +} + +/// Stream implementation for TypedRowIterator. +/// +/// It only works with owned types! For example, &str is not supported. +impl Stream for TypedRowIterator +where + RowT: for<'r> DeserializeRow<'r>, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut s = self.as_mut(); + + let next_fut = s.raw_iterator.next(); + futures::pin_mut!(next_fut); + let iter = ready_some_ok!(next_fut.poll(cx)); + let value = >::deserialize(iter).map_err(|e| e.into()); + Poll::Ready(Some(value)) + } +} + +impl TypedRowIterator { + /// If tracing was enabled returns tracing ids of all finished page queries + pub fn get_tracing_ids(&self) -> &[Uuid] { + self.raw_iterator.get_tracing_ids() + } + + /// Returns specification of row columns + pub fn get_column_specs(&self) -> &[ColumnSpec] { + self.raw_iterator.get_column_specs() + } +} + +impl Unpin for TypedRowIterator {} + +pub struct Legacy08RowIterator { + raw_iterator: RawIterator, +} + +impl Stream for Legacy08RowIterator { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut s = self.as_mut(); + + let next_fut = s.raw_iterator.next(); + futures::pin_mut!(next_fut); + + let next_elem: Option, QueryError>> = match next_fut.poll(cx) { + Poll::Ready(next_elem) => next_elem, + Poll::Pending => return Poll::Pending, + }; + + let next_ready: Option = match next_elem { + Some(Ok(iter)) => Some(Row::deserialize(iter).map_err(|e| e.into())), + Some(Err(e)) => Some(Err(e)), + None => None, + }; + + Poll::Ready(next_ready) + } +} + +impl Legacy08RowIterator { + /// If tracing was enabled returns tracing ids of all finished page queries + pub fn get_tracing_ids(&self) -> &[Uuid] { + self.raw_iterator.get_tracing_ids() + } + + /// Returns specification of row columns + pub fn get_column_specs(&self) -> &[ColumnSpec] { + self.raw_iterator.get_column_specs() + } + + pub fn into_typed(self) -> Legacy08TypedRowIterator { + Legacy08TypedRowIterator { + row_iterator: self, + _phantom_data: Default::default(), + } + } +} + /// Iterator over rows returned by paged queries /// where each row is parsed as the given type\ /// Returned by `RowIterator::into_typed` -pub struct TypedRowIterator { - row_iterator: RowIterator, - phantom_data: std::marker::PhantomData, +pub struct Legacy08TypedRowIterator { + row_iterator: Legacy08RowIterator, + _phantom_data: std::marker::PhantomData, } -impl TypedRowIterator { +impl Legacy08TypedRowIterator { /// If tracing was enabled returns tracing ids of all finished page queries pub fn get_tracing_ids(&self) -> &[Uuid] { self.row_iterator.get_tracing_ids() @@ -891,27 +1023,17 @@ pub enum NextRowError { /// Fetching pages is asynchronous so `TypedRowIterator` does not implement the `Iterator` trait.\ /// Instead it uses the asynchronous `Stream` trait -impl Stream for TypedRowIterator { +impl Stream for Legacy08TypedRowIterator { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut s = self.as_mut(); - let next_elem: Option> = - match Pin::new(&mut s.row_iterator).poll_next(cx) { - Poll::Ready(next_elem) => next_elem, - Poll::Pending => return Poll::Pending, - }; - - let next_ready: Option = match next_elem { - Some(Ok(next_row)) => Some(RowT::from_row(next_row).map_err(|e| e.into())), - Some(Err(e)) => Some(Err(e.into())), - None => None, - }; - - Poll::Ready(next_ready) + let next_row = ready_some_ok!(Pin::new(&mut s.row_iterator).poll_next(cx)); + let typed_row_res = RowT::from_row(next_row).map_err(|e| e.into()); + Poll::Ready(Some(typed_row_res)) } } // TypedRowIterator can be moved freely for any RowT so it's Unpin -impl Unpin for TypedRowIterator {} +impl Unpin for Legacy08TypedRowIterator {} diff --git a/scylla/src/transport/legacy_query_result.rs b/scylla/src/transport/legacy_query_result.rs new file mode 100644 index 0000000000..634b4592a1 --- /dev/null +++ b/scylla/src/transport/legacy_query_result.rs @@ -0,0 +1,590 @@ +use crate::frame::response::cql_to_rust::{FromRow, FromRowError}; +use crate::frame::response::result::ColumnSpec; +use crate::frame::response::result::Row; +use crate::transport::session::{IntoTypedRows, TypedRowIter}; +use bytes::Bytes; +use thiserror::Error; +use uuid::Uuid; + +/// Result of a single query\ +/// Contains all rows returned by the database and some more information +#[non_exhaustive] +#[derive(Default, Debug)] +pub struct Legacy08QueryResult { + /// Rows returned by the database.\ + /// Queries like `SELECT` will have `Some(Vec)`, while queries like `INSERT` will have `None`.\ + /// Can contain an empty Vec. + pub rows: Option>, + /// Warnings returned by the database + pub warnings: Vec, + /// CQL Tracing uuid - can only be Some if tracing is enabled for this query + pub tracing_id: Option, + /// Paging state returned from the server + pub paging_state: Option, + /// Column specification returned from the server + pub col_specs: Vec, + /// The original size of the serialized rows in request + pub serialized_size: usize, +} + +impl Legacy08QueryResult { + /// Returns the number of received rows.\ + /// Fails when the query isn't of a type that could return rows, same as [`rows()`](Legacy08QueryResult::rows). + pub fn rows_num(&self) -> Result { + match &self.rows { + Some(rows) => Ok(rows.len()), + None => Err(RowsExpectedError), + } + } + + /// Returns the received rows when present.\ + /// If `Legacy08QueryResult.rows` is `None`, which means that this query is not supposed to return rows (e.g `INSERT`), returns an error.\ + /// Can return an empty `Vec`. + pub fn rows(self) -> Result, RowsExpectedError> { + match self.rows { + Some(rows) => Ok(rows), + None => Err(RowsExpectedError), + } + } + + /// Returns the received rows parsed as the given type.\ + /// Equal to `rows()?.into_typed()`.\ + /// Fails when the query isn't of a type that could return rows, same as [`rows()`](Legacy08QueryResult::rows). + pub fn rows_typed(self) -> Result, RowsExpectedError> { + Ok(self.rows()?.into_typed()) + } + + /// Returns `Ok` for a result of a query that shouldn't contain any rows.\ + /// Will return `Ok` for `INSERT` result, but a `SELECT` result, even an empty one, will cause an error.\ + /// Opposite of [`rows()`](Legacy08QueryResult::rows). + pub fn result_not_rows(&self) -> Result<(), RowsNotExpectedError> { + match self.rows { + Some(_) => Err(RowsNotExpectedError), + None => Ok(()), + } + } + + /// Returns rows when `Legacy08QueryResult.rows` is `Some`, otherwise an empty Vec.\ + /// Equal to `rows().unwrap_or_default()`. + pub fn rows_or_empty(self) -> Vec { + self.rows.unwrap_or_default() + } + + /// Returns rows parsed as the given type.\ + /// When `Legacy08QueryResult.rows` is `None`, returns 0 rows.\ + /// Equal to `rows_or_empty().into_typed::()`. + pub fn rows_typed_or_empty(self) -> TypedRowIter { + self.rows_or_empty().into_typed::() + } + + /// Returns first row from the received rows.\ + /// When the first row is not available, returns an error. + pub fn first_row(self) -> Result { + match self.maybe_first_row()? { + Some(row) => Ok(row), + None => Err(FirstRowError::RowsEmpty), + } + } + + /// Returns first row from the received rows parsed as the given type.\ + /// When the first row is not available, returns an error. + pub fn first_row_typed(self) -> Result { + Ok(self.first_row()?.into_typed()?) + } + + /// Returns `Option` containing the first of a result.\ + /// Fails when the query isn't of a type that could return rows, same as [`rows()`](Legacy08QueryResult::rows). + pub fn maybe_first_row(self) -> Result, RowsExpectedError> { + Ok(self.rows()?.into_iter().next()) + } + + /// Returns `Option` containing the first of a result.\ + /// Fails when the query isn't of a type that could return rows, same as [`rows()`](Legacy08QueryResult::rows). + pub fn maybe_first_row_typed( + self, + ) -> Result, MaybeFirstRowTypedError> { + match self.maybe_first_row()? { + Some(row) => Ok(Some(row.into_typed::()?)), + None => Ok(None), + } + } + + /// Returns the only received row.\ + /// Fails if the result is anything else than a single row.\ + pub fn single_row(self) -> Result { + let rows: Vec = self.rows()?; + + if rows.len() != 1 { + return Err(SingleRowError::BadNumberOfRows(rows.len())); + } + + Ok(rows.into_iter().next().unwrap()) + } + + /// Returns the only received row parsed as the given type.\ + /// Fails if the result is anything else than a single row.\ + pub fn single_row_typed(self) -> Result { + Ok(self.single_row()?.into_typed::()?) + } + + /// Returns a column specification for a column with given name, or None if not found + pub fn get_column_spec<'a>(&'a self, name: &str) -> Option<(usize, &'a ColumnSpec)> { + self.col_specs + .iter() + .enumerate() + .find(|(_id, spec)| spec.name == name) + } +} + +/// [`Legacy08QueryResult::rows()`](Legacy08QueryResult::rows) or a similar function called on a bad Legacy08QueryResult.\ +/// Expected `Legacy08QueryResult.rows` to be `Some`, but it was `None`.\ +/// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ +/// It is `None` for queries that can't return rows (e.g `INSERT`). +#[derive(Debug, Clone, Error, PartialEq, Eq)] +#[error( + "Legacy08QueryResult::rows() or similar function called on a bad Legacy08QueryResult. + Expected Legacy08QueryResult.rows to be Some, but it was None. + Legacy08QueryResult.rows is Some for queries that can return rows (e.g SELECT). + It is None for queries that can't return rows (e.g INSERT)." +)] +pub struct RowsExpectedError; + +/// [`Legacy08QueryResult::result_not_rows()`](Legacy08QueryResult::result_not_rows) called on a bad Legacy08QueryResult.\ +/// Expected `Legacy08QueryResult.rows` to be `None`, but it was `Some`.\ +/// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ +/// It is `None` for queries that can't return rows (e.g `INSERT`). +#[derive(Debug, Clone, Error, PartialEq, Eq)] +#[error( + "Legacy08QueryResult::result_not_rows() called on a bad Legacy08QueryResult. + Expected Legacy08QueryResult.rows to be None, but it was Some. + Legacy08QueryResult.rows is Some for queries that can return rows (e.g SELECT). + It is None for queries that can't return rows (e.g INSERT)." +)] +pub struct RowsNotExpectedError; + +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum FirstRowError { + /// [`Legacy08QueryResult::first_row()`](Legacy08QueryResult::first_row) called on a bad Legacy08QueryResult.\ + /// Expected `Legacy08QueryResult.rows` to be `Some`, but it was `None`.\ + /// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ + /// It is `None` for queries that can't return rows (e.g `INSERT`). + #[error(transparent)] + RowsExpected(#[from] RowsExpectedError), + + /// Rows in `Legacy08QueryResult` are empty + #[error("Rows in Legacy08QueryResult are empty")] + RowsEmpty, +} + +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum FirstRowTypedError { + /// [`Legacy08QueryResult::first_row_typed()`](Legacy08QueryResult::first_row_typed) called on a bad Legacy08QueryResult.\ + /// Expected `Legacy08QueryResult.rows` to be `Some`, but it was `None`.\ + /// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ + /// It is `None` for queries that can't return rows (e.g `INSERT`). + #[error(transparent)] + RowsExpected(#[from] RowsExpectedError), + + /// Rows in `Legacy08QueryResult` are empty + #[error("Rows in Legacy08QueryResult are empty")] + RowsEmpty, + + /// Parsing row as the given type failed + #[error(transparent)] + FromRowError(#[from] FromRowError), +} + +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum MaybeFirstRowTypedError { + /// [`Legacy08QueryResult::maybe_first_row_typed()`](Legacy08QueryResult::maybe_first_row_typed) called on a bad Legacy08QueryResult.\ + /// Expected `Legacy08QueryResult.rows` to be `Some`, but it was `None`. + /// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ + /// It is `None` for queries that can't return rows (e.g `INSERT`). + #[error(transparent)] + RowsExpected(#[from] RowsExpectedError), + + /// Parsing row as the given type failed + #[error(transparent)] + FromRowError(#[from] FromRowError), +} + +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum SingleRowError { + /// [`Legacy08QueryResult::single_row()`](Legacy08QueryResult::single_row) called on a bad Legacy08QueryResult.\ + /// Expected `Legacy08QueryResult.rows` to be `Some`, but it was `None`.\ + /// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ + /// It is `None` for queries that can't return rows (e.g `INSERT`). + #[error(transparent)] + RowsExpected(#[from] RowsExpectedError), + + /// Expected a single row, found other number of rows + #[error("Expected a single row, found {0} rows")] + BadNumberOfRows(usize), +} + +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum SingleRowTypedError { + /// [`Legacy08QueryResult::single_row_typed()`](Legacy08QueryResult::single_row_typed) called on a bad Legacy08QueryResult.\ + /// Expected `Legacy08QueryResult.rows` to be `Some`, but it was `None`.\ + /// `Legacy08QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ + /// It is `None` for queries that can't return rows (e.g `INSERT`). + #[error(transparent)] + RowsExpected(#[from] RowsExpectedError), + + /// Expected a single row, found other number of rows + #[error("Expected a single row, found {0} rows")] + BadNumberOfRows(usize), + + /// Parsing row as the given type failed + #[error(transparent)] + FromRowError(#[from] FromRowError), +} + +impl From for FirstRowTypedError { + fn from(err: FirstRowError) -> FirstRowTypedError { + match err { + FirstRowError::RowsExpected(e) => FirstRowTypedError::RowsExpected(e), + FirstRowError::RowsEmpty => FirstRowTypedError::RowsEmpty, + } + } +} + +impl From for SingleRowTypedError { + fn from(err: SingleRowError) -> SingleRowTypedError { + match err { + SingleRowError::RowsExpected(e) => SingleRowTypedError::RowsExpected(e), + SingleRowError::BadNumberOfRows(r) => SingleRowTypedError::BadNumberOfRows(r), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row, TableSpec}; + use std::convert::TryInto; + + // Returns specified number of rows, each one containing one int32 value. + // Values are 0, 1, 2, 3, 4, ... + fn make_rows(rows_num: usize) -> Vec { + let mut rows: Vec = Vec::with_capacity(rows_num); + for cur_value in 0..rows_num { + let int_val: i32 = cur_value.try_into().unwrap(); + rows.push(Row { + columns: vec![Some(CqlValue::Int(int_val))], + }); + } + rows + } + + // Just like make_rows, but each column has one String value + // values are "val0", "val1", "val2", ... + fn make_string_rows(rows_num: usize) -> Vec { + let mut rows: Vec = Vec::with_capacity(rows_num); + for cur_value in 0..rows_num { + rows.push(Row { + columns: vec![Some(CqlValue::Text(format!("val{}", cur_value)))], + }); + } + rows + } + + fn make_not_rows_query_result() -> Legacy08QueryResult { + let table_spec = TableSpec { + ks_name: "some_keyspace".to_string(), + table_name: "some_table".to_string(), + }; + + let column_spec = ColumnSpec { + table_spec, + name: "column0".to_string(), + typ: ColumnType::Int, + }; + + Legacy08QueryResult { + rows: None, + warnings: vec![], + tracing_id: None, + paging_state: None, + col_specs: vec![column_spec], + serialized_size: 0, + } + } + + fn make_rows_query_result(rows_num: usize) -> Legacy08QueryResult { + let mut res = make_not_rows_query_result(); + res.rows = Some(make_rows(rows_num)); + res + } + + fn make_string_rows_query_result(rows_num: usize) -> Legacy08QueryResult { + let mut res = make_not_rows_query_result(); + res.rows = Some(make_string_rows(rows_num)); + res + } + + #[test] + fn rows_num_test() { + assert_eq!( + make_not_rows_query_result().rows_num(), + Err(RowsExpectedError) + ); + assert_eq!(make_rows_query_result(0).rows_num(), Ok(0)); + assert_eq!(make_rows_query_result(1).rows_num(), Ok(1)); + assert_eq!(make_rows_query_result(2).rows_num(), Ok(2)); + assert_eq!(make_rows_query_result(3).rows_num(), Ok(3)); + } + + #[test] + fn rows_test() { + assert_eq!(make_not_rows_query_result().rows(), Err(RowsExpectedError)); + assert_eq!(make_rows_query_result(0).rows(), Ok(vec![])); + assert_eq!(make_rows_query_result(1).rows(), Ok(make_rows(1))); + assert_eq!(make_rows_query_result(2).rows(), Ok(make_rows(2))); + } + + #[test] + fn rows_typed_test() { + assert!(make_not_rows_query_result().rows_typed::<(i32,)>().is_err()); + + let rows0: Vec<(i32,)> = make_rows_query_result(0) + .rows_typed::<(i32,)>() + .unwrap() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows0, vec![]); + + let rows1: Vec<(i32,)> = make_rows_query_result(1) + .rows_typed::<(i32,)>() + .unwrap() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows1, vec![(0,)]); + + let rows2: Vec<(i32,)> = make_rows_query_result(2) + .rows_typed::<(i32,)>() + .unwrap() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows2, vec![(0,), (1,)]); + } + + #[test] + fn result_not_rows_test() { + assert_eq!(make_not_rows_query_result().result_not_rows(), Ok(())); + assert_eq!( + make_rows_query_result(0).result_not_rows(), + Err(RowsNotExpectedError) + ); + assert_eq!( + make_rows_query_result(1).result_not_rows(), + Err(RowsNotExpectedError) + ); + assert_eq!( + make_rows_query_result(2).result_not_rows(), + Err(RowsNotExpectedError) + ); + } + + #[test] + fn rows_or_empty_test() { + assert_eq!(make_not_rows_query_result().rows_or_empty(), vec![]); + assert_eq!(make_rows_query_result(0).rows_or_empty(), make_rows(0)); + assert_eq!(make_rows_query_result(1).rows_or_empty(), make_rows(1)); + assert_eq!(make_rows_query_result(2).rows_or_empty(), make_rows(2)); + } + + #[test] + fn rows_typed_or_empty() { + let rows_empty: Vec<(i32,)> = make_not_rows_query_result() + .rows_typed_or_empty::<(i32,)>() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows_empty, vec![]); + + let rows0: Vec<(i32,)> = make_rows_query_result(0) + .rows_typed_or_empty::<(i32,)>() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows0, vec![]); + + let rows1: Vec<(i32,)> = make_rows_query_result(1) + .rows_typed_or_empty::<(i32,)>() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows1, vec![(0,)]); + + let rows2: Vec<(i32,)> = make_rows_query_result(2) + .rows_typed_or_empty::<(i32,)>() + .map(|r| r.unwrap()) + .collect(); + + assert_eq!(rows2, vec![(0,), (1,)]); + } + + #[test] + fn first_row_test() { + assert_eq!( + make_not_rows_query_result().first_row(), + Err(FirstRowError::RowsExpected(RowsExpectedError)) + ); + assert_eq!( + make_rows_query_result(0).first_row(), + Err(FirstRowError::RowsEmpty) + ); + assert_eq!( + make_rows_query_result(1).first_row(), + Ok(make_rows(1).into_iter().next().unwrap()) + ); + assert_eq!( + make_rows_query_result(2).first_row(), + Ok(make_rows(2).into_iter().next().unwrap()) + ); + assert_eq!( + make_rows_query_result(3).first_row(), + Ok(make_rows(3).into_iter().next().unwrap()) + ); + } + + #[test] + fn first_row_typed_test() { + assert_eq!( + make_not_rows_query_result().first_row_typed::<(i32,)>(), + Err(FirstRowTypedError::RowsExpected(RowsExpectedError)) + ); + assert_eq!( + make_rows_query_result(0).first_row_typed::<(i32,)>(), + Err(FirstRowTypedError::RowsEmpty) + ); + assert_eq!( + make_rows_query_result(1).first_row_typed::<(i32,)>(), + Ok((0,)) + ); + assert_eq!( + make_rows_query_result(2).first_row_typed::<(i32,)>(), + Ok((0,)) + ); + assert_eq!( + make_rows_query_result(3).first_row_typed::<(i32,)>(), + Ok((0,)) + ); + + assert!(matches!( + make_string_rows_query_result(2).first_row_typed::<(i32,)>(), + Err(FirstRowTypedError::FromRowError(_)) + )); + } + + #[test] + fn maybe_first_row_test() { + assert_eq!( + make_not_rows_query_result().maybe_first_row(), + Err(RowsExpectedError) + ); + assert_eq!(make_rows_query_result(0).maybe_first_row(), Ok(None)); + assert_eq!( + make_rows_query_result(1).maybe_first_row(), + Ok(Some(make_rows(1).into_iter().next().unwrap())) + ); + assert_eq!( + make_rows_query_result(2).maybe_first_row(), + Ok(Some(make_rows(2).into_iter().next().unwrap())) + ); + assert_eq!( + make_rows_query_result(3).maybe_first_row(), + Ok(Some(make_rows(3).into_iter().next().unwrap())) + ); + } + + #[test] + fn maybe_first_row_typed_test() { + assert_eq!( + make_not_rows_query_result().maybe_first_row_typed::<(i32,)>(), + Err(MaybeFirstRowTypedError::RowsExpected(RowsExpectedError)) + ); + + assert_eq!( + make_rows_query_result(0).maybe_first_row_typed::<(i32,)>(), + Ok(None) + ); + + assert_eq!( + make_rows_query_result(1).maybe_first_row_typed::<(i32,)>(), + Ok(Some((0,))) + ); + + assert_eq!( + make_rows_query_result(2).maybe_first_row_typed::<(i32,)>(), + Ok(Some((0,))) + ); + + assert_eq!( + make_rows_query_result(3).maybe_first_row_typed::<(i32,)>(), + Ok(Some((0,))) + ); + + assert!(matches!( + make_string_rows_query_result(1).maybe_first_row_typed::<(i32,)>(), + Err(MaybeFirstRowTypedError::FromRowError(_)) + )) + } + + #[test] + fn single_row_test() { + assert_eq!( + make_not_rows_query_result().single_row(), + Err(SingleRowError::RowsExpected(RowsExpectedError)) + ); + assert_eq!( + make_rows_query_result(0).single_row(), + Err(SingleRowError::BadNumberOfRows(0)) + ); + assert_eq!( + make_rows_query_result(1).single_row(), + Ok(make_rows(1).into_iter().next().unwrap()) + ); + assert_eq!( + make_rows_query_result(2).single_row(), + Err(SingleRowError::BadNumberOfRows(2)) + ); + assert_eq!( + make_rows_query_result(3).single_row(), + Err(SingleRowError::BadNumberOfRows(3)) + ); + } + + #[test] + fn single_row_typed_test() { + assert_eq!( + make_not_rows_query_result().single_row_typed::<(i32,)>(), + Err(SingleRowTypedError::RowsExpected(RowsExpectedError)) + ); + assert_eq!( + make_rows_query_result(0).single_row_typed::<(i32,)>(), + Err(SingleRowTypedError::BadNumberOfRows(0)) + ); + assert_eq!( + make_rows_query_result(1).single_row_typed::<(i32,)>(), + Ok((0,)) + ); + assert_eq!( + make_rows_query_result(2).single_row_typed::<(i32,)>(), + Err(SingleRowTypedError::BadNumberOfRows(2)) + ); + assert_eq!( + make_rows_query_result(3).single_row_typed::<(i32,)>(), + Err(SingleRowTypedError::BadNumberOfRows(3)) + ); + + assert!(matches!( + make_string_rows_query_result(1).single_row_typed::<(i32,)>(), + Err(SingleRowTypedError::FromRowError(_)) + )); + } +} diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index ceec5a42b3..9c9d6e5933 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -2610,7 +2610,7 @@ mod latency_awareness { let session = create_new_session_builder() .default_execution_profile_handle(handle) - .build() + .build_legacy() .await .unwrap(); diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index 3b6bea3830..831fc9d0af 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -6,6 +6,7 @@ pub mod downgrading_consistency_retry_policy; pub mod execution_profile; pub mod host_filter; pub mod iterator; +pub mod legacy_query_result; pub mod load_balancing; pub mod locator; pub(crate) mod metrics; diff --git a/scylla/src/transport/query_result.rs b/scylla/src/transport/query_result.rs index 98a623f01c..df24b56d40 100644 --- a/scylla/src/transport/query_result.rs +++ b/scylla/src/transport/query_result.rs @@ -1,590 +1,228 @@ -use crate::frame::response::cql_to_rust::{FromRow, FromRowError}; -use crate::frame::response::result::ColumnSpec; -use crate::frame::response::result::Row; -use crate::transport::session::{IntoTypedRows, TypedRowIter}; use bytes::Bytes; use thiserror::Error; use uuid::Uuid; -/// Result of a single query\ -/// Contains all rows returned by the database and some more information -#[non_exhaustive] +use scylla_cql::frame::frame_errors::ParseError; +use scylla_cql::frame::response::result::{ColumnSpec, RawRows, Row}; +use scylla_cql::types::deserialize::row::DeserializeRow; +use scylla_cql::types::deserialize::TypedRowIterator; + +use super::legacy_query_result::Legacy08QueryResult; + +/// Raw results of a single query. +/// +/// More comprehensive description TODO #[derive(Default, Debug)] pub struct QueryResult { - /// Rows returned by the database.\ - /// Queries like `SELECT` will have `Some(Vec)`, while queries like `INSERT` will have `None`.\ - /// Can contain an empty Vec. - pub rows: Option>, - /// Warnings returned by the database - pub warnings: Vec, - /// CQL Tracing uuid - can only be Some if tracing is enabled for this query - pub tracing_id: Option, - /// Paging state returned from the server - pub paging_state: Option, - /// Column specification returned from the server - pub col_specs: Vec, - /// The original size of the serialized rows in request - pub serialized_size: usize, + raw_rows: Option, + tracing_id: Option, + warnings: Vec, } impl QueryResult { - /// Returns the number of received rows.\ - /// Fails when the query isn't of a type that could return rows, same as [`rows()`](QueryResult::rows). - pub fn rows_num(&self) -> Result { - match &self.rows { - Some(rows) => Ok(rows.len()), - None => Err(RowsExpectedError), + pub(crate) fn new( + raw_rows: Option, + tracing_id: Option, + warnings: Vec, + ) -> Self { + Self { + raw_rows, + tracing_id, + warnings, } } - /// Returns the received rows when present.\ - /// If `QueryResult.rows` is `None`, which means that this query is not supposed to return rows (e.g `INSERT`), returns an error.\ - /// Can return an empty `Vec`. - pub fn rows(self) -> Result, RowsExpectedError> { - match self.rows { - Some(rows) => Ok(rows), - None => Err(RowsExpectedError), - } + /// Returns the number of received rows, or `None` if the response wasn't of Rows type. + pub fn rows_num(&self) -> Option { + Some(self.raw_rows.as_ref()?.rows_count()) } - /// Returns the received rows parsed as the given type.\ - /// Equal to `rows()?.into_typed()`.\ - /// Fails when the query isn't of a type that could return rows, same as [`rows()`](QueryResult::rows). - pub fn rows_typed(self) -> Result, RowsExpectedError> { - Ok(self.rows()?.into_typed()) + /// Returns the size of the serialized rows, or `None` if the response wasn't of Rows type. + pub fn rows_size(&self) -> Option { + Some(self.raw_rows.as_ref()?.rows_size()) } - /// Returns `Ok` for a result of a query that shouldn't contain any rows.\ - /// Will return `Ok` for `INSERT` result, but a `SELECT` result, even an empty one, will cause an error.\ - /// Opposite of [`rows()`](QueryResult::rows). - pub fn result_not_rows(&self) -> Result<(), RowsNotExpectedError> { - match self.rows { - Some(_) => Err(RowsNotExpectedError), - None => Ok(()), - } + /// Returns a bool indicating the current response is of Rows type. + pub fn is_rows(&self) -> bool { + self.raw_rows.is_some() } - /// Returns rows when `QueryResult.rows` is `Some`, otherwise an empty Vec.\ - /// Equal to `rows().unwrap_or_default()`. - pub fn rows_or_empty(self) -> Vec { - self.rows.unwrap_or_default() + /// Returns the received rows when present. + /// + /// Returns an error if the original query didn't return a Rows response (e.g. it was an `INSERT`), + /// or the response is of incorrect type. + pub fn rows<'s, R: DeserializeRow<'s>>(&'s self) -> Result, RowsError> { + Ok(self + .raw_rows + .as_ref() + .ok_or(RowsError::NotRowsResponse)? + .rows_iter()?) } - /// Returns rows parsed as the given type.\ - /// When `QueryResult.rows` is `None`, returns 0 rows.\ - /// Equal to `rows_or_empty().into_typed::()`. - pub fn rows_typed_or_empty(self) -> TypedRowIter { - self.rows_or_empty().into_typed::() + /// Returns the received rows when present, or None. + /// + /// Returns an error if the rows are present but are of incorrect type. + pub fn maybe_rows<'s, R: DeserializeRow<'s>>( + &'s self, + ) -> Result>, ParseError> { + match &self.raw_rows { + Some(rows) => Ok(Some(rows.rows_iter()?)), + None => Ok(None), + } } - /// Returns first row from the received rows.\ - /// When the first row is not available, returns an error. - pub fn first_row(self) -> Result { - match self.maybe_first_row()? { - Some(row) => Ok(row), - None => Err(FirstRowError::RowsEmpty), + /// Returns `Ok` for a result of a query that shouldn't contain any rows.\ + /// Will return `Ok` for `INSERT` result, but a `SELECT` result, even an empty one, will cause an error.\ + /// Opposite of [`rows()`](QueryResult::rows). + pub fn result_not_rows(&self) -> Result<(), ResultNotRowsError> { + match &self.raw_rows { + Some(_) => Err(ResultNotRowsError), + None => Ok(()), } } - /// Returns first row from the received rows parsed as the given type.\ + /// Returns first row from the received rows.\ /// When the first row is not available, returns an error. - pub fn first_row_typed(self) -> Result { - Ok(self.first_row()?.into_typed()?) - } - - /// Returns `Option` containing the first of a result.\ - /// Fails when the query isn't of a type that could return rows, same as [`rows()`](QueryResult::rows). - pub fn maybe_first_row(self) -> Result, RowsExpectedError> { - Ok(self.rows()?.into_iter().next()) + pub fn first_row<'s, R: DeserializeRow<'s>>(&'s self) -> Result { + match self.maybe_first_row::() { + Ok(Some(row)) => Ok(row), + Ok(None) => Err(FirstRowError::RowsEmpty), + Err(RowsError::NotRowsResponse) => Err(FirstRowError::NotRowsResponse), + Err(RowsError::TypeCheckFailed(err)) => Err(FirstRowError::TypeCheckFailed(err)), + } } - /// Returns `Option` containing the first of a result.\ + /// Returns `Option` containing the first of a result.\ /// Fails when the query isn't of a type that could return rows, same as [`rows()`](QueryResult::rows). - pub fn maybe_first_row_typed( - self, - ) -> Result, MaybeFirstRowTypedError> { - match self.maybe_first_row()? { - Some(row) => Ok(Some(row.into_typed::()?)), - None => Ok(None), - } + pub fn maybe_first_row<'s, R: DeserializeRow<'s>>(&'s self) -> Result, RowsError> { + Ok(self.rows::()?.next().transpose()?) } /// Returns the only received row.\ /// Fails if the result is anything else than a single row.\ - pub fn single_row(self) -> Result { - let rows: Vec = self.rows()?; - - if rows.len() != 1 { - return Err(SingleRowError::BadNumberOfRows(rows.len())); + pub fn single_row<'s, R: DeserializeRow<'s>>(&'s self) -> Result { + match self.rows::() { + Ok(mut rows) => match rows.next() { + Some(Ok(row)) => { + if rows.rows_remaining() != 0 { + return Err(SingleRowError::UnexpectedRowCount( + rows.rows_remaining() + 1, + )); + } + Ok(row) + } + Some(Err(err)) => Err(err.into()), + None => Err(SingleRowError::UnexpectedRowCount(0)), + }, + Err(RowsError::NotRowsResponse) => Err(SingleRowError::NotRowsResponse), + Err(RowsError::TypeCheckFailed(err)) => Err(SingleRowError::TypeCheckFailed(err)), } - - Ok(rows.into_iter().next().unwrap()) - } - - /// Returns the only received row parsed as the given type.\ - /// Fails if the result is anything else than a single row.\ - pub fn single_row_typed(self) -> Result { - Ok(self.single_row()?.into_typed::()?) } /// Returns a column specification for a column with given name, or None if not found pub fn get_column_spec<'a>(&'a self, name: &str) -> Option<(usize, &'a ColumnSpec)> { - self.col_specs + self.raw_rows + .as_ref()? + .metadata() + .col_specs .iter() .enumerate() .find(|(_id, spec)| spec.name == name) } -} - -/// [`QueryResult::rows()`](QueryResult::rows) or a similar function called on a bad QueryResult.\ -/// Expected `QueryResult.rows` to be `Some`, but it was `None`.\ -/// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ -/// It is `None` for queries that can't return rows (e.g `INSERT`). -#[derive(Debug, Clone, Error, PartialEq, Eq)] -#[error( - "QueryResult::rows() or similar function called on a bad QueryResult. - Expected QueryResult.rows to be Some, but it was None. - QueryResult.rows is Some for queries that can return rows (e.g SELECT). - It is None for queries that can't return rows (e.g INSERT)." -)] -pub struct RowsExpectedError; - -/// [`QueryResult::result_not_rows()`](QueryResult::result_not_rows) called on a bad QueryResult.\ -/// Expected `QueryResult.rows` to be `None`, but it was `Some`.\ -/// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ -/// It is `None` for queries that can't return rows (e.g `INSERT`). -#[derive(Debug, Clone, Error, PartialEq, Eq)] -#[error( - "QueryResult::result_not_rows() called on a bad QueryResult. - Expected QueryResult.rows to be None, but it was Some. - QueryResult.rows is Some for queries that can return rows (e.g SELECT). - It is None for queries that can't return rows (e.g INSERT)." -)] -pub struct RowsNotExpectedError; - -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum FirstRowError { - /// [`QueryResult::first_row()`](QueryResult::first_row) called on a bad QueryResult.\ - /// Expected `QueryResult.rows` to be `Some`, but it was `None`.\ - /// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ - /// It is `None` for queries that can't return rows (e.g `INSERT`). - #[error(transparent)] - RowsExpected(#[from] RowsExpectedError), - /// Rows in `QueryResult` are empty - #[error("Rows in QueryResult are empty")] - RowsEmpty, -} - -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum FirstRowTypedError { - /// [`QueryResult::first_row_typed()`](QueryResult::first_row_typed) called on a bad QueryResult.\ - /// Expected `QueryResult.rows` to be `Some`, but it was `None`.\ - /// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ - /// It is `None` for queries that can't return rows (e.g `INSERT`). - #[error(transparent)] - RowsExpected(#[from] RowsExpectedError), - - /// Rows in `QueryResult` are empty - #[error("Rows in QueryResult are empty")] - RowsEmpty, - - /// Parsing row as the given type failed - #[error(transparent)] - FromRowError(#[from] FromRowError), -} - -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum MaybeFirstRowTypedError { - /// [`QueryResult::maybe_first_row_typed()`](QueryResult::maybe_first_row_typed) called on a bad QueryResult.\ - /// Expected `QueryResult.rows` to be `Some`, but it was `None`. - /// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ - /// It is `None` for queries that can't return rows (e.g `INSERT`). - #[error(transparent)] - RowsExpected(#[from] RowsExpectedError), - - /// Parsing row as the given type failed - #[error(transparent)] - FromRowError(#[from] FromRowError), -} - -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum SingleRowError { - /// [`QueryResult::single_row()`](QueryResult::single_row) called on a bad QueryResult.\ - /// Expected `QueryResult.rows` to be `Some`, but it was `None`.\ - /// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ - /// It is `None` for queries that can't return rows (e.g `INSERT`). - #[error(transparent)] - RowsExpected(#[from] RowsExpectedError), - - /// Expected a single row, found other number of rows - #[error("Expected a single row, found {0} rows")] - BadNumberOfRows(usize), -} - -#[derive(Debug, Clone, Error, PartialEq, Eq)] -pub enum SingleRowTypedError { - /// [`QueryResult::single_row_typed()`](QueryResult::single_row_typed) called on a bad QueryResult.\ - /// Expected `QueryResult.rows` to be `Some`, but it was `None`.\ - /// `QueryResult.rows` is `Some` for queries that can return rows (e.g `SELECT`).\ - /// It is `None` for queries that can't return rows (e.g `INSERT`). - #[error(transparent)] - RowsExpected(#[from] RowsExpectedError), - - /// Expected a single row, found other number of rows - #[error("Expected a single row, found {0} rows")] - BadNumberOfRows(usize), - - /// Parsing row as the given type failed - #[error(transparent)] - FromRowError(#[from] FromRowError), -} - -impl From for FirstRowTypedError { - fn from(err: FirstRowError) -> FirstRowTypedError { - match err { - FirstRowError::RowsExpected(e) => FirstRowTypedError::RowsExpected(e), - FirstRowError::RowsEmpty => FirstRowTypedError::RowsEmpty, - } + pub fn column_specs(&self) -> Option<&[ColumnSpec]> { + Some(self.raw_rows.as_ref()?.metadata().col_specs.as_slice()) } -} -impl From for SingleRowTypedError { - fn from(err: SingleRowError) -> SingleRowTypedError { - match err { - SingleRowError::RowsExpected(e) => SingleRowTypedError::RowsExpected(e), - SingleRowError::BadNumberOfRows(r) => SingleRowTypedError::BadNumberOfRows(r), - } + pub fn warnings(&self) -> impl Iterator { + self.warnings.iter().map(String::as_str) } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row, TableSpec}; - use std::convert::TryInto; - // Returns specified number of rows, each one containing one int32 value. - // Values are 0, 1, 2, 3, 4, ... - fn make_rows(rows_num: usize) -> Vec { - let mut rows: Vec = Vec::with_capacity(rows_num); - for cur_value in 0..rows_num { - let int_val: i32 = cur_value.try_into().unwrap(); - rows.push(Row { - columns: vec![Some(CqlValue::Int(int_val))], - }); - } - rows + pub fn paging_state(&self) -> Option { + self.raw_rows.as_ref()?.metadata().paging_state.clone() } - // Just like make_rows, but each column has one String value - // values are "val0", "val1", "val2", ... - fn make_string_rows(rows_num: usize) -> Vec { - let mut rows: Vec = Vec::with_capacity(rows_num); - for cur_value in 0..rows_num { - rows.push(Row { - columns: vec![Some(CqlValue::Text(format!("val{}", cur_value)))], - }); - } - rows + pub fn tracing_id(&self) -> Option { + self.tracing_id } - fn make_not_rows_query_result() -> QueryResult { - let table_spec = TableSpec { - ks_name: "some_keyspace".to_string(), - table_name: "some_table".to_string(), - }; - - let column_spec = ColumnSpec { - table_spec, - name: "column0".to_string(), - typ: ColumnType::Int, - }; - - QueryResult { - rows: None, - warnings: vec![], - tracing_id: None, - paging_state: None, - col_specs: vec![column_spec], - serialized_size: 0, + pub fn into_legacy_result(self) -> Result { + if let Some(raw_rows) = self.raw_rows { + let deserialized_rows = raw_rows + .rows_iter::()? + .collect::, ParseError>>()?; + let serialized_size = raw_rows.rows_size(); + let metadata = raw_rows.into_metadata(); + Ok(Legacy08QueryResult { + rows: Some(deserialized_rows), + warnings: self.warnings, + tracing_id: self.tracing_id, + paging_state: metadata.paging_state, + col_specs: metadata.col_specs, + serialized_size, + }) + } else { + Ok(Legacy08QueryResult { + rows: None, + warnings: self.warnings, + tracing_id: self.tracing_id, + paging_state: None, + col_specs: Vec::new(), + serialized_size: 0, + }) } } +} - fn make_rows_query_result(rows_num: usize) -> QueryResult { - let mut res = make_not_rows_query_result(); - res.rows = Some(make_rows(rows_num)); - res - } - - fn make_string_rows_query_result(rows_num: usize) -> QueryResult { - let mut res = make_not_rows_query_result(); - res.rows = Some(make_string_rows(rows_num)); - res - } - - #[test] - fn rows_num_test() { - assert_eq!( - make_not_rows_query_result().rows_num(), - Err(RowsExpectedError) - ); - assert_eq!(make_rows_query_result(0).rows_num(), Ok(0)); - assert_eq!(make_rows_query_result(1).rows_num(), Ok(1)); - assert_eq!(make_rows_query_result(2).rows_num(), Ok(2)); - assert_eq!(make_rows_query_result(3).rows_num(), Ok(3)); - } - - #[test] - fn rows_test() { - assert_eq!(make_not_rows_query_result().rows(), Err(RowsExpectedError)); - assert_eq!(make_rows_query_result(0).rows(), Ok(vec![])); - assert_eq!(make_rows_query_result(1).rows(), Ok(make_rows(1))); - assert_eq!(make_rows_query_result(2).rows(), Ok(make_rows(2))); - } - - #[test] - fn rows_typed_test() { - assert!(make_not_rows_query_result().rows_typed::<(i32,)>().is_err()); - - let rows0: Vec<(i32,)> = make_rows_query_result(0) - .rows_typed::<(i32,)>() - .unwrap() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows0, vec![]); - - let rows1: Vec<(i32,)> = make_rows_query_result(1) - .rows_typed::<(i32,)>() - .unwrap() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows1, vec![(0,)]); - - let rows2: Vec<(i32,)> = make_rows_query_result(2) - .rows_typed::<(i32,)>() - .unwrap() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows2, vec![(0,), (1,)]); - } - - #[test] - fn result_not_rows_test() { - assert_eq!(make_not_rows_query_result().result_not_rows(), Ok(())); - assert_eq!( - make_rows_query_result(0).result_not_rows(), - Err(RowsNotExpectedError) - ); - assert_eq!( - make_rows_query_result(1).result_not_rows(), - Err(RowsNotExpectedError) - ); - assert_eq!( - make_rows_query_result(2).result_not_rows(), - Err(RowsNotExpectedError) - ); - } - - #[test] - fn rows_or_empty_test() { - assert_eq!(make_not_rows_query_result().rows_or_empty(), vec![]); - assert_eq!(make_rows_query_result(0).rows_or_empty(), make_rows(0)); - assert_eq!(make_rows_query_result(1).rows_or_empty(), make_rows(1)); - assert_eq!(make_rows_query_result(2).rows_or_empty(), make_rows(2)); - } - - #[test] - fn rows_typed_or_empty() { - let rows_empty: Vec<(i32,)> = make_not_rows_query_result() - .rows_typed_or_empty::<(i32,)>() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows_empty, vec![]); - - let rows0: Vec<(i32,)> = make_rows_query_result(0) - .rows_typed_or_empty::<(i32,)>() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows0, vec![]); - - let rows1: Vec<(i32,)> = make_rows_query_result(1) - .rows_typed_or_empty::<(i32,)>() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows1, vec![(0,)]); - - let rows2: Vec<(i32,)> = make_rows_query_result(2) - .rows_typed_or_empty::<(i32,)>() - .map(|r| r.unwrap()) - .collect(); - - assert_eq!(rows2, vec![(0,), (1,)]); - } - - #[test] - fn first_row_test() { - assert_eq!( - make_not_rows_query_result().first_row(), - Err(FirstRowError::RowsExpected(RowsExpectedError)) - ); - assert_eq!( - make_rows_query_result(0).first_row(), - Err(FirstRowError::RowsEmpty) - ); - assert_eq!( - make_rows_query_result(1).first_row(), - Ok(make_rows(1).into_iter().next().unwrap()) - ); - assert_eq!( - make_rows_query_result(2).first_row(), - Ok(make_rows(2).into_iter().next().unwrap()) - ); - assert_eq!( - make_rows_query_result(3).first_row(), - Ok(make_rows(3).into_iter().next().unwrap()) - ); - } - - #[test] - fn first_row_typed_test() { - assert_eq!( - make_not_rows_query_result().first_row_typed::<(i32,)>(), - Err(FirstRowTypedError::RowsExpected(RowsExpectedError)) - ); - assert_eq!( - make_rows_query_result(0).first_row_typed::<(i32,)>(), - Err(FirstRowTypedError::RowsEmpty) - ); - assert_eq!( - make_rows_query_result(1).first_row_typed::<(i32,)>(), - Ok((0,)) - ); - assert_eq!( - make_rows_query_result(2).first_row_typed::<(i32,)>(), - Ok((0,)) - ); - assert_eq!( - make_rows_query_result(3).first_row_typed::<(i32,)>(), - Ok((0,)) - ); - - assert!(matches!( - make_string_rows_query_result(2).first_row_typed::<(i32,)>(), - Err(FirstRowTypedError::FromRowError(_)) - )); - } - - #[test] - fn maybe_first_row_test() { - assert_eq!( - make_not_rows_query_result().maybe_first_row(), - Err(RowsExpectedError) - ); - assert_eq!(make_rows_query_result(0).maybe_first_row(), Ok(None)); - assert_eq!( - make_rows_query_result(1).maybe_first_row(), - Ok(Some(make_rows(1).into_iter().next().unwrap())) - ); - assert_eq!( - make_rows_query_result(2).maybe_first_row(), - Ok(Some(make_rows(2).into_iter().next().unwrap())) - ); - assert_eq!( - make_rows_query_result(3).maybe_first_row(), - Ok(Some(make_rows(3).into_iter().next().unwrap())) - ); - } +/// An error returned by [`QueryResult::rows`] or [`QueryResult::maybe_first_row`]. +#[derive(Debug, Error)] +pub enum RowsError { + /// The query response was not a Rows response + #[error("The query response was not a Rows response")] + NotRowsResponse, - #[test] - fn maybe_first_row_typed_test() { - assert_eq!( - make_not_rows_query_result().maybe_first_row_typed::<(i32,)>(), - Err(MaybeFirstRowTypedError::RowsExpected(RowsExpectedError)) - ); + /// Type check failed + #[error("Type check failed: {0}")] + TypeCheckFailed(#[from] ParseError), +} - assert_eq!( - make_rows_query_result(0).maybe_first_row_typed::<(i32,)>(), - Ok(None) - ); +/// An error returned by [`QueryResult::first_row`]. +#[derive(Debug, Error)] +pub enum FirstRowError { + /// The query response was not a Rows response + #[error("The query response was not a Rows response")] + NotRowsResponse, - assert_eq!( - make_rows_query_result(1).maybe_first_row_typed::<(i32,)>(), - Ok(Some((0,))) - ); + /// The query response was of Rows type, but no rows were returned + #[error("The query response was of Rows type, but no rows were returned")] + RowsEmpty, - assert_eq!( - make_rows_query_result(2).maybe_first_row_typed::<(i32,)>(), - Ok(Some((0,))) - ); + /// Type check failed + #[error("Type check failed: {0}")] + TypeCheckFailed(#[from] ParseError), +} - assert_eq!( - make_rows_query_result(3).maybe_first_row_typed::<(i32,)>(), - Ok(Some((0,))) - ); +/// An error returned by [`QueryResult::single_row`]. +#[derive(Debug, Error)] +pub enum SingleRowError { + /// The query response was not a Rows response + #[error("The query response was not a Rows response")] + NotRowsResponse, - assert!(matches!( - make_string_rows_query_result(1).maybe_first_row_typed::<(i32,)>(), - Err(MaybeFirstRowTypedError::FromRowError(_)) - )) - } + /// Expected one row, but got a different count + #[error("Expected a single row, but got {0}")] + UnexpectedRowCount(usize), - #[test] - fn single_row_test() { - assert_eq!( - make_not_rows_query_result().single_row(), - Err(SingleRowError::RowsExpected(RowsExpectedError)) - ); - assert_eq!( - make_rows_query_result(0).single_row(), - Err(SingleRowError::BadNumberOfRows(0)) - ); - assert_eq!( - make_rows_query_result(1).single_row(), - Ok(make_rows(1).into_iter().next().unwrap()) - ); - assert_eq!( - make_rows_query_result(2).single_row(), - Err(SingleRowError::BadNumberOfRows(2)) - ); - assert_eq!( - make_rows_query_result(3).single_row(), - Err(SingleRowError::BadNumberOfRows(3)) - ); - } + /// Type check failed + #[error("Type check failed: {0}")] + TypeCheckFailed(#[from] ParseError), +} - #[test] - fn single_row_typed_test() { - assert_eq!( - make_not_rows_query_result().single_row_typed::<(i32,)>(), - Err(SingleRowTypedError::RowsExpected(RowsExpectedError)) - ); - assert_eq!( - make_rows_query_result(0).single_row_typed::<(i32,)>(), - Err(SingleRowTypedError::BadNumberOfRows(0)) - ); - assert_eq!( - make_rows_query_result(1).single_row_typed::<(i32,)>(), - Ok((0,)) - ); - assert_eq!( - make_rows_query_result(2).single_row_typed::<(i32,)>(), - Err(SingleRowTypedError::BadNumberOfRows(2)) - ); - assert_eq!( - make_rows_query_result(3).single_row_typed::<(i32,)>(), - Err(SingleRowTypedError::BadNumberOfRows(3)) - ); +/// An error returned by [`QueryResult::result_not_rows`]. +/// +/// It indicates that response to the query was, unexpectedly, of Rows kind. +#[derive(Debug, Error)] +#[error("The query response was, unexpectedly, of Rows kind")] +pub struct ResultNotRowsError; - assert!(matches!( - make_string_rows_query_result(1).single_row_typed::<(i32,)>(), - Err(SingleRowTypedError::FromRowError(_)) - )); - } -} +// TODO: Tests diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 783124412e..a354d935f8 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -3,6 +3,7 @@ #[cfg(feature = "cloud")] use crate::cloud::CloudConfig; +use crate::Legacy08QueryResult; use crate::frame::types::LegacyConsistency; use crate::history; @@ -14,12 +15,14 @@ use bytes::Bytes; use futures::future::join_all; use futures::future::try_join_all; pub use scylla_cql::errors::TranslationError; -use scylla_cql::frame::response::result::Rows; +use scylla_cql::frame::response::result::RawRows; use scylla_cql::frame::response::NonErrorResponse; use std::borrow::Borrow; +use std::borrow::Cow; use std::collections::HashMap; use std::fmt::Display; use std::future::Future; +use std::marker::PhantomData; use std::net::SocketAddr; use std::str::FromStr; use std::sync::atomic::AtomicUsize; @@ -38,7 +41,9 @@ use super::connection::QueryResponse; use super::connection::SslConfig; use super::errors::{BadQuery, NewSessionError, QueryError}; use super::execution_profile::{ExecutionProfile, ExecutionProfileHandle, ExecutionProfileInner}; +use super::iterator::RawIterator; use super::partitioner::PartitionerName; +use super::query_result::RowsError; use super::topology::UntranslatedPeer; use super::NodeRef; use crate::cql_to_rust::FromRow; @@ -51,12 +56,12 @@ use crate::prepared_statement::{PartitionKeyError, PreparedStatement}; use crate::query::Query; use crate::routing::Token; use crate::statement::{Consistency, SerialConsistency}; -use crate::tracing::{GetTracingConfig, TracingEvent, TracingInfo}; +use crate::tracing::{GetTracingConfig, TracingInfo}; use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug}; use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName}; use crate::transport::connection_pool::PoolConfig; use crate::transport::host_filter::HostFilter; -use crate::transport::iterator::{PreparedIteratorConfig, RowIterator}; +use crate::transport::iterator::{Legacy08RowIterator, PreparedIteratorConfig}; use crate::transport::load_balancing::{self, RoutingInfo}; use crate::transport::metrics::Metrics; use crate::transport::node::Node; @@ -75,6 +80,10 @@ use crate::authentication::AuthenticatorProvider; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; +mod sealed { + pub trait Sealed {} +} + #[async_trait] pub trait AddressTranslator: Send + Sync { async fn translate_address( @@ -116,20 +125,40 @@ impl AddressTranslator for HashMap<&'static str, &'static str> { } } +pub trait DeserializationApiKind: sealed::Sealed {} + +pub enum CurrentDeserializationApi {} +impl sealed::Sealed for CurrentDeserializationApi {} +impl DeserializationApiKind for CurrentDeserializationApi {} + +pub enum Legacy08DeserializationApi {} +impl sealed::Sealed for Legacy08DeserializationApi {} +impl DeserializationApiKind for Legacy08DeserializationApi {} + /// `Session` manages connections to the cluster and allows to perform queries -pub struct Session { +pub struct GenericSession +where + DeserializationApi: DeserializationApiKind, +{ cluster: Cluster, default_execution_profile_handle: ExecutionProfileHandle, schema_agreement_interval: Duration, metrics: Arc, auto_await_schema_agreement_timeout: Option, refresh_metadata_on_auto_schema_agreement: bool, - keyspace_name: ArcSwapOption, + keyspace_name: Arc>, + _phantom_deser_api: PhantomData, } +pub type Session = GenericSession; +pub type Legacy08Session = GenericSession; + /// This implementation deliberately omits some details from Cluster in order /// to avoid cluttering the print with much information of little usability. -impl std::fmt::Debug for Session { +impl std::fmt::Debug for GenericSession +where + DeserApi: DeserializationApiKind, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Session") .field("cluster", &ClusterNeatDebug(&self.cluster)) @@ -355,31 +384,429 @@ impl IntoTypedRows for Vec { phantom_data: Default::default(), } } -} - -/// Iterator over rows parsed as the given type\ -/// Returned by `rows.into_typed::<(...)>()` -pub struct TypedRowIter { - row_iter: std::vec::IntoIter, - phantom_data: std::marker::PhantomData, -} +} + +/// Iterator over rows parsed as the given type\ +/// Returned by `rows.into_typed::<(...)>()` +pub struct TypedRowIter { + row_iter: std::vec::IntoIter, + phantom_data: std::marker::PhantomData, +} + +impl Iterator for TypedRowIter { + type Item = Result; + + fn next(&mut self) -> Option { + self.row_iter.next().map(RowT::from_row) + } +} + +pub enum RunQueryResult { + IgnoredWriteError, + Completed(ResT), +} + +impl GenericSession { + /// Sends a query to the database and receives a response.\ + /// Returns only a single page of results, to receive multiple pages use [query_iter](Session::query_iter) + /// + /// This is the easiest way to make a query, but performance is worse than that of prepared queries. + /// + /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/simple.html) for more information + /// # Arguments + /// * `query` - query to perform, can be just a `&str` or the [Query](crate::query::Query) struct. + /// * `values` - values bound to the query, easiest way is to use a tuple of bound values + /// + /// # Examples + /// ```rust + /// # use scylla::Legacy08Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { + /// // Insert an int and text into a table + /// session + /// .query( + /// "INSERT INTO ks.tab (a, b) VALUES(?, ?)", + /// (2_i32, "some text") + /// ) + /// .await?; + /// # Ok(()) + /// # } + /// ``` + /// ```rust + /// # use scylla::Legacy08Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { + /// use scylla::IntoTypedRows; + /// + /// // Read rows containing an int and text + /// let rows_opt = session + /// .query("SELECT a, b FROM ks.tab", &[]) + /// .await? + /// .rows; + /// + /// if let Some(rows) = rows_opt { + /// for row in rows.into_typed::<(i32, String)>() { + /// // Parse row as int and text \ + /// let (int_val, text_val): (i32, String) = row?; + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query( + &self, + query: impl Into, + values: impl ValueList, + ) -> Result { + self.do_query(query.into(), values.serialized()?).await + } + + /// Queries the database with a custom paging state. + /// # Arguments + /// + /// * `query` - query to be performed + /// * `values` - values bound to the query + /// * `paging_state` - previously received paging state or None + pub async fn query_paged( + &self, + query: impl Into, + values: impl ValueList, + paging_state: Option, + ) -> Result { + self.do_query_paged(query.into(), values.serialized()?, paging_state) + .await + } + + /// Run a simple query with paging\ + /// This method will query all pages of the result\ + /// + /// Returns an async iterator (stream) over all received rows\ + /// Page size can be specified in the [Query](crate::query::Query) passed to the function + /// + /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/paged.html) for more information + /// + /// # Arguments + /// * `query` - query to perform, can be just a `&str` or the [Query](crate::query::Query) struct. + /// * `values` - values bound to the query, easiest way is to use a tuple of bound values + /// + /// # Example + /// + /// ```rust + /// # use scylla::Legacy08Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { + /// use scylla::IntoTypedRows; + /// use futures::stream::StreamExt; + /// + /// let mut rows_stream = session + /// .query_iter("SELECT a, b FROM ks.t", &[]) + /// .await? + /// .into_typed::<(i32, i32)>(); + /// + /// while let Some(next_row_res) = rows_stream.next().await { + /// let (a, b): (i32, i32) = next_row_res?; + /// println!("a, b: {}, {}", a, b); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query_iter( + &self, + query: impl Into, + values: impl ValueList, + ) -> Result { + self.do_query_iter(query.into(), values.serialized()?).await + } + + /// Execute a prepared query. Requires a [PreparedStatement](crate::prepared_statement::PreparedStatement) + /// generated using [`Session::prepare`](Session::prepare)\ + /// Returns only a single page of results, to receive multiple pages use [execute_iter](Session::execute_iter) + /// + /// Prepared queries are much faster than simple queries: + /// * Database doesn't need to parse the query + /// * They are properly load balanced using token aware routing + /// + /// > ***Warning***\ + /// > For token/shard aware load balancing to work properly, all partition key values + /// > must be sent as bound values + /// > (see [performance section](https://rust-driver.docs.scylladb.com/stable/queries/prepared.html#performance)) + /// + /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/prepared.html) for more information + /// + /// # Arguments + /// * `prepared` - the prepared statement to execute, generated using [`Session::prepare`](Session::prepare) + /// * `values` - values bound to the query, easiest way is to use a tuple of bound values + /// + /// # Example + /// ```rust + /// # use scylla::Legacy08Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { + /// use scylla::prepared_statement::PreparedStatement; + /// + /// // Prepare the query for later execution + /// let prepared: PreparedStatement = session + /// .prepare("INSERT INTO ks.tab (a) VALUES(?)") + /// .await?; + /// + /// // Run the prepared query with some values, just like a simple query + /// let to_insert: i32 = 12345; + /// session.execute(&prepared, (to_insert,)).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn execute( + &self, + prepared: &PreparedStatement, + values: impl ValueList, + ) -> Result { + self.do_execute(prepared, values.serialized()?).await + } + + /// Executes a previously prepared statement with previously received paging state + /// # Arguments + /// + /// * `prepared` - a statement prepared with [prepare](crate::transport::session::Session::prepare) + /// * `values` - values bound to the query + /// * `paging_state` - paging state from the previous query or None + pub async fn execute_paged( + &self, + prepared: &PreparedStatement, + values: impl ValueList, + paging_state: Option, + ) -> Result { + self.do_execute_paged(prepared, values.serialized()?, paging_state) + .await + } + + /// Run a prepared query with paging\ + /// This method will query all pages of the result\ + /// + /// Returns an async iterator (stream) over all received rows\ + /// Page size can be specified in the [PreparedStatement](crate::prepared_statement::PreparedStatement) + /// passed to the function + /// + /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/paged.html) for more information + /// + /// # Arguments + /// * `prepared` - the prepared statement to execute, generated using [`Session::prepare`](Session::prepare) + /// * `values` - values bound to the query, easiest way is to use a tuple of bound values + /// + /// # Example + /// + /// ```rust + /// # use scylla::Legacy08Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { + /// use scylla::prepared_statement::PreparedStatement; + /// use scylla::IntoTypedRows; + /// use futures::stream::StreamExt; + /// + /// // Prepare the query for later execution + /// let prepared: PreparedStatement = session + /// .prepare("SELECT a, b FROM ks.t") + /// .await?; + /// + /// // Execute the query and receive all pages + /// let mut rows_stream = session + /// .execute_iter(prepared, &[]) + /// .await? + /// .into_typed::<(i32, i32)>(); + /// + /// while let Some(next_row_res) = rows_stream.next().await { + /// let (a, b): (i32, i32) = next_row_res?; + /// println!("a, b: {}, {}", a, b); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn execute_iter( + &self, + prepared: impl Into, + values: impl ValueList, + ) -> Result { + self.do_execute_iter(prepared.into(), values.serialized()?) + .await + } + + /// Perform a batch query\ + /// Batch contains many `simple` or `prepared` queries which are executed at once\ + /// Batch doesn't return any rows + /// + /// Batch values must contain values for each of the queries + /// + /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/batch.html) for more information + /// + /// # Arguments + /// * `batch` - [Batch](crate::batch::Batch) to be performed + /// * `values` - List of values for each query, it's the easiest to use a tuple of tuples + /// + /// # Example + /// ```rust + /// # use scylla::Legacy08Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { + /// use scylla::batch::Batch; + /// + /// let mut batch: Batch = Default::default(); + /// + /// // A query with two bound values + /// batch.append_statement("INSERT INTO ks.tab(a, b) VALUES(?, ?)"); + /// + /// // A query with one bound value + /// batch.append_statement("INSERT INTO ks.tab(a, b) VALUES(3, ?)"); + /// + /// // A query with no bound values + /// batch.append_statement("INSERT INTO ks.tab(a, b) VALUES(5, 6)"); + /// + /// // Batch values is a tuple of 3 tuples containing values for each query + /// let batch_values = ((1_i32, 2_i32), // Tuple with two values for the first query + /// (4_i32,), // Tuple with one value for the second query + /// ()); // Empty tuple/unit for the third query + /// + /// // Run the batch + /// session.batch(&batch, batch_values).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn batch( + &self, + batch: &Batch, + values: impl BatchValues, + ) -> Result { + self.do_batch(batch, values).await + } + + /// Creates a new Session instance that shared resources with + /// the current Session but supports the legacy API. + /// + /// This method is provided in order to make migration to the new + /// deserialization API easier. For example, if your program in general uses + /// the new API but you still have some modules left that use the old one, + /// you can use this method to create an instance that supports the old API + /// and pass it to the module that you intend to migrate later. + pub fn make_shared_session_with_legacy_api(&self) -> Legacy08Session { + Legacy08Session { + cluster: self.cluster.clone(), + auto_await_schema_agreement_timeout: self.auto_await_schema_agreement_timeout, + default_execution_profile_handle: self.default_execution_profile_handle.clone(), + metrics: self.metrics.clone(), + refresh_metadata_on_auto_schema_agreement: self + .refresh_metadata_on_auto_schema_agreement, + schema_agreement_interval: self.schema_agreement_interval, + keyspace_name: self.keyspace_name.clone(), + _phantom_deser_api: PhantomData, + } + } +} + +impl GenericSession { + pub async fn query( + &self, + query: impl Into, + values: impl ValueList, + ) -> Result { + Ok(self + .do_query(query.into(), values.serialized()?) + .await? + .into_legacy_result()?) + } + + pub async fn query_paged( + &self, + query: impl Into, + values: impl ValueList, + paging_state: Option, + ) -> Result { + Ok(self + .do_query_paged(query.into(), values.serialized()?, paging_state) + .await? + .into_legacy_result()?) + } + + pub async fn query_iter( + &self, + query: impl Into, + values: impl ValueList, + ) -> Result { + self.do_query_iter(query.into(), values.serialized()?) + .await + .map(RawIterator::into_legacy) + } + + pub async fn execute( + &self, + prepared: &PreparedStatement, + values: impl ValueList, + ) -> Result { + Ok(self + .do_execute(prepared, values.serialized()?) + .await? + .into_legacy_result()?) + } + + pub async fn execute_paged( + &self, + prepared: &PreparedStatement, + values: impl ValueList, + paging_state: Option, + ) -> Result { + Ok(self + .do_execute_paged(prepared, values.serialized()?, paging_state) + .await? + .into_legacy_result()?) + } -impl Iterator for TypedRowIter { - type Item = Result; + pub async fn execute_iter( + &self, + prepared: impl Into, + values: impl ValueList, + ) -> Result { + self.do_execute_iter(prepared.into(), values.serialized()?) + .await + .map(RawIterator::into_legacy) + } - fn next(&mut self) -> Option { - self.row_iter.next().map(RowT::from_row) + pub async fn batch( + &self, + batch: &Batch, + values: impl BatchValues, + ) -> Result { + Ok(self.do_batch(batch, values).await?.into_legacy_result()?) } -} -pub enum RunQueryResult { - IgnoredWriteError, - Completed(ResT), + /// Creates a new Session instance that shares resources with + /// the current Session but supports the new API. + /// + /// This method is provided in order to make migration to the new + /// deserialization API easier. For example, if your program in general uses + /// the old API but you want to migrate some modules to the new one, you + /// can use this method to create an instance that supports the new API + /// and pass it to the module that you intend to migrate. + /// + /// The new session object will use the same connections and cluster + /// metadata. + pub fn make_shared_session_with_new_api(&self) -> Session { + Session { + cluster: self.cluster.clone(), + auto_await_schema_agreement_timeout: self.auto_await_schema_agreement_timeout, + default_execution_profile_handle: self.default_execution_profile_handle.clone(), + metrics: self.metrics.clone(), + refresh_metadata_on_auto_schema_agreement: self + .refresh_metadata_on_auto_schema_agreement, + schema_agreement_interval: self.schema_agreement_interval, + keyspace_name: self.keyspace_name.clone(), + _phantom_deser_api: PhantomData, + } + } } /// Represents a CQL session, which can be used to communicate /// with the database -impl Session { +impl GenericSession +where + DeserApi: DeserializationApiKind, +{ /// Estabilishes a CQL session with the database /// /// Usually it's easier to use [SessionBuilder](crate::transport::session_builder::SessionBuilder) @@ -392,17 +819,17 @@ impl Session { /// ```rust /// # use std::error::Error; /// # async fn check_only_compiles() -> Result<(), Box> { - /// use scylla::{Session, SessionConfig}; + /// use scylla::{Legacy08Session, SessionConfig}; /// use scylla::transport::session::KnownNode; /// /// let mut config = SessionConfig::new(); /// config.known_nodes.push(KnownNode::Hostname("127.0.0.1:9042".to_string())); /// - /// let session: Session = Session::connect(config).await?; + /// let session: Legacy08Session = Legacy08Session::connect(config).await?; /// # Ok(()) /// # } /// ``` - pub async fn connect(config: SessionConfig) -> Result { + pub async fn connect(config: SessionConfig) -> Result { let known_nodes = config.known_nodes; #[cfg(feature = "cloud")] @@ -491,7 +918,7 @@ impl Session { let default_execution_profile_handle = config.default_execution_profile_handle; - let session = Session { + let session = Self { cluster, default_execution_profile_handle, schema_agreement_interval: config.schema_agreement_interval, @@ -499,7 +926,8 @@ impl Session { auto_await_schema_agreement_timeout: config.auto_await_schema_agreement_timeout, refresh_metadata_on_auto_schema_agreement: config .refresh_metadata_on_auto_schema_agreement, - keyspace_name: ArcSwapOption::default(), // will be set by use_keyspace + keyspace_name: Arc::new(ArcSwapOption::default()), // will be set by use_keyspace + _phantom_deser_api: PhantomData, }; if let Some(keyspace_name) = config.used_keyspace { @@ -511,75 +939,20 @@ impl Session { Ok(session) } - /// Sends a query to the database and receives a response.\ - /// Returns only a single page of results, to receive multiple pages use [query_iter](Session::query_iter) - /// - /// This is the easiest way to make a query, but performance is worse than that of prepared queries. - /// - /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/simple.html) for more information - /// # Arguments - /// * `query` - query to perform, can be just a `&str` or the [Query](crate::query::Query) struct. - /// * `values` - values bound to the query, easiest way is to use a tuple of bound values - /// - /// # Examples - /// ```rust - /// # use scylla::Session; - /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// // Insert an int and text into a table - /// session - /// .query( - /// "INSERT INTO ks.tab (a, b) VALUES(?, ?)", - /// (2_i32, "some text") - /// ) - /// .await?; - /// # Ok(()) - /// # } - /// ``` - /// ```rust - /// # use scylla::Session; - /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// use scylla::IntoTypedRows; - /// - /// // Read rows containing an int and text - /// let rows_opt = session - /// .query("SELECT a, b FROM ks.tab", &[]) - /// .await? - /// .rows; - /// - /// if let Some(rows) = rows_opt { - /// for row in rows.into_typed::<(i32, String)>() { - /// // Parse row as int and text \ - /// let (int_val, text_val): (i32, String) = row?; - /// } - /// } - /// # Ok(()) - /// # } - /// ``` - pub async fn query( + async fn do_query( &self, - query: impl Into, - values: impl ValueList, + query: Query, + values: Cow<'_, SerializedValues>, ) -> Result { - self.query_paged(query, values, None).await + self.do_query_paged(query, values, None).await } - /// Queries the database with a custom paging state. - /// # Arguments - /// - /// * `query` - query to be performed - /// * `values` - values bound to the query - /// * `paging_state` - previously received paging state or None - pub async fn query_paged( + async fn do_query_paged( &self, - query: impl Into, - values: impl ValueList, + query: Query, + serialized_values: Cow<'_, SerializedValues>, paging_state: Option, ) -> Result { - let query: Query = query.into(); - let serialized_values = values.serialized()?; - let span = RequestSpan::new_query(&query.contents, serialized_values.size()); let run_query_result = self .run_query( @@ -678,53 +1051,17 @@ impl Session { Ok(()) } - /// Run a simple query with paging\ - /// This method will query all pages of the result\ - /// - /// Returns an async iterator (stream) over all received rows\ - /// Page size can be specified in the [Query](crate::query::Query) passed to the function - /// - /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/paged.html) for more information - /// - /// # Arguments - /// * `query` - query to perform, can be just a `&str` or the [Query](crate::query::Query) struct. - /// * `values` - values bound to the query, easiest way is to use a tuple of bound values - /// - /// # Example - /// - /// ```rust - /// # use scylla::Session; - /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// use scylla::IntoTypedRows; - /// use futures::stream::StreamExt; - /// - /// let mut rows_stream = session - /// .query_iter("SELECT a, b FROM ks.t", &[]) - /// .await? - /// .into_typed::<(i32, i32)>(); - /// - /// while let Some(next_row_res) = rows_stream.next().await { - /// let (a, b): (i32, i32) = next_row_res?; - /// println!("a, b: {}, {}", a, b); - /// } - /// # Ok(()) - /// # } - /// ``` - pub async fn query_iter( + async fn do_query_iter( &self, - query: impl Into, - values: impl ValueList, - ) -> Result { - let query: Query = query.into(); - let serialized_values = values.serialized()?; - + query: Query, + serialized_values: Cow<'_, SerializedValues>, + ) -> Result { let execution_profile = query .get_execution_profile_handle() .unwrap_or_else(|| self.get_default_execution_profile_handle()) .access(); - RowIterator::new_for_query( + RawIterator::new_for_query( query, serialized_values.into_owned(), execution_profile, @@ -753,9 +1090,9 @@ impl Session { /// /// # Example /// ```rust - /// # use scylla::Session; + /// # use scylla::Legacy08Session; /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { /// use scylla::prepared_statement::PreparedStatement; /// /// // Prepare the query for later execution @@ -832,64 +1169,20 @@ impl Session { .as_deref() } - /// Execute a prepared query. Requires a [PreparedStatement](crate::prepared_statement::PreparedStatement) - /// generated using [`Session::prepare`](Session::prepare)\ - /// Returns only a single page of results, to receive multiple pages use [execute_iter](Session::execute_iter) - /// - /// Prepared queries are much faster than simple queries: - /// * Database doesn't need to parse the query - /// * They are properly load balanced using token aware routing - /// - /// > ***Warning***\ - /// > For token/shard aware load balancing to work properly, all partition key values - /// > must be sent as bound values - /// > (see [performance section](https://rust-driver.docs.scylladb.com/stable/queries/prepared.html#performance)) - /// - /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/prepared.html) for more information - /// - /// # Arguments - /// * `prepared` - the prepared statement to execute, generated using [`Session::prepare`](Session::prepare) - /// * `values` - values bound to the query, easiest way is to use a tuple of bound values - /// - /// # Example - /// ```rust - /// # use scylla::Session; - /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// use scylla::prepared_statement::PreparedStatement; - /// - /// // Prepare the query for later execution - /// let prepared: PreparedStatement = session - /// .prepare("INSERT INTO ks.tab (a) VALUES(?)") - /// .await?; - /// - /// // Run the prepared query with some values, just like a simple query - /// let to_insert: i32 = 12345; - /// session.execute(&prepared, (to_insert,)).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn execute( + async fn do_execute( &self, prepared: &PreparedStatement, - values: impl ValueList, + values: Cow<'_, SerializedValues>, ) -> Result { - self.execute_paged(prepared, values, None).await + self.do_execute_paged(prepared, values, None).await } - /// Executes a previously prepared statement with previously received paging state - /// # Arguments - /// - /// * `prepared` - a statement prepared with [prepare](crate::transport::session::Session::prepare) - /// * `values` - values bound to the query - /// * `paging_state` - paging state from the previous query or None - pub async fn execute_paged( + async fn do_execute_paged( &self, prepared: &PreparedStatement, - values: impl ValueList, + serialized_values: Cow<'_, SerializedValues>, paging_state: Option, ) -> Result { - let serialized_values = values.serialized()?; let values_ref = &serialized_values; let paging_state_ref = &paging_state; @@ -975,54 +1268,11 @@ impl Session { Ok(result) } - /// Run a prepared query with paging\ - /// This method will query all pages of the result\ - /// - /// Returns an async iterator (stream) over all received rows\ - /// Page size can be specified in the [PreparedStatement](crate::prepared_statement::PreparedStatement) - /// passed to the function - /// - /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/paged.html) for more information - /// - /// # Arguments - /// * `prepared` - the prepared statement to execute, generated using [`Session::prepare`](Session::prepare) - /// * `values` - values bound to the query, easiest way is to use a tuple of bound values - /// - /// # Example - /// - /// ```rust - /// # use scylla::Session; - /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// use scylla::prepared_statement::PreparedStatement; - /// use scylla::IntoTypedRows; - /// use futures::stream::StreamExt; - /// - /// // Prepare the query for later execution - /// let prepared: PreparedStatement = session - /// .prepare("SELECT a, b FROM ks.t") - /// .await?; - /// - /// // Execute the query and receive all pages - /// let mut rows_stream = session - /// .execute_iter(prepared, &[]) - /// .await? - /// .into_typed::<(i32, i32)>(); - /// - /// while let Some(next_row_res) = rows_stream.next().await { - /// let (a, b): (i32, i32) = next_row_res?; - /// println!("a, b: {}, {}", a, b); - /// } - /// # Ok(()) - /// # } - /// ``` - pub async fn execute_iter( + async fn do_execute_iter( &self, - prepared: impl Into, - values: impl ValueList, - ) -> Result { - let prepared = prepared.into(); - let serialized_values = values.serialized()?; + prepared: PreparedStatement, + serialized_values: Cow<'_, SerializedValues>, + ) -> Result { let partition_key = self.calculate_partition_key(&prepared, &serialized_values)?; let token = partition_key .as_ref() @@ -1033,7 +1283,7 @@ impl Session { .unwrap_or_else(|| self.get_default_execution_profile_handle()) .access(); - RowIterator::new_for_prepared_statement(PreparedIteratorConfig { + RawIterator::new_for_prepared_statement(PreparedIteratorConfig { prepared, values: serialized_values.into_owned(), partition_key, @@ -1045,47 +1295,7 @@ impl Session { .await } - /// Perform a batch query\ - /// Batch contains many `simple` or `prepared` queries which are executed at once\ - /// Batch doesn't return any rows - /// - /// Batch values must contain values for each of the queries - /// - /// See [the book](https://rust-driver.docs.scylladb.com/stable/queries/batch.html) for more information - /// - /// # Arguments - /// * `batch` - [Batch](crate::batch::Batch) to be performed - /// * `values` - List of values for each query, it's the easiest to use a tuple of tuples - /// - /// # Example - /// ```rust - /// # use scylla::Session; - /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// use scylla::batch::Batch; - /// - /// let mut batch: Batch = Default::default(); - /// - /// // A query with two bound values - /// batch.append_statement("INSERT INTO ks.tab(a, b) VALUES(?, ?)"); - /// - /// // A query with one bound value - /// batch.append_statement("INSERT INTO ks.tab(a, b) VALUES(3, ?)"); - /// - /// // A query with no bound values - /// batch.append_statement("INSERT INTO ks.tab(a, b) VALUES(5, 6)"); - /// - /// // Batch values is a tuple of 3 tuples containing values for each query - /// let batch_values = ((1_i32, 2_i32), // Tuple with two values for the first query - /// (4_i32,), // Tuple with one value for the second query - /// ()); // Empty tuple/unit for the third query - /// - /// // Run the batch - /// session.batch(&batch, batch_values).await?; - /// # Ok(()) - /// # } - /// ``` - pub async fn batch( + async fn do_batch( &self, batch: &Batch, values: impl BatchValues, @@ -1168,9 +1378,9 @@ impl Session { /// /// # Example /// ```rust /// # extern crate scylla; - /// # use scylla::Session; + /// # use scylla::Legacy08Session; /// # use std::error::Error; - /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { + /// # async fn check_only_compiles(session: &Legacy08Session) -> Result<(), Box> { /// use scylla::batch::Batch; /// /// // Create a batch statement with unprepared statements @@ -1229,10 +1439,10 @@ impl Session { /// * `case_sensitive` - if set to true the generated query will put keyspace name in quotes /// # Example /// ```rust - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::Compression; /// # async fn example() -> Result<(), Box> { - /// # let session = SessionBuilder::new().known_node("127.0.0.1:9042").build().await?; + /// # let session = SessionBuilder::new().known_node("127.0.0.1:9042").build_legacy().await?; /// session /// .query("INSERT INTO my_keyspace.tab (a) VALUES ('test1')", &[]) /// .await?; @@ -1359,46 +1569,40 @@ impl Session { traces_events_query.config.consistency = consistency; traces_events_query.set_page_size(1024); + let serialized_tracing_id = (tracing_id,).serialized()?.into_owned(); let (traces_session_res, traces_events_res) = tokio::try_join!( - self.query(traces_session_query, (tracing_id,)), - self.query(traces_events_query, (tracing_id,)) + self.do_query(traces_session_query, Cow::Borrowed(&serialized_tracing_id)), + self.do_query(traces_events_query, Cow::Borrowed(&serialized_tracing_id)) )?; // Get tracing info - let tracing_info_row_res: Option> = traces_session_res - .rows - .ok_or(QueryError::ProtocolError( - "Response to system_traces.sessions query was not Rows", - ))? - .into_typed::() - .next(); - - let mut tracing_info: TracingInfo = match tracing_info_row_res { - Some(tracing_info_row_res) => tracing_info_row_res.map_err(|_| { - QueryError::ProtocolError( + let maybe_tracing_info: Option = traces_session_res + .maybe_first_row() + .map_err(|err| match err { + RowsError::NotRowsResponse => QueryError::ProtocolError( + "Response to system_traces.sessions query was not Rows", + ), + RowsError::TypeCheckFailed(_) => QueryError::ProtocolError( "Columns from system_traces.session have an unexpected type", - ) - })?, + ), + })?; + + let mut tracing_info = match maybe_tracing_info { None => return Ok(None), + Some(tracing_info) => tracing_info, }; // Get tracing events - let tracing_event_rows = traces_events_res - .rows - .ok_or(QueryError::ProtocolError( - "Response to system_traces.events query was not Rows", - ))? - .into_typed::(); - - for event in tracing_event_rows { - let tracing_event: TracingEvent = event.map_err(|_| { - QueryError::ProtocolError( - "Columns from system_traces.events have an unexpected type", - ) - })?; + let tracing_event_rows = traces_events_res.rows().map_err(|err| match err { + RowsError::NotRowsResponse => { + QueryError::ProtocolError("Response to system_traces.events query was not Rows") + } + RowsError::TypeCheckFailed(_) => QueryError::ProtocolError( + "Columns from system_traces.events have an unexpected type", + ), + })?; - tracing_info.events.push(tracing_event); - } + tracing_info.events = tracing_event_rows.collect::>()?; if tracing_info.events.is_empty() { return Ok(None); @@ -2043,15 +2247,13 @@ impl RequestSpan { } pub(crate) fn record_result_fields(&self, result: &QueryResult) { - self.span.record("result_size", result.serialized_size); - if let Some(rows) = result.rows.as_ref() { - self.span.record("result_rows", rows.len()); - } + self.span.record("result_size", result.rows_size()); + self.span.record("result_rows", result.rows_num()); } - pub(crate) fn record_rows_fields(&self, rows: &Rows) { - self.span.record("result_size", rows.serialized_size); - self.span.record("result_rows", rows.rows.len()); + pub(crate) fn record_rows_fields(&self, rows: &RawRows) { + self.span.record("result_size", rows.rows_size()); + self.span.record("result_rows", rows.rows_count()); } pub(crate) fn record_replicas<'a>(&'a self, replicas: &'a [impl Borrow>]) { diff --git a/scylla/src/transport/session_builder.rs b/scylla/src/transport/session_builder.rs index b3ebaba017..5ba7894c71 100644 --- a/scylla/src/transport/session_builder.rs +++ b/scylla/src/transport/session_builder.rs @@ -2,7 +2,10 @@ use super::errors::NewSessionError; use super::execution_profile::ExecutionProfileHandle; -use super::session::{AddressTranslator, Session, SessionConfig}; +use super::session::{ + AddressTranslator, CurrentDeserializationApi, GenericSession, Legacy08DeserializationApi, + SessionConfig, +}; use super::Compression; #[cfg(feature = "cloud")] @@ -52,13 +55,13 @@ pub type CloudSessionBuilder = GenericSessionBuilder; /// # Example /// /// ``` -/// # use scylla::{Session, SessionBuilder}; +/// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::Compression; /// # async fn example() -> Result<(), Box> { -/// let session: Session = SessionBuilder::new() +/// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .compression(Some(Compression::Snappy)) -/// .build() +/// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -84,17 +87,23 @@ impl SessionBuilder { /// Add a known node with a hostname /// # Examples /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new().known_node("127.0.0.1:9042").build().await?; + /// let session: Legacy08Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .build_legacy() + /// .await?; /// # Ok(()) /// # } /// ``` /// /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new().known_node("db1.example.com").build().await?; + /// let session: Legacy08Session = SessionBuilder::new() + /// .known_node("db1.example.com") + /// .build_legacy() + /// .await?; /// # Ok(()) /// # } /// ``` @@ -106,12 +115,12 @@ impl SessionBuilder { /// Add a known node with an IP address /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use std::net::{SocketAddr, IpAddr, Ipv4Addr}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 9042)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -124,11 +133,11 @@ impl SessionBuilder { /// Add a list of known nodes with hostnames /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_nodes(["127.0.0.1:9042", "db1.example.com"]) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -141,15 +150,15 @@ impl SessionBuilder { /// Add a list of known nodes with IP addresses /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use std::net::{SocketAddr, IpAddr, Ipv4Addr}; /// # async fn example() -> Result<(), Box> { /// let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 3)), 9042); /// let addr2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 17, 0, 4)), 9042); /// - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_nodes_addr([addr1, addr2]) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -167,14 +176,14 @@ impl SessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::Compression; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .use_keyspace("my_keyspace_name", false) /// .user("cassandra", "cassandra") - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -193,7 +202,7 @@ impl SessionBuilder { /// ``` /// # use std::sync::Arc; /// use bytes::Bytes; - /// use scylla::{Session, SessionBuilder}; + /// use scylla::{Legacy08Session, SessionBuilder}; /// use async_trait::async_trait; /// use scylla::authentication::{AuthenticatorProvider, AuthenticatorSession, AuthError}; /// # use scylla::transport::Compression; @@ -221,12 +230,12 @@ impl SessionBuilder { /// } /// /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .use_keyspace("my_keyspace_name", false) /// .user("cassandra", "cassandra") /// .authenticator_provider(Arc::new(CustomAuthenticatorProvider)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -247,7 +256,7 @@ impl SessionBuilder { /// # use async_trait::async_trait; /// # use std::net::SocketAddr; /// # use std::sync::Arc; - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::session::{AddressTranslator, TranslationError}; /// # use scylla::transport::topology::UntranslatedPeer; /// struct IdentityTranslator; @@ -263,10 +272,10 @@ impl SessionBuilder { /// } /// /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .address_translator(Arc::new(IdentityTranslator)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -277,7 +286,7 @@ impl SessionBuilder { /// # use std::sync::Arc; /// # use std::collections::HashMap; /// # use std::str::FromStr; - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::session::{AddressTranslator, TranslationError}; /// # /// # async fn example() -> Result<(), Box> { @@ -285,10 +294,10 @@ impl SessionBuilder { /// let addr_before_translation = SocketAddr::from_str("192.168.0.42:19042").unwrap(); /// let addr_after_translation = SocketAddr::from_str("157.123.12.42:23203").unwrap(); /// translation_rules.insert(addr_before_translation, addr_after_translation); - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .address_translator(Arc::new(translation_rules)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -329,13 +338,13 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::Compression; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .compression(Some(Compression::Snappy)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -350,13 +359,13 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use std::time::Duration; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .schema_agreement_interval(Duration::from_secs(5)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -370,17 +379,17 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{statement::Consistency, ExecutionProfile, Session, SessionBuilder}; + /// # use scylla::{statement::Consistency, ExecutionProfile, Legacy08Session, SessionBuilder}; /// # use std::time::Duration; /// # async fn example() -> Result<(), Box> { /// let execution_profile = ExecutionProfile::builder() /// .consistency(Consistency::All) /// .request_timeout(Some(Duration::from_secs(2))) /// .build(); - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .default_execution_profile_handle(execution_profile.into_handle()) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -398,12 +407,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .tcp_nodelay(true) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -415,17 +424,17 @@ impl GenericSessionBuilder { /// Set keyspace to be used on all connections.\ /// Each connection will send `"USE "` before sending any requests.\ - /// This can be later changed with [`Session::use_keyspace`] + /// This can be later changed with [`crate::Session::use_keyspace`] /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::Compression; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .use_keyspace("my_keyspace_name", false) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -446,7 +455,7 @@ impl GenericSessionBuilder { /// ``` /// # use std::fs; /// # use std::path::PathBuf; - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use openssl::ssl::{SslContextBuilder, SslVerifyMode, SslMethod, SslFiletype}; /// # async fn example() -> Result<(), Box> { /// let certdir = fs::canonicalize(PathBuf::from("./examples/certs/scylla.crt"))?; @@ -454,10 +463,10 @@ impl GenericSessionBuilder { /// context_builder.set_certificate_file(certdir.as_path(), SslFiletype::PEM)?; /// context_builder.set_verify(SslVerifyMode::NONE); /// - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .ssl_context(Some(context_builder.build())) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -468,7 +477,34 @@ impl GenericSessionBuilder { self } - /// Builds the Session after setting all the options + /// Builds the Session after setting all the options. + /// + /// The new session object uses the legacy deserialization API. If you wish + /// to use the new API, use [`SessionBuilder::build`]. + /// + /// # Example + /// ``` + /// # use scylla::{Legacy08Session, SessionBuilder}; + /// # use scylla::transport::Compression; + /// # async fn example() -> Result<(), Box> { + /// let session: Legacy08Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .compression(Some(Compression::Snappy)) + /// .build_legacy() // Turns SessionBuilder into Session + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn build_legacy( + &self, + ) -> Result, NewSessionError> { + GenericSession::connect(self.config.clone()).await + } + + /// Builds the Session after setting all the options. + /// + /// The new session object uses the new deserialization API. If you wish + /// to use the old API, use [`SessionBuilder::build_legacy`]. /// /// # Example /// ``` @@ -483,8 +519,10 @@ impl GenericSessionBuilder { /// # Ok(()) /// # } /// ``` - pub async fn build(&self) -> Result { - Session::connect(self.config.clone()).await + pub async fn build( + &self, + ) -> Result, NewSessionError> { + GenericSession::connect(self.config.clone()).await } /// Changes connection timeout @@ -493,13 +531,13 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use std::time::Duration; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .connection_timeout(Duration::from_secs(30)) - /// .build() // Turns SessionBuilder into Session + /// .build_legacy() // Turns SessionBuilder into Session /// .await?; /// # Ok(()) /// # } @@ -514,17 +552,17 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { /// use std::num::NonZeroUsize; /// use scylla::transport::session::PoolSize; /// /// // This session will establish 4 connections to each node. /// // For Scylla clusters, this number will be divided across shards - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .pool_size(PoolSize::PerHost(NonZeroUsize::new(4).unwrap())) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -558,12 +596,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .disallow_shard_aware_port(true) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -578,12 +616,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .keyspaces_to_fetch(["my_keyspace"]) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -601,12 +639,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .fetch_schema_metadata(true) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -622,12 +660,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .keepalive_interval(std::time::Duration::from_secs(42)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -649,12 +687,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .auto_schema_agreement_timeout(std::time::Duration::from_secs(120)) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -669,12 +707,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .no_auto_schema_agreement() - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -697,16 +735,16 @@ impl GenericSessionBuilder { /// # use async_trait::async_trait; /// # use std::net::SocketAddr; /// # use std::sync::Arc; - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # use scylla::transport::session::{AddressTranslator, TranslationError}; /// # use scylla::transport::host_filter::DcHostFilter; /// /// # async fn example() -> Result<(), Box> { /// // The session will only connect to nodes from "my-local-dc" - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .host_filter(Arc::new(DcHostFilter::new("my-local-dc".to_string()))) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } @@ -721,12 +759,12 @@ impl GenericSessionBuilder { /// /// # Example /// ``` - /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::{Legacy08Session, SessionBuilder}; /// # async fn example() -> Result<(), Box> { - /// let session: Session = SessionBuilder::new() + /// let session: Legacy08Session = SessionBuilder::new() /// .known_node("127.0.0.1:9042") /// .refresh_metadata_on_auto_schema_agreement(true) - /// .build() + /// .build_legacy() /// .await?; /// # Ok(()) /// # } diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 7ee0e735bf..065e074303 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -1,6 +1,4 @@ -use crate as scylla; use crate::batch::{Batch, BatchStatement}; -use crate::frame::response::result::Row; use crate::frame::value::ValueList; use crate::prepared_statement::PreparedStatement; use crate::query::Query; @@ -10,6 +8,7 @@ use crate::statement::Consistency; use crate::tracing::{GetTracingConfig, TracingInfo}; use crate::transport::errors::{BadKeyspaceName, BadQuery, DbError, QueryError}; use crate::transport::partitioner::{Murmur3Partitioner, Partitioner, PartitionerName}; +use crate::transport::session::Session; use crate::transport::topology::Strategy::SimpleStrategy; use crate::transport::topology::{ CollectionType, ColumnKind, CqlType, NativeType, UserDefinedType, @@ -17,14 +16,14 @@ use crate::transport::topology::{ use crate::utils::test_utils::{ create_new_session_builder, supports_feature, unique_keyspace_name, }; -use crate::CachingSession; use crate::ExecutionProfile; -use crate::QueryResult; -use crate::{IntoTypedRows, Session, SessionBuilder}; +use crate::{self as scylla, QueryResult}; +use crate::{CachingSession, SessionBuilder}; use assert_matches::assert_matches; use bytes::Bytes; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures::{FutureExt, TryStreamExt}; use itertools::Itertools; +use scylla_cql::frame::response::result::Row; use scylla_cql::frame::value::Value; use std::collections::BTreeSet; use std::collections::{BTreeMap, HashMap}; @@ -52,7 +51,10 @@ async fn test_connection_failure() { .remote_handle(); tokio::spawn(fut); - let res = SessionBuilder::new().known_node_addr(addr).build().await; + let res = SessionBuilder::new() + .known_node_addr(addr) + .build_legacy() + .await; match res { Ok(_) => panic!("Unexpected success"), Err(err) => println!("Connection error (it was expected): {:?}", err), @@ -103,29 +105,24 @@ async fn test_unprepared_statement() { .await .unwrap(); - let (a_idx, _) = query_result.get_column_spec("a").unwrap(); - let (b_idx, _) = query_result.get_column_spec("b").unwrap(); - let (c_idx, _) = query_result.get_column_spec("c").unwrap(); + assert_eq!(query_result.get_column_spec("a").unwrap().0, 0); + assert_eq!(query_result.get_column_spec("b").unwrap().0, 1); + assert_eq!(query_result.get_column_spec("c").unwrap().0, 2); assert!(query_result.get_column_spec("d").is_none()); - let rs = query_result.rows.unwrap(); + let mut results = query_result + .rows::<(i32, i32, String)>() + .unwrap() + .collect::, _>>() + .unwrap(); - let mut results: Vec<(i32, i32, &String)> = rs - .iter() - .map(|r| { - let a = r.columns[a_idx].as_ref().unwrap().as_int().unwrap(); - let b = r.columns[b_idx].as_ref().unwrap().as_int().unwrap(); - let c = r.columns[c_idx].as_ref().unwrap().as_text().unwrap(); - (a, b, c) - }) - .collect(); results.sort(); assert_eq!( results, vec![ - (1, 2, &String::from("abc")), - (1, 4, &String::from("hello")), - (7, 11, &String::from("")) + (1, 2, String::from("abc")), + (1, 4, String::from("hello")), + (7, 11, String::from("")) ] ); let query_result = session @@ -138,7 +135,7 @@ async fn test_unprepared_statement() { assert_eq!(spec.name, name); // Check column name. assert_eq!(spec.table_spec.ks_name, ks); } - let mut results_from_manual_paging: Vec = vec![]; + let mut results_from_manual_paging = vec![]; let query = Query::new(format!("SELECT a, b, c FROM {}.t", ks)).with_page_size(1); let mut paging_state: Option = None; let mut watchdog = 0; @@ -147,14 +144,19 @@ async fn test_unprepared_statement() { .query_paged(query.clone(), &[], paging_state) .await .unwrap(); - results_from_manual_paging.append(&mut rs_manual.rows.unwrap()); - if watchdog > 30 || rs_manual.paging_state.is_none() { + let mut page_results = rs_manual + .rows::<(i32, i32, String)>() + .unwrap() + .collect::, _>>() + .unwrap(); + results_from_manual_paging.append(&mut page_results); + if watchdog > 30 || rs_manual.paging_state().is_none() { break; } watchdog += 1; - paging_state = rs_manual.paging_state; + paging_state = rs_manual.paging_state(); } - assert_eq!(results_from_manual_paging, rs); + assert_eq!(results_from_manual_paging, results); } #[tokio::test] @@ -217,19 +219,13 @@ async fn test_prepared_statement() { // Verify that token calculation is compatible with Scylla { - let rs = session + let (value,): (i64,) = session .query(format!("SELECT token(a) FROM {}.t2", ks), &[]) .await .unwrap() - .rows + .single_row::<(i64,)>() .unwrap(); - let token = Token { - value: rs.first().unwrap().columns[0] - .as_ref() - .unwrap() - .as_bigint() - .unwrap(), - }; + let token = Token { value }; let prepared_token = Murmur3Partitioner::hash( &prepared_statement .compute_partition_key(&serialized_values) @@ -243,19 +239,13 @@ async fn test_prepared_statement() { assert_eq!(token, cluster_data_token); } { - let rs = session + let (value,): (i64,) = session .query(format!("SELECT token(a,b,c) FROM {}.complex_pk", ks), &[]) .await .unwrap() - .rows + .single_row::<(i64,)>() .unwrap(); - let token = Token { - value: rs.first().unwrap().columns[0] - .as_ref() - .unwrap() - .as_bigint() - .unwrap(), - }; + let token = Token { value }; let prepared_token = Murmur3Partitioner::hash( &prepared_complex_pk_statement .compute_partition_key(&serialized_values) @@ -275,15 +265,14 @@ async fn test_prepared_statement() { .query(format!("SELECT a,b,c FROM {}.t2", ks), &[]) .await .unwrap() - .rows + .rows::<(i32, i32, String)>() + .unwrap() + .collect::, _>>() .unwrap(); - let r = rs.first().unwrap(); - let a = r.columns[0].as_ref().unwrap().as_int().unwrap(); - let b = r.columns[1].as_ref().unwrap().as_int().unwrap(); - let c = r.columns[2].as_ref().unwrap().as_text().unwrap(); - assert_eq!((a, b, c), (17, 16, &String::from("I'm prepared!!!"))); + let r = &rs[0]; + assert_eq!(r, &(17, 16, String::from("I'm prepared!!!"))); - let mut results_from_manual_paging: Vec = vec![]; + let mut results_from_manual_paging = vec![]; let query = Query::new(format!("SELECT a, b, c FROM {}.t2", ks)).with_page_size(1); let prepared_paged = session.prepare(query).await.unwrap(); let mut paging_state: Option = None; @@ -293,30 +282,32 @@ async fn test_prepared_statement() { .execute_paged(&prepared_paged, &[], paging_state) .await .unwrap(); - results_from_manual_paging.append(&mut rs_manual.rows.unwrap()); - if watchdog > 30 || rs_manual.paging_state.is_none() { + let mut page_results = rs_manual + .rows::<(i32, i32, String)>() + .unwrap() + .collect::, _>>() + .unwrap(); + results_from_manual_paging.append(&mut page_results); + if watchdog > 30 || rs_manual.paging_state().is_none() { break; } watchdog += 1; - paging_state = rs_manual.paging_state; + paging_state = rs_manual.paging_state(); } assert_eq!(results_from_manual_paging, rs); } { - let rs = session + let (a, b, c, d, e): (i32, i32, String, i32, Option) = session .query(format!("SELECT a,b,c,d,e FROM {}.complex_pk", ks), &[]) .await .unwrap() - .rows + .single_row::<(i32, i32, String, i32, Option)>() .unwrap(); - let r = rs.first().unwrap(); - let a = r.columns[0].as_ref().unwrap().as_int().unwrap(); - let b = r.columns[1].as_ref().unwrap().as_int().unwrap(); - let c = r.columns[2].as_ref().unwrap().as_text().unwrap(); - let d = r.columns[3].as_ref().unwrap().as_int().unwrap(); - let e = r.columns[4].as_ref(); assert!(e.is_none()); - assert_eq!((a, b, c, d), (17, 16, &String::from("I'm prepared!!!"), 7)) + assert_eq!( + (a, b, c.as_str(), d, e), + (17, 16, "I'm prepared!!!", 7, None) + ); } // Check that ValueList macro works { @@ -345,7 +336,7 @@ async fn test_prepared_statement() { ) .await .unwrap(); - let mut rs = session + let output: ComplexPk = session .query( format!( "SELECT * FROM {}.complex_pk WHERE a = 9 and b = 8 and c = 'seven'", @@ -355,10 +346,10 @@ async fn test_prepared_statement() { ) .await .unwrap() - .rows + .into_legacy_result() // TODO: Fix after macros are added .unwrap() - .into_typed::(); - let output = rs.next().unwrap().unwrap(); + .single_row_typed() + .unwrap(); assert_eq!(input, output) } } @@ -415,29 +406,22 @@ async fn test_batch() { .await .unwrap(); - let rs = session + let mut results: Vec<(i32, i32, String)> = session .query(format!("SELECT a, b, c FROM {}.t_batch", ks), &[]) .await .unwrap() - .rows + .rows::<(i32, i32, String)>() + .unwrap() + .collect::>() .unwrap(); - let mut results: Vec<(i32, i32, &String)> = rs - .iter() - .map(|r| { - let a = r.columns[0].as_ref().unwrap().as_int().unwrap(); - let b = r.columns[1].as_ref().unwrap().as_int().unwrap(); - let c = r.columns[2].as_ref().unwrap().as_text().unwrap(); - (a, b, c) - }) - .collect(); results.sort(); assert_eq!( results, vec![ - (1, 2, &String::from("abc")), - (1, 4, &String::from("hello")), - (7, 11, &String::from("")) + (1, 2, String::from("abc")), + (1, 4, String::from("hello")), + (7, 11, String::from("")) ] ); @@ -456,26 +440,19 @@ async fn test_batch() { .unwrap(); session.batch(&batch, values).await.unwrap(); - let rs = session + let results: Vec<(i32, i32, String)> = session .query( format!("SELECT a, b, c FROM {}.t_batch WHERE a = 4", ks), &[], ) .await .unwrap() - .rows - .unwrap(); - let results: Vec<(i32, i32, &String)> = rs - .iter() - .map(|r| { - let a = r.columns[0].as_ref().unwrap().as_int().unwrap(); - let b = r.columns[1].as_ref().unwrap().as_int().unwrap(); - let c = r.columns[2].as_ref().unwrap().as_text().unwrap(); - (a, b, c) - }) - .collect(); + .rows::<(i32, i32, String)>() + .unwrap() + .collect::>() + .unwrap(); - assert_eq!(results, vec![(4, 20, &String::from("foobar"))]); + assert_eq!(results, vec![(4, 20, String::from("foobar"))]); } #[tokio::test] @@ -512,22 +489,16 @@ async fn test_token_calculation() { let serialized_values = values.serialized().unwrap().into_owned(); session.execute(&prepared_statement, &values).await.unwrap(); - let rs = session + let (value,): (i64,) = session .query( format!("SELECT token(a) FROM {}.t3 WHERE a = ?", ks), &values, ) .await .unwrap() - .rows + .single_row::<(i64,)>() .unwrap(); - let token = Token { - value: rs.first().unwrap().columns[0] - .as_ref() - .unwrap() - .as_bigint() - .unwrap(), - }; + let token = Token { value }; let prepared_token = Murmur3Partitioner::hash( &prepared_statement .compute_partition_key(&serialized_values) @@ -578,7 +549,7 @@ async fn test_token_awareness() { // Execute a query and observe tracing info let res = session.execute(&prepared_statement, values).await.unwrap(); let tracing_info = session - .get_tracing_info_custom(res.tracing_id.as_ref().unwrap(), &get_tracing_config) + .get_tracing_info_custom(res.tracing_id().as_ref().unwrap(), &get_tracing_config) .await .unwrap(); @@ -632,9 +603,8 @@ async fn test_use_keyspace() { .query("SELECT * FROM tab", &[]) .await .unwrap() - .rows + .rows::<(String,)>() .unwrap() - .into_typed::<(String,)>() .map(|res| res.unwrap().0) .collect(); @@ -682,9 +652,8 @@ async fn test_use_keyspace() { .query("SELECT * FROM tab", &[]) .await .unwrap() - .rows + .rows::<(String,)>() .unwrap() - .into_typed::<(String,)>() .map(|res| res.unwrap().0) .collect(); @@ -742,9 +711,8 @@ async fn test_use_keyspace_case_sensitivity() { .query("SELECT * from tab", &[]) .await .unwrap() - .rows + .rows::<(String,)>() .unwrap() - .into_typed::<(String,)>() .map(|row| row.unwrap().0) .collect(); @@ -758,9 +726,8 @@ async fn test_use_keyspace_case_sensitivity() { .query("SELECT * from tab", &[]) .await .unwrap() - .rows + .rows::<(String,)>() .unwrap() - .into_typed::<(String,)>() .map(|row| row.unwrap().0) .collect(); @@ -799,9 +766,8 @@ async fn test_raw_use_keyspace() { .query("SELECT * FROM tab", &[]) .await .unwrap() - .rows + .rows::<(String,)>() .unwrap() - .into_typed::<(String,)>() .map(|res| res.unwrap().0) .collect(); @@ -915,17 +881,17 @@ async fn test_tracing_query(session: &Session, ks: String) { let untraced_query: Query = Query::new(format!("SELECT * FROM {}.tab", ks)); let untraced_query_result: QueryResult = session.query(untraced_query, &[]).await.unwrap(); - assert!(untraced_query_result.tracing_id.is_none()); + assert!(untraced_query_result.tracing_id().is_none()); // A query with tracing enabled has a tracing uuid in result let mut traced_query: Query = Query::new(format!("SELECT * FROM {}.tab", ks)); traced_query.config.tracing = true; let traced_query_result: QueryResult = session.query(traced_query, &[]).await.unwrap(); - assert!(traced_query_result.tracing_id.is_some()); + assert!(traced_query_result.tracing_id().is_some()); // Querying this uuid from tracing table gives some results - assert_in_tracing_table(session, traced_query_result.tracing_id.unwrap()).await; + assert_in_tracing_table(session, traced_query_result.tracing_id().unwrap()).await; } async fn test_tracing_execute(session: &Session, ks: String) { @@ -938,7 +904,7 @@ async fn test_tracing_execute(session: &Session, ks: String) { let untraced_prepared_result: QueryResult = session.execute(&untraced_prepared, &[]).await.unwrap(); - assert!(untraced_prepared_result.tracing_id.is_none()); + assert!(untraced_prepared_result.tracing_id().is_none()); // Executing a prepared statement with tracing enabled has a tracing uuid in result let mut traced_prepared = session @@ -949,10 +915,10 @@ async fn test_tracing_execute(session: &Session, ks: String) { traced_prepared.config.tracing = true; let traced_prepared_result: QueryResult = session.execute(&traced_prepared, &[]).await.unwrap(); - assert!(traced_prepared_result.tracing_id.is_some()); + assert!(traced_prepared_result.tracing_id().is_some()); // Querying this uuid from tracing table gives some results - assert_in_tracing_table(session, traced_prepared_result.tracing_id.unwrap()).await; + assert_in_tracing_table(session, traced_prepared_result.tracing_id().unwrap()).await; } async fn test_tracing_prepare(session: &Session, ks: String) { @@ -983,7 +949,7 @@ async fn test_get_tracing_info(session: &Session, ks: String) { traced_query.config.tracing = true; let traced_query_result: QueryResult = session.query(traced_query, &[]).await.unwrap(); - let tracing_id: Uuid = traced_query_result.tracing_id.unwrap(); + let tracing_id: Uuid = traced_query_result.tracing_id().unwrap(); // The reason why we enable so long waiting for TracingInfo is... Cassandra. (Yes, again.) // In Cassandra Java Driver, the wait time for tracing info is 10 seconds, so here we do the same. @@ -1031,7 +997,7 @@ async fn test_tracing_query_iter(session: &Session, ks: String) { assert!(!traced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let traced_typed_row_iter = traced_row_iter.into_typed::<(i32,)>(); + let traced_typed_row_iter = traced_row_iter.into_legacy().into_typed::<(i32,)>(); assert!(!traced_typed_row_iter.get_tracing_ids().is_empty()); for tracing_id in traced_typed_row_iter.get_tracing_ids() { @@ -1054,7 +1020,7 @@ async fn test_tracing_execute_iter(session: &Session, ks: String) { assert!(untraced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let untraced_typed_row_iter = untraced_row_iter.into_typed::<(i32,)>(); + let untraced_typed_row_iter = untraced_row_iter.into_legacy().into_typed::<(i32,)>(); assert!(untraced_typed_row_iter.get_tracing_ids().is_empty()); // A prepared statement with tracing enabled has a tracing ids in result @@ -1072,7 +1038,7 @@ async fn test_tracing_execute_iter(session: &Session, ks: String) { assert!(!traced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let traced_typed_row_iter = traced_row_iter.into_typed::<(i32,)>(); + let traced_typed_row_iter = traced_row_iter.into_legacy().into_typed::<(i32,)>(); assert!(!traced_typed_row_iter.get_tracing_ids().is_empty()); for tracing_id in traced_typed_row_iter.get_tracing_ids() { @@ -1086,7 +1052,7 @@ async fn test_tracing_batch(session: &Session, ks: String) { untraced_batch.append_statement(&format!("INSERT INTO {}.tab (a) VALUES('a')", ks)[..]); let untraced_batch_result: QueryResult = session.batch(&untraced_batch, ((),)).await.unwrap(); - assert!(untraced_batch_result.tracing_id.is_none()); + assert!(untraced_batch_result.tracing_id().is_none()); // Batch with tracing enabled has a tracing uuid in result let mut traced_batch: Batch = Default::default(); @@ -1094,9 +1060,9 @@ async fn test_tracing_batch(session: &Session, ks: String) { traced_batch.config.tracing = true; let traced_batch_result: QueryResult = session.batch(&traced_batch, ((),)).await.unwrap(); - assert!(traced_batch_result.tracing_id.is_some()); + assert!(traced_batch_result.tracing_id().is_some()); - assert_in_tracing_table(session, traced_batch_result.tracing_id.unwrap()).await; + assert_in_tracing_table(session, traced_batch_result.tracing_id().unwrap()).await; } async fn assert_in_tracing_table(session: &Session, tracing_uuid: Uuid) { @@ -1107,15 +1073,14 @@ async fn assert_in_tracing_table(session: &Session, tracing_uuid: Uuid) { // If rows are empty perform 8 retries with a 32ms wait in between for _ in 0..8 { - let row_opt = session + let rows_num = session .query(traces_query.clone(), (tracing_uuid,)) .await .unwrap() - .rows - .into_iter() - .next(); + .rows_num() + .unwrap(); - if row_opt.is_some() { + if rows_num > 0 { // Ok there was some row for this tracing_uuid return; } @@ -1240,9 +1205,8 @@ async fn test_timestamp() { ) .await .unwrap() - .rows + .rows::<(String, String, i64)>() .unwrap() - .into_typed::<(String, String, i64)>() .map(Result::unwrap) .collect::>(); results.sort(); @@ -1294,7 +1258,7 @@ async fn test_request_timeout() { { let timeouting_session = create_new_session_builder() .default_execution_profile_handle(fast_timeouting_profile_handle) - .build() + .build_legacy() .await .unwrap(); @@ -1754,7 +1718,7 @@ async fn test_table_partitioner_in_metadata() { async fn test_turning_off_schema_fetching() { let session = create_new_session_builder() .fetch_schema_metadata(false) - .build() + .build_legacy() .await .unwrap(); let ks = unique_keyspace_name(); @@ -1854,9 +1818,8 @@ async fn test_named_bind_markers() { .query("SELECT pk, ck, v FROM t", &[]) .await .unwrap() - .rows + .rows::<(i32, i32, i32)>() .unwrap() - .into_typed::<(i32, i32, i32)>() .map(|res| res.unwrap()) .collect(); @@ -1987,7 +1950,7 @@ async fn test_unprepared_reprepare_in_execute() { .query("SELECT a, b, c FROM tab", ()) .await .unwrap() - .rows_typed::<(i32, i32, i32)>() + .rows::<(i32, i32, i32)>() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2036,7 +1999,7 @@ async fn test_unusual_valuelists() { .query("SELECT a, b, c FROM tab", ()) .await .unwrap() - .rows_typed::<(i32, i32, String)>() + .rows::<(i32, i32, String)>() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2107,7 +2070,7 @@ async fn test_unprepared_reprepare_in_batch() { .query("SELECT a, b, c FROM tab", ()) .await .unwrap() - .rows_typed::<(i32, i32, i32)>() + .rows::<(i32, i32, i32)>() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2174,7 +2137,7 @@ async fn test_unprepared_reprepare_in_caching_session_execute() { .execute("SELECT a, b, c FROM tab", &()) .await .unwrap() - .rows_typed::<(i32, i32, i32)>() + .rows::<(i32, i32, i32)>() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2241,7 +2204,7 @@ async fn assert_test_batch_table_rows_contain(sess: &Session, expected_rows: &[( .query("SELECT a, b FROM test_batch_table", ()) .await .unwrap() - .rows_typed::<(i32, i32)>() + .rows::<(i32, i32)>() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2467,7 +2430,7 @@ async fn test_batch_lwts() { let batch_res: QueryResult = session.batch(&batch, ((), (), ())).await.unwrap(); // Scylla returns 5 columns, but Cassandra returns only 1 - let is_scylla: bool = batch_res.col_specs.len() == 5; + let is_scylla: bool = batch_res.column_specs().unwrap().len() == 5; if is_scylla { test_batch_lwts_for_scylla(&session, &batch, batch_res).await; @@ -2482,11 +2445,8 @@ async fn test_batch_lwts_for_scylla(session: &Session, batch: &Batch, batch_res: // Returned columns are: // [applied], p1, c1, r1, r2 - let batch_res_rows: Vec<(bool, IntOrNull, IntOrNull, IntOrNull, IntOrNull)> = batch_res - .rows_typed() - .unwrap() - .map(|r| r.unwrap()) - .collect(); + let batch_res_rows: Vec<(bool, IntOrNull, IntOrNull, IntOrNull, IntOrNull)> = + batch_res.rows().unwrap().map(|r| r.unwrap()).collect(); let expected_batch_res_rows = vec![ (true, Some(0), Some(0), Some(0), Some(0)), @@ -2502,7 +2462,7 @@ async fn test_batch_lwts_for_scylla(session: &Session, batch: &Batch, batch_res: let prepared_batch_res_rows: Vec<(bool, IntOrNull, IntOrNull, IntOrNull, IntOrNull)> = prepared_batch_res - .rows_typed() + .rows() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2522,11 +2482,7 @@ async fn test_batch_lwts_for_cassandra(session: &Session, batch: &Batch, batch_r // Returned columns are: // [applied] - let batch_res_rows: Vec<(bool,)> = batch_res - .rows_typed() - .unwrap() - .map(|r| r.unwrap()) - .collect(); + let batch_res_rows: Vec<(bool,)> = batch_res.rows().unwrap().map(|r| r.unwrap()).collect(); let expected_batch_res_rows = vec![(true,)]; @@ -2540,7 +2496,7 @@ async fn test_batch_lwts_for_cassandra(session: &Session, batch: &Batch, batch_r // [applied], p1, c1, r1, r2 let prepared_batch_res_rows: Vec<(bool, IntOrNull, IntOrNull, IntOrNull, IntOrNull)> = prepared_batch_res - .rows_typed() + .rows() .unwrap() .map(|r| r.unwrap()) .collect(); @@ -2646,7 +2602,8 @@ async fn test_iter_works_when_retry_policy_returns_ignore_write_error() { let mut iter = session .query_iter("INSERT INTO t (pk v) VALUES (1, 2)", ()) .await - .unwrap(); + .unwrap() + .into_typed::(); assert!(retried_flag.load(Ordering::Relaxed)); while iter.try_next().await.unwrap().is_some() {} @@ -2657,7 +2614,11 @@ async fn test_iter_works_when_retry_policy_returns_ignore_write_error() { .prepare("INSERT INTO t (pk, v) VALUES (?, ?)") .await .unwrap(); - let mut iter = session.execute_iter(p, (1, 2)).await.unwrap(); + let mut iter = session + .execute_iter(p, (1, 2)) + .await + .unwrap() + .into_typed::(); assert!(retried_flag.load(Ordering::Relaxed)); while iter.try_next().await.unwrap().is_some() {} @@ -2724,3 +2685,44 @@ async fn test_get_keyspace_name() { .unwrap(); assert_eq!(*session.get_keyspace().unwrap(), ks); } + +#[tokio::test] +async fn test_api_migration_session_sharing() { + { + let session = create_new_session_builder().build().await.unwrap(); + let session_shared = session.make_shared_session_with_legacy_api(); + + // If we are unlucky then we will race with metadata fetch/cluster update + // and both invocations will return different cluster data. This should be + // SUPER rare, but in order to reduce the chance of flakiness to a minimum + // we will try it three times in a row. Cluster data is updated once per + // minute, so this should be good enough. + let mut matched = false; + for _ in 0..3 { + let cd1 = session.get_cluster_data(); + let cd2 = session_shared.get_cluster_data(); + + if Arc::ptr_eq(&cd1, &cd2) { + matched = true; + break; + } + } + assert!(matched); + } + { + let session = create_new_session_builder().build_legacy().await.unwrap(); + let session_shared = session.make_shared_session_with_new_api(); + + let mut matched = false; + for _ in 0..3 { + let cd1 = session.get_cluster_data(); + let cd2 = session_shared.get_cluster_data(); + + if Arc::ptr_eq(&cd1, &cd2) { + matched = true; + break; + } + } + assert!(matched); + } +} diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index 63107e9db5..64a0ba30a2 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -12,9 +12,9 @@ use futures::stream::{self, StreamExt, TryStreamExt}; use futures::Stream; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; -use scylla_cql::frame::response::result::Row; use scylla_cql::frame::value::ValueList; -use scylla_macros::FromRow; +use scylla_cql::types::deserialize::row::DeserializeRow; +use scylla_macros::DeserializeRow; use std::borrow::BorrowMut; use std::cell::Cell; use std::collections::HashMap; @@ -640,11 +640,13 @@ async fn query_metadata( Ok(Metadata { peers, keyspaces }) } -#[derive(FromRow)] -#[scylla_crate = "scylla_cql"] +#[derive(DeserializeRow)] +#[scylla(crate = "scylla_cql")] struct NodeInfoRow { host_id: Option, + #[scylla(rename = "rpc_address")] untranslated_ip_addr: IpAddr, + #[scylla(rename = "data_center")] datacenter: Option, rack: Option, tokens: Option>, @@ -672,6 +674,7 @@ async fn query_peers(conn: &Arc, connect_port: u16) -> Result())) .into_stream() .try_flatten() .and_then(|row_result| future::ok((NodeInfoSource::Peer, row_result))); @@ -682,6 +685,7 @@ async fn query_peers(conn: &Arc, connect_port: u16) -> Result())) .into_stream() .try_flatten() .and_then(|row_result| future::ok((NodeInfoSource::Local, row_result))); @@ -692,10 +696,7 @@ async fn query_peers(conn: &Arc, connect_port: u16) -> Result( conn: &Arc, query_str: &str, keyspaces_to_fetch: &[String], -) -> impl Stream> { +) -> impl Stream> +where + R: for<'r> DeserializeRow<'r>, +{ let keyspaces = &[keyspaces_to_fetch] as &[&[String]]; let (query_str, query_values) = if !keyspaces_to_fetch.is_empty() { (format!("{query_str} where keyspace_name in ?"), keyspaces) @@ -791,7 +795,9 @@ fn query_filter_keyspace_name( query.set_page_size(1024); let fut = async move { let query_values = query_values?; - conn.query_iter(query, query_values).await + conn.query_iter(query, query_values) + .await + .map(|it| it.into_typed::()) }; fut.into_stream().try_flatten() } @@ -801,7 +807,7 @@ async fn query_keyspaces( keyspaces_to_fetch: &[String], fetch_schema: bool, ) -> Result, QueryError> { - let rows = query_filter_keyspace_name( + let rows = query_filter_keyspace_name::<(String, HashMap)>( conn, "select keyspace_name, replication from system_schema.keyspaces", keyspaces_to_fetch, @@ -819,10 +825,7 @@ async fn query_keyspaces( }; rows.map(|row_result| { - let row = row_result?; - let (keyspace_name, strategy_map) = row.into_typed().map_err(|_| { - QueryError::ProtocolError("system_schema.keyspaces has invalid column type") - })?; + let (keyspace_name, strategy_map) = row_result?; let strategy: Strategy = strategy_from_string_map(strategy_map)?; let tables = all_tables.remove(&keyspace_name).unwrap_or_default(); @@ -844,8 +847,8 @@ async fn query_keyspaces( .await } -#[derive(FromRow, Debug)] -#[scylla_crate = "crate"] +#[derive(DeserializeRow, Debug)] +#[scylla(crate = "crate")] struct UdtRow { keyspace_name: String, type_name: String, @@ -887,7 +890,7 @@ async fn query_user_defined_types( conn: &Arc, keyspaces_to_fetch: &[String], ) -> Result>>, QueryError> { - let rows = query_filter_keyspace_name( + let rows = query_filter_keyspace_name::( conn, "select keyspace_name, type_name, field_names, field_types from system_schema.types", keyspaces_to_fetch, @@ -895,14 +898,7 @@ async fn query_user_defined_types( let mut udt_rows: Vec = rows .map(|row_result| { - let row = row_result?; - let udt_row = row - .into_typed::() - .map_err(|_| { - QueryError::ProtocolError("system_schema.types has invalid column type") - })? - .try_into()?; - + let udt_row = row_result?.try_into()?; Ok::<_, QueryError>(udt_row) }) .try_collect() @@ -1208,7 +1204,7 @@ async fn query_tables( keyspaces_to_fetch: &[String], udts: &HashMap>>, ) -> Result>, QueryError> { - let rows = query_filter_keyspace_name( + let rows = query_filter_keyspace_name::<(String, String)>( conn, "SELECT keyspace_name, table_name FROM system_schema.tables", keyspaces_to_fetch, @@ -1217,12 +1213,7 @@ async fn query_tables( let mut tables = query_tables_schema(conn, keyspaces_to_fetch, udts).await?; rows.map(|row_result| { - let row = row_result?; - let (keyspace_name, table_name) = row.into_typed().map_err(|_| { - QueryError::ProtocolError("system_schema.tables has invalid column type") - })?; - - let keyspace_and_table_name = (keyspace_name, table_name); + let keyspace_and_table_name = row_result?; let table = tables.remove(&keyspace_and_table_name).unwrap_or(Table { columns: HashMap::new(), @@ -1249,7 +1240,7 @@ async fn query_views( keyspaces_to_fetch: &[String], udts: &HashMap>>, ) -> Result>, QueryError> { - let rows = query_filter_keyspace_name( + let rows = query_filter_keyspace_name::<(String, String, String)>( conn, "SELECT keyspace_name, view_name, base_table_name FROM system_schema.views", keyspaces_to_fetch, @@ -1259,11 +1250,7 @@ async fn query_views( let mut tables = query_tables_schema(conn, keyspaces_to_fetch, udts).await?; rows.map(|row_result| { - let row = row_result?; - let (keyspace_name, view_name, base_table_name) = row.into_typed().map_err(|_| { - QueryError::ProtocolError("system_schema.views has invalid column type") - })?; - + let (keyspace_name, view_name, base_table_name) = row_result?; let keyspace_and_view_name = (keyspace_name, view_name); let table = tables.remove(&keyspace_and_view_name).unwrap_or(Table { @@ -1300,24 +1287,16 @@ async fn query_tables_schema( // This column shouldn't be exposed to the user but is currently exposed in system tables. const THRIFT_EMPTY_TYPE: &str = "empty"; - let rows = query_filter_keyspace_name(conn, + type RowType = (String, String, String, String, i32, String); + + let rows = query_filter_keyspace_name::(conn, "select keyspace_name, table_name, column_name, kind, position, type from system_schema.columns", keyspaces_to_fetch ); let mut tables_schema = HashMap::new(); rows.map(|row_result| { - let row = row_result?; - let (keyspace_name, table_name, column_name, kind, position, type_): ( - String, - String, - String, - String, - i32, - String, - ) = row.into_typed().map_err(|_| { - QueryError::ProtocolError("system_schema.columns has invalid column type") - })?; + let (keyspace_name, table_name, column_name, kind, position, type_) = row_result?; if type_ == THRIFT_EMPTY_TYPE { return Ok::<_, QueryError>(()); @@ -1522,15 +1501,13 @@ async fn query_table_partitioners( let rows = conn .clone() .query_iter(partitioner_query, &[]) + .map(|it| it.map(|it| it.into_typed::<(String, String, Option)>())) .into_stream() .try_flatten(); let result = rows .map(|row_result| { - let (keyspace_name, table_name, partitioner) = - row_result?.into_typed().map_err(|_| { - QueryError::ProtocolError("system_schema.tables has invalid column type") - })?; + let (keyspace_name, table_name, partitioner) = row_result?; Ok::<_, QueryError>(((keyspace_name, table_name), partitioner)) }) .try_collect::>() diff --git a/scylla/src/utils/test_utils.rs b/scylla/src/utils/test_utils.rs index aea5ed27d4..bf004157ed 100644 --- a/scylla/src/utils/test_utils.rs +++ b/scylla/src/utils/test_utils.rs @@ -46,7 +46,7 @@ pub(crate) async fn supports_feature(session: &Session, feature: &str) -> bool { .query("SELECT supported_features FROM system.local", ()) .await .unwrap() - .single_row_typed() + .single_row() .unwrap(); features diff --git a/scylla/tests/execution_profiles.rs b/scylla/tests/execution_profiles.rs index 39a32530d0..a271e7b501 100644 --- a/scylla/tests/execution_profiles.rs +++ b/scylla/tests/execution_profiles.rs @@ -156,7 +156,7 @@ async fn test_execution_profiles() { .known_node(proxy_uris[0].as_str()) .address_translator(Arc::new(translation_map)) .default_execution_profile_handle(profile1.into_handle()) - .build() + .build_legacy() .await .unwrap(); let ks = unique_keyspace_name();