Skip to content

Commit

Permalink
Merge branch 'release/1.3.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius committed Nov 3, 2023
2 parents 8d31f9d + 9802476 commit bba4648
Show file tree
Hide file tree
Showing 15 changed files with 369 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "scyllapy"
version = "1.2.1"
version = "1.3.0"
edition = "2021"

[lib]
Expand Down
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,55 @@ async def execute(scylla: Scylla) -> None:
)
```

## User defined types

We also support user defined types. You can pass them as a parameter to query.
Or parse it as a model in response.

Here's binding example. Imagine we have defined a type in scylla like this:

```cql
CREATE TYPE IF NOT EXISTS test (
id int,
name text
);
```

Now we need to define a model for it in python.

```python
from dataclasses import dataclass
from scyllapy.extra_types import ScyllaPyUDT

@dataclass
class TestUDT(ScyllaPyUDT):
# Always define fields in the same order as in scylla.
# Otherwise you will get an error, or wrong data.
id: int
name: str

async def execute(scylla: Scylla) -> None:
await scylla.execute(
"INSERT INTO table(id, udt_col) VALUES (?, ?)",
[1, TestUDT(id=1, name="test")],
)

```

We also support pydantic based models. Decalre them like this:

```python
from pydantic import BaseModel
from scyllapy.extra_types import ScyllaPyUDT


class TestUDT(BaseModel, ScyllaPyUDT):
# Always define fields in the same order as in scylla.
# Otherwise you will get an error, or wrong data.
id: int
name: str

