Skip to content

Commit

Permalink
add query recording options (#135)
Browse files Browse the repository at this point in the history
* add query recording options

* make record mode single var, add better types

* fix tests, tweak logic

* fix comment

* clean up and fix test

* changelog

* add test

* use test fixture
  • Loading branch information
emmyoop authored May 29, 2024
1 parent a28ff8a commit 315aeae
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 59 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240528-110518.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Allow dynamic selection of record types when recording.
time: 2024-05-28T11:05:18.290107-05:00
custom:
Author: emmyoop
Issue: "140"
51 changes: 44 additions & 7 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,18 @@ class Diff:
class RecorderMode(Enum):
RECORD = 1
REPLAY = 2
RECORD_QUERIES = 3


class Recorder:
_record_cls_by_name: Dict[str, Type] = {}
_record_name_by_params_name: Dict[str, str] = {}

def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None:
def __init__(
self, mode: RecorderMode, types: Optional[List], recording_path: Optional[str] = None
) -> None:
self.mode = mode
self.types = types
self._records_by_type: Dict[str, List[Record]] = {}
self._replay_diffs: List["Diff"] = []

Expand Down Expand Up @@ -118,13 +122,14 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]:
records_by_type: Dict[str, List[Record]] = {}

for record_type_name in loaded_dct:
# TODO: this breaks with QueryRecord on replay since it's
# not in common so isn't part of cls._record_cls_by_name yet
record_cls = cls._record_cls_by_name[record_type_name]
rec_list = []
for record_dct in loaded_dct[record_type_name]:
rec = record_cls.from_dict(record_dct)
rec_list.append(rec) # type: ignore
records_by_type[record_type_name] = rec_list

return records_by_type

def expect_record(self, params: Any) -> Any:
Expand All @@ -147,17 +152,46 @@ def print_diffs(self) -> None:


def get_record_mode_from_env() -> Optional[RecorderMode]:
replay_val = os.environ.get("DBT_REPLAY")
if replay_val is not None and replay_val != "0" and replay_val.lower() != "false":
return RecorderMode.REPLAY
"""
Get the record mode from the environment variables.
If the mode is not set to 'RECORD' or 'REPLAY', return None.
Expected format: 'DBT_RECORDER_MODE=RECORD'
"""
record_mode = os.environ.get("DBT_RECORDER_MODE")

record_val = os.environ.get("DBT_RECORD")
if record_val is not None and record_val != "0" and record_val.lower() != "false":
if record_mode is None:
return None

if record_mode.lower() == "record":
return RecorderMode.RECORD
# replaying requires a file path, otherwise treat as noop
elif (
record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_REPLAY_PATH") is not None
):
return RecorderMode.REPLAY

# if you don't specify record/replay it's a noop
return None


def get_record_types_from_env() -> Optional[List]:
"""
Get the record subset from the environment variables.
If no types are provided, there will be no filtering.
Invalid types will be ignored.
Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord'
"""
record_types_str = os.environ.get("DBT_RECORDER_TYPES")

# if all is specified we don't want any type filtering
if record_types_str is None or record_types_str.lower == "all":
return None

return record_types_str.split(",")


