Skip to content

Commit

Permalink
Merge pull request #377 from Mause/feature/preload-extensions
Browse files Browse the repository at this point in the history
feat: allow preloading of extensions
  • Loading branch information
Mause authored Aug 21, 2022
2 parents 43dfe91 + ba3fbe8 commit 0e71b68
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 5 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@

Basic SQLAlchemy driver for [DuckDB](https://duckdb.org/)

<!--ts-->
* [duckdb_engine](#duckdb_engine)
* [Installation](#installation)
* [Usage](#usage)
* [Configuration](#configuration)
* [How to register a pandas DataFrame](#how-to-register-a-pandas-dataframe)
* [Things to keep in mind](#things-to-keep-in-mind)
* [Auto-incrementing ID columns](#auto-incrementing-id-columns)
* [Pandas read_sql() chunksize](#pandas-read_sql-chunksize)
* [Unsigned integer support](#unsigned-integer-support)
* [Preloading extensions (experimental)](#preloading-extensions-experimental)
* [The name](#the-name)
<!--te-->

## Installation
```sh
$ pip install duckdb-engine
Expand Down Expand Up @@ -115,6 +129,24 @@ The `pandas.read_sql()` method can read tables from `duckdb_engine` into DataFra

Unsigned integers are supported by DuckDB, and are available in [`duckdb_engine.datatypes`](duckdb_engine/datatypes.py).

## Preloading extensions (experimental)

Until the DuckDB python client allows you to natively preload extensions, I've added experimental support via a `connect_args` parameter

```python
from sqlalchemy import create_engine

create_engine(
'duckdb:///:memory:',
connect_args={
'preload_extensions': ['https'],
'config': {
's3_region': 'ap-southeast-1'
}
}
)
```

## The name

Yes, I'm aware this package should be named `duckdb-driver` or something, I wasn't thinking when I named it and it's too hard to change the name now
20 changes: 17 additions & 3 deletions duckdb_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.ext.compiler import compiles

from . import datatypes
from .config import apply_config, get_core_config

__version__ = "0.5.0"

Expand Down Expand Up @@ -131,7 +132,6 @@ class Dialect(PGDialect_psycopg2):
supports_statement_cache = False
supports_comments = False
supports_sane_rowcount = False
supports_comments = False
inspector = DuckDBInspector
# colspecs TODO: remap types to duckdb types
colspecs = util.update_copy(
Expand All @@ -149,7 +149,21 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def connect(self, *cargs: Any, **cparams: Any) -> "Connection":
return ConnectionWrapper(duckdb.connect(*cargs, **cparams))

core_keys = get_core_config()
preload_extensions = cparams.pop("preload_extensions", [])
config = cparams.get("config", {})

ext = {k: config.pop(k) for k in list(config) if k not in core_keys}

conn = duckdb.connect(*cargs, **cparams)

for extension in preload_extensions:
conn.execute(f"LOAD {extension}")

apply_config(self, conn, ext)

return ConnectionWrapper(conn)

def on_connect(self) -> None:
pass
Expand Down Expand Up @@ -189,7 +203,7 @@ def get_view_names(
connection: Any,
schema: Optional[Any] = ...,
include: Any = ...,
**kw: Any
**kw: Any,
) -> Any:
s = "SELECT name FROM sqlite_master WHERE type='view' ORDER BY name"
rs = connection.exec_driver_sql(s)
Expand Down
24 changes: 24 additions & 0 deletions duckdb_engine/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from functools import lru_cache
from typing import Dict, Set

import duckdb
from sqlalchemy import String
from sqlalchemy.engine import Dialect


@lru_cache()
def get_core_config() -> Set[str]:
rows = (
duckdb.connect(":memory:")
.execute("SELECT name FROM duckdb_settings()")
.fetchall()
)
return {name for name, in rows}


def apply_config(
dialect: Dialect, conn: duckdb.DuckDBPyConnection, ext: Dict[str, str]
) -> None:
process = String().literal_processor(dialect=dialect)
for k, v in ext.items():
conn.execute(f"SET {k} = {process(v)}")
19 changes: 19 additions & 0 deletions duckdb_engine/tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import zlib
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -151,6 +152,24 @@ def test_get_views(engine: Engine) -> None:
assert views == ["test"]


@mark.skipif(os.uname().machine == "aarch64", reason="not supported on aarch64")
def test_preload_extension() -> None:
duckdb.default_connection.execute("INSTALL httpfs")
engine = create_engine(
"duckdb:///",
connect_args={
"preload_extensions": ["httpfs"],
"config": {"s3_region": "ap-southeast-2"},
},
)

# check that we get an error indicating that the extension was loaded
with engine.connect() as conn, raises(Exception, match="HTTP HEAD error"):
conn.execute(
"SELECT * FROM read_parquet('https://domain/path/to/file.parquet');"
)


@fixture
def inspector(engine: Engine, session: Session) -> Inspector:
session.execute(text("create table test (id int);"))
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ whitelist_externals = poetry
commands =
poetry install -v
poetry run pip install duckdb --pre -U
poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml
poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml --verbose -rs

[testenv]
whitelist_externals = poetry
commands =
poetry install -v
poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml
poetry run pytest --junitxml=results.xml --cov --cov-report xml:coverage.xml --verbose -rs

0 comments on commit 0e71b68

Please sign in to comment.