Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add record/replay support #1106

Merged
merged 8 commits into from
Jul 16, 2024
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240716-174655.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add record/replay support.
peterallenwebb marked this conversation as resolved.
Show resolved Hide resolved
time: 2024-07-16T17:46:55.11204-04:00
custom:
Author: peterallenwebb
Issue: "1106"
40 changes: 27 additions & 13 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
DbtConfigError,
)
from dbt_common.exceptions import DbtDatabaseError
from dbt_common.record import get_record_mode_from_env, RecorderMode
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag
from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle

if TYPE_CHECKING:
import agate
Expand Down Expand Up @@ -374,20 +376,32 @@ def connect():

if creds.query_tag:
session_parameters.update({"QUERY_TAG": creds.query_tag})
handle = None

# In replay mode, we won't connect to a real database at all, while
# in record and diff modes we do, but insert an intermediate handle
# object which monitors native connection activity.
rec_mode = get_record_mode_from_env()
handle = None
if rec_mode != RecorderMode.REPLAY:
handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)

handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)
if rec_mode is not None:
# If using the record/replay mechanism, regardless of mode, we
# use a wrapper.
handle = SnowflakeRecordReplayHandle(handle, connection)

return handle

Expand Down
2 changes: 2 additions & 0 deletions dbt/adapters/snowflake/record/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from dbt.adapters.snowflake.record.cursor.cursor import SnowflakeRecordReplayCursor
from dbt.adapters.snowflake.record.handle import SnowflakeRecordReplayHandle
21 changes: 21 additions & 0 deletions dbt/adapters/snowflake/record/cursor/cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dbt_common.record import record_function

from dbt.adapters.record import RecordReplayCursor
from dbt.adapters.snowflake.record.cursor.sfqid import CursorGetSfqidRecord
from dbt.adapters.snowflake.record.cursor.sqlstate import CursorGetSqlStateRecord


class SnowflakeRecordReplayCursor(RecordReplayCursor):
"""A custom extension of RecordReplayCursor that adds the sqlstate
and sfqid properties which are specific to snowflake-connector."""

@property
@property
@record_function(CursorGetSqlStateRecord, method=True, id_field_name="connection_name")
def sqlstate(self):
return self.native_cursor.sqlstate

@property
@record_function(CursorGetSfqidRecord, method=True, id_field_name="connection_name")
def sfqid(self):
return self.native_cursor.sfqid
21 changes: 21 additions & 0 deletions dbt/adapters/snowflake/record/cursor/sfqid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import dataclasses
from typing import Optional

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorGetSfqidParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSfqidResult:
msg: Optional[str]


@Recorder.register_record_type
class CursorGetSfqidRecord(Record):
params_cls = CursorGetSfqidParams
result_cls = CursorGetSfqidResult
group = "Database"
21 changes: 21 additions & 0 deletions dbt/adapters/snowflake/record/cursor/sqlstate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import dataclasses
from typing import Optional

from dbt_common.record import Record, Recorder


@dataclasses.dataclass
class CursorGetSqlStateParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSqlStateResult:
msg: Optional[str]


@Recorder.register_record_type
class CursorGetSqlStateRecord(Record):
params_cls = CursorGetSqlStateParams
result_cls = CursorGetSqlStateResult
group = "Database"
12 changes: 12 additions & 0 deletions dbt/adapters/snowflake/record/handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dbt.adapters.record import RecordReplayHandle

from dbt.adapters.snowflake.record.cursor.cursor import SnowflakeRecordReplayCursor


class SnowflakeRecordReplayHandle(RecordReplayHandle):
"""A custom extension of RecordReplayHandle that returns a
snowflake-connector-specific SnowflakeRecordReplayCursor object."""

def cursor(self):
cursor = None if self.native_handle is None else self.native_handle.cursor()
return SnowflakeRecordReplayCursor(cursor, self.connection)
Loading