From 40948531d38de0dd96df3fee398f19cab911382e Mon Sep 17 00:00:00 2001 From: Petro Tiurin <93913847+ptiurin@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:32:00 +0000 Subject: [PATCH] fix: Readd database support for resource manager (#333) --- src/firebolt/model/V1/database.py | 181 ++++++++++++++++++ src/firebolt/model/V1/provider.py | 16 ++ src/firebolt/service/V1/base.py | 6 + src/firebolt/service/V1/binding.py | 102 +++++++++- src/firebolt/service/V1/database.py | 127 ++++++++++++ src/firebolt/service/V1/provider.py | 10 + src/firebolt/service/V1/region.py | 66 +++++++ src/firebolt/service/manager.py | 21 +- .../resource_manager/V1/conftest.py | 1 + .../resource_manager/V1/test_database.py | 60 ++++++ tests/unit/conftest.py | 12 +- tests/unit/service/V1/conftest.py | 107 ++++++++++- tests/unit/service/V1/test_bindings.py | 98 ++++++++++ tests/unit/service/V1/test_database.py | 131 +++++++++++++ tests/unit/service/V1/test_engine.py | 4 +- tests/unit/service/V1/test_region.py | 30 +++ .../unit/service/V1/test_resource_manager.py | 6 + 17 files changed, 965 insertions(+), 13 deletions(-) create mode 100644 src/firebolt/model/V1/database.py create mode 100644 src/firebolt/model/V1/provider.py create mode 100644 src/firebolt/service/V1/database.py create mode 100644 src/firebolt/service/V1/provider.py create mode 100644 src/firebolt/service/V1/region.py create mode 100644 tests/integration/resource_manager/V1/test_database.py create mode 100644 tests/unit/service/V1/test_database.py create mode 100644 tests/unit/service/V1/test_region.py diff --git a/src/firebolt/model/V1/database.py b/src/firebolt/model/V1/database.py new file mode 100644 index 00000000000..27d868088de --- /dev/null +++ b/src/firebolt/model/V1/database.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any, List, Optional, Sequence + +from pydantic import Field, PrivateAttr + +from firebolt.model.V1 import FireboltBaseModel +from firebolt.model.V1.region import RegionKey +from firebolt.service.V1.engine import EngineService +from firebolt.service.V1.types import EngineStatusSummary +from firebolt.utils.exception import AttachedEngineInUseError +from firebolt.utils.urls import ACCOUNT_DATABASE_URL + +if TYPE_CHECKING: + from firebolt.model.V1.binding import Binding + from firebolt.model.V1.engine import Engine + from firebolt.service.V1.database import DatabaseService + +logger = logging.getLogger(__name__) + + +class DatabaseKey(FireboltBaseModel): + account_id: str + database_id: str + + +class FieldMask(FireboltBaseModel): + paths: Sequence[str] = Field(alias="paths") + + +class Database(FireboltBaseModel): + """ + A Firebolt database. + + Databases belong to a region and have a description, + but otherwise are not configurable. + """ + + # internal + _service: DatabaseService = PrivateAttr() + + # required + name: str = Field(min_length=1, max_length=255, regex=r"^[0-9a-zA-Z_]+$") + compute_region_key: RegionKey = Field(alias="compute_region_id") + + # optional + database_key: Optional[DatabaseKey] = Field(None, alias="id") + description: Optional[str] = Field(None, max_length=255) + emoji: Optional[str] = Field(None, max_length=255) + current_status: Optional[str] + health_status: Optional[str] + data_size_full: Optional[int] + data_size_compressed: Optional[int] + is_system_database: Optional[bool] + storage_bucket_name: Optional[str] + create_time: Optional[datetime] + create_actor: Optional[str] + last_update_time: Optional[datetime] + last_update_actor: Optional[str] + desired_status: Optional[str] + + @classmethod + def parse_obj_with_service( + cls, obj: Any, database_service: DatabaseService + ) -> Database: + database = cls.parse_obj(obj) + database._service = database_service + return database + + @property + def database_id(self) -> Optional[str]: + if self.database_key is None: + return None + return self.database_key.database_id + + def get_attached_engines(self) -> List[Engine]: + """Get a list of engines that are attached to this database.""" + + return self._service.resource_manager.bindings.get_engines_bound_to_database( # noqa: E501 + database=self + ) + + def attach_to_engine( + self, engine: Engine, is_default_engine: bool = False + ) -> Binding: + """ + Attach an engine to this database. + + Args: + engine: The engine to attach. + is_default_engine: + Whether this engine should be used as default for this database. + Only one engine can be set as default for a single database. + This will overwrite any existing default. + """ + + return self._service.resource_manager.bindings.create( + engine=engine, database=self, is_default_engine=is_default_engine + ) + + def delete(self) -> Database: + """ + Delete a database from Firebolt. + + Raises an error if there are any attached engines. + """ + + for engine in self.get_attached_engines(): + if engine.current_status_summary in { + EngineStatusSummary.ENGINE_STATUS_SUMMARY_STARTING, + EngineStatusSummary.ENGINE_STATUS_SUMMARY_STOPPING, + }: + raise AttachedEngineInUseError(method_name="delete") + + logger.info( + f"Deleting Database (database_id={self.database_id}, name={self.name})" + ) + response = self._service.client.delete( + url=ACCOUNT_DATABASE_URL.format( + account_id=self._service.account_id, database_id=self.database_id + ), + headers={"Content-type": "application/json"}, + ) + return Database.parse_obj_with_service( + response.json()["database"], self._service + ) + + def update(self, description: str) -> Database: + """ + Updates a database description. + """ + + class _DatabaseUpdateRequest(FireboltBaseModel): + """Helper model for sending Database creation requests.""" + + account_id: str + database: Database + database_id: str + update_mask: FieldMask + + self.description = description + + logger.info( + f"Updating Database (database_id={self.database_id}, " + f"name={self.name}, description={self.description})" + ) + + payload = _DatabaseUpdateRequest( + account_id=self._service.account_id, + database=self, + database_id=self.database_id, + update_mask=FieldMask(paths=["description"]), + ).jsonable_dict(by_alias=True) + + response = self._service.client.patch( + url=ACCOUNT_DATABASE_URL.format( + account_id=self._service.account_id, database_id=self.database_id + ), + headers={"Content-type": "application/json"}, + json=payload, + ) + + return Database.parse_obj_with_service( + response.json()["database"], self._service + ) + + def get_default_engine(self) -> Optional[Engine]: + """ + Returns: default engine of the database, or None if default engine is missing + """ + rm = self._service.resource_manager + assert isinstance(rm.engines, EngineService), "Expected EngineService V1" + default_engines: List[Engine] = [ + rm.engines.get(binding.engine_id) + for binding in rm.bindings.get_many(database_id=self.database_id) + if binding.is_default_engine + ] + + return None if len(default_engines) == 0 else default_engines[0] diff --git a/src/firebolt/model/V1/provider.py b/src/firebolt/model/V1/provider.py new file mode 100644 index 00000000000..01ac13176ed --- /dev/null +++ b/src/firebolt/model/V1/provider.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from firebolt.model.V1 import FireboltBaseModel + + +class Provider(FireboltBaseModel, frozen=True): # type: ignore + provider_id: str = Field(alias="id") + name: str + + # optional + create_time: Optional[datetime] + display_name: Optional[str] + last_update_time: Optional[datetime] diff --git a/src/firebolt/service/V1/base.py b/src/firebolt/service/V1/base.py index fa21eaf1510..24b6caea66c 100644 --- a/src/firebolt/service/V1/base.py +++ b/src/firebolt/service/V1/base.py @@ -1,3 +1,5 @@ +from typing import Optional + from firebolt.client import ClientV1 as Client from firebolt.service.manager import ResourceManager @@ -13,3 +15,7 @@ def client(self) -> Client: @property def account_id(self) -> str: return self.resource_manager.account_id + + @property + def default_region_setting(self) -> Optional[str]: + return self.resource_manager.default_region diff --git a/src/firebolt/service/V1/binding.py b/src/firebolt/service/V1/binding.py index 12ff19bd962..e58bac95e4b 100644 --- a/src/firebolt/service/V1/binding.py +++ b/src/firebolt/service/V1/binding.py @@ -1,15 +1,35 @@ import logging from typing import List, Optional -from firebolt.model.V1.binding import Binding +from firebolt.model.V1.binding import Binding, BindingKey +from firebolt.model.V1.database import Database +from firebolt.model.V1.engine import Engine from firebolt.service.V1.base import BaseService -from firebolt.utils.urls import ACCOUNT_BINDINGS_URL +from firebolt.service.V1.database import DatabaseService +from firebolt.service.V1.engine import EngineService +from firebolt.utils.exception import AlreadyBoundError +from firebolt.utils.urls import ( + ACCOUNT_BINDINGS_URL, + ACCOUNT_DATABASE_BINDING_URL, +) from firebolt.utils.util import prune_dict logger = logging.getLogger(__name__) class BindingService(BaseService): + def get_by_key(self, binding_key: BindingKey) -> Binding: + """Get a binding by its BindingKey""" + response = self.client.get( + url=ACCOUNT_DATABASE_BINDING_URL.format( + account_id=binding_key.account_id, + database_id=binding_key.database_id, + engine_id=binding_key.engine_id, + ) + ) + binding: dict = response.json()["binding"] + return Binding.parse_obj(binding) + def get_many( self, database_id: Optional[str] = None, @@ -47,3 +67,81 @@ def get_many( ), ) return [Binding.parse_obj(i["node"]) for i in response.json()["edges"]] + + def get_database_bound_to_engine(self, engine: Engine) -> Optional[Database]: + """Get the database to which an engine is bound, if any.""" + try: + binding = self.get_many(engine_id=engine.engine_id)[0] + except IndexError: + return None + try: + assert isinstance( + self.resource_manager.databases, DatabaseService + ), "Expected DatabaseService V1" + return self.resource_manager.databases.get(id_=binding.database_id) + except (KeyError, IndexError): + return None + + def get_engines_bound_to_database(self, database: Database) -> List[Engine]: + """Get a list of engines that are bound to a database.""" + + bindings = self.get_many(database_id=database.database_id) + if not bindings: + return [] + assert isinstance( + self.resource_manager.engines, EngineService + ), "Expected EngineService V1" + return self.resource_manager.engines.get_by_ids( + ids=[b.engine_id for b in bindings] + ) + + def create( + self, engine: Engine, database: Database, is_default_engine: bool + ) -> Binding: + """ + Create a new binding between an engine and a database. + + Args: + engine: Engine to bind. + database: Database to bind. + is_default_engine: + Whether this engine should be used as default for this database. + Only one engine can be set as default for a single database. + This will overwrite any existing default. + + Returns: + New binding between the engine and database. + """ + + existing_database = self.get_database_bound_to_engine(engine=engine) + if existing_database is not None: + raise AlreadyBoundError( + f"The engine {engine.name} is already bound " + f"to {existing_database.name}!" + ) + + logger.info( + f"Attaching Engine (engine_id={engine.engine_id}, name={engine.name}) " + f"to Database (database_id={database.database_id}, " + f"name={database.name})" + ) + binding = Binding( + binding_key=BindingKey( + account_id=self.account_id, + database_id=database.database_id, + engine_id=engine.engine_id, + ), + is_default_engine=is_default_engine, + ) + + response = self.client.post( + url=ACCOUNT_DATABASE_BINDING_URL.format( + account_id=self.account_id, + database_id=database.database_id, + engine_id=engine.engine_id, + ), + json=binding.jsonable_dict( + by_alias=True, include={"binding_key": ..., "is_default_engine": ...} + ), + ) + return Binding.parse_obj(response.json()["binding"]) diff --git a/src/firebolt/service/V1/database.py b/src/firebolt/service/V1/database.py new file mode 100644 index 00000000000..2673f0a740d --- /dev/null +++ b/src/firebolt/service/V1/database.py @@ -0,0 +1,127 @@ +import logging +from typing import List, Optional, Union + +from firebolt.model.V1 import FireboltBaseModel +from firebolt.model.V1.database import Database +from firebolt.service.V1.base import BaseService +from firebolt.service.V1.types import DatabaseOrder +from firebolt.utils.urls import ( + ACCOUNT_DATABASE_BY_NAME_URL, + ACCOUNT_DATABASE_URL, + ACCOUNT_DATABASES_URL, +) +from firebolt.utils.util import prune_dict + +logger = logging.getLogger(__name__) + + +class DatabaseService(BaseService): + def get(self, id_: str) -> Database: + """Get a Database from Firebolt by its ID.""" + + response = self.client.get( + url=ACCOUNT_DATABASE_URL.format(account_id=self.account_id, database_id=id_) + ) + return Database.parse_obj_with_service( + obj=response.json()["database"], database_service=self + ) + + def get_by_name(self, name: str) -> Database: + """Get a database from Firebolt by its name.""" + + database_id = self.get_id_by_name(name=name) + return self.get(id_=database_id) + + def get_id_by_name(self, name: str) -> str: + """Get a database ID from Firebolt by its name.""" + + response = self.client.get( + url=ACCOUNT_DATABASE_BY_NAME_URL.format(account_id=self.account_id), + params={"database_name": name}, + ) + database_id = response.json()["database_id"]["database_id"] + return database_id + + def get_many( + self, + name_contains: Optional[str] = None, + attached_engine_name_eq: Optional[str] = None, + attached_engine_name_contains: Optional[str] = None, + order_by: Optional[Union[str, DatabaseOrder]] = None, + ) -> List[Database]: + """ + Get a list of databases on Firebolt. + + Args: + name_contains: Filter for databases with a name containing this substring + attached_engine_name_eq: Filter for databases by an exact engine name + attached_engine_name_contains: Filter for databases by engines with a + name containing this substring + order_by: Method by which to order the results. + See :py:class:`firebolt.service.types.DatabaseOrder` + + Returns: + A list of databases matching the filters + """ + + if isinstance(order_by, str): + order_by = DatabaseOrder[order_by].name + + params = { + "page.first": "1000", + "order_by": order_by, + "filter.name_contains": name_contains, + "filter.attached_engine_name_eq": attached_engine_name_eq, + "filter.attached_engine_name_contains": attached_engine_name_contains, + } + + response = self.client.get( + url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), + params=prune_dict(params), + ) + + return [ + Database.parse_obj_with_service(obj=d["node"], database_service=self) + for d in response.json()["edges"] + ] + + def create( + self, name: str, region: Optional[str] = None, description: Optional[str] = None + ) -> Database: + """ + Create a new Database on Firebolt. + + Args: + name: Name of the database + region: Region name in which to create the database + + Returns: + The newly created database + """ + + class _DatabaseCreateRequest(FireboltBaseModel): + """Helper model for sending database creation requests.""" + + account_id: str + database: Database + + if region is None: + region_key = self.resource_manager.regions.default_region.key + else: + region_key = self.resource_manager.regions.get_by_name(name=region).key + database = Database( + name=name, compute_region_key=region_key, description=description + ) + + logger.info(f"Creating Database (name={name})") + response = self.client.post( + url=ACCOUNT_DATABASES_URL.format(account_id=self.account_id), + headers={"Content-type": "application/json"}, + json=_DatabaseCreateRequest( + account_id=self.account_id, + database=database, + ).jsonable_dict(by_alias=True), + ) + return Database.parse_obj_with_service( + obj=response.json()["database"], database_service=self + ) diff --git a/src/firebolt/service/V1/provider.py b/src/firebolt/service/V1/provider.py new file mode 100644 index 00000000000..2d9ddec2692 --- /dev/null +++ b/src/firebolt/service/V1/provider.py @@ -0,0 +1,10 @@ +from firebolt.client import Client +from firebolt.model.V1.provider import Provider +from firebolt.utils.urls import PROVIDERS_URL + + +def get_provider_id(client: Client) -> str: + """Get the AWS provider_id.""" + response = client.get(url=PROVIDERS_URL) + providers = [Provider.parse_obj(i["node"]) for i in response.json()["edges"]] + return providers[0].provider_id diff --git a/src/firebolt/service/V1/region.py b/src/firebolt/service/V1/region.py new file mode 100644 index 00000000000..a6676b7943c --- /dev/null +++ b/src/firebolt/service/V1/region.py @@ -0,0 +1,66 @@ +from typing import Dict, List + +from firebolt.model.V1.region import Region, RegionKey +from firebolt.service.manager import ResourceManager +from firebolt.service.V1.base import BaseService +from firebolt.utils.urls import REGIONS_URL +from firebolt.utils.util import cached_property + + +class RegionService(BaseService): + def __init__(self, resource_manager: ResourceManager): + """ + Service to manage AWS regions (us-east-1, etc) + + Args: + resource_manager: Resource manager to use + """ + + super().__init__(resource_manager=resource_manager) + + @cached_property + def regions(self) -> List[Region]: + """List of available AWS regions on Firebolt.""" + + response = self.client.get(url=REGIONS_URL, params={"page.first": 5000}) + return [Region.parse_obj(i["node"]) for i in response.json()["edges"]] + + @cached_property + def regions_by_name(self) -> Dict[str, Region]: + """Dict of {RegionLookup to Region}""" + + return {r.name: r for r in self.regions} + + @cached_property + def regions_by_key(self) -> Dict[RegionKey, Region]: + """Dict of {RegionKey to Region}""" + + return {r.key: r for r in self.regions} + + @cached_property + def default_region(self) -> Region: + """Default AWS region, could be provided from environment.""" + + if not self.default_region_setting: + raise ValueError( + "default_region parameter must be set when initializing " + "the resource manager." + ) + return self.get_by_name(name=self.default_region_setting) + + def get_by_name(self, name: str) -> Region: + """Get an AWS region by its name (eg. us-east-1).""" + + return self.regions_by_name[name] + + def get_by_key(self, key: RegionKey) -> Region: + """Get an AWS region by its key.""" + + return self.regions_by_key[key] + + def get_by_id(self, id_: str) -> Region: + """Get an AWS region by region_id.""" + + return self.get_by_key( + RegionKey(provider_id=self.resource_manager.provider_id, region_id=id_) + ) diff --git a/src/firebolt/service/manager.py b/src/firebolt/service/manager.py index dddb31b8fa3..0b66c1d86bf 100644 --- a/src/firebolt/service/manager.py +++ b/src/firebolt/service/manager.py @@ -14,6 +14,7 @@ ) from firebolt.common import Settings from firebolt.db import connect +from firebolt.service.V1.provider import get_provider_id from firebolt.utils.util import fix_url_schema DEFAULT_TIMEOUT_SECONDS: int = 60 * 2 @@ -48,11 +49,12 @@ class ResourceManager: "_connection", "regions", "instance_types", - "_provider_id", + "provider_id", "databases", "engines", "engine_revisions", "bindings", + "default_region", "_version", ) @@ -62,6 +64,8 @@ def __init__( auth: Optional[Auth] = None, account_name: Optional[str] = None, api_endpoint: str = DEFAULT_API_URL, + # Legacy parameters + default_region: Optional[str] = None, ): if settings: logger.warning(SETTINGS_DEPRECATION_MESSAGE) @@ -73,6 +77,7 @@ def __init__( auth = settings.auth account_name = settings.account_name api_endpoint = settings.server + default_region = settings.default_region for param, name in ((auth, "auth"),): if not param: @@ -113,6 +118,8 @@ def __init__( self.account_name = account_name self.api_endpoint = api_endpoint self.account_id = self._client.account_id + self.default_region = default_region + self.provider_id: Optional[str] = None if version == 2: self._init_services_v2() elif version == 1: @@ -131,13 +138,25 @@ def _init_services_v2(self) -> None: self.databases = DatabaseService(resource_manager=self) self.engines = EngineService(resource_manager=self) + # Not applicable to V2 + self.provider_id = None + def _init_services_v1(self) -> None: # avoid circular import from firebolt.service.V1.binding import BindingService + from firebolt.service.V1.database import DatabaseService from firebolt.service.V1.engine import EngineService + from firebolt.service.V1.region import RegionService + + # Cloud Platform Resources (AWS) + self.regions = RegionService(resource_manager=self) # type: ignore + # Firebolt Resources self.bindings = BindingService(resource_manager=self) # type: ignore self.engines = EngineService(resource_manager=self) # type: ignore + self.databases = DatabaseService(resource_manager=self) # type: ignore + + self.provider_id = get_provider_id(client=self._client) def __del__(self) -> None: if hasattr(self, "_client"): diff --git a/tests/integration/resource_manager/V1/conftest.py b/tests/integration/resource_manager/V1/conftest.py index 2ca845aa94f..410f93e7f4d 100644 --- a/tests/integration/resource_manager/V1/conftest.py +++ b/tests/integration/resource_manager/V1/conftest.py @@ -11,4 +11,5 @@ def resource_manager( return ResourceManager( auth=password_auth, api_endpoint=api_endpoint, + default_region="us-east-1", ) diff --git a/tests/integration/resource_manager/V1/test_database.py b/tests/integration/resource_manager/V1/test_database.py new file mode 100644 index 00000000000..60f918df7f7 --- /dev/null +++ b/tests/integration/resource_manager/V1/test_database.py @@ -0,0 +1,60 @@ +from firebolt.service.manager import ResourceManager + + +def test_database_get_default_engine( + resource_manager: ResourceManager, + database_name: str, + stopped_engine_name: str, + engine_name: str, +): + """ + Checks that the default engine is either running or stopped engine + """ + db = resource_manager.databases.get_by_name(database_name) + + engine = db.get_default_engine() + assert engine is not None, "default engine is None, but shouldn't" + assert engine.name in [ + stopped_engine_name, + engine_name, + ], "Returned default engine name is neither of known engines" + + +def test_create_new_database(resource_manager: ResourceManager, database_name: str): + new_database_name = database_name + "_rm_test" + + db = resource_manager.databases.create( + name=new_database_name, description="test database" + ) + assert db is not None, "new database is None, but shouldn't" + assert db.name == new_database_name, "new database name doesn't match" + + db.delete() + + +def test_get_by_id(resource_manager: ResourceManager, database_name: str): + db = resource_manager.databases.get_by_name(database_name) + db_id = db.database_id + assert db_id is not None, "database id is None, but shouldn't" + + test_id = resource_manager.databases.get_id_by_name(database_name) + assert test_id is not None, "database id is None, but shouldn't" + assert test_id == db_id, "database id doesn't match" + + db_by_id = resource_manager.databases.get(db_id) + assert db_by_id is not None, "database by id is None, but shouldn't" + assert db_by_id.name == database_name, "database by id name doesn't match" + + +def test_update_description(resource_manager: ResourceManager, database_name: str): + db = resource_manager.databases.get_by_name(database_name) + + new_description = "new test description" + db.update(description=new_description) + assert db.description == new_description, "new description doesn't match" + + +def test_get_many(resource_manager: ResourceManager, database_name: str): + dbs = resource_manager.databases.get_many() + assert len(dbs) > 0, "no databases returned, but shouldn't" + assert any(db.name == database_name for db in dbs), "database not found" diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a4f6d9c3045..165c5b111bb 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -9,7 +9,6 @@ from firebolt.client.auth import Auth, ClientCredentials, UsernamePassword from firebolt.client.client import ClientV2 from firebolt.common.settings import Settings -from firebolt.model.V1.region import Region from firebolt.utils.exception import ( AccountNotFoundError, DatabaseError, @@ -314,17 +313,18 @@ def check_credentials( @fixture -def region_1() -> Region: - return "us-east-1" +def region_string() -> str: + return "mock_region_1" @fixture def settings( - server: str, region_1: str, username_password_auth: Auth, account_name: str + server: str, region_string: str, username_password_auth: Auth, account_name: str ) -> Settings: - return Settings( + seett = Settings( server=server, auth=username_password_auth, - default_region=region_1, + default_region=region_string, account_name=account_name, ) + return seett diff --git a/tests/unit/service/V1/conftest.py b/tests/unit/service/V1/conftest.py index e17e0b39716..1b67ab45309 100644 --- a/tests/unit/service/V1/conftest.py +++ b/tests/unit/service/V1/conftest.py @@ -1,6 +1,6 @@ import json from re import Pattern, compile -from typing import Callable +from typing import Callable, List from urllib.parse import urlparse import httpx @@ -9,12 +9,15 @@ from firebolt.client.auth.base import Auth from firebolt.model.V1.binding import Binding, BindingKey +from firebolt.model.V1.database import Database, DatabaseKey from firebolt.model.V1.engine import Engine, EngineKey, EngineSettings +from firebolt.model.V1.provider import Provider from firebolt.model.V1.region import Region, RegionKey from firebolt.utils.exception import AccountNotFoundError from firebolt.utils.urls import ( ACCOUNT_BINDINGS_URL, ACCOUNT_BY_NAME_URL, + ACCOUNT_DATABASE_BINDING_URL, ACCOUNT_DATABASE_BY_NAME_URL, ACCOUNT_DATABASE_URL, ACCOUNT_DATABASES_URL, @@ -23,6 +26,7 @@ ACCOUNT_LIST_ENGINES_URL, ACCOUNT_URL, AUTH_URL, + ENGINES_BY_IDS_URL, PROVIDERS_URL, REGIONS_URL, ) @@ -67,6 +71,19 @@ def mock_engine(engine_name, region_key, engine_settings, account_id, server) -> ) +@fixture +def provider() -> Provider: + return Provider( + provider_id="mock_provider_id", + name="mock_provider_name", + ) + + +@fixture +def mock_providers(provider) -> List[Provider]: + return [provider] + + @fixture def provider_callback(provider_url: str, mock_providers) -> Callable: def do_mock( @@ -87,6 +104,33 @@ def provider_url(server: str) -> str: return f"https://{server}{PROVIDERS_URL}" +@fixture +def region_1(provider) -> Region: + return Region( + key=RegionKey( + provider_id=provider.provider_id, + region_id="mock_region_id_1", + ), + name="mock_region_1", + ) + + +@fixture +def region_2(provider) -> Region: + return Region( + key=RegionKey( + provider_id=provider.provider_id, + region_id="mock_region_id_2", + ), + name="mock_region_2", + ) + + +@fixture +def mock_regions(region_1, region_2) -> List[Region]: + return [region_1, region_2] + + @fixture def region_callback(region_url: str, mock_regions) -> Callable: def do_mock( @@ -223,6 +267,21 @@ def account_engine_url(server: str, account_id, mock_engine) -> str: ) +@fixture +def db_description() -> str: + return "database description" + + +@fixture +def mock_database(region_1: Region, account_id: str, database_id: str) -> Database: + return Database( + name="mock_db_name", + description="mock_db_description", + compute_region_key=region_1.key, + database_key=DatabaseKey(account_id=account_id, database_id=database_id), + ) + + @fixture def create_databases_callback(databases_url: str, mock_database) -> Callable: def do_mock( @@ -385,6 +444,26 @@ def bindings_url(server: str, account_id: str, mock_engine: Engine) -> str: ) +@fixture +def database_bindings_url(server: str, account_id: str, mock_database: Database) -> str: + return ( + f"https://{server}" + + ACCOUNT_BINDINGS_URL.format(account_id=account_id) + + f"?page.first=5000&filter.id_database_id_eq={mock_database.database_id}" + ) + + +@fixture +def create_binding_url( + server: str, account_id: str, mock_database: Database, mock_engine: Engine +) -> str: + return f"https://{server}" + ACCOUNT_DATABASE_BINDING_URL.format( + account_id=account_id, + database_id=mock_database.database_id, + engine_id=mock_engine.engine_id, + ) + + @fixture def create_binding_callback(create_binding_url: str, binding) -> Callable: def do_mock( @@ -401,11 +480,11 @@ def do_mock( @fixture -def binding(account_id, mock_engine, db_id) -> Binding: +def binding(account_id, mock_engine, database_id) -> Binding: return Binding( binding_key=BindingKey( account_id=account_id, - database_id=db_id, + database_id=database_id, engine_id=mock_engine.engine_id, ), is_default_engine=True, @@ -427,6 +506,23 @@ def do_mock( return do_mock +@fixture +def bindings_database_callback( + database_bindings_url: str, binding: Binding +) -> Callable: + def do_mock( + request: httpx.Request = None, + **kwargs, + ) -> Response: + assert request.url == database_bindings_url + return Response( + status_code=httpx.codes.OK, + json=list_to_paginated_response([binding]), + ) + + return do_mock + + @fixture def auth_url(server: str) -> str: return f"https://{server}{AUTH_URL}" @@ -501,3 +597,8 @@ def do_mock( @fixture def auth(username_password_auth) -> Auth: return username_password_auth + + +@fixture +def engines_by_id_url(server: str) -> str: + return f"https://{server}" + ENGINES_BY_IDS_URL diff --git a/tests/unit/service/V1/test_bindings.py b/tests/unit/service/V1/test_bindings.py index 93bc19e8086..b4c3bf60542 100644 --- a/tests/unit/service/V1/test_bindings.py +++ b/tests/unit/service/V1/test_bindings.py @@ -1,17 +1,23 @@ from re import Pattern from typing import Callable +from pytest import raises from pytest_httpx import HTTPXMock from firebolt.common.settings import Settings +from firebolt.model.V1.binding import Binding +from firebolt.model.V1.database import Database from firebolt.model.V1.engine import Engine from firebolt.service.manager import ResourceManager +from firebolt.utils.exception import AlreadyBoundError def test_get_many_bindings( httpx_mock: HTTPXMock, auth_callback: Callable, auth_url: str, + provider_callback: Callable, + provider_url: str, account_id_callback: Callable, account_id_url: Pattern, bindings_url: str, @@ -20,6 +26,7 @@ def test_get_many_bindings( mock_engine: Engine, ): httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(bindings_callback, url=bindings_url) @@ -27,3 +34,94 @@ def test_get_many_bindings( bindings = resource_manager.bindings.get_many(engine_id=mock_engine.engine_id) assert len(bindings) > 0 assert any(binding.is_default_engine for binding in bindings) + + +def test_create_binding( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + account_id_callback: Callable, + account_id_url: Pattern, + bindings_url: str, + binding: Binding, + create_binding_url: str, + settings: Settings, + mock_engine: Engine, + mock_database: Database, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_response(url=bindings_url, method="GET", json={"edges": []}) + httpx_mock.add_response( + url=create_binding_url, method="POST", json={"binding": binding.dict()} + ) + + resource_manager = ResourceManager(settings=settings) + binding = resource_manager.bindings.create( + engine=mock_engine, database=mock_database, is_default_engine=True + ) + assert binding.engine_id == mock_engine.engine_id + assert binding.database_id == mock_database.database_id + + +def test_create_binding_existing_db( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + account_id_callback: Callable, + account_id_url: Pattern, + bindings_url: str, + bindings_callback: Callable, + database_url: str, + database_callback: Callable, + settings: Settings, + mock_engine: Engine, + mock_database: Database, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_callback(bindings_callback, url=bindings_url) + httpx_mock.add_callback(database_callback, url=database_url) + + resource_manager = ResourceManager(settings=settings) + with raises(AlreadyBoundError): + resource_manager.bindings.create( + engine=mock_engine, database=mock_database, is_default_engine=True + ) + + +def test_get_engines_bound_to_db( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + account_id_callback: Callable, + account_id_url: Pattern, + database_bindings_url: str, + bindings_database_callback: Callable, + settings: Settings, + mock_engine: Engine, + mock_database: Database, + engines_by_id_url: str, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_callback(bindings_database_callback, url=database_bindings_url) + httpx_mock.add_response( + url=engines_by_id_url, method="POST", json={"engines": [mock_engine.dict()]} + ) + + resource_manager = ResourceManager(settings=settings) + engines = resource_manager.bindings.get_engines_bound_to_database( + database=mock_database + ) + assert len(engines) > 0 + assert any(engine.engine_id == mock_engine.engine_id for engine in engines) diff --git a/tests/unit/service/V1/test_database.py b/tests/unit/service/V1/test_database.py new file mode 100644 index 00000000000..42099359959 --- /dev/null +++ b/tests/unit/service/V1/test_database.py @@ -0,0 +1,131 @@ +from re import Pattern, compile +from typing import Callable + +from pytest_httpx import HTTPXMock + +from firebolt.common import Settings +from firebolt.model.V1.database import Database +from firebolt.service.manager import ResourceManager + + +def test_database_create( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + region_callback: Callable, + region_url: str, + settings: Settings, + account_id_callback: Callable, + account_id_url: Pattern, + create_databases_callback: Callable, + databases_url: str, + db_name: str, + db_description: str, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(region_callback, url=region_url) + httpx_mock.add_callback(create_databases_callback, url=databases_url, method="POST") + + manager = ResourceManager(settings=settings) + database = manager.databases.create(name=db_name, description=db_description) + + assert database.name == db_name + assert database.description == db_description + + +def test_database_get_by_name( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + settings: Settings, + account_id_callback: Callable, + account_id_url: Pattern, + database_get_by_name_callback: Callable, + database_get_by_name_url: str, + database_get_callback: Callable, + database_get_url: str, + mock_database: Database, +): + + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(database_get_by_name_callback, url=database_get_by_name_url) + httpx_mock.add_callback(database_get_callback, url=database_get_url) + + manager = ResourceManager(settings=settings) + database = manager.databases.get_by_name(name=mock_database.name) + + assert database.name == mock_database.name + + +def test_database_get_many( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + settings: Settings, + account_id_callback: Callable, + account_id_url: Pattern, + database_get_by_name_callback: Callable, + database_get_by_name_url: str, + databases_get_callback: Callable, + databases_url: str, + mock_database: Database, +): + + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback( + databases_get_callback, + url=compile(databases_url + "?[a-zA-Z0-9=&]*"), + method="GET", + ) + + manager = ResourceManager(settings=settings) + databases = manager.databases.get_many( + name_contains=mock_database.name, + attached_engine_name_eq="mockengine", + attached_engine_name_contains="mockengine", + ) + + assert len(databases) == 1 + assert databases[0].name == mock_database.name + + +def test_database_update( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + settings: Settings, + account_id_callback: Callable, + account_id_url: Pattern, + database_update_callback: Callable, + database_url: str, + mock_database: Database, +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + + httpx_mock.add_callback(database_update_callback, url=database_url, method="PATCH") + + manager = ResourceManager(settings=settings) + + mock_database._service = manager.databases + database = mock_database.update(description="new description") + + assert database.description == "new description" diff --git a/tests/unit/service/V1/test_engine.py b/tests/unit/service/V1/test_engine.py index 2f63eec8138..64c803b8c7e 100644 --- a/tests/unit/service/V1/test_engine.py +++ b/tests/unit/service/V1/test_engine.py @@ -12,6 +12,8 @@ def test_engine_start_stop( httpx_mock: HTTPXMock, auth_callback: Callable, auth_url: str, + provider_callback: Callable, + provider_url: str, settings: Settings, mock_engine: Engine, account_id_callback: Callable, @@ -20,7 +22,7 @@ def test_engine_start_stop( account_engine_url: str, ): httpx_mock.add_callback(auth_callback, url=auth_url) - + httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(account_id_callback, url=account_id_url) httpx_mock.add_callback(auth_callback, url=auth_url) diff --git a/tests/unit/service/V1/test_region.py b/tests/unit/service/V1/test_region.py new file mode 100644 index 00000000000..f16fc312396 --- /dev/null +++ b/tests/unit/service/V1/test_region.py @@ -0,0 +1,30 @@ +from re import Pattern +from typing import Callable, List + +from pytest_httpx import HTTPXMock + +from firebolt.common import Settings +from firebolt.model.V1.region import Region +from firebolt.service.manager import ResourceManager + + +def test_region( + httpx_mock: HTTPXMock, + auth_callback: Callable, + auth_url: str, + provider_callback: Callable, + provider_url: str, + region_callback: Callable, + region_url: str, + account_id_callback: Callable, + account_id_url: Pattern, + settings: Settings, + mock_regions: List[Region], +): + httpx_mock.add_callback(auth_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) + httpx_mock.add_callback(region_callback, url=region_url) + httpx_mock.add_callback(account_id_callback, url=account_id_url) + + manager = ResourceManager(settings=settings) + assert manager.regions.regions == mock_regions diff --git a/tests/unit/service/V1/test_resource_manager.py b/tests/unit/service/V1/test_resource_manager.py index 008a5b08b4d..c44fb416860 100644 --- a/tests/unit/service/V1/test_resource_manager.py +++ b/tests/unit/service/V1/test_resource_manager.py @@ -20,6 +20,8 @@ def test_rm_credentials( user: str, password: str, auth_url: str, + provider_callback: Callable, + provider_url: str, account_id_url: Pattern, account_id_callback: Callable, access_token: str, @@ -28,6 +30,7 @@ def test_rm_credentials( url = "https://url" httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(check_token_callback, url=url) httpx_mock.add_callback(account_id_callback, url=account_id_url) @@ -62,6 +65,8 @@ def test_rm_token_cache( user: str, password: str, auth_url: str, + provider_callback: Callable, + provider_url: str, account_id_url: Pattern, account_id_callback: Callable, access_token: str, @@ -70,6 +75,7 @@ def test_rm_token_cache( url = "https://url" httpx_mock.add_callback(check_credentials_callback, url=auth_url) + httpx_mock.add_callback(provider_callback, url=provider_url) httpx_mock.add_callback(check_token_callback, url=url) httpx_mock.add_callback(account_id_callback, url=account_id_url)