Skip to content

Commit

Permalink
Extract poll refresh pipeline from cursor (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db authored Dec 3, 2024
1 parent 5066e76 commit 69aa772
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 167 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
- Fix behavior flag use in init of DatabricksAdapter (thanks @VersusFacit!) ([836](https://github.com/databricks/dbt-databricks/pull/836))
- Restrict pydantic to V1 per dbt Labs' request ([843](https://github.com/databricks/dbt-databricks/pull/843))
- Switching to Ruff for formatting and linting ([847](https://github.com/databricks/dbt-databricks/pull/847))
- Refactoring location of DLT polling code ([849](https://github.com/databricks/dbt-databricks/pull/849))
- Switching to Hatch and pyproject.toml for project config ([853](https://github.com/databricks/dbt-databricks/pull/853))

## dbt-databricks 1.8.7 (October 10, 2024)
Expand Down
55 changes: 55 additions & 0 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,60 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str:
return response_json["run_id"]


class DltPipelineApi(PollableApi):
def __init__(self, session: Session, host: str, polling_interval: int):
super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60)

def poll_for_completion(self, pipeline_id: str) -> None:
self._poll_api(
url=f"/{pipeline_id}",
params={},
get_state_func=lambda response: response.json()["state"],
terminal_states={"IDLE", "FAILED", "DELETED"},
expected_end_state="IDLE",
unexpected_end_state_func=self._get_exception,
)

def _get_exception(self, response: Response) -> None:
response_json = response.json()
cause = response_json.get("cause")
if cause:
raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}")
else:
latest_update = response_json.get("latest_updates")[0]
last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update)
raise DbtRuntimeError(
f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}"
)

def get_update_error(self, pipeline_id: str, update_id: str) -> str:
response = self.session.get(f"/{pipeline_id}/events")
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


class DatabricksApiClient:
def __init__(
self,
Expand All @@ -481,6 +535,7 @@ def __init__(
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
self.workflows = WorkflowJobApi(session, host)
self.workflow_permissions = JobPermissionsApi(session, host)
self.dlt_pipelines = DltPipelineApi(session, host, polling_interval)

@staticmethod
def create(
Expand Down
157 changes: 6 additions & 151 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError
from dbt_common.utils import cast_to_str
from requests import Session

import databricks.sql as dbsql
from databricks.sql.client import Connection as DatabricksSQLConnection
Expand All @@ -35,7 +34,6 @@
)
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.api_client import DatabricksApiClient
from dbt.adapters.databricks.auth import BearerAuth
from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider
from dbt.adapters.databricks.events.connection_events import (
ConnectionAcquire,
Expand All @@ -61,7 +59,6 @@
CursorCreate,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import redact_credentials
Expand Down Expand Up @@ -227,97 +224,6 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
bindings = [self._fix_binding(binding) for binding in bindings]
self._cursor.execute(sql, bindings)

def poll_refresh_pipeline(self, pipeline_id: str) -> None:
# interval in seconds
polling_interval = 10

# timeout in seconds
timeout = 60 * 60

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
host: str = self._creds.host or ""
headers = (
self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore
)
session = Session()
session.auth = BearerAuth(headers)
session.headers = {"User-Agent": self._user_agent}
pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the most recently created update for the pipeline
latest_update = _find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
update_id = latest_update.get("update_id", "")
prev_state = state

logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))

start = time.time()
exceeded_timeout = False
while state not in stopped_states:
if time.time() - start > timeout:
exceeded_timeout = True
break

# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the update we are currently polling
update = _find_update(pipeline, update_id)
if not update:
raise DbtRuntimeError(
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
)

state = update.get("state")
if state != prev_state:
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
prev_state = state

if state == "FAILED":
logger.error(
PipelineRefreshError(
pipeline_id,
update_id,
_get_update_error_msg(session, host, pipeline_id, update_id),
)
)

# another update may have been created due to retry_on_fail settings
# get the latest update and see if it is a new one
latest_update = _find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

latest_update_id = latest_update.get("update_id", "")
if latest_update_id != update_id:
update_id = latest_update_id
state = None

if exceeded_timeout:
raise DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = _get_update_error_msg(session, host, pipeline_id, update_id)
raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}")

if state == "CANCELED":
raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled")

return

@classmethod
def findUpdate(cls, updates: list, id: str) -> Optional[dict]:
matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None

@property
def hex_query_id(self) -> str:
"""Return the hex GUID for this query
Expand Down Expand Up @@ -475,12 +381,15 @@ class DatabricksConnectionManager(SparkConnectionManager):
credentials_provider: Optional[TCredentialProvider] = None
_user_agent = f"dbt-databricks/{__version__}"

def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
super().__init__(profile, mp_context)
creds = cast(DatabricksCredentials, self.profile.credentials)
self.api_client = DatabricksApiClient.create(creds, 15 * 60)

def cancel_open(self) -> list[str]:
cancelled = super().cancel_open()
creds = cast(DatabricksCredentials, self.profile.credentials)
api_client = DatabricksApiClient.create(creds, 15 * 60)
logger.info("Cancelling open python jobs")
PythonRunTracker.cancel_runs(api_client)
PythonRunTracker.cancel_runs(self.api_client)
return cancelled

def compare_dbr_version(self, major: int, minor: int) -> int:
Expand Down Expand Up @@ -1079,60 +988,6 @@ def exponential_backoff(attempt: int) -> int:
)


def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

response = session.get(pipeline_url)
if response.status_code != 200:
raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}")

return response.json()


def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")

if not id:
return updates[0]

matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None


def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str:
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
response = session.get(events_url)
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


def _get_compute_name(query_header_context: Any) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
Expand Down
14 changes: 3 additions & 11 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
DatabricksConnectionManager,
DatabricksDBTConnection,
DatabricksSQLConnectionWrapper,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.python_models.python_submissions import (
Expand Down Expand Up @@ -807,19 +805,13 @@ def get_from_relation(
"""Get the relation config from the relation."""

relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation)
connection = cast(DatabricksDBTConnection, adapter.connections.get_thread_connection())
wrapper: DatabricksSQLConnectionWrapper = connection.handle

# Ensure any current refreshes are completed before returning the relation config
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
if tblproperties.pipeline_id:
# TODO fix this path so that it doesn't need a cursor
# It just calls APIs to poll the pipeline status
cursor = wrapper.cursor()
try:
cursor.poll_refresh_pipeline(tblproperties.pipeline_id)
finally:
cursor.close()
adapter.connections.api_client.dlt_pipelines.poll_for_completion(
tblproperties.pipeline_id
)
return relation_config


Expand Down
71 changes: 71 additions & 0 deletions tests/unit/api_client/test_dlt_pipeline_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.databricks.api_client import DltPipelineApi
from tests.unit.api_client.api_test_base import ApiTestBase


class TestDltPipelineApi(ApiTestBase):
@pytest.fixture
def api(self, session, host):
return DltPipelineApi(session, host, 1)

@pytest.fixture
def pipeline_id(self):
return "pipeline_id"

@pytest.fixture
def update_id(self):
return "update_id"

def test_get_update_error__non_200(self, api, session, pipeline_id, update_id):
session.get.return_value.status_code = 500
with pytest.raises(DbtRuntimeError):
api.get_update_error(pipeline_id, update_id)

def test_get_update_error__200_no_events(self, api, session, pipeline_id, update_id):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {"events": []}
assert api.get_update_error(pipeline_id, update_id) == ""

def test_get_update_error__200_no_error_events(self, api, session, pipeline_id, update_id):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {
"events": [{"event_type": "update_progress", "origin": {"update_id": update_id}}]
}
assert api.get_update_error(pipeline_id, update_id) == ""

def test_get_update_error__200_error_events(self, api, session, pipeline_id, update_id):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {
"events": [
{
"message": "I failed",
"details": {"update_progress": {"state": "FAILED"}},
"event_type": "update_progress",
"origin": {"update_id": update_id},
}
]
}
assert api.get_update_error(pipeline_id, update_id) == "I failed"

def test_poll_for_completion__non_200(self, api, session, pipeline_id):
self.assert_non_200_raises_error(lambda: api.poll_for_completion(pipeline_id), session)

def test_poll_for_completion__200(self, api, session, host, pipeline_id):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {"state": "IDLE"}
api.poll_for_completion(pipeline_id)
session.get.assert_called_once_with(
f"https://{host}/api/2.0/pipelines/{pipeline_id}", json=None, params={}
)

def test_poll_for_completion__failed_with_cause(self, api, session, pipeline_id):
session.get.return_value.status_code = 200
session.get.return_value.json.return_value = {
"state": "FAILED",
"pipeline_id": pipeline_id,
"cause": "I failed",
}
with pytest.raises(DbtRuntimeError, match=f"Pipeline {pipeline_id} failed: I failed"):
api.poll_for_completion(pipeline_id)
Loading

0 comments on commit 69aa772

Please sign in to comment.