Skip to content

Commit

Permalink
Add record/replay support.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Jun 25, 2024
1 parent d0a259f commit 7d879f0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 13 deletions.
40 changes: 27 additions & 13 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
DbtConfigError,
)
from dbt_common.exceptions import DbtDatabaseError
from dbt_common.record import get_record_mode_from_env, RecorderMode
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag
from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle

if TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -370,20 +372,32 @@ def connect():

if creds.query_tag:
session_parameters.update({"QUERY_TAG": creds.query_tag})
handle = None

# In replay mode, we won't connect to a real database at all, while
# in record and diff modes we do, but insert an intermediate handle
# object which monitors native connection activity.
rec_mode = get_record_mode_from_env()
handle = None
if rec_mode != RecorderMode.REPLAY:
handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)

handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)
if rec_mode is not None:
# If using the record/replay mechanism, regardless of mode, we
# use a wrapper.
handle = SnowflakeRecordReplayHandle(handle, connection)

return handle

Expand Down
59 changes: 59 additions & 0 deletions dbt/adapters/snowflake/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import dataclasses
from typing import Optional

from dbt.adapters.record import RecordReplayHandle, RecordReplayCursor
from dbt_common.record import record_function, Record, Recorder


class SnowflakeRecordReplayHandle(RecordReplayHandle):
def cursor(self):
cursor = None if self.native_handle is None else self.native_handle.cursor()
return SnowflakeRecordReplayCursor(cursor, self.connection)


@dataclasses.dataclass
class CursorGetSqlStateParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSqlStateResult:
msg: Optional[str]


class CursorGetSqlStateRecord(Record):
params_cls = CursorGetSqlStateParams
result_cls = CursorGetSqlStateResult


Recorder.register_record_type(CursorGetSqlStateRecord)


@dataclasses.dataclass
class CursorGetSqfidParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSqfidResult:
msg: Optional[str]


class CursorGetSqfidRecord(Record):
params_cls = CursorGetSqfidParams
result_cls = CursorGetSqfidResult


Recorder.register_record_type(CursorGetSqfidRecord)


class SnowflakeRecordReplayCursor(RecordReplayCursor):
@property
@record_function(CursorGetSqlStateRecord, method=True, id_field_name="connection_name")
def sqlstate(self):
return self.native_cursor.sqlstate

@property
@record_function(CursorGetSqfidRecord, method=True, id_field_name="connection_name")
def sfqid(self):
return self.native_cursor.sfqid

0 comments on commit 7d879f0

Please sign in to comment.