Skip to content

Commit

Permalink
Add SQL runner utility primitives to io.sql namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Nov 7, 2023
1 parent da0f0f5 commit d604d81
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

## Unreleased

- Add SQL runner utility primitives to `io.sql` namespace


## 2023/11/06 v0.0.2
Expand Down
Empty file added cratedb_toolkit/io/__init__.py
Empty file.
1 change: 1 addition & 0 deletions cratedb_toolkit/io/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from cratedb_toolkit.util.database import DatabaseAdapter, run_sql # noqa: F401
41 changes: 34 additions & 7 deletions cratedb_toolkit/util/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Copyright (c) 2023, Crate.io Inc.
# Distributed under the terms of the AGPLv3 license, see LICENSE.
import io
import typing as t
from pathlib import Path

import sqlalchemy as sa
import sqlparse
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.sql.elements import AsBoolean

Expand All @@ -22,12 +27,23 @@ def __init__(self, dburi: str):
self.engine = sa.create_engine(self.dburi, echo=False)
self.connection = self.engine.connect()

def run_sql(self, sql: str, records: bool = False, ignore: str = None):
def run_sql(self, sql: t.Union[str, Path, io.IOBase], records: bool = False, ignore: str = None):
"""
Run SQL statement, and return results, optionally ignoring exceptions.
"""

sql_effective: str
if isinstance(sql, str):
sql_effective = sql
elif isinstance(sql, Path):
sql_effective = sql.read_text()
elif isinstance(sql, io.IOBase):
sql_effective = sql.read()
else:
raise KeyError("SQL statement type must be either string, Path, or IO handle")

try:
return self.run_sql_real(sql=sql, records=records)
return self.run_sql_real(sql=sql_effective, records=records)
except Exception as ex:
if not ignore:
raise
Expand All @@ -38,12 +54,23 @@ def run_sql_real(self, sql: str, records: bool = False):
"""
Invoke SQL statement, and return results.
"""
result = self.connection.execute(sa.text(sql))
if records:
rows = result.mappings().fetchall()
return [dict(row.items()) for row in rows]
results = []
with self.engine.connect() as connection:
for statement in sqlparse.split(sql):
result = connection.execute(sa.text(statement))
data: t.Any
if records:
rows = result.mappings().fetchall()
data = [dict(row.items()) for row in rows]
else:
data = result.fetchall()
results.append(data)

# Backward-compatibility.
if len(results) == 1:
return results[0]
else:
return result.fetchall()
return results

def count_records(self, tablename_full: str):
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ dependencies = [
"crash",
"crate[sqlalchemy]>=0.34",
"sqlalchemy>=2",
"sqlparse<0.5",
]
[project.optional-dependencies]
develop = [
Expand Down
Empty file added tests/io/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions tests/io/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import io

import pytest
import sqlalchemy as sa

from cratedb_toolkit.io.sql import run_sql


def sqlcmd(sql):
return run_sql(dburi="crate://crate@localhost:4200/", sql=sql, records=True)


def test_run_sql_from_string():
sql_string = "SELECT 1;"
outcome = sqlcmd(sql_string)
assert outcome == [{"1": 1}]


def test_run_sql_from_file(tmp_path):
sql_file = tmp_path / "temp.sql"
sql_file.write_text("SELECT 1;")
outcome = sqlcmd(sql_file)
assert outcome == [{"1": 1}]


def test_run_sql_from_buffer():
sql_buffer = io.StringIO("SELECT 1;")
outcome = sqlcmd(sql_buffer)
assert outcome == [{"1": 1}]


def test_run_sql_invalid_host(capsys):
with pytest.raises(sa.exc.OperationalError) as ex:
run_sql(dburi="crate://localhost:12345", sql="SELECT 1;")
assert ex.match(
".*ConnectionError.*No more Servers available.*HTTPConnectionPool.*"
"Failed to establish a new connection.*Connection refused.*"
)
2 changes: 1 addition & 1 deletion tests/retention/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_setup_verbose(caplog, cratedb, settings):
assert result.exit_code == 0

assert cratedb.database.table_exists(settings.policy_table.fullname) is True
assert 3 <= len(caplog.records) <= 7
assert 3 <= len(caplog.records) <= 10


def test_setup_dryrun(caplog, cratedb, settings):
Expand Down

0 comments on commit d604d81

Please sign in to comment.