diff --git a/duckdb_engine/datatypes.py b/duckdb_engine/datatypes.py index d993e997..363f6e87 100644 --- a/duckdb_engine/datatypes.py +++ b/duckdb_engine/datatypes.py @@ -233,7 +233,7 @@ def visit_struct( identifier_preparer: PGIdentifierPreparer, **kw: Any, ) -> str: - return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer) + return "STRUCT" + struct_or_union(instance, compiler, identifier_preparer, **kw) @compiles(Union, "duckdb") # type: ignore[misc] @@ -243,13 +243,14 @@ def visit_union( identifier_preparer: PGIdentifierPreparer, **kw: Any, ) -> str: - return "UNION" + struct_or_union(instance, compiler, identifier_preparer) + return "UNION" + struct_or_union(instance, compiler, identifier_preparer, **kw) def struct_or_union( instance: typing.Union[Union, Struct], compiler: PGTypeCompiler, identifier_preparer: PGIdentifierPreparer, + **kw: Any, ) -> str: fields = instance.fields if fields is None: @@ -257,7 +258,10 @@ def struct_or_union( return "({})".format( ", ".join( "{} {}".format( - identifier_preparer.quote_identifier(key), process_type(value, compiler) + identifier_preparer.quote_identifier(key), + process_type( + value, compiler, identifier_preparer=identifier_preparer, **kw + ), ) for key, value in fields.items() ) @@ -267,13 +271,14 @@ def struct_or_union( def process_type( value: typing.Union[TypeEngine, Type[TypeEngine]], compiler: PGTypeCompiler, + **kw: Any, ) -> str: - return compiler.process(type_api.to_instance(value)) + return compiler.process(type_api.to_instance(value), **kw) @compiles(Map, "duckdb") # type: ignore[misc] def visit_map(instance: Map, compiler: PGTypeCompiler, **kw: Any) -> str: return "MAP({}, {})".format( - process_type(instance.key_type, compiler), - process_type(instance.value_type, compiler), + process_type(instance.key_type, compiler, **kw), + process_type(instance.value_type, compiler, **kw), ) diff --git a/duckdb_engine/tests/test_datatypes.py b/duckdb_engine/tests/test_datatypes.py index 2e3b3cd8..6f28adf6 100644 --- a/duckdb_engine/tests/test_datatypes.py +++ b/duckdb_engine/tests/test_datatypes.py @@ -211,6 +211,29 @@ class Entry(base): assert result.map == map_data +def test_double_nested_types(engine: Engine, session: Session) -> None: + """Test for https://github.com/Mause/duckdb_engine/issues/1138""" + importorskip("duckdb", "0.5.0") # nested types require at least duckdb 0.5.0 + base = declarative_base() + + class Entry(base): + __tablename__ = "test_struct" + + id = Column(Integer, primary_key=True, default=0) + outer = Column(Struct({"inner": Struct({"val": Integer})})) + + base.metadata.create_all(bind=engine) + + outer = {"inner": {"val": 42}} + + session.add(Entry(outer=outer)) # type: ignore[call-arg] + session.commit() + + result = session.query(Entry).one() + + assert result.outer == outer + + def test_interval(engine: Engine, snapshot: SnapshotTest) -> None: test_table = Table("test_table", MetaData(), Column("duration", Interval))