Skip to content

Commit

Permalink
feat: support USE DATABASE query (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiurin authored Dec 14, 2023
1 parent db9e51d commit b9d5c22
Show file tree
Hide file tree
Showing 17 changed files with 409 additions and 22 deletions.
28 changes: 19 additions & 9 deletions src/firebolt/async_db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def __init__(
super().__init__(*args, **kwargs)
self._client = client
self.connection = connection
if connection.database:
self.database = connection.database

@property
def database(self) -> Optional[str]:
return self.parameters.get("database")

@database.setter
def database(self, database: str) -> None:
self.parameters["database"] = database

@abstractmethod
async def _api_request(
Expand All @@ -100,12 +110,8 @@ async def _raise_if_error(self, resp: Response) -> None:
f"Error executing query:\n{resp.read().decode('utf-8')}"
)
if resp.status_code == codes.FORBIDDEN:
if self.connection.database and not await self.is_db_available(
self.connection.database
):
raise FireboltDatabaseError(
f"Database {self.connection.database} does not exist"
)
if self.database and not await self.is_db_available(self.database):
raise FireboltDatabaseError(f"Database {self.database} does not exist")
raise ProgrammingError(resp.read().decode("utf-8"))
if (
resp.status_code == codes.SERVICE_UNAVAILABLE
Expand Down Expand Up @@ -200,6 +206,8 @@ async def _do_execute(
query, {"output_format": JSON_OUTPUT_FORMAT}
)
await self._raise_if_error(resp)
# get parameters from response
self._parse_response_headers(resp.headers)
row_set = self._row_set_from_response(resp)

self._append_row_set(row_set)
Expand Down Expand Up @@ -439,8 +447,8 @@ async def _api_request(
parameters = parameters or {}
if use_set_parameters:
parameters = {**(self._set_parameters or {}), **parameters}
if self.connection.database:
parameters["database"] = self.connection.database
if self.parameters:
parameters = {**self.parameters, **parameters}
if self.connection._is_system:
assert isinstance(self._client, AsyncClientV2)
parameters["account_id"] = await self._client.account_id
Expand Down Expand Up @@ -543,13 +551,15 @@ async def _api_request(
set parameters are sent. Setting this to False will allow
self._set_parameters to be ignored.
"""
parameters = parameters or {}
if use_set_parameters:
parameters = {**(self._set_parameters or {}), **(parameters or {})}
if self.parameters:
parameters = {**self.parameters, **parameters}
return await self._client.request(
url=f"/{path}",
method="POST",
params={
"database": self.connection.database,
**(parameters or dict()),
},
content=query,
Expand Down
11 changes: 10 additions & 1 deletion src/firebolt/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from firebolt.client.auth import Auth
from firebolt.client.auth.base import AuthRequest
from firebolt.client.constants import DEFAULT_API_URL
from firebolt.client.constants import (
DEFAULT_API_URL,
PROTOCOL_VERSION,
PROTOCOL_VERSION_HEADER_NAME,
)
from firebolt.utils.exception import (
AccountNotFoundError,
FireboltEngineError,
Expand Down Expand Up @@ -51,6 +55,11 @@ def __init__(
self._api_endpoint = URL(fix_url_schema(api_endpoint))
self._auth_endpoint = get_auth_endpoint(self._api_endpoint)
super().__init__(*args, auth=auth, **kwargs)
self._set_default_header(PROTOCOL_VERSION_HEADER_NAME, PROTOCOL_VERSION)

def _set_default_header(self, key: str, value: str) -> None:
if key not in self.headers:
self.headers[key] = value

def _build_auth(self, auth: Optional[AuthTypes]) -> Auth:
"""Create Auth object based on auth provided.
Expand Down
2 changes: 2 additions & 0 deletions src/firebolt/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from httpx import CookieConflict, HTTPError, InvalidURL, StreamError

DEFAULT_API_URL: str = "api.app.firebolt.io"
PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version"
PROTOCOL_VERSION: str = "2.0"
_REQUEST_ERRORS: Tuple[Type, ...] = (
HTTPError,
InvalidURL,
Expand Down
28 changes: 27 additions & 1 deletion src/firebolt/common/base_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from httpx import Response
from httpx import Headers, Response

from firebolt.common._types import (
ColType,
Expand Down Expand Up @@ -52,6 +52,10 @@ class QueryStatus(Enum):
EXECUTION_ERROR = 8


# known parameters that can be set on the server side
SERVER_SIDE_PARAMETERS = ["database"]


@dataclass
class Statistics:
"""
Expand Down Expand Up @@ -109,6 +113,7 @@ def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any:
class BaseCursor:
__slots__ = (
"connection",
"parameters",
"_arraysize",
"_client",
"_state",
Expand Down Expand Up @@ -140,6 +145,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
]
] = []
self._set_parameters: Dict[str, Any] = dict()
self.parameters: Dict[str, str] = dict()
self._rowcount = -1
self._idx = 0
self._next_set_idx = 0
Expand Down Expand Up @@ -243,6 +249,26 @@ def _reset(self) -> None:
self._next_set_idx = 0
self._query_id = ""

def _parse_response_headers(self, headers: Headers) -> None:
"""Parse response and update relevant cursor fields."""
update_parameters = headers.get("Firebolt-Update-Parameters")
# parse update parameters dict and set keys as attributes
if update_parameters:
# parse key1=value1,key2=value2 comma separated string into dict
param_dict = dict(item.split("=") for item in update_parameters.split(","))
# strip whitespace from keys and values
param_dict = {
key.strip(): value.strip() for key, value in param_dict.items()
}
for key, value in param_dict.items():
if key in SERVER_SIDE_PARAMETERS:
self.parameters[key] = value
else:
logger.debug(
f"Unknown parameter {key} returned by the server. "
"It will be ignored."
)

def _row_set_from_response(
self, response: Response
) -> Tuple[
Expand Down
26 changes: 19 additions & 7 deletions src/firebolt/db/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def __init__(
super().__init__(*args, **kwargs)
self._client = client
self.connection = connection
if connection.database:
self.database = connection.database

@property
def database(self) -> Optional[str]:
return self.parameters.get("database")

@database.setter
def database(self, database: str) -> None:
self.parameters["database"] = database

def _raise_if_error(self, resp: Response) -> None:
"""Raise a proper error if any"""
Expand All @@ -84,11 +94,9 @@ def _raise_if_error(self, resp: Response) -> None:
f"Error executing query:\n{resp.read().decode('utf-8')}"
)
if resp.status_code == codes.FORBIDDEN:
if self.connection.database and not self.is_db_available(
self.connection.database
):
if self.database and not self.is_db_available(self.database):
raise FireboltDatabaseError(
f"Database {self.connection.database} does not exist"
f"Database {self.parameters['database']} does not exist"
)
raise ProgrammingError(resp.read().decode("utf-8"))
if (
Expand Down Expand Up @@ -188,6 +196,8 @@ def _do_execute(
query, {"output_format": JSON_OUTPUT_FORMAT}
)
self._raise_if_error(resp)
# get parameters from response
self._parse_response_headers(resp.headers)
row_set = self._row_set_from_response(resp)

self._append_row_set(row_set)
Expand Down Expand Up @@ -379,8 +389,8 @@ def _api_request(
parameters = parameters or {}
if use_set_parameters:
parameters = {**(self._set_parameters or {}), **parameters}
if self.connection.database:
parameters["database"] = self.connection.database
if self.parameters:
parameters = {**self.parameters, **parameters}
if self.connection._is_system:
assert isinstance(self._client, ClientV2) # Type check
parameters["account_id"] = self._client.account_id
Expand Down Expand Up @@ -480,13 +490,15 @@ def _api_request(
set parameters are sent. Setting this to False will allow
self._set_parameters to be ignored.
"""
parameters = parameters or {}
if use_set_parameters:
parameters = {**(self._set_parameters or {}), **(parameters or {})}
if self.parameters:
parameters = {**self.parameters, **parameters}
return self._client.request(
url=f"/{path}",
method="POST",
params={
"database": self.connection.database,
**(parameters or dict()),
},
content=query,
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def database_name() -> str:
return must_env(DATABASE_NAME_ENV)


@fixture(scope="session")
def use_db_name(database_name: str):
return f"{database_name}_use_db_test"


@fixture(scope="session")
def account_name() -> str:
return must_env(ACCOUNT_NAME_ENV)
Expand Down
40 changes: 39 additions & 1 deletion tests/integration/dbapi/async/V1/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from decimal import Decimal
from typing import Any, List

from pytest import mark, raises
from pytest import fixture, mark, raises

from firebolt.async_db import Binary, Connection, Cursor, OperationalError
from firebolt.async_db.cursor import QueryStatus
Expand Down Expand Up @@ -486,3 +486,41 @@ async def test_bytea_roundtrip(
assert (
bytes_data.decode("utf-8") == data
), "Invalid bytea data returned after roundtrip"


@fixture
async def setup_db(connection_no_engine: Connection, use_db_name: str):
use_db_name = f"{use_db_name}_async"
with connection_no_engine.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {use_db_name}")
yield
await cursor.execute(f"DROP DATABASE {use_db_name}")


@mark.xfail(reason="USE DATABASE is not yet available in 1.0 Firebolt")
async def test_use_database(
setup_db,
connection_no_engine: Connection,
use_db_name: str,
database_name: str,
) -> None:
test_db_name = f"{use_db_name}_async"
test_table_name = "verify_use_db_async"
"""Use database works as expected."""
with connection_no_engine.cursor() as c:
await c.execute(f"USE DATABASE {test_db_name}")
assert c.database == test_db_name
await c.execute(f"CREATE TABLE {test_table_name} (id int)")
await c.execute(
"SELECT table_name FROM information_schema.tables "
f"WHERE table_name = '{test_table_name}'"
)
assert (await c.fetchone())[0] == test_table_name, "Table was not created"
# Change DB and verify table is not there
await c.execute(f"USE DATABASE {database_name}")
assert c.database == database_name
await c.execute(
"SELECT table_name FROM information_schema.tables "
f"WHERE table_name = '{test_table_name}'"
)
assert (await c.fetchone()) is None, "Database was not changed"
42 changes: 41 additions & 1 deletion tests/integration/dbapi/async/V2/test_queries_async.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import date, datetime
from decimal import Decimal
from os import environ
from typing import List

from pytest import mark, raises
from pytest import fixture, mark, raises

from firebolt.async_db import Binary, Connection, Cursor, OperationalError
from firebolt.async_db.cursor import QueryStatus
from firebolt.common._types import ColType, Column
from tests.integration.conftest import API_ENDPOINT_ENV
from tests.integration.dbapi.utils import assert_deep_eq

VALS_TO_INSERT_2 = ",".join(
Expand Down Expand Up @@ -411,3 +413,41 @@ async def test_bytea_roundtrip(
assert (
bytes_data.decode("utf-8") == data
), "Invalid bytea data returned after roundtrip"


@fixture
async def setup_db(connection_system_engine_no_db: Connection, use_db_name: str):
use_db_name = use_db_name + "_async"
with connection_system_engine_no_db.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {use_db_name}")
yield
await cursor.execute(f"DROP DATABASE {use_db_name}")


@mark.xfail("dev" not in environ[API_ENDPOINT_ENV], reason="Only works on dev")
async def test_use_database(
setup_db,
connection_system_engine_no_db: Connection,
use_db_name: str,
database_name: str,
) -> None:
test_db_name = use_db_name + "_async"
test_table_name = "verify_use_db_async"
"""Use database works as expected."""
with connection_system_engine_no_db.cursor() as c:
await c.execute(f"USE DATABASE {test_db_name}")
assert c.database == test_db_name
await c.execute(f"CREATE TABLE {test_table_name} (id int)")
await c.execute(
"SELECT table_name FROM information_schema.tables "
f"WHERE table_name = '{test_table_name}'"
)
assert (await c.fetchone())[0] == test_table_name, "Table was not created"
# Change DB and verify table is not there
await c.execute(f"USE DATABASE {database_name}")
assert c.database == database_name
await c.execute(
"SELECT table_name FROM information_schema.tables "
f"WHERE table_name = '{test_table_name}'"
)
assert (await c.fetchone()) is None, "Database was not changed"
Loading

0 comments on commit b9d5c22

Please sign in to comment.