Skip to content

Commit

Permalink
Merge pull request #1139 from NickCrews/forward-compile-kwargs
Browse files Browse the repository at this point in the history
fix: propagate kwargs through compilation nodes
  • Loading branch information
Mause authored Oct 22, 2024
2 parents 80d4c0b + 95d8cef commit c5aa43f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
17 changes: 11 additions & 6 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def visit_struct(
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer)
return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer, **kw)


@compiles(Union, "duckdb") # type: ignore[misc]
Expand All @@ -243,21 +243,25 @@ def visit_union(
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
return "UNION" + struct_or_union(instance, compiler, identifier_preparer)
return "UNION" + struct_or_union(instance, compiler, identifier_preparer, **kw)


def struct_or_union(
instance: typing.Union[Union, Struct],
compiler: PGTypeCompiler,
identifier_preparer: PGIdentifierPreparer,
**kw: Any,
) -> str:
fields = instance.fields
if fields is None:
raise exc.CompileError(f"DuckDB {repr(instance)} type requires fields")
return "({})".format(
", ".join(
"{} {}".format(
identifier_preparer.quote_identifier(key), process_type(value, compiler)
identifier_preparer.quote_identifier(key),
process_type(
value, compiler, identifier_preparer=identifier_preparer, **kw
),
)
for key, value in fields.items()
)
Expand All @@ -267,13 +271,14 @@ def struct_or_union(
def process_type(
value: typing.Union[TypeEngine, Type[TypeEngine]],
compiler: PGTypeCompiler,
**kw: Any,
) -> str:
return compiler.process(type_api.to_instance(value))
return compiler.process(type_api.to_instance(value), **kw)


@compiles(Map, "duckdb") # type: ignore[misc]
def visit_map(instance: Map, compiler: PGTypeCompiler, **kw: Any) -> str:
return "MAP({}, {})".format(
process_type(instance.key_type, compiler),
process_type(instance.value_type, compiler),
process_type(instance.key_type, compiler, **kw),
process_type(instance.value_type, compiler, **kw),
)
23 changes: 23 additions & 0 deletions duckdb_engine/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,29 @@ class Entry(base):
assert result.map == map_data


def test_double_nested_types(engine: Engine, session: Session) -> None:
"""Test for https://github.com/Mause/duckdb_engine/issues/1138"""
importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0
base = declarative_base()

class Entry(base):
__tablename__ = "test_struct"

id = Column(Integer, primary_key=True, default=0)
outer = Column(Struct({"inner": Struct({"val": Integer})}))

base.metadata.create_all(bind=engine)

outer = {"inner": {"val": 42}}

session.add(Entry(outer=outer)) # type: ignore[call-arg]
session.commit()

result = session.query(Entry).one()

assert result.outer == outer


def test_interval(engine: Engine, snapshot: SnapshotTest) -> None:
test_table = Table("test_table", MetaData(), Column("duration", Interval))

Expand Down

0 comments on commit c5aa43f

Please sign in to comment.