From 95d8cef236629815027b26cbb66a87069a5ccddc Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Mon, 21 Oct 2024 18:37:53 -0800 Subject: [PATCH] fix: propagate kwargs through compilation nodes Fixes https://github.com/Mause/duckdb_engine/issues/1138 When we visit a struct or union, we are handed an IdentifierProvider. Before, we just used it at this first layer, but forgot to pass it down to lower compilation steps. Now, we pass it along. This was a problem when we encountered nested dtypes, eg struct> --- duckdb_engine/datatypes.py | 17 +++++++++++------ duckdb_engine/tests/test_datatypes.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/duckdb_engine/datatypes.py b/duckdb_engine/datatypes.py index d993e997..363f6e87 100644 --- a/duckdb_engine/datatypes.py +++ b/duckdb_engine/datatypes.py @@ -233,7 +233,7 @@ def visit_struct( identifier_preparer: PGIdentifierPreparer, **kw: Any, ) -> str: - return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer) + return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer, **kw) @compiles(Union, "duckdb") # type: ignore[misc] @@ -243,13 +243,14 @@ def visit_union( identifier_preparer: PGIdentifierPreparer, **kw: Any, ) -> str: - return "UNION" + struct_or_union(instance, compiler, identifier_preparer) + return "UNION" + struct_or_union(instance, compiler, identifier_preparer, **kw) def struct_or_union( instance: typing.Union[Union, Struct], compiler: PGTypeCompiler, identifier_preparer: PGIdentifierPreparer, + **kw: Any, ) -> str: fields = instance.fields if fields is None: @@ -257,7 +258,10 @@ def struct_or_union( return "({})".format( ", ".join( "{} {}".format( - identifier_preparer.quote_identifier(key), process_type(value, compiler) + identifier_preparer.quote_identifier(key), + process_type( + value, compiler, identifier_preparer=identifier_preparer, **kw + ), ) for key, value in fields.items() ) @@ -267,13 +271,14 @@ def struct_or_union( def process_type( value: typing.Union[TypeEngine, Type[TypeEngine]], compiler: PGTypeCompiler, + **kw: Any, ) -> str: - return compiler.process(type_api.to_instance(value)) + return compiler.process(type_api.to_instance(value), **kw) @compiles(Map, "duckdb") # type: ignore[misc] def visit_map(instance: Map, compiler: PGTypeCompiler, **kw: Any) -> str: return "MAP({}, {})".format( - process_type(instance.key_type, compiler), - process_type(instance.value_type, compiler), + process_type(instance.key_type, compiler, **kw), + process_type(instance.value_type, compiler, **kw), ) diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 2e3b3cd8..6f28adf6 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -211,6 +211,29 @@ class Entry(base): assert result.map == map_data +def test_double_nested_types(engine: Engine, session: Session) -> None: + """Test for https://github.com/Mause/duckdb_engine/issues/1138""" + importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0 + base = declarative_base() + + class Entry(base): + __tablename__ = "test_struct" + + id = Column(Integer, primary_key=True, default=0) + outer = Column(Struct({"inner": Struct({"val": Integer})})) + + base.metadata.create_all(bind=engine) + + outer = {"inner": {"val": 42}} + + session.add(Entry(outer=outer)) # type: ignore[call-arg] + session.commit() + + result = session.query(Entry).one() + + assert result.outer == outer + + def test_interval(engine: Engine, snapshot: SnapshotTest) -> None: test_table = Table("test_table", MetaData(), Column("duration", Interval))