def record_function(record_type, method=False, tuple_result=False):
def record_function_inner(func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
Expand All @@ -176,6 +210,9 @@ def record_replay_wrapper(*args, **kwargs):
if recorder is None:
return func_to_record(*args, **kwargs)

if recorder.types is not None and record_type.__name__ not in recorder.types:
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
# params constructor.
param_args = args[1:] if method else args
Expand Down
17 changes: 16 additions & 1 deletion docs/guides/record_replay.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@ Note also the `LoadFileRecord` class passed as a parameter to this decorator. Th

The final detail needed is to define the classes specified by `params_cls` and `result_cls`, which must be dataclasses with properties whose order and names correspond to the parameters passed to the recorded function. In this case those are the `LoadFileParams` and `LoadFileResult` classes, respectively.

With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them.
With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them.

## How to record/replay
If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions.

`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var.


```bash
DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run
```

replay need the file to replay
```bash
DBT_RECORDER_MODE=replay DBT_RECORDER_REPLAY_PATH=recording.json dbt run
```

## Final Thoughts

Expand Down
159 changes: 108 additions & 51 deletions tests/unit/test_record.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import os
import pytest
from typing import Optional

from dbt_common.context import set_invocation_context, get_invocation_context
Expand All @@ -24,58 +25,114 @@ class TestRecord(Record):
result_cls = TestRecordResult


def test_decorator_records():
prev = os.environ.get("DBT_RECORD", None)
try:
os.environ["DBT_RECORD"] = "True"
recorder = Recorder(RecorderMode.RECORD)
set_invocation_context({})
get_invocation_context().recorder = recorder

@record_function(TestRecord)
def test_func(a: int, b: str, c: Optional[str] = None) -> str:
return str(a) + b + (c if c else "")

test_func(123, "abc")

expected_record = TestRecord(
params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc")
)

assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params
assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result

finally:
if prev is None:
os.environ.pop("DBT_RECORD", None)
else:
os.environ["DBT_RECORD"] = prev


def test_decorator_replays():
prev = os.environ.get("DBT_RECORD", None)
try:
os.environ["DBT_RECORD"] = "True"
recorder = Recorder(RecorderMode.REPLAY)
set_invocation_context({})
get_invocation_context().recorder = recorder

expected_record = TestRecord(
params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc")
)

recorder._records_by_type["TestRecord"] = [expected_record]
@dataclasses.dataclass
class NotTestRecordParams:
a: int
b: str
c: Optional[str] = None

@record_function(TestRecord)
def test_func(a: int, b: str, c: Optional[str] = None) -> str:
raise Exception("This should not actually be called")

res = test_func(123, "abc")
@dataclasses.dataclass
class NotTestRecordResult:
return_val: str

assert res == "123abc"

finally:
if prev is None:
os.environ.pop("DBT_RECORD", None)
else:
os.environ["DBT_RECORD"] = prev
@Recorder.register_record_type
class NotTestRecord(Record):
params_cls = NotTestRecordParams
result_cls = NotTestRecordResult


@pytest.fixture(scope="function", autouse=True)
def setup():
# capture the previous state of the environment variables
prev_mode = os.environ.get("DBT_RECORDER_MODE", None)
prev_type = os.environ.get("DBT_RECORDER_TYPES", None)
prev_fp = os.environ.get("DBT_RECORDER_REPLAY_PATH", None)
# clear the environment variables
os.environ.pop("DBT_RECORDER_MODE", None)
os.environ.pop("DBT_RECORDER_TYPES", None)
os.environ.pop("DBT_RECORDER_REPLAY_PATH", None)
yield
# reset the environment variables to their previous state
if prev_mode is None:
os.environ.pop("DBT_RECORDER_MODE", None)
else:
os.environ["DBT_RECORDER_MODE"] = prev_mode
if prev_type is None:
os.environ.pop("DBT_RECORDER_TYPES", None)
else:
os.environ["DBT_RECORDER_TYPES"] = prev_type
if prev_fp is None:
os.environ.pop("DBT_RECORDER_REPLAY_PATH", None)
else:
os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_fp


def test_decorator_records(setup):
os.environ["DBT_RECORDER_MODE"] = "Record"
recorder = Recorder(RecorderMode.RECORD, None)
set_invocation_context({})
get_invocation_context().recorder = recorder

@record_function(TestRecord)
def test_func(a: int, b: str, c: Optional[str] = None) -> str:
return str(a) + b + (c if c else "")

test_func(123, "abc")

expected_record = TestRecord(
params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc")
)

assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params
assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result


def test_record_types(setup):
os.environ["DBT_RECORDER_MODE"] = "Record"
os.environ["DBT_RECORDER_TYPES"] = "TestRecord"
recorder = Recorder(RecorderMode.RECORD, ["TestRecord"])
set_invocation_context({})
get_invocation_context().recorder = recorder

@record_function(TestRecord)
def test_func(a: int, b: str, c: Optional[str] = None) -> str:
return str(a) + b + (c if c else "")

@record_function(NotTestRecord)
def not_test_func(a: int, b: str, c: Optional[str] = None) -> str:
return str(a) + b + (c if c else "")

test_func(123, "abc")
not_test_func(456, "def")

expected_record = TestRecord(
params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc")
)

assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params
assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result
assert NotTestRecord not in recorder._records_by_type


def test_decorator_replays(setup):
os.environ["DBT_RECORDER_MODE"] = "Replay"
os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json"
recorder = Recorder(RecorderMode.REPLAY, None)
set_invocation_context({})
get_invocation_context().recorder = recorder

expected_record = TestRecord(
params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc")
)

recorder._records_by_type["TestRecord"] = [expected_record]

@record_function(TestRecord)
def test_func(a: int, b: str, c: Optional[str] = None) -> str:
raise Exception("This should not actually be called")

res = test_func(123, "abc")

assert res == "123abc"

0 comments on commit 315aeae

Please sign in to comment.