From d5c1fb726d1326c3d03e8e3b60518160fc2a6a52 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Thu, 17 Oct 2024 11:57:23 -0400 Subject: [PATCH] add query_id to SQLQueryStatus (demonstration only) --- dbt/adapters/snowflake/connections.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index fc2c09c19..aaf6e11a1 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -15,7 +15,7 @@ from contextlib import contextmanager from dataclasses import dataclass from io import StringIO -from time import sleep +from time import sleep, perf_counter from typing import Optional, Tuple, Union, Any, List, Iterable, TYPE_CHECKING @@ -43,8 +43,11 @@ DbtRuntimeError, DbtConfigError, ) +from dbt_common.events.contextvars import get_node_info from dbt_common.exceptions import DbtDatabaseError +from dbt_common.events.functions import fire_event from dbt_common.record import get_record_mode_from_env, RecorderMode +from dbt.adapters.events.types import SQLQueryStatus from dbt.adapters.exceptions.connection import FailedToConnectError from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials from dbt.adapters.sql import SQLConnectionManager @@ -84,6 +87,11 @@ def snowflake_private_key(private_key: RSAPrivateKey) -> bytes: ) +@dataclass +class SnowflakeAdapterResponse(AdapterResponse): + pass + + @dataclass class SnowflakeCredentials(Credentials): account: str @@ -531,6 +539,8 @@ def add_query( bindings: Optional[Any] = None, abridge_sql_log: bool = False, ) -> Tuple[Connection, Any]: + pre = perf_counter() + if bindings: # The snowflake connector is stricter than, e.g., psycopg2 - # which allows any iterable thing to be passed as a binding. @@ -556,6 +566,15 @@ def add_query( if cursor is None: self._raise_cursor_not_found_error(sql) + fire_event( + SQLQueryStatus( + status=str(self.get_response(cursor)), + elapsed=perf_counter() - pre, + node_info=get_node_info(), + query_id=cursor.sfqid, + ) + ) + return connection, cursor def _stripped_queries(self, sql: str) -> List[str]: