diff --git a/README.md b/README.md index f847ef80..cf583043 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,20 @@ Basic SQLAlchemy driver for [DuckDB](https://duckdb.org/) + +* [duckdb_engine](#duckdb_engine) + * [Installation](#installation) + * [Usage](#usage) + * [Configuration](#configuration) + * [How to register a pandas DataFrame](#how-to-register-a-pandas-dataframe) + * [Things to keep in mind](#things-to-keep-in-mind) + * [Auto-incrementing ID columns](#auto-incrementing-id-columns) + * [Pandas read_sql() chunksize](#pandas-read_sql-chunksize) + * [Unsigned integer support](#unsigned-integer-support) + * [Preloading extensions (experimental)](#preloading-extensions-experimental) + * [The name](#the-name) + + ## Installation ```sh $ pip install duckdb-engine @@ -115,6 +129,24 @@ The `pandas.read_sql()` method can read tables from `duckdb_engine` into DataFra Unsigned integers are supported by DuckDB, and are available in [`duckdb_engine.datatypes`](duckdb_engine/datatypes.py). +## Preloading extensions (experimental) + +Until the DuckDB python client allows you to natively preload extensions, I've added experimental support via a `connect_args` parameter + +```python +from sqlalchemy import create_engine + +create_engine( + 'duckdb:///:memory:', + connect_args={ + 'preload_extensions': ['https'], + 'config': { + 's3_region': 'ap-southeast-1' + } + } +) +``` + ## The name Yes, I'm aware this package should be named `duckdb-driver` or something, I wasn't thinking when I named it and it's too hard to change the name now diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index f1bd0299..8413eb5f 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -10,6 +10,7 @@ from sqlalchemy.ext.compiler import compiles from . import datatypes +from .config import apply_config, get_core_config __version__ = "0.5.0" @@ -131,7 +132,6 @@ class Dialect(PGDialect_psycopg2): supports_statement_cache = False supports_comments = False supports_sane_rowcount = False - supports_comments = False inspector = DuckDBInspector # colspecs TODO: remap types to duckdb types colspecs = util.update_copy( @@ -149,7 +149,21 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) def connect(self, *cargs: Any, **cparams: Any) -> "Connection": - return ConnectionWrapper(duckdb.connect(*cargs, **cparams)) + + core_keys = get_core_config() + preload_extensions = cparams.pop("preload_extensions", []) + config = cparams.get("config", {}) + + ext = {k: config.pop(k) for k in list(config) if k not in core_keys} + + conn = duckdb.connect(*cargs, **cparams) + + for extension in preload_extensions: + conn.execute(f"LOAD {extension}") + + apply_config(self, conn, ext) + + return ConnectionWrapper(conn) def on_connect(self) -> None: pass @@ -189,7 +203,7 @@ def get_view_names( connection: Any, schema: Optional[Any] = ..., include: Any = ..., - **kw: Any + **kw: Any, ) -> Any: s = "SELECT name FROM sqlite_master WHERE type='view' ORDER BY name" rs = connection.exec_driver_sql(s) diff --git a/duckdb_engine/config.py b/duckdb_engine/config.py new file mode 100644 index 00000000..fe197139 --- /dev/null +++ b/duckdb_engine/config.py @@ -0,0 +1,24 @@ +from functools import lru_cache +from typing import Dict, Set + +import duckdb +from sqlalchemy import String +from sqlalchemy.engine import Dialect + + +@lru_cache() +def get_core_config() -> Set[str]: + rows = ( + duckdb.connect(":memory:") + .execute("SELECT name FROM duckdb_settings()") + .fetchall() + ) + return {name for name, in rows} + + +def apply_config( + dialect: Dialect, conn: duckdb.DuckDBPyConnection, ext: Dict[str, str] +) -> None: + process = String().literal_processor(dialect=dialect) + for k, v in ext.items(): + conn.execute(f"SET {k} = {process(v)}") diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index 5c1b4e41..3da94fe5 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -1,4 +1,5 @@ import logging +import os import zlib from datetime import timedelta from pathlib import Path @@ -151,6 +152,24 @@ def test_get_views(engine: Engine) -> None: assert views == ["test"] +@mark.skipif(os.uname().machine == "aarch64", reason="not supported on aarch64") +def test_preload_extension() -> None: + duckdb.default_connection.execute("INSTALL httpfs") + engine = create_engine( + "duckdb:///", + connect_args={ + "preload_extensions": ["httpfs"], + "config": {"s3_region": "ap-southeast-2"}, + }, + ) + + # check that we get an error indicating that the extension was loaded + with engine.connect() as conn, raises(Exception, match="HTTP HEAD error"): + conn.execute( + "SELECT * FROM read_parquet('https://domain/path/to/file.parquet');" + ) + + @fixture def inspector(engine: Engine, session: Session) -> Inspector: session.execute(text("create table test (id int);")) diff --git a/tox.ini b/tox.ini index a5ece71b..cf75560f 100644 --- a/tox.ini +++ b/tox.ini @@ -21,10 +21,10 @@ whitelist_externals = poetry commands = poetry install -v poetry run pip install duckdb --pre -U - poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml + poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml --verbose -rs [testenv] whitelist_externals = poetry commands = poetry install -v - poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml + poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml --verbose -rs