Skip to content

Commit

Permalink
Merge pull request #659 from Mause/bugfix/nested-schema-reflection
Browse files Browse the repository at this point in the history
feat: nested column types
  • Loading branch information
Mause authored Jun 20, 2023
2 parents eaa9bf7 + 6e0b24c commit 58200b5
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 8 deletions.
143 changes: 137 additions & 6 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
select * from duckdb_types where type_category = 'NUMERIC';
```
"""
import typing
from typing import Any, Callable, Dict, Optional, Type

from typing import Any

from sqlalchemy.dialects.postgresql.base import PGTypeCompiler
from sqlalchemy import exc
from sqlalchemy.dialects.postgresql.base import PGIdentifierPreparer, PGTypeCompiler
from sqlalchemy.engine import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql import sqltypes, type_api
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.types import BigInteger, Integer, SmallInteger

Expand Down Expand Up @@ -85,12 +87,88 @@ def compile_uint(element: Integer, compiler: PGTypeCompiler, **kw: Any) -> str:
assert types


TV = typing.Union[Type[TypeEngine], TypeEngine]


class Struct(TypeEngine):
pass
"""
Represents a STRUCT type in DuckDB
```python
from duckdb_engine.datatypes import Struct
from sqlalchemy import Table, Column, String
Table(
'hello',
Column('name', Struct({'first': String, 'last': String})
)
```
:param fields: only optional due to limitations with how much type information DuckDB returns to us in the description field
"""

__visit_name__ = "struct"

def __init__(self, fields: Optional[Dict[str, TV]] = None):
self.fields = fields


class Map(TypeEngine):
pass
"""
Represents a MAP type in DuckDB
```python
from duckdb_engine.datatypes import Map
from sqlalchemy import Table, Column, String
Table(
'hello',
Column('name', Map(String, String)
)
```
"""

__visit_name__ = "map"
key_type: TV
value_type: TV

def __init__(self, key_type: TV, value_type: TV):
self.key_type = key_type
self.value_type = value_type

def bind_processor(
self, dialect: Dialect
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: (
{"key": list(value), "value": list(value.values())} if value else None
)

def result_processor(
self, dialect: Dialect, coltype: str
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: dict(zip(value["key"], value["value"])) if value else {}


class Union(TypeEngine):
"""
Represents a UNION type in DuckDB
```python
from duckdb_engine.datatypes import Union
from sqlalchemy import Table, Column, String
Table(
'hello',
Column('name', Union({"name": String, "age": String})
)
```
"""

__visit_name__ = "union"
fields: Dict[str, TV]

def __init__(self, fields: Dict[str, TV]):
self.fields = fields


ISCHEMA_NAMES = {
Expand All @@ -110,3 +188,56 @@ class Map(TypeEngine):
def register_extension_types() -> None:
for subclass in types:
compiles(subclass, "duckdb")(compile_uint)


@compiles(Struct, "duckdb") # type: ignore[misc]
def visit_struct(
instance: Struct,
compiler: PGTypeCompiler,
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer)


@compiles(Union, "duckdb") # type: ignore[misc]
def visit_union(
instance: Union,
compiler: PGTypeCompiler,
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
return "UNION" + struct_or_union(instance, compiler, identifier_preparer)


def struct_or_union(
instance: typing.Union[Union, Struct],
compiler: PGTypeCompiler,
identifier_preparer: PGIdentifierPreparer,
) -> str:
fields = instance.fields
if fields is None:
raise exc.CompileError(f"DuckDB {repr(instance)} type requires fields")
return "({})".format(
", ".join(
"{} {}".format(
identifier_preparer.quote_identifier(key), process_type(value, compiler)
)
for key, value in fields.items()
)
)


def process_type(
value: typing.Union[TypeEngine, Type[TypeEngine]],
compiler: PGTypeCompiler,
) -> str:
return compiler.process(type_api.to_instance(value))


@compiles(Map, "duckdb") # type: ignore[misc]
def visit_map(instance: Map, compiler: PGTypeCompiler, **kw: Any) -> str:
return "MAP({}, {})".format(
process_type(instance.key_type, compiler),
process_type(instance.value_type, compiler),
)
6 changes: 6 additions & 0 deletions duckdb_engine/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import wraps
from typing import Any, Callable, TypeVar

Expand All @@ -8,6 +9,11 @@
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import ParamSpec

warnings.filterwarnings(
"ignore",
"distutils Version classes are deprecated. Use packaging.version instead.",
DeprecationWarning,
)
P = ParamSpec("P")

FuncT = TypeVar("FuncT", bound=Callable[..., Any])
Expand Down
30 changes: 28 additions & 2 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import duckdb
from pytest import importorskip, mark
from sqlalchemy import Column, Integer, MetaData, Table, inspect, text
from sqlalchemy import Column, Integer, MetaData, String, Table, inspect, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.engine import Engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql import sqltypes
from sqlalchemy.types import JSON

from ..datatypes import types
from ..datatypes import Map, Struct, types


@mark.parametrize("coltype", types)
Expand Down Expand Up @@ -114,3 +114,29 @@ def test_all_types_reflection(engine: Engine) -> None:
else:
assert col.type != sqltypes.NULLTYPE, name
assert not capture


def test_nested_types(engine: Engine, session: Session) -> None:
importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0
base = declarative_base()

class Entry(base):
__tablename__ = "test_struct"

id = Column(Integer, primary_key=True, default=0)
struct = Column(Struct(fields={"name": String}))
map = Column(Map(String, Integer))
# union = Column(Union(fields={"name": String, "age": Integer}))

base.metadata.create_all(bind=engine)

struct_data = {"name": "Edgar"}
map_data = {"one": 1, "two": 2}

session.add(Entry(struct=struct_data, map=map_data)) # type: ignore[call-arg]
session.commit()

result = session.query(Entry).one()

assert result.struct == struct_data
assert result.map == map_data

0 comments on commit 58200b5

Please sign in to comment.