Skip to content

Commit

Permalink
I think this works
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Nov 20, 2024
1 parent 119e7b5 commit 4ddbaf9
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 107 deletions.
33 changes: 24 additions & 9 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,16 +452,31 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str:
return response_json["run_id"]


class DltPipelineApi(DatabricksApi):
def __init__(self, session: Session, host: str):
super().__init__(session, host, "/api/2.0/pipelines")
class DltPipelineApi(PollableApi):
def __init__(self, session: Session, host: str, polling_interval: int):
super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60)

def state(self, pipeline_id: str) -> dict:
response = self.session.get(f"/{pipeline_id}")
if response.status_code != 200:
raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}")
def poll_for_completion(self, pipeline_id: str) -> None:
self._poll_api(
url=f"/{pipeline_id}",
params={},
get_state_func=lambda response: response.json()["state"],
terminal_states={"IDLE", "FAILED", "DELETED"},
expected_end_state="IDLE",
unexpected_end_state_func=self._get_exception,
)

return response.json()
def _get_exception(self, response: Response) -> None:
response_json = response.json()
cause = response_json.get("cause")
if cause:
raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}")
else:
latest_update = response_json.get("latest_updates")[0]
last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update)
raise DbtRuntimeError(
f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}"
)

def get_update_error(self, pipeline_id: str, update_id: str) -> str:
response = self.session.get(f"/{pipeline_id}/events")
Expand Down Expand Up @@ -512,7 +527,7 @@ def __init__(
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
self.workflows = WorkflowJobApi(session, host)
self.workflow_permissions = JobPermissionsApi(session, host)
self.dlt_pipelines = DltPipelineApi(session, host)
self.dlt_pipelines = DltPipelineApi(session, host, polling_interval)

@staticmethod
def create(
Expand Down
92 changes: 0 additions & 92 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
CursorCreate,
)
from dbt.adapters.databricks.events.other_events import QueryError
from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError
from dbt.adapters.databricks.logging import logger
from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker
from dbt.adapters.databricks.utils import redact_credentials
Expand Down Expand Up @@ -704,97 +703,6 @@ def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterRe
message = "OK"
return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore

def poll_for_pipeline_completion(self, pipeline_id: str) -> None:
# interval in seconds
polling_interval = 10

# timeout in seconds
timeout = 60 * 60

stopped_states = ("COMPLETED", "FAILED", "CANCELED")
pipeline = self.api_client.dlt_pipelines.state(pipeline_id)
# get the most recently created update for the pipeline
latest_update = self._find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

state = latest_update.get("state")
# we use update_id to retrieve the update in the polling loop
update_id = latest_update.get("update_id", "")
prev_state = state

logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))

start = time.time()
exceeded_timeout = False
while state not in stopped_states:
if time.time() - start > timeout:
exceeded_timeout = True
break

# should we do exponential backoff?
time.sleep(polling_interval)

pipeline = self.api_client.dlt_pipelines.state(pipeline_id)
# get the update we are currently polling
update = self._find_update(pipeline, update_id)
if not update:
raise DbtRuntimeError(
f"Error getting pipeline update info: {pipeline_id}, update: {update_id}"
)

state = update.get("state")
if state != prev_state:
logger.info(PipelineRefresh(pipeline_id, update_id, str(state)))
prev_state = state

if state == "FAILED":
logger.error(
PipelineRefreshError(
pipeline_id,
update_id,
self.api_client.dlt_pipelines.get_update_error(pipeline_id, update_id),
)
)

# another update may have been created due to retry_on_fail settings
# get the latest update and see if it is a new one
latest_update = self._find_update(pipeline)
if not latest_update:
raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}")

latest_update_id = latest_update.get("update_id", "")
if latest_update_id != update_id:
update_id = latest_update_id
state = None

if exceeded_timeout:
raise DbtRuntimeError("timed out waiting for materialized view refresh")

if state == "FAILED":
msg = self.api_client.dlt_pipelines.get_update_error(pipeline_id, update_id)
raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}")

if state == "CANCELED":
raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled")

return

@staticmethod
def _find_update(pipeline: dict, id: str = "") -> Optional[dict]:
updates = pipeline.get("latest_updates", [])
if not updates:
raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}")

if not id:
return updates[0]

matches = [x for x in updates if x.get("update_id") == id]
if matches:
return matches[0]

return None


class ExtendedSessionConnectionManager(DatabricksConnectionManager):
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
Expand Down
4 changes: 3 additions & 1 deletion dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,9 @@ def get_from_relation(
# Ensure any current refreshes are completed before returning the relation config
tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"])
if tblproperties.pipeline_id:
adapter.connections.poll_for_pipeline_completion(tblproperties.pipeline_id)
adapter.connections.api_client.dlt_pipelines.poll_for_completion(
tblproperties.pipeline_id
)
return relation_config


Expand Down
15 changes: 10 additions & 5 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,15 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co
assert connection.credentials.token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
assert connection.credentials.schema == "analytics"

def test_list_relations_without_caching__no_relations(self):
@mock.patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_list_relations_without_caching__no_relations(self, _):
with mock.patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
mocked.return_value = []
adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn"))
assert adapter.list_relations("database", "schema") == []

def test_list_relations_without_caching__some_relations(self):
@mock.patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_list_relations_without_caching__some_relations(self, _):
with mock.patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
mocked.return_value = [("name", "table", "hudi", "owner")]
adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn"))
Expand All @@ -361,7 +363,8 @@ def test_list_relations_without_caching__some_relations(self):
assert relation.owner == "owner"
assert relation.is_hudi

def test_list_relations_without_caching__hive_relation(self):
@mock.patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_list_relations_without_caching__hive_relation(self, _):
with mock.patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked:
mocked.return_value = [("name", "table", None, None)]
adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn"))
Expand All @@ -374,7 +377,8 @@ def test_list_relations_without_caching__hive_relation(self):
assert relation.type == DatabricksRelationType.Table
assert not relation.has_information()

def test_get_schema_for_catalog__no_columns(self):
@mock.patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_get_schema_for_catalog__no_columns(self, _):
with mock.patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info:
list_info.return_value = [(Mock(), "info")]
with mock.patch.object(DatabricksAdapter, "_get_columns_for_catalog") as get_columns:
Expand All @@ -383,7 +387,8 @@ def test_get_schema_for_catalog__no_columns(self):
table = adapter._get_schema_for_catalog("database", "schema", "name")
assert len(table.rows) == 0

def test_get_schema_for_catalog__some_columns(self):
@mock.patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create")
def test_get_schema_for_catalog__some_columns(self, _):
with mock.patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info:
list_info.return_value = [(Mock(), "info")]
with mock.patch.object(DatabricksAdapter, "_get_columns_for_catalog") as get_columns:
Expand Down

0 comments on commit 4ddbaf9

Please sign in to comment.