Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add record/replay support #1106

Merged
merged 8 commits into from
Jul 16, 2024
Next Next commit
Add record/replay support.
peterallenwebb committed Jun 25, 2024
commit 7d879f0de92e7e2eccdb541369098e6783c2a294
40 changes: 27 additions & 13 deletions dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
@@ -36,13 +36,15 @@
DbtConfigError,
)
from dbt_common.exceptions import DbtDatabaseError
from dbt_common.record import get_record_mode_from_env, RecorderMode
from dbt.adapters.exceptions.connection import FailedToConnectError
from dbt.adapters.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.adapters.sql import SQLConnectionManager
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.events.functions import warn_or_error
from dbt.adapters.events.types import AdapterEventWarning, AdapterEventError
from dbt_common.ui import line_wrap_message, warning_tag
from dbt.adapters.snowflake.record import SnowflakeRecordReplayHandle

if TYPE_CHECKING:
import agate
@@ -370,20 +372,32 @@ def connect():

if creds.query_tag:
session_parameters.update({"QUERY_TAG": creds.query_tag})
handle = None

# In replay mode, we won't connect to a real database at all, while
# in record and diff modes we do, but insert an intermediate handle
# object which monitors native connection activity.
rec_mode = get_record_mode_from_env()
handle = None
if rec_mode != RecorderMode.REPLAY:
handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)

handle = snowflake.connector.connect(
account=creds.account,
database=creds.database,
schema=creds.schema,
warehouse=creds.warehouse,
role=creds.role,
autocommit=True,
client_session_keep_alive=creds.client_session_keep_alive,
application="dbt",
insecure_mode=creds.insecure_mode,
session_parameters=session_parameters,
**creds.auth_args(),
)
if rec_mode is not None:
# If using the record/replay mechanism, regardless of mode, we
# use a wrapper.
handle = SnowflakeRecordReplayHandle(handle, connection)

return handle

59 changes: 59 additions & 0 deletions dbt/adapters/snowflake/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import dataclasses
from typing import Optional

from dbt.adapters.record import RecordReplayHandle, RecordReplayCursor
from dbt_common.record import record_function, Record, Recorder


class SnowflakeRecordReplayHandle(RecordReplayHandle):
def cursor(self):
cursor = None if self.native_handle is None else self.native_handle.cursor()
return SnowflakeRecordReplayCursor(cursor, self.connection)


@dataclasses.dataclass
class CursorGetSqlStateParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSqlStateResult:
msg: Optional[str]


class CursorGetSqlStateRecord(Record):
params_cls = CursorGetSqlStateParams
result_cls = CursorGetSqlStateResult


Recorder.register_record_type(CursorGetSqlStateRecord)


@dataclasses.dataclass
class CursorGetSqfidParams:
connection_name: str


@dataclasses.dataclass
class CursorGetSqfidResult:
msg: Optional[str]


class CursorGetSqfidRecord(Record):
params_cls = CursorGetSqfidParams
result_cls = CursorGetSqfidResult


Recorder.register_record_type(CursorGetSqfidRecord)


class SnowflakeRecordReplayCursor(RecordReplayCursor):
@property
@record_function(CursorGetSqlStateRecord, method=True, id_field_name="connection_name")
def sqlstate(self):
return self.native_cursor.sqlstate

@property
@record_function(CursorGetSqfidRecord, method=True, id_field_name="connection_name")
def sfqid(self):
return self.native_cursor.sfqid