From 69aa772703f2743b1d40ba98a2a9dd33e4deee46 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:32:37 -0800 Subject: [PATCH] Extract poll refresh pipeline from cursor (#849) --- CHANGELOG.md | 1 + dbt/adapters/databricks/api_client.py | 55 ++++++ dbt/adapters/databricks/connections.py | 157 +----------------- dbt/adapters/databricks/impl.py | 14 +- .../unit/api_client/test_dlt_pipeline_api.py | 71 ++++++++ tests/unit/test_adapter.py | 15 +- 6 files changed, 146 insertions(+), 167 deletions(-) create mode 100644 tests/unit/api_client/test_dlt_pipeline_api.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f8e07129..64697246 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ - Fix behavior flag use in init of DatabricksAdapter (thanks @VersusFacit!) ([836](https://github.com/databricks/dbt-databricks/pull/836)) - Restrict pydantic to V1 per dbt Labs' request ([843](https://github.com/databricks/dbt-databricks/pull/843)) - Switching to Ruff for formatting and linting ([847](https://github.com/databricks/dbt-databricks/pull/847)) +- Refactoring location of DLT polling code ([849](https://github.com/databricks/dbt-databricks/pull/849)) - Switching to Hatch and pyproject.toml for project config ([853](https://github.com/databricks/dbt-databricks/pull/853)) ## dbt-databricks 1.8.7 (October 10, 2024) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index c14af9f5..cfabb235 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -460,6 +460,60 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str: return response_json["run_id"] +class DltPipelineApi(PollableApi): + def __init__(self, session: Session, host: str, polling_interval: int): + super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60) + + def poll_for_completion(self, pipeline_id: str) -> None: + self._poll_api( + url=f"/{pipeline_id}", + params={}, + get_state_func=lambda response: response.json()["state"], + terminal_states={"IDLE", "FAILED", "DELETED"}, + expected_end_state="IDLE", + unexpected_end_state_func=self._get_exception, + ) + + def _get_exception(self, response: Response) -> None: + response_json = response.json() + cause = response_json.get("cause") + if cause: + raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}") + else: + latest_update = response_json.get("latest_updates")[0] + last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update) + raise DbtRuntimeError( + f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}" + ) + + def get_update_error(self, pipeline_id: str, update_id: str) -> str: + response = self.session.get(f"/{pipeline_id}/events") + if response.status_code != 200: + raise DbtRuntimeError( + f"Error getting pipeline event info for {pipeline_id}: {response.text}" + ) + + events = response.json().get("events", []) + update_events = [ + e + for e in events + if e.get("event_type", "") == "update_progress" + and e.get("origin", {}).get("update_id") == update_id + ] + + error_events = [ + e + for e in update_events + if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED" + ] + + msg = "" + if error_events: + msg = error_events[0].get("message", "") + + return msg + + class DatabricksApiClient: def __init__( self, @@ -481,6 +535,7 @@ def __init__( self.job_runs = JobRunsApi(session, host, polling_interval, timeout) self.workflows = WorkflowJobApi(session, host) self.workflow_permissions = JobPermissionsApi(session, host) + self.dlt_pipelines = DltPipelineApi(session, host, polling_interval) @staticmethod def create( diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 49b1ba82..33ce5760 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -17,7 +17,6 @@ from dbt_common.events.functions import fire_event from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError from dbt_common.utils import cast_to_str -from requests import Session import databricks.sql as dbsql from databricks.sql.client import Connection as DatabricksSQLConnection @@ -35,7 +34,6 @@ ) from dbt.adapters.databricks.__version__ import version as __version__ from dbt.adapters.databricks.api_client import DatabricksApiClient -from dbt.adapters.databricks.auth import BearerAuth from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider from dbt.adapters.databricks.events.connection_events import ( ConnectionAcquire, @@ -61,7 +59,6 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError -from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -227,97 +224,6 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: bindings = [self._fix_binding(binding) for binding in bindings] self._cursor.execute(sql, bindings) - def poll_refresh_pipeline(self, pipeline_id: str) -> None: - # interval in seconds - polling_interval = 10 - - # timeout in seconds - timeout = 60 * 60 - - stopped_states = ("COMPLETED", "FAILED", "CANCELED") - host: str = self._creds.host or "" - headers = ( - self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore - ) - session = Session() - session.auth = BearerAuth(headers) - session.headers = {"User-Agent": self._user_agent} - pipeline = _get_pipeline_state(session, host, pipeline_id) - # get the most recently created update for the pipeline - latest_update = _find_update(pipeline) - if not latest_update: - raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}") - - state = latest_update.get("state") - # we use update_id to retrieve the update in the polling loop - update_id = latest_update.get("update_id", "") - prev_state = state - - logger.info(PipelineRefresh(pipeline_id, update_id, str(state))) - - start = time.time() - exceeded_timeout = False - while state not in stopped_states: - if time.time() - start > timeout: - exceeded_timeout = True - break - - # should we do exponential backoff? - time.sleep(polling_interval) - - pipeline = _get_pipeline_state(session, host, pipeline_id) - # get the update we are currently polling - update = _find_update(pipeline, update_id) - if not update: - raise DbtRuntimeError( - f"Error getting pipeline update info: {pipeline_id}, update: {update_id}" - ) - - state = update.get("state") - if state != prev_state: - logger.info(PipelineRefresh(pipeline_id, update_id, str(state))) - prev_state = state - - if state == "FAILED": - logger.error( - PipelineRefreshError( - pipeline_id, - update_id, - _get_update_error_msg(session, host, pipeline_id, update_id), - ) - ) - - # another update may have been created due to retry_on_fail settings - # get the latest update and see if it is a new one - latest_update = _find_update(pipeline) - if not latest_update: - raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}") - - latest_update_id = latest_update.get("update_id", "") - if latest_update_id != update_id: - update_id = latest_update_id - state = None - - if exceeded_timeout: - raise DbtRuntimeError("timed out waiting for materialized view refresh") - - if state == "FAILED": - msg = _get_update_error_msg(session, host, pipeline_id, update_id) - raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}") - - if state == "CANCELED": - raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled") - - return - - @classmethod - def findUpdate(cls, updates: list, id: str) -> Optional[dict]: - matches = [x for x in updates if x.get("update_id") == id] - if matches: - return matches[0] - - return None - @property def hex_query_id(self) -> str: """Return the hex GUID for this query @@ -475,12 +381,15 @@ class DatabricksConnectionManager(SparkConnectionManager): credentials_provider: Optional[TCredentialProvider] = None _user_agent = f"dbt-databricks/{__version__}" + def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): + super().__init__(profile, mp_context) + creds = cast(DatabricksCredentials, self.profile.credentials) + self.api_client = DatabricksApiClient.create(creds, 15 * 60) + def cancel_open(self) -> list[str]: cancelled = super().cancel_open() - creds = cast(DatabricksCredentials, self.profile.credentials) - api_client = DatabricksApiClient.create(creds, 15 * 60) logger.info("Cancelling open python jobs") - PythonRunTracker.cancel_runs(api_client) + PythonRunTracker.cancel_runs(self.api_client) return cancelled def compare_dbr_version(self, major: int, minor: int) -> int: @@ -1079,60 +988,6 @@ def exponential_backoff(attempt: int) -> int: ) -def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict: - pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}" - - response = session.get(pipeline_url) - if response.status_code != 200: - raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}") - - return response.json() - - -def _find_update(pipeline: dict, id: str = "") -> Optional[dict]: - updates = pipeline.get("latest_updates", []) - if not updates: - raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}") - - if not id: - return updates[0] - - matches = [x for x in updates if x.get("update_id") == id] - if matches: - return matches[0] - - return None - - -def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str: - events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events" - response = session.get(events_url) - if response.status_code != 200: - raise DbtRuntimeError( - f"Error getting pipeline event info for {pipeline_id}: {response.text}" - ) - - events = response.json().get("events", []) - update_events = [ - e - for e in events - if e.get("event_type", "") == "update_progress" - and e.get("origin", {}).get("update_id") == update_id - ] - - error_events = [ - e - for e in update_events - if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED" - ] - - msg = "" - if error_events: - msg = error_events[0].get("message", "") - - return msg - - def _get_compute_name(query_header_context: Any) -> Optional[str]: # Get the name of the specified compute resource from the node's # config. diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 3e0288b0..f68664e8 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -34,8 +34,6 @@ from dbt.adapters.databricks.connections import ( USE_LONG_SESSIONS, DatabricksConnectionManager, - DatabricksDBTConnection, - DatabricksSQLConnectionWrapper, ExtendedSessionConnectionManager, ) from dbt.adapters.databricks.python_models.python_submissions import ( @@ -807,19 +805,13 @@ def get_from_relation( """Get the relation config from the relation.""" relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation) - connection = cast(DatabricksDBTConnection, adapter.connections.get_thread_connection()) - wrapper: DatabricksSQLConnectionWrapper = connection.handle # Ensure any current refreshes are completed before returning the relation config tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"]) if tblproperties.pipeline_id: - # TODO fix this path so that it doesn't need a cursor - # It just calls APIs to poll the pipeline status - cursor = wrapper.cursor() - try: - cursor.poll_refresh_pipeline(tblproperties.pipeline_id) - finally: - cursor.close() + adapter.connections.api_client.dlt_pipelines.poll_for_completion( + tblproperties.pipeline_id + ) return relation_config diff --git a/tests/unit/api_client/test_dlt_pipeline_api.py b/tests/unit/api_client/test_dlt_pipeline_api.py new file mode 100644 index 00000000..7dd1418e --- /dev/null +++ b/tests/unit/api_client/test_dlt_pipeline_api.py @@ -0,0 +1,71 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.databricks.api_client import DltPipelineApi +from tests.unit.api_client.api_test_base import ApiTestBase + + +class TestDltPipelineApi(ApiTestBase): + @pytest.fixture + def api(self, session, host): + return DltPipelineApi(session, host, 1) + + @pytest.fixture + def pipeline_id(self): + return "pipeline_id" + + @pytest.fixture + def update_id(self): + return "update_id" + + def test_get_update_error__non_200(self, api, session, pipeline_id, update_id): + session.get.return_value.status_code = 500 + with pytest.raises(DbtRuntimeError): + api.get_update_error(pipeline_id, update_id) + + def test_get_update_error__200_no_events(self, api, session, pipeline_id, update_id): + session.get.return_value.status_code = 200 + session.get.return_value.json.return_value = {"events": []} + assert api.get_update_error(pipeline_id, update_id) == "" + + def test_get_update_error__200_no_error_events(self, api, session, pipeline_id, update_id): + session.get.return_value.status_code = 200 + session.get.return_value.json.return_value = { + "events": [{"event_type": "update_progress", "origin": {"update_id": update_id}}] + } + assert api.get_update_error(pipeline_id, update_id) == "" + + def test_get_update_error__200_error_events(self, api, session, pipeline_id, update_id): + session.get.return_value.status_code = 200 + session.get.return_value.json.return_value = { + "events": [ + { + "message": "I failed", + "details": {"update_progress": {"state": "FAILED"}}, + "event_type": "update_progress", + "origin": {"update_id": update_id}, + } + ] + } + assert api.get_update_error(pipeline_id, update_id) == "I failed" + + def test_poll_for_completion__non_200(self, api, session, pipeline_id): + self.assert_non_200_raises_error(lambda: api.poll_for_completion(pipeline_id), session) + + def test_poll_for_completion__200(self, api, session, host, pipeline_id): + session.get.return_value.status_code = 200 + session.get.return_value.json.return_value = {"state": "IDLE"} + api.poll_for_completion(pipeline_id) + session.get.assert_called_once_with( + f"https://{host}/api/2.0/pipelines/{pipeline_id}", json=None, params={} + ) + + def test_poll_for_completion__failed_with_cause(self, api, session, pipeline_id): + session.get.return_value.status_code = 200 + session.get.return_value.json.return_value = { + "state": "FAILED", + "pipeline_id": pipeline_id, + "cause": "I failed", + } + with pytest.raises(DbtRuntimeError, match=f"Pipeline {pipeline_id} failed: I failed"): + api.poll_for_completion(pipeline_id) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index b26d5f80..78ae12cb 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -340,13 +340,15 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co assert connection.credentials.token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" assert connection.credentials.schema == "analytics" - def test_list_relations_without_caching__no_relations(self): + @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") + def test_list_relations_without_caching__no_relations(self, _): with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked: mocked.return_value = [] adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn")) assert adapter.list_relations("database", "schema") == [] - def test_list_relations_without_caching__some_relations(self): + @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") + def test_list_relations_without_caching__some_relations(self, _): with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked: mocked.return_value = [("name", "table", "hudi", "owner")] adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn")) @@ -360,7 +362,8 @@ def test_list_relations_without_caching__some_relations(self): assert relation.owner == "owner" assert relation.is_hudi - def test_list_relations_without_caching__hive_relation(self): + @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") + def test_list_relations_without_caching__hive_relation(self, _): with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked: mocked.return_value = [("name", "table", None, None)] adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn")) @@ -373,7 +376,8 @@ def test_list_relations_without_caching__hive_relation(self): assert relation.type == DatabricksRelationType.Table assert not relation.has_information() - def test_get_schema_for_catalog__no_columns(self): + @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") + def test_get_schema_for_catalog__no_columns(self, _): with patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info: list_info.return_value = [(Mock(), "info")] with patch.object(DatabricksAdapter, "_get_columns_for_catalog") as get_columns: @@ -382,7 +386,8 @@ def test_get_schema_for_catalog__no_columns(self): table = adapter._get_schema_for_catalog("database", "schema", "name") assert len(table.rows) == 0 - def test_get_schema_for_catalog__some_columns(self): + @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") + def test_get_schema_for_catalog__some_columns(self, _): with patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info: list_info.return_value = [(Mock(), "info")] with patch.object(DatabricksAdapter, "_get_columns_for_catalog") as get_columns: