Skip to content

Commit

Permalink
Merge pull request #1068 from Mause/Mause-patch-3
Browse files Browse the repository at this point in the history
fix: update Map type for 1.0+
  • Loading branch information
Mause authored Oct 21, 2024
2 parents 5a031c8 + 0fa58a1 commit bc79a18
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
22 changes: 18 additions & 4 deletions duckdb_engine/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import typing
from typing import Any, Callable, Dict, Optional, Type

import duckdb
from packaging.version import Version
from sqlalchemy import exc
from sqlalchemy.dialects.postgresql.base import PGIdentifierPreparer, PGTypeCompiler
from sqlalchemy.engine import Dialect
Expand All @@ -23,6 +25,10 @@
# BIGINT INT8, LONG -9223372036854775808 9223372036854775807
(BigInteger, SmallInteger) # pure reexport

duckdb_version = duckdb.__version__

IS_GT_1 = Version(duckdb_version) > Version("1.0.0")


class UInt64(Integer):
pass
Expand Down Expand Up @@ -82,8 +88,10 @@ class UInteger(Integer):
pass


class VarInt(Integer):
pass
if IS_GT_1:

class VarInt(Integer):
pass


def compile_uint(element: Integer, compiler: PGTypeCompiler, **kw: Any) -> str:
Expand Down Expand Up @@ -157,7 +165,12 @@ def bind_processor(
def result_processor(
self, dialect: Dialect, coltype: str
) -> Optional[Callable[[Optional[dict]], Optional[dict]]]:
return lambda value: dict(zip(value["key"], value["value"])) if value else {}
if IS_GT_1:
return lambda value: value
else:
return (
lambda value: dict(zip(value["key"], value["value"])) if value else {}
)


class Union(TypeEngine):
Expand Down Expand Up @@ -203,8 +216,9 @@ def __init__(self, fields: Dict[str, TV]):
"enum": sqltypes.Enum,
"bool": sqltypes.BOOLEAN,
"varchar": String,
"varint": VarInt,
}
if IS_GT_1:
ISCHEMA_NAMES["varint"] = VarInt


def register_extension_types() -> None:
Expand Down
4 changes: 3 additions & 1 deletion duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,10 @@ def test_close(engine: Engine) -> None:


def test_with_cache(tmp_path: Path) -> None:
importorskip("duckdb", "1.0.0")
tmp_db_path = str(tmp_path / "db_cached")
engine1 = create_engine(f"duckdb:///{tmp_db_path}?threads=1")
engine2 = create_engine(f"duckdb:///{tmp_db_path}?threads=2")
engine2 = create_engine(f"duckdb:///{tmp_db_path}?threads=1")
with engine1.connect() as conn1:
with engine2.connect() as conn2:
res1 = conn1.execute(
Expand All @@ -624,6 +625,7 @@ def test_with_cache(tmp_path: Path) -> None:
text("select value from duckdb_settings() where name = 'threads'")
).fetchall()
assert res1 == res2
# TODO: how do we validate that both connections point to the same database instance?
assert res1[0][0] == "1"


Expand Down

0 comments on commit bc79a18

Please sign in to comment.