```

# Query building

Expand Down
47 changes: 46 additions & 1 deletion python/scyllapy/extra_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,48 @@
import dataclasses
from typing import Any, List

from ._internal.extra_types import BigInt, Counter, Double, SmallInt, TinyInt, Unset

__all__ = ("BigInt", "Counter", "Double", "SmallInt", "TinyInt", "Unset")
try:
import pydantic
except ImportError:
pydantic = None


class ScyllaPyUDT:
"""
Class for declaring UDT models.
This class is a mixin for models like dataclasses and pydantic models,
or classes that have `__slots__` attribute.
It can be further extended to support other model types.
"""

def __dump_udt__(self) -> List[Any]:
"""
Method to dump UDT models to a dict.
This method returns a list of values in the order of the UDT fields.
Because in the protocol, UDT fields should be sent in the same order as
they were declared.
"""
if dataclasses.is_dataclass(self):
values = []
for field in dataclasses.fields(self):
values.append(getattr(self, field.name))
return values
if pydantic is not None and isinstance(self, pydantic.BaseModel):
values = []
for param in self.__class__.__signature__.parameters:
values.append(getattr(self, param))
return values
if hasattr(self, "__slots__"):
values = []
for slot in self.__slots__:
values.append(getattr(self, slot))
return values
raise ValueError("Unsupported model type")


__all__ = ("BigInt", "Counter", "Double", "SmallInt", "TinyInt", "Unset", "ScyllaPyUDT")
78 changes: 78 additions & 0 deletions python/tests/test_extra_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import asdict, dataclass
from typing import Any

import pytest
Expand Down Expand Up @@ -69,3 +70,80 @@ async def test_unset(scylla: Scylla) -> None:
f"INSERT INTO {table_name}(id, name) VALUES (?, ?)",
[1, extra_types.Unset()],
)


@pytest.mark.anyio
async def test_udts(scylla: Scylla) -> None:
@dataclass
class TestUDT(extra_types.ScyllaPyUDT):
id: int
name: str

table_name = random_string(4)

udt_val = TestUDT(id=1, name="test")
await scylla.execute(f"CREATE TYPE test_udt{table_name} (id int, name text)")
await scylla.execute(
f"CREATE TABLE {table_name} "
f"(id INT PRIMARY KEY, udt_col frozen<test_udt{table_name}>)",
)
await scylla.execute(
f"INSERT INTO {table_name} (id, udt_col) VALUES (?, ?)",
[1, udt_val],
)

res = await scylla.execute(f"SELECT * FROM {table_name}")
assert res.all() == [{"id": 1, "udt_col": asdict(udt_val)}]


@pytest.mark.anyio
async def test_nested_udts(scylla: Scylla) -> None:
@dataclass
class NestedUDT(extra_types.ScyllaPyUDT):
one: int
two: str

@dataclass
class TestUDT(extra_types.ScyllaPyUDT):
id: int
name: str
nested: NestedUDT

table_name = random_string(4)

udt_val = TestUDT(id=1, name="test", nested=NestedUDT(one=1, two="2"))
await scylla.execute(f"CREATE TYPE nested_udt{table_name} (one int, two text)")
await scylla.execute(
f"CREATE TYPE test_udt{table_name} "
f"(id int, name text, nested frozen<nested_udt{table_name}>)",
)
await scylla.execute(
f"CREATE TABLE {table_name} "
f"(id INT PRIMARY KEY, udt_col frozen<test_udt{table_name}>)",
)
await scylla.execute(
f"INSERT INTO {table_name} (id, udt_col) VALUES (?, ?)",
[1, udt_val],
)

res = await scylla.execute(f"SELECT * FROM {table_name}")
assert res.all() == [{"id": 1, "udt_col": asdict(udt_val)}]


@pytest.mark.parametrize(
["typ", "val"],
[
("BIGINT", 1),
("TINYINT", 1),
("SMALLINT", 1),
("INT", 1),
("FLOAT", 1.0),
("DOUBLE", 1.0),
],
)
@pytest.mark.anyio
async def test_autocast_positional(scylla: Scylla, typ: str, val: Any) -> None:
table_name = random_string(4)
await scylla.execute(f"CREATE TABLE {table_name}(id INT PRIMARY KEY, val {typ})")
prepared = await scylla.prepare(f"INSERT INTO {table_name}(id, val) VALUES (?, ?)")
await scylla.execute(prepared, [1, val])
19 changes: 19 additions & 0 deletions python/tests/test_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from tests.utils import random_string

from scyllapy import Scylla


@pytest.mark.anyio
async def test_udt_parsing(scylla: Scylla) -> None:
table_name = random_string(4)
await scylla.execute(f"CREATE TYPE test_udt{table_name} (id int, name text)")
await scylla.execute(
f"CREATE TABLE {table_name} "
f"(id int PRIMARY KEY, udt_col frozen<test_udt{table_name}>)",
)
await scylla.execute(
f"INSERT INTO {table_name} (id, udt_col) VALUES (1, {{id: 1, name: 'test'}})",
)
res = await scylla.execute(f"SELECT * FROM {table_name}")
assert res.all() == [{"id": 1, "udt_col": {"id": 1, "name": "test"}}]
31 changes: 31 additions & 0 deletions python/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,34 @@ class TestDTO:
res = await scylla.execute(f"SELECT id FROM {table_name}")

assert res.all(as_class=TestDTO) == [TestDTO(id=42)]


@pytest.mark.anyio
async def test_udt_as_dataclass(scylla: Scylla) -> None:
@dataclass
class UDTType:
id: int
name: str

@dataclass
class TestDTO:
id: int
udt_col: UDTType

def __post_init__(self) -> None:
if not isinstance(self.udt_col, UDTType):
self.udt_col = UDTType(**self.udt_col)

table_name = random_string(4)
await scylla.execute(f"CREATE TYPE test_udt{table_name} (id int, name text)")
await scylla.execute(
f"CREATE TABLE {table_name} "
f"(id int PRIMARY KEY, udt_col frozen<test_udt{table_name}>)",
)
await scylla.execute(
f"INSERT INTO {table_name} (id, udt_col) VALUES (1, {{id: 1, name: 'test'}})",
)
res = await scylla.execute(f"SELECT * FROM {table_name}")
assert res.all(as_class=TestDTO) == [
TestDTO(id=1, udt_col=UDTType(id=1, name="test")),
]
2 changes: 1 addition & 1 deletion src/batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl ScyllaPyInlineBatch {
self.inner.append_statement(query);
if let Some(passed_params) = values {
self.values
.push(parse_python_query_params(Some(passed_params), false)?);
.push(parse_python_query_params(Some(passed_params), false, None)?);
} else {
self.values.push(SerializedValues::new());
}
Expand Down
3 changes: 3 additions & 0 deletions src/exceptions/rust_err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub enum ScyllaPyError {
RowsDowncastError(String),
#[error("Cannot parse value of column {0} as {1}.")]
ValueDowncastError(String, &'static str),
#[error("Cannot downcast UDT {0} of column {1}. Reason: {2}.")]
UDTDowncastError(String, String, String),
#[error("Query didn't suppose to return anything.")]
NoReturnsError,
#[error("Query doesn't have columns.")]
Expand Down Expand Up @@ -73,6 +75,7 @@ impl From<ScyllaPyError> for pyo3::PyErr {
| ScyllaPyError::IpParseError(_) => ScyllaPyBindingError::new_err((err_desc,)),
ScyllaPyError::RowsDowncastError(_)
| ScyllaPyError::ValueDowncastError(_, _)
| ScyllaPyError::UDTDowncastError(_, _, _)
| ScyllaPyError::NoReturnsError
| ScyllaPyError::NoColumns => ScyllaPyMappingError::new_err((err_desc,)),
ScyllaPyError::QueryBuilderError(_) => ScyllaPyQueryBuiderError::new_err((err_desc,)),
Expand Down
2 changes: 1 addition & 1 deletion src/prepared_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use scylla::prepared_statement::PreparedStatement;
#[pyclass(name = "PreparedQuery")]
#[derive(Clone, Debug)]
pub struct ScyllaPyPreparedQuery {
inner: PreparedStatement,
pub inner: PreparedStatement,
}

impl From<PreparedStatement> for ScyllaPyPreparedQuery {
Expand Down
4 changes: 2 additions & 2 deletions src/query_builder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl Delete {
slf.where_clauses_.push(clause);
if let Some(vals) = values {
for value in vals {
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
}
}
Ok(slf)
Expand Down Expand Up @@ -148,7 +148,7 @@ impl Delete {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
let parsed_values = if let Some(vals) = values {
vals.iter()
.map(|item| py_to_value(item))
.map(|item| py_to_value(item, None))
.collect::<Result<Vec<_>, _>>()?
} else {
vec![]
Expand Down
2 changes: 1 addition & 1 deletion src/query_builder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl Insert {
if value.is_none() {
slf.values_.push(ScyllaPyCQLDTO::Unset);
} else {
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
}
Ok(slf)
}
Expand Down
2 changes: 1 addition & 1 deletion src/query_builder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl Select {
slf.where_clauses_.push(clause);
if let Some(vals) = values {
for value in vals {
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
}
}
Ok(slf)
Expand Down
10 changes: 5 additions & 5 deletions src/query_builder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl Update {
value: &'a PyAny,
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
slf.assignments_.push(UpdateAssignment::Simple(name));
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
Ok(slf)
}

Expand All @@ -147,7 +147,7 @@ impl Update {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
slf.assignments_
.push(UpdateAssignment::Inc(name.clone(), name));
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
Ok(slf)
}

Expand All @@ -164,7 +164,7 @@ impl Update {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
slf.assignments_
.push(UpdateAssignment::Dec(name.clone(), name));
slf.values_.push(py_to_value(value)?);
slf.values_.push(py_to_value(value, None)?);
Ok(slf)
}
/// Add where clause.
Expand All @@ -187,7 +187,7 @@ impl Update {
slf.where_clauses_.push(clause);
if let Some(vals) = values {
for value in vals {
slf.where_values_.push(py_to_value(value)?);
slf.where_values_.push(py_to_value(value, None)?);
}
}
Ok(slf)
Expand Down Expand Up @@ -248,7 +248,7 @@ impl Update {
) -> ScyllaPyResult<PyRefMut<'a, Self>> {
let parsed_values = if let Some(vals) = values {
vals.iter()
.map(|item| py_to_value(item))
.map(|item| py_to_value(item, None))
.collect::<Result<Vec<_>, _>>()?
} else {
vec![]
Expand Down
12 changes: 10 additions & 2 deletions src/scylla_cls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,13 @@ impl Scylla {
params: Option<&'a PyAny>,
paged: bool,
) -> ScyllaPyResult<&'a PyAny> {
let mut col_spec = None;
// We need to prepare parameter we're going to use
// in query.
let query_params = parse_python_query_params(params, true)?;
if let ExecuteInput::PreparedQuery(prepared) = &query {
col_spec = Some(prepared.inner.get_prepared_metadata().col_specs.as_ref());
}
let query_params = parse_python_query_params(params, true, col_spec)?;
// We need this clone, to safely share the session between threads.
let (query, prepared) = match query {
ExecuteInput::Text(txt) => (Some(Query::new(txt)), None),
Expand Down Expand Up @@ -322,7 +326,11 @@ impl Scylla {
let mut batch_params = Vec::new();
if let Some(passed_params) = params {
for query_params in passed_params {
batch_params.push(parse_python_query_params(Some(query_params), false)?);
batch_params.push(parse_python_query_params(
Some(query_params),
false,
None,
)?);
}
}
(batch.into(), batch_params)
Expand Down
Loading

0 comments on commit bba4648

Please sign in to comment.