Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Nov 19, 2024
1 parent da76ae9 commit cfde84e
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 161 deletions.
40 changes: 40 additions & 0 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,45 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str:
return response_json["run_id"]


class DltPipelineApi(DatabricksApi):
def __init__(self, session: Session, host: str):
super().__init__(session, host, "/api/2.0/pipelines")

def state(self, pipeline_id: str) -> dict:
response = self.session.get(f"/{pipeline_id}")
if response.status_code != 200:
raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}")

return response.json()

def get_update_error(self, pipeline_id: str, update_id: str) -> str:
response = self.session.get(f"/{pipeline_id}/events")
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


class DatabricksApiClient:
def __init__(
self,
Expand All @@ -473,6 +512,7 @@ def __init__(
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
self.workflows = WorkflowJobApi(session, host)
self.workflow_permissions = JobPermissionsApi(session, host)
self.dlt_pipelines = DltPipelineApi(session, host)

@staticmethod
def create(
Expand Down
247 changes: 97 additions & 150 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dbt_common.events.functions import fire_event
from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError
from dbt_common.utils import cast_to_str
from requests import Session

import databricks.sql as dbsql
from databricks.sql.client import Connection as DatabricksSQLConnection
Expand All @@ -35,7 +34,6 @@
)
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.api_client import DatabricksApiClient
from dbt.adapters.databricks.auth import BearerAuth
from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider
from dbt.adapters.databricks.events.connection_events import (
ConnectionAcquire,
Expand Down Expand Up @@ -227,97 +225,6 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
bindings = [self._fix_binding(binding) for binding in bindings]
self._cursor.execute(sql, bindings)

def poll_refresh_pipeline(self, pipeline_id: str) -> None:
# interval in seconds
polling_interval = 10

# timeout in seconds
timeout = 60 * 60

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
host: str = self._creds.host or ""
headers = (
self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore
)
session = Session()
session.auth = BearerAuth(headers)
session.headers = {"User-Agent": self._user_agent}
pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the most recently created update for the pipeline
latest_update = _find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
update_id = latest_update.get("update_id", "")
prev_state = state

logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))

start = time.time()
exceeded_timeout = False
while state not in stopped_states:
if time.time() - start > timeout:
exceeded_timeout = True
break

# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = _get_pipeline_state(session, host, pipeline_id)
# get the update we are currently polling
update = _find_update(pipeline, update_id)
if not update:
raise DbtRuntimeError(
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
)

state = update.get("state")
if state != prev_state:
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
prev_state = state

if state == "FAILED":
logger.error(
PipelineRefreshError(
pipeline_id,
update_id,
_get_update_error_msg(session, host, pipeline_id, update_id),
)
)

# another update may have been created due to retry_on_fail settings
# get the latest update and see if it is a new one
latest_update = _find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

latest_update_id = latest_update.get("update_id", "")
if latest_update_id != update_id:
update_id = latest_update_id
state = None

if exceeded_timeout:
raise DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = _get_update_error_msg(session, host, pipeline_id, update_id)
raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}")

if state == "CANCELED":
raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled")

return

@classmethod
def findUpdate(cls, updates: list, id: str) -> Optional[dict]:
matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None

