From dcd6a033d3ccf2199039a9a9d74636842e29710e Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 30 Mar 2024 12:41:39 +0100 Subject: [PATCH 1/5] Updated dependencies and code accordingly. Signed-off-by: Pavel Kirilin --- Cargo.toml | 13 ++-- pyproject.toml | 2 +- src/batches.rs | 14 +++-- src/exceptions/rust_err.rs | 6 +- src/queries.rs | 14 ++--- src/query_builder/delete.rs | 7 +-- src/query_builder/insert.rs | 7 +-- src/query_builder/select.rs | 6 +- src/query_builder/update.rs | 9 ++- src/query_results.rs | 14 ++++- src/scylla_cls.rs | 26 +++++--- src/utils.rs | 119 ++++++++++++++++++++++-------------- 12 files changed, 142 insertions(+), 95 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f0b7709..266e170 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,20 +14,23 @@ eq-float = "0.1.0" futures = "0.3.28" log = "0.4.20" openssl = { version = "0.10.57", features = ["vendored"] } -pyo3 = { version = "0.19.2", features = [ +pyo3 = { version = "0.20.0", features = [ "auto-initialize", "abi3-py38", "extension-module", "chrono", + "rust_decimal", ] } -pyo3-asyncio = { version = "0.19.0", features = ["tokio-runtime"] } -pyo3-log = "0.8.3" +rust_decimal = "1.0" +pyo3-asyncio = { version = "0.20.0", features = ["tokio-runtime"] } +pyo3-log = "0.9.0" rustc-hash = "1.1.0" -scylla = { version = "0.10.1", features = ["ssl"] } -scylla-cql = "0.0.9" +scylla = { version = "0.12.0", features = ["ssl", "full-serialization"] } +bigdecimal-04 = { package = "bigdecimal", version = "0.4" } thiserror = "1.0.48" tokio = { version = "1.32.0", features = ["bytes"] } uuid = { version = "1.4.1", features = ["v4"] } +time = { version = "*", features = ["formatting", "macros"] } [profile.release] lto = "fat" diff --git a/pyproject.toml b/pyproject.toml index 861b2ee..b9e00fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ "Intended Audience :: Developers", "Topic :: Database :: Front-Ends", ] - +dependencies = ["python-dateutil"] [tool.maturin] python-source = "python" diff --git a/src/batches.rs b/src/batches.rs index 76e7b6e..c13e681 100644 --- a/src/batches.rs +++ b/src/batches.rs @@ -1,11 +1,13 @@ use pyo3::{pyclass, pymethods, types::PyDict, PyAny}; -use scylla::batch::{Batch, BatchStatement, BatchType}; +use scylla::{ + batch::{Batch, BatchStatement, BatchType}, + frame::value::LegacySerializedValues, +}; use crate::{ exceptions::rust_err::ScyllaPyResult, inputs::BatchQueryInput, queries::ScyllaPyRequestParams, utils::parse_python_query_params, }; -use scylla::frame::value::SerializedValues; #[pyclass(name = "BatchType")] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -27,7 +29,7 @@ pub struct ScyllaPyBatch { pub struct ScyllaPyInlineBatch { inner: Batch, request_params: ScyllaPyRequestParams, - values: Vec, + values: Vec, } impl From for Batch { @@ -38,7 +40,7 @@ impl From for Batch { } } -impl From for (Batch, Vec) { +impl From for (Batch, Vec) { fn from(mut value: ScyllaPyInlineBatch) -> Self { value.request_params.apply_to_batch(&mut value.inner); (value.inner, value.values) @@ -74,7 +76,7 @@ impl ScyllaPyInlineBatch { pub fn add_query_inner( &mut self, query: impl Into, - values: impl Into, + values: impl Into, ) { self.inner.append_statement(query); self.values.push(values.into()); @@ -123,7 +125,7 @@ impl ScyllaPyInlineBatch { self.values .push(parse_python_query_params(Some(passed_params), false, None)?); } else { - self.values.push(SerializedValues::new()); + self.values.push(LegacySerializedValues::new()); } Ok(()) } diff --git a/src/exceptions/rust_err.rs b/src/exceptions/rust_err.rs index 72e607e..744c5d9 100644 --- a/src/exceptions/rust_err.rs +++ b/src/exceptions/rust_err.rs @@ -19,15 +19,15 @@ pub enum ScyllaPyError { // Derived exception. #[error("{0}")] - QueryError(#[from] scylla_cql::errors::QueryError), + QueryError(#[from] scylla::transport::errors::QueryError), #[error("{0}")] - DBError(#[from] scylla_cql::errors::DbError), + DBError(#[from] scylla::transport::errors::DbError), #[error("Python exception: {0}.")] PyError(#[from] pyo3::PyErr), #[error("OpenSSL error: {0}.")] SSLError(#[from] openssl::error::ErrorStack), #[error("Cannot construct new session: {0}.")] - ScyllaSessionError(#[from] scylla_cql::errors::NewSessionError), + ScyllaSessionError(#[from] scylla::transport::errors::NewSessionError), // Binding errors #[error("Binding error. Cannot build values for query: {0},")] diff --git a/src/queries.rs b/src/queries.rs index 1d593d6..878da8e 100644 --- a/src/queries.rs +++ b/src/queries.rs @@ -66,31 +66,31 @@ impl ScyllaPyRequestParams { }; Ok(Self { consistency: params - .get_item("consistency") + .get_item("consistency")? .map(pyo3::FromPyObject::extract) .transpose()?, serial_consistency: params - .get_item("serial_consistency") + .get_item("serial_consistency")? .map(pyo3::FromPyObject::extract) .transpose()?, request_timeout: params - .get_item("request_timeout") + .get_item("request_timeout")? .map(pyo3::FromPyObject::extract) .transpose()?, timestamp: params - .get_item("timestamp") + .get_item("timestamp")? .map(pyo3::FromPyObject::extract) .transpose()?, is_idempotent: params - .get_item("is_idempotent") + .get_item("is_idempotent")? .map(pyo3::FromPyObject::extract) .transpose()?, tracing: params - .get_item("tracing") + .get_item("tracing")? .map(pyo3::FromPyObject::extract) .transpose()?, profile: params - .get_item("profile") + .get_item("profile")? .map(pyo3::FromPyObject::extract) .transpose()?, }) diff --git a/src/query_builder/delete.rs b/src/query_builder/delete.rs index 79a4c43..912db59 100644 --- a/src/query_builder/delete.rs +++ b/src/query_builder/delete.rs @@ -1,5 +1,5 @@ use pyo3::{pyclass, pymethods, types::PyDict, PyAny, PyRefMut, Python}; -use scylla::query::Query; +use scylla::{frame::value::LegacySerializedValues, query::Query}; use super::utils::{pretty_build, IfCluase, Timeout}; use crate::{ @@ -9,7 +9,6 @@ use crate::{ scylla_cls::Scylla, utils::{py_to_value, ScyllaPyCQLDTO}, }; -use scylla::frame::value::SerializedValues; #[pyclass] #[derive(Clone, Debug, Default)] @@ -208,7 +207,7 @@ impl Delete { /// /// Adds current query to batch. /// - /// # Error + /// # Errors /// /// May result into error if query cannot be build. /// Or values cannot be passed to batch. @@ -221,7 +220,7 @@ impl Delete { } else { self.values_.clone() }; - let mut serialized = SerializedValues::new(); + let mut serialized = LegacySerializedValues::new(); for val in values { serialized.add_value(&val)?; } diff --git a/src/query_builder/insert.rs b/src/query_builder/insert.rs index ed916ac..30f6e53 100644 --- a/src/query_builder/insert.rs +++ b/src/query_builder/insert.rs @@ -1,5 +1,5 @@ use pyo3::{pyclass, pymethods, types::PyDict, PyAny, PyRefMut, Python}; -use scylla::query::Query; +use scylla::{frame::value::LegacySerializedValues, query::Query}; use crate::{ batches::ScyllaPyInlineBatch, @@ -8,7 +8,6 @@ use crate::{ scylla_cls::Scylla, utils::{py_to_value, ScyllaPyCQLDTO}, }; -use scylla::frame::value::SerializedValues; use super::utils::{pretty_build, Timeout}; @@ -172,7 +171,7 @@ impl Insert { /// /// Adds current query to batch. /// - /// # Error + /// # Errors /// /// May result into error if query cannot be build. /// Or values cannot be passed to batch. @@ -180,7 +179,7 @@ impl Insert { let mut query = Query::new(self.build_query()?); self.request_params_.apply_to_query(&mut query); - let mut serialized = SerializedValues::new(); + let mut serialized = LegacySerializedValues::new(); for val in self.values_.clone() { serialized.add_value(&val)?; } diff --git a/src/query_builder/select.rs b/src/query_builder/select.rs index f1683d2..3e24819 100644 --- a/src/query_builder/select.rs +++ b/src/query_builder/select.rs @@ -14,7 +14,7 @@ use crate::{ }; use super::utils::{pretty_build, Timeout}; -use scylla::frame::value::SerializedValues; +use scylla::frame::value::LegacySerializedValues; #[pyclass] #[derive(Clone, Debug, Default)] @@ -255,14 +255,14 @@ impl Select { /// /// Adds current query to batch. /// - /// # Error + /// # Errors /// /// Returns error if values cannot be passed to batch. pub fn add_to_batch(&self, batch: &mut ScyllaPyInlineBatch) -> ScyllaPyResult<()> { let mut query = Query::new(self.build_query()); self.request_params_.apply_to_query(&mut query); - let mut serialized = SerializedValues::new(); + let mut serialized = LegacySerializedValues::new(); for val in self.values_.clone() { serialized.add_value(&val)?; } diff --git a/src/query_builder/update.rs b/src/query_builder/update.rs index 0a9120f..4cb766c 100644 --- a/src/query_builder/update.rs +++ b/src/query_builder/update.rs @@ -1,5 +1,5 @@ use pyo3::{pyclass, pymethods, types::PyDict, PyAny, PyRefMut, Python}; -use scylla::query::Query; +use scylla::{frame::value::LegacySerializedValues, query::Query}; use crate::{ batches::ScyllaPyInlineBatch, @@ -10,7 +10,6 @@ use crate::{ }; use super::utils::{pretty_build, IfCluase, Timeout}; -use scylla::frame::value::SerializedValues; #[derive(Clone, Debug)] enum UpdateAssignment { Simple(String), @@ -136,7 +135,7 @@ impl Update { /// Increment column value. /// - /// # Error + /// # Errors /// /// If cannot convert python type /// to appropriate rust type. @@ -292,7 +291,7 @@ impl Update { /// /// Adds current query to batch. /// - /// # Error + /// # Errors /// /// May result into error if query cannot be build. /// Or values cannot be passed to batch. @@ -308,7 +307,7 @@ impl Update { values }; - let mut serialized = SerializedValues::new(); + let mut serialized = LegacySerializedValues::new(); for val in values { serialized.add_value(&val)?; } diff --git a/src/query_results.rs b/src/query_results.rs index 0740bbf..fcad553 100644 --- a/src/query_results.rs +++ b/src/query_results.rs @@ -130,7 +130,7 @@ impl ScyllaPyQueryResult { /// This function grabs rows from all function and /// tries to get the first column of any row. /// - /// # Erros + /// # Errors /// /// May result in an error if: /// * Query doesn't have a returns; @@ -158,7 +158,7 @@ impl ScyllaPyQueryResult { /// This function grabs first row and /// tries to get the first column of a result. /// - /// # Erros + /// # Errors /// /// May result in an error if: /// * Query doesn't have a returns; @@ -238,7 +238,15 @@ impl ScyllaPyIterableQueryResult { /// Actual async iteration. /// - /// Here we define how to + /// Here we define how to iterate over rows. + /// + /// # Errors + /// + /// May return an error if: + /// * No more rows to iterate; + /// * No columns in a row. + /// * Cannot convert column to python object. + /// * Cannot acquire GIL. pub fn __anext__(&self, py: Python<'_>) -> ScyllaPyResult> { let streamer = self.inner.clone(); let map_function = self.mapper.clone(); diff --git a/src/scylla_cls.rs b/src/scylla_cls.rs index 11352e7..2e6f3fa 100644 --- a/src/scylla_cls.rs +++ b/src/scylla_cls.rs @@ -67,11 +67,13 @@ impl Scylla { if paged { match (query, prepared) { (Some(query), None) => Ok(ScyllaPyQueryReturns::IterableQueryResult( - ScyllaPyIterableQueryResult::new(session.query_iter(query, values).await?), + ScyllaPyIterableQueryResult::new( + session.query_iter(query, values.serialized()?).await?, + ), )), (None, Some(prepared)) => Ok(ScyllaPyQueryReturns::IterableQueryResult( ScyllaPyIterableQueryResult::new( - session.execute_iter(prepared, values).await?, + session.execute_iter(prepared, values.serialized()?).await?, ), )), _ => Err(ScyllaPyError::SessionError( @@ -81,11 +83,13 @@ impl Scylla { } else { match (query, prepared) { (Some(query), None) => Ok(ScyllaPyQueryReturns::QueryResult( - ScyllaPyQueryResult::new(session.query(query, values).await?), - )), - (None, Some(prepared)) => Ok(ScyllaPyQueryReturns::QueryResult( - ScyllaPyQueryResult::new(session.execute(&prepared, values).await?), + ScyllaPyQueryResult::new(session.query(query, values.serialized()?).await?), )), + (None, Some(prepared)) => { + Ok(ScyllaPyQueryReturns::QueryResult(ScyllaPyQueryResult::new( + session.execute(&prepared, values.serialized()?).await?, + ))) + } _ => Err(ScyllaPyError::SessionError( "You should pass either query or prepared query.".into(), )), @@ -246,7 +250,7 @@ impl Scylla { session_builder.connection_timeout(Duration::from_secs(connection_timeout)); } let mut session_guard = scylla_session.write().await; - *session_guard = Some(session_builder.build().await?); + *session_guard = Some(Box::pin(session_builder.build()).await?); Ok(()) }) } @@ -309,6 +313,10 @@ impl Scylla { /// Execute a batch statement. /// /// This function takes a batch and list of lists of params. + /// + /// # Errors + /// + /// Can result in an error in any case, when something goes wrong. #[pyo3(signature = (batch, params = None))] pub fn batch<'a>( &'a self, @@ -357,6 +365,10 @@ impl Scylla { /// /// After preparation it returns a prepared /// query, that you can use later. + /// + /// # Errors + /// + /// May return an error, if session is not initialized. pub fn prepare<'a>( &'a self, python: Python<'a>, diff --git a/src/utils.rs b/src/utils.rs index 5a14ee2..c154a1a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,18 +1,16 @@ use std::{collections::HashMap, future::Future, hash::BuildHasherDefault, str::FromStr}; -use chrono::Duration; use pyo3::{ types::{PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PyModule, PySet, PyString, PyTuple}, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject, }; use scylla::{ frame::{ - response::result::{ColumnType, CqlValue}, - value::{SerializedValues, Value}, + response::result::{ColumnSpec, ColumnType, CqlValue}, + value::{LegacySerializedValues, Value}, }, BufMut, }; -use scylla_cql::frame::response::result::ColumnSpec; use std::net::IpAddr; @@ -21,6 +19,9 @@ use crate::{ extra_types::{BigInt, Counter, Double, ScyllaPyUnset, SmallInt, TinyInt}, }; +const DATE_FORMAT: &[::time::format_description::FormatItem<'static>] = + ::time::macros::format_description!(version = 2, "[year]-[month]-[day]"); + /// Add submodule. /// /// This function is required, @@ -94,8 +95,8 @@ pub enum ScyllaPyCQLDTO { Float(eq_float::F32), Bytes(Vec), Date(chrono::NaiveDate), - Time(chrono::Duration), - Timestamp(chrono::Duration), + Time(chrono::NaiveTime), + Timestamp(chrono::DateTime), Uuid(uuid::Uuid), Inet(IpAddr), List(Vec), @@ -121,14 +122,14 @@ impl Value for ScyllaPyCQLDTO { ScyllaPyCQLDTO::Counter(counter) => counter.serialize(buf), ScyllaPyCQLDTO::TinyInt(tinyint) => tinyint.serialize(buf), ScyllaPyCQLDTO::Date(date) => date.serialize(buf), - ScyllaPyCQLDTO::Time(time) => scylla::frame::value::Time(*time).serialize(buf), + ScyllaPyCQLDTO::Time(time) => time.serialize(buf), ScyllaPyCQLDTO::Map(map) => map .iter() .cloned() .collect::>>() .serialize(buf), ScyllaPyCQLDTO::Timestamp(timestamp) => { - scylla::frame::value::Timestamp(*timestamp).serialize(buf) + scylla::frame::value::CqlTimestamp::from(*timestamp).serialize(buf) } ScyllaPyCQLDTO::Null => Option::::None.serialize(buf), ScyllaPyCQLDTO::Udt(udt) => { @@ -243,20 +244,19 @@ pub fn py_to_value( item.call_method0("isoformat")?.extract::<&str>()?, )?)) } else if item.get_type().name()? == "time" { - Ok(ScyllaPyCQLDTO::Time( - chrono::NaiveTime::from_str(item.call_method0("isoformat")?.extract::<&str>()?)? - .signed_duration_since( - chrono::NaiveTime::from_num_seconds_from_midnight_opt(0, 0).ok_or( - ScyllaPyError::BindingError(format!( - "Cannot calculate offset from midnight for value {item}" - )), - )?, - ), - )) + Ok(ScyllaPyCQLDTO::Time(chrono::NaiveTime::from_str( + item.call_method0("isoformat")?.extract::<&str>()?, + )?)) } else if item.get_type().name()? == "datetime" { let milliseconds = item.call_method0("timestamp")?.extract::()? * 1000f64; #[allow(clippy::cast_possible_truncation)] - let timestamp = Duration::milliseconds(milliseconds.trunc() as i64); + let seconds = milliseconds as i64 / 1_000; + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let nsecs = (milliseconds as i64 % 1_000) as u32 * 1_000_000; + let timestamp = chrono::DateTime::::from_timestamp(seconds, nsecs).ok_or( + ScyllaPyError::BindingError("Cannot convert datetime to timestamp.".into()), + )?; Ok(ScyllaPyCQLDTO::Timestamp(timestamp)) } else if item.is_instance_of::() || item.is_instance_of::() @@ -418,6 +418,7 @@ pub fn cql_to_py<'a>( col_name.into(), "Timeuuid", ))? + .as_ref() .as_simple() .to_string(); Ok(py.import("uuid")?.getattr("UUID")?.call1((uuid_str,))?) @@ -431,37 +432,61 @@ pub fn cql_to_py<'a>( // same driver. Will fix it on demand. let duration = unwrapped_value - .as_duration() + .as_cql_duration() .ok_or(ScyllaPyError::ValueDowncastError( col_name.into(), "Duration", ))?; let kwargs = PyDict::new(py); - kwargs.set_item("microseconds", duration.num_microseconds())?; + kwargs.set_item("months", duration.months)?; + kwargs.set_item("days", duration.days)?; + kwargs.set_item("microseconds", duration.nanoseconds / 1_000)?; Ok(py - .import("datetime")? - .getattr("timedelta")? + .import("dateutil")? + .getattr("relativedelta")? + .getattr("relativedelta")? .call((), Some(kwargs))?) } ColumnType::Timestamp => { // Timestamp - num of milliseconds since unix epoch. let timestamp = unwrapped_value - .as_duration() + .as_cql_timestamp() .ok_or(ScyllaPyError::ValueDowncastError( col_name.into(), "Timestamp", ))?; + let milliseconds = timestamp.0; + if milliseconds < 0 { + return Err(ScyllaPyError::ValueDowncastError( + col_name.into(), + "Timestamp cannot be less than 0", + )); + } + let seconds = + milliseconds + .checked_div(1_000) + .ok_or(ScyllaPyError::ValueDowncastError( + col_name.into(), + "Cannot get seconds out of milliseconds.", + ))?; + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + let nsecs = (milliseconds % 1_000).checked_mul(1_000_000).ok_or( + ScyllaPyError::ValueDowncastError(col_name.into(), "Cannot construct nanoseconds"), + )? as u32; + + let timestamp = chrono::DateTime::::from_timestamp(seconds, nsecs).ok_or( + ScyllaPyError::ValueDowncastError( + col_name.into(), + "Cannot construct datetime based on timestamp", + ), + )?; #[allow(clippy::cast_precision_loss)] - let seconds = timestamp.num_seconds() as f64; - #[allow(clippy::cast_precision_loss)] - let micros = (timestamp - Duration::seconds(timestamp.num_seconds())).num_milliseconds() - as f64 - / 1_000f64; // Converting microseconds to seconds to construct timestamp - Ok(py - .import("datetime")? - .getattr("datetime")? - .call_method1("fromtimestamp", (seconds + micros,))?) + Ok(py.import("datetime")?.getattr("datetime")?.call_method1( + "fromtimestamp", + (timestamp.timestamp_millis() as f64 / 1000f64,), + )?) } ColumnType::Inet => Ok(unwrapped_value .as_inet() @@ -472,7 +497,8 @@ pub fn cql_to_py<'a>( let formatted_date = unwrapped_value .as_date() .ok_or(ScyllaPyError::ValueDowncastError(col_name.into(), "Date"))? - .format("%Y-%m-%d") + .format(DATE_FORMAT) + .map_err(|_| ScyllaPyError::ValueDowncastError(col_name.into(), "Date"))? .to_string(); Ok(py .import("datetime")? @@ -500,19 +526,18 @@ pub fn cql_to_py<'a>( .to_object(py) .into_ref(py)), ColumnType::Time => { - let duration = unwrapped_value - .as_duration() + let time = unwrapped_value + .as_time() .ok_or(ScyllaPyError::ValueDowncastError(col_name.into(), "Time"))?; - let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(0, 0).ok_or( - ScyllaPyError::ValueDowncastError( - col_name.into(), - "Time, because it's value is too big", - ), - )? + duration; + let kwargs = PyDict::new(py); + kwargs.set_item("hour", time.hour())?; + kwargs.set_item("minute", time.minute())?; + kwargs.set_item("second", time.second())?; + kwargs.set_item("microsecond", time.microsecond())?; Ok(py .import("datetime")? .getattr("time")? - .call_method1("fromisoformat", (time.format("%H:%M:%S%.6f").to_string(),))?) + .call((), Some(kwargs))?) } ColumnType::UserDefinedType { type_name, @@ -557,13 +582,13 @@ pub fn cql_to_py<'a>( } } -/// Parse python type to `SerializedValues`. +/// Parse python type to `LegacySerializedValues`. /// /// Serialized values are used for /// parameter binding. We parse python types /// into our own types that are capable /// of being bound to query and add parsed -/// results to `SerializedValues`. +/// results to `LegacySerializedValues`. /// /// # Errors /// @@ -573,8 +598,8 @@ pub fn parse_python_query_params( params: Option<&PyAny>, allow_dicts: bool, col_spec: Option<&[ColumnSpec]>, -) -> ScyllaPyResult { - let mut values = SerializedValues::new(); +) -> ScyllaPyResult { + let mut values = LegacySerializedValues::new(); let Some(params) = params else { return Ok(values); From cde055d430447d781fbe5f8d54236cace8097a28 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 30 Mar 2024 15:42:46 +0100 Subject: [PATCH 2/5] Fixed clippy. Signed-off-by: Pavel Kirilin --- src/query_builder/delete.rs | 2 +- src/query_builder/insert.rs | 2 +- src/query_builder/update.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/query_builder/delete.rs b/src/query_builder/delete.rs index 912db59..8a784f5 100644 --- a/src/query_builder/delete.rs +++ b/src/query_builder/delete.rs @@ -34,7 +34,7 @@ impl Delete { .columns .as_ref() .map_or(String::new(), |cols| cols.join(", ")); - let params = vec![ + let params = [ self.timestamp_ .map(|timestamp| format!("TIMESTAMP {timestamp}")), self.timeout_.as_ref().map(|timeout| match timeout { diff --git a/src/query_builder/insert.rs b/src/query_builder/insert.rs index 30f6e53..67ed097 100644 --- a/src/query_builder/insert.rs +++ b/src/query_builder/insert.rs @@ -50,7 +50,7 @@ impl Insert { } else { "" }; - let params = vec![ + let params = [ self.timestamp_ .map(|timestamp| format!("TIMESTAMP {timestamp}")), self.ttl_.map(|ttl| format!("TTL {ttl}")), diff --git a/src/query_builder/update.rs b/src/query_builder/update.rs index 4cb766c..ef22ab3 100644 --- a/src/query_builder/update.rs +++ b/src/query_builder/update.rs @@ -57,7 +57,7 @@ impl Update { "Update should contain at least one where clause", )); } - let params = vec![ + let params = [ self.timestamp_ .map(|timestamp| format!("TIMESTAMP {timestamp}")), self.ttl_.map(|ttl| format!("TTL {ttl}")), From 86149b89af318562857a8132edbae4188087aa3e Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 30 Mar 2024 19:22:12 +0100 Subject: [PATCH 3/5] Added Decimals support. Signed-off-by: Pavel Kirilin --- Cargo.toml | 2 -- pyproject.toml | 2 +- python/tests/test_bindings.py | 3 +++ src/utils.rs | 34 ++++++++++++++++++++++++++++++---- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 266e170..6cefd0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,9 +19,7 @@ pyo3 = { version = "0.20.0", features = [ "abi3-py38", "extension-module", "chrono", - "rust_decimal", ] } -rust_decimal = "1.0" pyo3-asyncio = { version = "0.20.0", features = ["tokio-runtime"] } pyo3-log = "0.9.0" rustc-hash = "1.1.0" diff --git a/pyproject.toml b/pyproject.toml index b9e00fe..283e523 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,4 +105,4 @@ convention = "pep257" ignore-decorators = ["typing.overload"] [tool.ruff.pylint] -allow-magic-value-types = ["int", "str", "float", "tuple"] +allow-magic-value-types = ["int", "str", "float"] diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index f99ec1a..7cb1cb1 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -2,6 +2,7 @@ import ipaddress import random import uuid +from decimal import Decimal from typing import Any, Callable import pytest @@ -30,6 +31,8 @@ ("UUID", uuid.uuid5(uuid.uuid4(), "name")), ("INET", ipaddress.ip_address("192.168.1.1")), ("INET", ipaddress.ip_address("2001:db8::8a2e:370:7334")), + ("DECIMAL", Decimal("1.1")), + ("DECIMAL", Decimal("1.112e10")), ], ) async def test_bindings( diff --git a/src/utils.rs b/src/utils.rs index c154a1a..5e1328c 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -92,6 +92,7 @@ pub enum ScyllaPyCQLDTO { Counter(i64), Bool(bool), Double(eq_float::F64), + Decimal(bigdecimal_04::BigDecimal), Float(eq_float::F32), Bytes(Vec), Date(chrono::NaiveDate), @@ -131,11 +132,12 @@ impl Value for ScyllaPyCQLDTO { ScyllaPyCQLDTO::Timestamp(timestamp) => { scylla::frame::value::CqlTimestamp::from(*timestamp).serialize(buf) } - ScyllaPyCQLDTO::Null => Option::::None.serialize(buf), + ScyllaPyCQLDTO::Null => Option::::None.serialize(buf), ScyllaPyCQLDTO::Udt(udt) => { buf.extend(udt); Ok(()) } + ScyllaPyCQLDTO::Decimal(decimal) => decimal.serialize(buf), ScyllaPyCQLDTO::Unset => scylla::frame::value::Unset.serialize(buf), } } @@ -247,6 +249,12 @@ pub fn py_to_value( Ok(ScyllaPyCQLDTO::Time(chrono::NaiveTime::from_str( item.call_method0("isoformat")?.extract::<&str>()?, )?)) + } else if item.get_type().name()? == "Decimal" { + Ok(ScyllaPyCQLDTO::Decimal( + bigdecimal_04::BigDecimal::from_str(item.str()?.to_str()?).map_err(|err| { + ScyllaPyError::BindingError(format!("Cannot parse decimal {err}")) + })?, + )) } else if item.get_type().name()? == "datetime" { let milliseconds = item.call_method0("timestamp")?.extract::()? * 1000f64; #[allow(clippy::cast_possible_truncation)] @@ -576,9 +584,27 @@ pub fn cql_to_py<'a>( } Ok(res_map) } - ColumnType::Custom(_) | ColumnType::Varint | ColumnType::Decimal => Err( - ScyllaPyError::ValueDowncastError(col_name.into(), "Unknown"), - ), + ColumnType::Decimal => { + // Because the `as_decimal` method is not implemented for `CqlValue`, + // will make a PR. + let decimal: bigdecimal_04::BigDecimal = match unwrapped_value { + CqlValue::Decimal(inner) => inner.clone().into(), + _ => { + return Err(ScyllaPyError::ValueDowncastError( + col_name.into(), + "Decimal", + )) + } + }; + Ok(py + .import("decimal")? + .getattr("Decimal")? + .call1((decimal.to_scientific_notation(),))?) + } + ColumnType::Custom(_) | ColumnType::Varint => Err(ScyllaPyError::ValueDowncastError( + col_name.into(), + "Unknown", + )), } } From e5cdceb130e3faeaa46e22d6e05ab2402220512c Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 30 Mar 2024 20:38:33 +0100 Subject: [PATCH 4/5] Added duration and varint support. Signed-off-by: Pavel Kirilin --- .pre-commit-config.yaml | 4 +++- README.md | 29 ++++++++++++++++++++++++ python/tests/test_bindings.py | 12 ++++++---- python/tests/test_extra_types.py | 27 +++++++++++++++++++++- src/utils.rs | 39 ++++++++++++++++++++++++++++++-- 5 files changed, 102 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6885f35..5081196 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,11 +20,13 @@ repos: always_run: true args: ["python"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.5.1 + rev: v1.9.0 hooks: - id: mypy name: python mypy always_run: true + additional_dependencies: + - "types-python-dateutil" pass_filenames: false args: ["python"] - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/README.md b/README.md index f64fd88..9b0ba0e 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,35 @@ new_query = query.with_consistency(Consistency.ALL) All `with_` methods create new query, copying all other parameters. +Here's the list of scylla types and corresponding python types that you should use while passing parameters to queries: + +| Scylla type | Python type | +| ----------- | ---------------------- | +| int | int | +| tinyint | extra_types.TinyInt | +| bigint | extra_types.BigInt | +| varint | any int type | +| float | float | +| double | extra_types.Double | +| decimal | decimal.Decimal | +| ascii | str | +| text | str | +| varchar | str | +| blob | bytes | +| boolean | bool | +| counter | extra_types.Counter | +| date | datetime.date | +| uuid | uuid.UUID | +| inet | ipaddress | +| time | datetime.time | +| timestamp | datetime.datetime | +| duration | dateutil.relativedelta | + +All types from `extra_types` module are used to eliminate any possible ambiguity while passing parameters to queries. You can find more information about them in `Extra types` section. + +We use relative delta from `dateutil` for duration, because it's the only way to represent it in python. Since scylla operates with months, days and nanosecond, there's no way we can represent it in python, becuase months are variable length. + + ## Named parameters Also, you can provide named parameters to querties, by using name diff --git a/python/tests/test_bindings.py b/python/tests/test_bindings.py index 7cb1cb1..87f0b18 100644 --- a/python/tests/test_bindings.py +++ b/python/tests/test_bindings.py @@ -6,6 +6,7 @@ from typing import Any, Callable import pytest +from dateutil.relativedelta import relativedelta from tests.utils import random_string from scyllapy import Scylla @@ -33,6 +34,8 @@ ("INET", ipaddress.ip_address("2001:db8::8a2e:370:7334")), ("DECIMAL", Decimal("1.1")), ("DECIMAL", Decimal("1.112e10")), + ("DURATION", relativedelta(months=1, days=2, microseconds=10)), + ("VARINT", 1000), ], ) async def test_bindings( @@ -42,15 +45,14 @@ async def test_bindings( ) -> None: table_name = random_string(4) await scylla.execute( - f"CREATE TABLE {table_name} (id {type_name}, PRIMARY KEY (id))", + f"CREATE TABLE {table_name} (id INT, value {type_name}, PRIMARY KEY (id))", ) - insert_query = f"INSERT INTO {table_name}(id) VALUES (?)" - await scylla.execute(insert_query, [test_val]) + insert_query = f"INSERT INTO {table_name}(id, value) VALUES (?, ?)" + await scylla.execute(insert_query, [1, test_val]) result = await scylla.execute(f"SELECT * FROM {table_name}") rows = result.all() - assert len(rows) == 1 - assert rows[0] == {"id": test_val} + assert rows == [{"id": 1, "value": test_val}] @pytest.mark.anyio diff --git a/python/tests/test_extra_types.py b/python/tests/test_extra_types.py index a0f7d8b..befee62 100644 --- a/python/tests/test_extra_types.py +++ b/python/tests/test_extra_types.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import Any +from typing import Any, Callable import pytest from tests.utils import random_string @@ -147,3 +147,28 @@ async def test_autocast_positional(scylla: Scylla, typ: str, val: Any) -> None: await scylla.execute(f"CREATE TABLE {table_name}(id INT PRIMARY KEY, val {typ})") prepared = await scylla.prepare(f"INSERT INTO {table_name}(id, val) VALUES (?, ?)") await scylla.execute(prepared, [1, val]) + + +@pytest.mark.parametrize( + ["cast_func", "val"], + [ + (extra_types.BigInt, 1000000), + (extra_types.SmallInt, 10), + (extra_types.TinyInt, 1), + (int, 1), + ], +) +@pytest.mark.anyio +async def test_varint( + scylla: Scylla, + cast_func: Callable[[Any], Any], + val: Any, +) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name}(id INT PRIMARY KEY, val VARINT)") + await scylla.execute( + f"INSERT INTO {table_name}(id, val) VALUES (?, ?)", + (1, cast_func(val)), + ) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "val": val}] diff --git a/src/utils.rs b/src/utils.rs index 5e1328c..2cff220 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -7,7 +7,7 @@ use pyo3::{ use scylla::{ frame::{ response::result::{ColumnSpec, ColumnType, CqlValue}, - value::{LegacySerializedValues, Value}, + value::{CqlDuration, LegacySerializedValues, Value}, }, BufMut, }; @@ -93,6 +93,11 @@ pub enum ScyllaPyCQLDTO { Bool(bool), Double(eq_float::F64), Decimal(bigdecimal_04::BigDecimal), + Duration { + months: i32, + days: i32, + nanoseconds: i64, + }, Float(eq_float::F32), Bytes(Vec), Date(chrono::NaiveDate), @@ -139,6 +144,16 @@ impl Value for ScyllaPyCQLDTO { } ScyllaPyCQLDTO::Decimal(decimal) => decimal.serialize(buf), ScyllaPyCQLDTO::Unset => scylla::frame::value::Unset.serialize(buf), + ScyllaPyCQLDTO::Duration { + months, + days, + nanoseconds, + } => CqlDuration { + months: *months, + days: *days, + nanoseconds: *nanoseconds, + } + .serialize(buf), } } } @@ -266,6 +281,16 @@ pub fn py_to_value( ScyllaPyError::BindingError("Cannot convert datetime to timestamp.".into()), )?; Ok(ScyllaPyCQLDTO::Timestamp(timestamp)) + } else if item.get_type().name()? == "relativedelta" { + let months = item.getattr("months")?.extract::()?; + let days = item.getattr("days")?.extract::()?; + let nanoseconds = item.getattr("microseconds")?.extract::()? * 1_000 + + item.getattr("seconds")?.extract::()? * 1_000_000; + Ok(ScyllaPyCQLDTO::Duration { + months, + days, + nanoseconds, + }) } else if item.is_instance_of::() || item.is_instance_of::() || item.is_instance_of::() @@ -601,7 +626,17 @@ pub fn cql_to_py<'a>( .getattr("Decimal")? .call1((decimal.to_scientific_notation(),))?) } - ColumnType::Custom(_) | ColumnType::Varint => Err(ScyllaPyError::ValueDowncastError( + ColumnType::Varint => { + let bigint: bigdecimal_04::num_bigint::BigInt = match unwrapped_value { + CqlValue::Varint(inner) => inner.clone().into(), + _ => return Err(ScyllaPyError::ValueDowncastError(col_name.into(), "Varint")), + }; + Ok(py + .import("builtins")? + .getattr("int")? + .call1((bigint.to_string(),))?) + } + ColumnType::Custom(_) => Err(ScyllaPyError::ValueDowncastError( col_name.into(), "Unknown", )), From 651273456bc9d620db815b53a6f05032a2e868e1 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 20 Apr 2024 11:28:49 +0200 Subject: [PATCH 5/5] Version bumped to 1.3.1 Signed-off-by: Pavel Kirilin --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 6cefd0c..f8e04dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scyllapy" -version = "1.3.0" +version = "1.3.1" edition = "2021" [lib]