diff --git a/Cargo.toml b/Cargo.toml index 94637c5..f0b7709 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scyllapy" -version = "1.2.1" +version = "1.3.0" edition = "2021" [lib] diff --git a/README.md b/README.md index 0d3a8d2..f64fd88 100644 --- a/README.md +++ b/README.md @@ -324,6 +324,55 @@ async def execute(scylla: Scylla) -> None: ) ``` +## User defined types + +We also support user defined types. You can pass them as a parameter to query. +Or parse it as a model in response. + +Here's binding example. Imagine we have defined a type in scylla like this: + +```cql +CREATE TYPE IF NOT EXISTS test ( + id int, + name text +); +``` + +Now we need to define a model for it in python. + +```python +from dataclasses import dataclass +from scyllapy.extra_types import ScyllaPyUDT + +@dataclass +class TestUDT(ScyllaPyUDT): + # Always define fields in the same order as in scylla. + # Otherwise you will get an error, or wrong data. + id: int + name: str + +async def execute(scylla: Scylla) -> None: + await scylla.execute( + "INSERT INTO table(id, udt_col) VALUES (?, ?)", + [1, TestUDT(id=1, name="test")], + ) + +``` + +We also support pydantic based models. Decalre them like this: + +```python +from pydantic import BaseModel +from scyllapy.extra_types import ScyllaPyUDT + + +class TestUDT(BaseModel, ScyllaPyUDT): + # Always define fields in the same order as in scylla. + # Otherwise you will get an error, or wrong data. + id: int + name: str + +``` # Query building diff --git a/python/scyllapy/extra_types.py b/python/scyllapy/extra_types.py index 34ca76b..d38054d 100644 --- a/python/scyllapy/extra_types.py +++ b/python/scyllapy/extra_types.py @@ -1,3 +1,48 @@ +import dataclasses +from typing import Any, List + from ._internal.extra_types import BigInt, Counter, Double, SmallInt, TinyInt, Unset -__all__ = ("BigInt", "Counter", "Double", "SmallInt", "TinyInt", "Unset") +try: + import pydantic +except ImportError: + pydantic = None + + +class ScyllaPyUDT: + """ + Class for declaring UDT models. + + This class is a mixin for models like dataclasses and pydantic models, + or classes that have `__slots__` attribute. + + It can be further extended to support other model types. + """ + + def __dump_udt__(self) -> List[Any]: + """ + Method to dump UDT models to a dict. + + This method returns a list of values in the order of the UDT fields. + Because in the protocol, UDT fields should be sent in the same order as + they were declared. + """ + if dataclasses.is_dataclass(self): + values = [] + for field in dataclasses.fields(self): + values.append(getattr(self, field.name)) + return values + if pydantic is not None and isinstance(self, pydantic.BaseModel): + values = [] + for param in self.__class__.__signature__.parameters: + values.append(getattr(self, param)) + return values + if hasattr(self, "__slots__"): + values = [] + for slot in self.__slots__: + values.append(getattr(self, slot)) + return values + raise ValueError("Unsupported model type") + + +__all__ = ("BigInt", "Counter", "Double", "SmallInt", "TinyInt", "Unset", "ScyllaPyUDT") diff --git a/python/tests/test_extra_types.py b/python/tests/test_extra_types.py index 28ed073..a0f7d8b 100644 --- a/python/tests/test_extra_types.py +++ b/python/tests/test_extra_types.py @@ -1,3 +1,4 @@ +from dataclasses import asdict, dataclass from typing import Any import pytest @@ -69,3 +70,80 @@ async def test_unset(scylla: Scylla) -> None: f"INSERT INTO {table_name}(id, name) VALUES (?, ?)", [1, extra_types.Unset()], ) + + +@pytest.mark.anyio +async def test_udts(scylla: Scylla) -> None: + @dataclass + class TestUDT(extra_types.ScyllaPyUDT): + id: int + name: str + + table_name = random_string(4) + + udt_val = TestUDT(id=1, name="test") + await scylla.execute(f"CREATE TYPE test_udt{table_name} (id int, name text)") + await scylla.execute( + f"CREATE TABLE {table_name} " + f"(id INT PRIMARY KEY, udt_col frozen)", + ) + await scylla.execute( + f"INSERT INTO {table_name} (id, udt_col) VALUES (?, ?)", + [1, udt_val], + ) + + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "udt_col": asdict(udt_val)}] + + +@pytest.mark.anyio +async def test_nested_udts(scylla: Scylla) -> None: + @dataclass + class NestedUDT(extra_types.ScyllaPyUDT): + one: int + two: str + + @dataclass + class TestUDT(extra_types.ScyllaPyUDT): + id: int + name: str + nested: NestedUDT + + table_name = random_string(4) + + udt_val = TestUDT(id=1, name="test", nested=NestedUDT(one=1, two="2")) + await scylla.execute(f"CREATE TYPE nested_udt{table_name} (one int, two text)") + await scylla.execute( + f"CREATE TYPE test_udt{table_name} " + f"(id int, name text, nested frozen)", + ) + await scylla.execute( + f"CREATE TABLE {table_name} " + f"(id INT PRIMARY KEY, udt_col frozen)", + ) + await scylla.execute( + f"INSERT INTO {table_name} (id, udt_col) VALUES (?, ?)", + [1, udt_val], + ) + + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "udt_col": asdict(udt_val)}] + + +@pytest.mark.parametrize( + ["typ", "val"], + [ + ("BIGINT", 1), + ("TINYINT", 1), + ("SMALLINT", 1), + ("INT", 1), + ("FLOAT", 1.0), + ("DOUBLE", 1.0), + ], +) +@pytest.mark.anyio +async def test_autocast_positional(scylla: Scylla, typ: str, val: Any) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TABLE {table_name}(id INT PRIMARY KEY, val {typ})") + prepared = await scylla.prepare(f"INSERT INTO {table_name}(id, val) VALUES (?, ?)") + await scylla.execute(prepared, [1, val]) diff --git a/python/tests/test_parsing.py b/python/tests/test_parsing.py new file mode 100644 index 0000000..b4cfab5 --- /dev/null +++ b/python/tests/test_parsing.py @@ -0,0 +1,19 @@ +import pytest +from tests.utils import random_string + +from scyllapy import Scylla + + +@pytest.mark.anyio +async def test_udt_parsing(scylla: Scylla) -> None: + table_name = random_string(4) + await scylla.execute(f"CREATE TYPE test_udt{table_name} (id int, name text)") + await scylla.execute( + f"CREATE TABLE {table_name} " + f"(id int PRIMARY KEY, udt_col frozen)", + ) + await scylla.execute( + f"INSERT INTO {table_name} (id, udt_col) VALUES (1, {{id: 1, name: 'test'}})", + ) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all() == [{"id": 1, "udt_col": {"id": 1, "name": "test"}}] diff --git a/python/tests/test_queries.py b/python/tests/test_queries.py index d265f0a..5d3652a 100644 --- a/python/tests/test_queries.py +++ b/python/tests/test_queries.py @@ -28,3 +28,34 @@ class TestDTO: res = await scylla.execute(f"SELECT id FROM {table_name}") assert res.all(as_class=TestDTO) == [TestDTO(id=42)] + + +@pytest.mark.anyio +async def test_udt_as_dataclass(scylla: Scylla) -> None: + @dataclass + class UDTType: + id: int + name: str + + @dataclass + class TestDTO: + id: int + udt_col: UDTType + + def __post_init__(self) -> None: + if not isinstance(self.udt_col, UDTType): + self.udt_col = UDTType(**self.udt_col) + + table_name = random_string(4) + await scylla.execute(f"CREATE TYPE test_udt{table_name} (id int, name text)") + await scylla.execute( + f"CREATE TABLE {table_name} " + f"(id int PRIMARY KEY, udt_col frozen)", + ) + await scylla.execute( + f"INSERT INTO {table_name} (id, udt_col) VALUES (1, {{id: 1, name: 'test'}})", + ) + res = await scylla.execute(f"SELECT * FROM {table_name}") + assert res.all(as_class=TestDTO) == [ + TestDTO(id=1, udt_col=UDTType(id=1, name="test")), + ] diff --git a/src/batches.rs b/src/batches.rs index 7fdc42d..76e7b6e 100644 --- a/src/batches.rs +++ b/src/batches.rs @@ -121,7 +121,7 @@ impl ScyllaPyInlineBatch { self.inner.append_statement(query); if let Some(passed_params) = values { self.values - .push(parse_python_query_params(Some(passed_params), false)?); + .push(parse_python_query_params(Some(passed_params), false, None)?); } else { self.values.push(SerializedValues::new()); } diff --git a/src/exceptions/rust_err.rs b/src/exceptions/rust_err.rs index 8d4baa2..72e607e 100644 --- a/src/exceptions/rust_err.rs +++ b/src/exceptions/rust_err.rs @@ -44,6 +44,8 @@ pub enum ScyllaPyError { RowsDowncastError(String), #[error("Cannot parse value of column {0} as {1}.")] ValueDowncastError(String, &'static str), + #[error("Cannot downcast UDT {0} of column {1}. Reason: {2}.")] + UDTDowncastError(String, String, String), #[error("Query didn't suppose to return anything.")] NoReturnsError, #[error("Query doesn't have columns.")] @@ -73,6 +75,7 @@ impl From for pyo3::PyErr { | ScyllaPyError::IpParseError(_) => ScyllaPyBindingError::new_err((err_desc,)), ScyllaPyError::RowsDowncastError(_) | ScyllaPyError::ValueDowncastError(_, _) + | ScyllaPyError::UDTDowncastError(_, _, _) | ScyllaPyError::NoReturnsError | ScyllaPyError::NoColumns => ScyllaPyMappingError::new_err((err_desc,)), ScyllaPyError::QueryBuilderError(_) => ScyllaPyQueryBuiderError::new_err((err_desc,)), diff --git a/src/prepared_queries.rs b/src/prepared_queries.rs index 64623b1..65ee4f9 100644 --- a/src/prepared_queries.rs +++ b/src/prepared_queries.rs @@ -4,7 +4,7 @@ use scylla::prepared_statement::PreparedStatement; #[pyclass(name = "PreparedQuery")] #[derive(Clone, Debug)] pub struct ScyllaPyPreparedQuery { - inner: PreparedStatement, + pub inner: PreparedStatement, } impl From for ScyllaPyPreparedQuery { diff --git a/src/query_builder/delete.rs b/src/query_builder/delete.rs index 14df316..79a4c43 100644 --- a/src/query_builder/delete.rs +++ b/src/query_builder/delete.rs @@ -110,7 +110,7 @@ impl Delete { slf.where_clauses_.push(clause); if let Some(vals) = values { for value in vals { - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); } } Ok(slf) @@ -148,7 +148,7 @@ impl Delete { ) -> ScyllaPyResult> { let parsed_values = if let Some(vals) = values { vals.iter() - .map(|item| py_to_value(item)) + .map(|item| py_to_value(item, None)) .collect::, _>>()? } else { vec![] diff --git a/src/query_builder/insert.rs b/src/query_builder/insert.rs index fb00826..ed916ac 100644 --- a/src/query_builder/insert.rs +++ b/src/query_builder/insert.rs @@ -114,7 +114,7 @@ impl Insert { if value.is_none() { slf.values_.push(ScyllaPyCQLDTO::Unset); } else { - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); } Ok(slf) } diff --git a/src/query_builder/select.rs b/src/query_builder/select.rs index 875544b..f1683d2 100644 --- a/src/query_builder/select.rs +++ b/src/query_builder/select.rs @@ -153,7 +153,7 @@ impl Select { slf.where_clauses_.push(clause); if let Some(vals) = values { for value in vals { - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); } } Ok(slf) diff --git a/src/query_builder/update.rs b/src/query_builder/update.rs index 43d678a..0a9120f 100644 --- a/src/query_builder/update.rs +++ b/src/query_builder/update.rs @@ -130,7 +130,7 @@ impl Update { value: &'a PyAny, ) -> ScyllaPyResult> { slf.assignments_.push(UpdateAssignment::Simple(name)); - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); Ok(slf) } @@ -147,7 +147,7 @@ impl Update { ) -> ScyllaPyResult> { slf.assignments_ .push(UpdateAssignment::Inc(name.clone(), name)); - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); Ok(slf) } @@ -164,7 +164,7 @@ impl Update { ) -> ScyllaPyResult> { slf.assignments_ .push(UpdateAssignment::Dec(name.clone(), name)); - slf.values_.push(py_to_value(value)?); + slf.values_.push(py_to_value(value, None)?); Ok(slf) } /// Add where clause. @@ -187,7 +187,7 @@ impl Update { slf.where_clauses_.push(clause); if let Some(vals) = values { for value in vals { - slf.where_values_.push(py_to_value(value)?); + slf.where_values_.push(py_to_value(value, None)?); } } Ok(slf) @@ -248,7 +248,7 @@ impl Update { ) -> ScyllaPyResult> { let parsed_values = if let Some(vals) = values { vals.iter() - .map(|item| py_to_value(item)) + .map(|item| py_to_value(item, None)) .collect::, _>>()? } else { vec![] diff --git a/src/scylla_cls.rs b/src/scylla_cls.rs index 320d23d..11352e7 100644 --- a/src/scylla_cls.rs +++ b/src/scylla_cls.rs @@ -290,9 +290,13 @@ impl Scylla { params: Option<&'a PyAny>, paged: bool, ) -> ScyllaPyResult<&'a PyAny> { + let mut col_spec = None; // We need to prepare parameter we're going to use // in query. - let query_params = parse_python_query_params(params, true)?; + if let ExecuteInput::PreparedQuery(prepared) = &query { + col_spec = Some(prepared.inner.get_prepared_metadata().col_specs.as_ref()); + } + let query_params = parse_python_query_params(params, true, col_spec)?; // We need this clone, to safely share the session between threads. let (query, prepared) = match query { ExecuteInput::Text(txt) => (Some(Query::new(txt)), None), @@ -322,7 +326,11 @@ impl Scylla { let mut batch_params = Vec::new(); if let Some(passed_params) = params { for query_params in passed_params { - batch_params.push(parse_python_query_params(Some(query_params), false)?); + batch_params.push(parse_python_query_params( + Some(query_params), + false, + None, + )?); } } (batch.into(), batch_params) diff --git a/src/utils.rs b/src/utils.rs index 5b4a960..5a14ee2 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,10 +5,14 @@ use pyo3::{ types::{PyBool, PyBytes, PyDict, PyFloat, PyInt, PyList, PyModule, PySet, PyString, PyTuple}, IntoPy, Py, PyAny, PyObject, PyResult, Python, ToPyObject, }; -use scylla::frame::{ - response::result::{ColumnType, CqlValue}, - value::{SerializedValues, Value}, +use scylla::{ + frame::{ + response::result::{ColumnType, CqlValue}, + value::{SerializedValues, Value}, + }, + BufMut, }; +use scylla_cql::frame::response::result::ColumnSpec; use std::net::IpAddr; @@ -96,6 +100,8 @@ pub enum ScyllaPyCQLDTO { Inet(IpAddr), List(Vec), Map(Vec<(ScyllaPyCQLDTO, ScyllaPyCQLDTO)>), + // UDT holds serialized bytes according to the protocol. + Udt(Vec), } impl Value for ScyllaPyCQLDTO { @@ -125,6 +131,10 @@ impl Value for ScyllaPyCQLDTO { scylla::frame::value::Timestamp(*timestamp).serialize(buf) } ScyllaPyCQLDTO::Null => Option::::None.serialize(buf), + ScyllaPyCQLDTO::Udt(udt) => { + buf.extend(udt); + Ok(()) + } ScyllaPyCQLDTO::Unset => scylla::frame::value::Unset.serialize(buf), } } @@ -140,7 +150,11 @@ impl Value for ScyllaPyCQLDTO { /// /// May raise an error, if /// value cannot be converted or unnown type was passed. -pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { +#[allow(clippy::too_many_lines)] +pub fn py_to_value( + item: &PyAny, + column_type: Option<&ColumnType>, +) -> ScyllaPyResult { if item.is_none() { Ok(ScyllaPyCQLDTO::Null) } else if item.is_instance_of::() { @@ -150,9 +164,20 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { } else if item.is_instance_of::() { Ok(ScyllaPyCQLDTO::Bool(item.extract::()?)) } else if item.is_instance_of::() { - Ok(ScyllaPyCQLDTO::Int(item.extract::()?)) + match column_type { + Some(ColumnType::TinyInt) => Ok(ScyllaPyCQLDTO::TinyInt(item.extract::()?)), + Some(ColumnType::SmallInt) => Ok(ScyllaPyCQLDTO::SmallInt(item.extract::()?)), + Some(ColumnType::BigInt) => Ok(ScyllaPyCQLDTO::BigInt(item.extract::()?)), + Some(ColumnType::Counter) => Ok(ScyllaPyCQLDTO::Counter(item.extract::()?)), + Some(_) | None => Ok(ScyllaPyCQLDTO::Int(item.extract::()?)), + } } else if item.is_instance_of::() { - Ok(ScyllaPyCQLDTO::Float(eq_float::F32(item.extract::()?))) + match column_type { + Some(ColumnType::Double) => Ok(ScyllaPyCQLDTO::Double(eq_float::F64( + item.extract::()?, + ))), + Some(_) | None => Ok(ScyllaPyCQLDTO::Float(eq_float::F32(item.extract::()?))), + } } else if item.is_instance_of::() { Ok(ScyllaPyCQLDTO::SmallInt( item.extract::()?.get_value(), @@ -175,6 +200,36 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { )) } else if item.is_instance_of::() { Ok(ScyllaPyCQLDTO::Bytes(item.extract::>()?)) + } else if item.hasattr("__dump_udt__")? { + let dumped = item.call_method0("__dump_udt__")?; + let dumped_py = dumped.downcast::().map_err(|err| { + ScyllaPyError::BindingError(format!( + "Cannot get UDT values. __dump_udt__ has returned not a list value. {err}" + )) + })?; + let mut buf = Vec::new(); + // Here we put the size of UDT value. + // Now it's zero, but we will replace it after serialization. + buf.put_i32(0); + for val in dumped_py { + // Here we serialize all fields. + py_to_value(val, None)? + .serialize(buf.as_mut()) + .map_err(|err| { + ScyllaPyError::BindingError(format!( + "Cannot serialize UDT field because of {err}" + )) + })?; + } + // Then we calculate the size of the UDT value, cast it to i32 + // and put it in the beginning of the buffer. + let buf_len: i32 = buf.len().try_into().map_err(|_| { + ScyllaPyError::BindingError("Cannot serialize. UDT value is too big.".into()) + })?; + // Here we also subtract 4 bytes, because we don't want to count + // size buffer itself. + buf[0..4].copy_from_slice(&(buf_len - 4).to_be_bytes()[..]); + Ok(ScyllaPyCQLDTO::Udt(buf)) } else if item.get_type().name()? == "UUID" { Ok(ScyllaPyCQLDTO::Uuid(uuid::Uuid::parse_str( item.str()?.extract::<&str>()?, @@ -209,7 +264,7 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { { let mut items = Vec::new(); for inner in item.iter()? { - items.push(py_to_value(inner?)?); + items.push(py_to_value(inner?, column_type)?); } Ok(ScyllaPyCQLDTO::List(items)) } else if item.is_instance_of::() { @@ -222,8 +277,8 @@ pub fn py_to_value(item: &PyAny) -> ScyllaPyResult { ScyllaPyError::BindingError(format!("Cannot cast to tuple: {err}")) })?; items.push(( - py_to_value(item_tuple.get_item(0)?)?, - py_to_value(item_tuple.get_item(1)?)?, + py_to_value(item_tuple.get_item(0)?, column_type)?, + py_to_value(item_tuple.get_item(1)?, column_type)?, )); } Ok(ScyllaPyCQLDTO::Map(items)) @@ -459,13 +514,46 @@ pub fn cql_to_py<'a>( .getattr("time")? .call_method1("fromisoformat", (time.format("%H:%M:%S%.6f").to_string(),))?) } - ColumnType::Custom(_) - | ColumnType::Varint - | ColumnType::Decimal - | ColumnType::UserDefinedType { .. } => Err(ScyllaPyError::ValueDowncastError( - col_name.into(), - "Unknown", - )), + ColumnType::UserDefinedType { + type_name, + keyspace, + field_types, + } => { + let mut fields: HashMap<&str, &ColumnType, _> = HashMap::with_capacity_and_hasher( + field_types.len(), + BuildHasherDefault::::default(), + ); + for (field_name, field_type) in field_types { + fields.insert(field_name.as_str(), field_type); + } + let map_values = unwrapped_value + .as_udt() + .ok_or(ScyllaPyError::ValueDowncastError(col_name.into(), "UDT"))? + .iter() + .map(|(key, val)| -> ScyllaPyResult<(&str, &'a PyAny)> { + let column_type = fields.get(key.as_str()).ok_or_else(|| { + ScyllaPyError::UDTDowncastError( + format!("{keyspace}.{type_name}"), + col_name.into(), + format!("UDT field {key} is not defined in schema"), + ) + })?; + Ok(( + key.as_str(), + cql_to_py(py, col_name, column_type, val.as_ref())?, + )) + }) + .collect::, _>>()?; + + let res_map = PyDict::new(py); + for (key, value) in map_values { + res_map.set_item(key, value)?; + } + Ok(res_map) + } + ColumnType::Custom(_) | ColumnType::Varint | ColumnType::Decimal => Err( + ScyllaPyError::ValueDowncastError(col_name.into(), "Unknown"), + ), } } @@ -484,6 +572,7 @@ pub fn cql_to_py<'a>( pub fn parse_python_query_params( params: Option<&PyAny>, allow_dicts: bool, + col_spec: Option<&[ColumnSpec]>, ) -> ScyllaPyResult { let mut values = SerializedValues::new(); @@ -495,17 +584,30 @@ pub fn parse_python_query_params( // Otherwise it parses dict to named parameters. if params.is_instance_of::() || params.is_instance_of::() { let params = params.extract::>()?; - for param in params { - let py_dto = py_to_value(param)?; + for (index, param) in params.iter().enumerate() { + let coltype = col_spec.and_then(|specs| specs.get(index)).map(|f| &f.typ); + let py_dto = py_to_value(param, coltype)?; values.add_value(&py_dto)?; } return Ok(values); } else if params.is_instance_of::() { if allow_dicts { + let types_map = col_spec + .map(|specs| { + specs + .iter() + .map(|spec| (spec.name.as_str(), spec.typ.clone())) + .collect::>>() + }) + .unwrap_or_default(); + // let map = HashMap::with_capacity_and_hasher(, hasher) let dict = params .extract::>>()?; for (name, value) in dict { - values.add_named_value(name.to_lowercase().as_str(), &py_to_value(value)?)?; + values.add_named_value( + name.to_lowercase().as_str(), + &py_to_value(value, types_map.get(name))?, + )?; } return Ok(values); }