@property
def hex_query_id(self) -> str:
"""Return the hex GUID for this query
Expand Down Expand Up @@ -475,12 +382,15 @@ class DatabricksConnectionManager(SparkConnectionManager):
credentials_provider: Optional[TCredentialProvider] = None
_user_agent = f"dbt-databricks/{__version__}"

def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
super().__init__(profile, mp_context)
creds = cast(DatabricksCredentials, self.profile.credentials)
self.api_client = DatabricksApiClient.create(creds, 15 * 60)

def cancel_open(self) -> list[str]:
cancelled = super().cancel_open()
creds = cast(DatabricksCredentials, self.profile.credentials)
api_client = DatabricksApiClient.create(creds, 15 * 60)
logger.info("Cancelling open python jobs")
PythonRunTracker.cancel_runs(api_client)
PythonRunTracker.cancel_runs(self.api_client)
return cancelled

def compare_dbr_version(self, major: int, minor: int) -> int:
Expand Down Expand Up @@ -794,6 +704,97 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe
message = "OK"
return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore

def poll_for_pipeline_completion(self, pipeline_id: str) -> None:
# interval in seconds
polling_interval = 10

# timeout in seconds
timeout = 60 * 60

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
pipeline = self.api_client.dlt_pipelines.state(pipeline_id)
# get the most recently created update for the pipeline
latest_update = self._find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
update_id = latest_update.get("update_id", "")
prev_state = state

logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))

start = time.time()
exceeded_timeout = False
while state not in stopped_states:
if time.time() - start > timeout:
exceeded_timeout = True
break

# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = self.api_client.dlt_pipelines.state(pipeline_id)
# get the update we are currently polling
update = self._find_update(pipeline, update_id)
if not update:
raise DbtRuntimeError(
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
)

state = update.get("state")
if state != prev_state:
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
prev_state = state

if state == "FAILED":
logger.error(
PipelineRefreshError(
pipeline_id,
update_id,
self.api_client.dlt_pipelines.get_update_error(pipeline_id, update_id),
)
)

# another update may have been created due to retry_on_fail settings
# get the latest update and see if it is a new one
latest_update = self._find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

latest_update_id = latest_update.get("update_id", "")
if latest_update_id != update_id:
update_id = latest_update_id
state = None

if exceeded_timeout:
raise DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = self.api_client.dlt_pipelines.get_update_error(pipeline_id, update_id)
raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}")

if state == "CANCELED":
raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled")

return

@staticmethod
def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")

if not id:
return updates[0]

matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None


class ExtendedSessionConnectionManager(DatabricksConnectionManager):
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
Expand Down Expand Up @@ -1079,60 +1080,6 @@ def exponential_backoff(attempt: int) -> int:
)


def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

response = session.get(pipeline_url)
if response.status_code != 200:
raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}")

return response.json()


def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")

if not id:
return updates[0]

matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None


def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str:
events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events"
response = session.get(events_url)
if response.status_code != 200:
raise DbtRuntimeError(
f"Error getting pipeline event info for {pipeline_id}: {response.text}"
)

events = response.json().get("events", [])
update_events = [
e
for e in events
if e.get("event_type", "") == "update_progress"
and e.get("origin", {}).get("update_id") == update_id
]

error_events = [
e
for e in update_events
if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED"
]

msg = ""
if error_events:
msg = error_events[0].get("message", "")

return msg


def _get_compute_name(query_header_context: Any) -> Optional[str]:
# Get the name of the specified compute resource from the node's
# config.
Expand Down
12 changes: 1 addition & 11 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
from dbt.adapters.databricks.connections import (
USE_LONG_SESSIONS,
DatabricksConnectionManager,
DatabricksDBTConnection,
DatabricksSQLConnectionWrapper,
ExtendedSessionConnectionManager,
)
from dbt.adapters.databricks.python_models.python_submissions import (
Expand Down Expand Up @@ -807,19 +805,11 @@ def get_from_relation(
"""Get the relation config from the relation."""

relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation)
connection = cast(DatabricksDBTConnection, adapter.connections.get_thread_connection())
wrapper: DatabricksSQLConnectionWrapper = connection.handle

# Ensure any current refreshes are completed before returning the relation config
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
if tblproperties.pipeline_id:
# TODO fix this path so that it doesn't need a cursor
# It just calls APIs to poll the pipeline status
cursor = wrapper.cursor()
try:
cursor.poll_refresh_pipeline(tblproperties.pipeline_id)
finally:
cursor.close()
adapter.connections.poll_for_pipeline_completion(tblproperties.pipeline_id)
return relation_config


Expand Down

0 comments on commit cfde84e

Please sign in to comment.