Skip to content

Commit

Permalink
make record mode single var, add better types
Browse files Browse the repository at this point in the history
  • Loading branch information
emmyoop committed May 23, 2024
1 parent 622188c commit 8a1b6a4
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 23 deletions.
65 changes: 51 additions & 14 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ class Recorder:
_record_cls_by_name: Dict[str, Type] = {}
_record_name_by_params_name: Dict[str, str] = {}

def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None:
def __init__(
self, mode: RecorderMode, types: Optional[List], recording_path: Optional[str] = None
) -> None:
self.mode = mode
self.types = types
self._records_by_type: Dict[str, List[Record]] = {}
self._replay_diffs: List["Diff"] = []

Expand Down Expand Up @@ -148,21 +151,58 @@ def print_diffs(self) -> None:


def get_record_mode_from_env() -> Optional[RecorderMode]:
replay_val = os.environ.get("DBT_REPLAY")
if replay_val is not None and replay_val != "0" and replay_val.lower() != "false":
return RecorderMode.REPLAY
"""
Get the record mode from the environment variables.
record_val = os.environ.get("DBT_RECORD")
if record_val is not None and record_val != "0" and record_val.lower() != "false":
return RecorderMode.RECORD
If the mode is not set to 'RECORD' or 'REPLAY', return None.
Expected format: 'DBT_RECORDER_MODE=RECORD'
"""
record_mode = os.environ.get("DBT_RECORDER_MODE")

record_val = os.environ.get("DBT_RECORD_QUERIES")
if record_val is not None and record_val != "0" and record_val.lower() != "false":
return RecorderMode.RECORD_QUERIES
if record_mode is None:
return None

if record_mode.lower() == "record":
return RecorderMode.RECORD
elif record_mode.lower() == "replay":
return RecorderMode.REPLAY

# if you don't specify record/replay it's a noop
return None


def get_record_types_from_env() -> Optional[List]:
"""
Get the record subset from the environment variables.
If no types are provided, there will be no filtering.
Invalid types will be ignored.
Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord'
"""
record_types_str = os.environ.get("DBT_RECORDER_TYPES")

# if all is specified we don't want any type filtering
if record_types_str is None or record_types_str.lower == "all":
return None

record_types = record_types_str.split(",")

for type in record_types:
# Types not defined in common are not in the record_types list yet
# TODO: is there a better way to do this without hardcoding? We can't just
# wait for later because if it's QueryRecord (not defined in common) we don't
# want to remove it to ensure everything else is filtered out....
if type not in Recorder._record_cls_by_name and type != "QueryRecord":
print(f"Invalid record type: {type}") # TODO: remove after testing
record_types.remove(type)

# if everything is invalid we don't want any type filtering
if len(record_types) == 0:
return None

return record_types


def record_function(record_type, method=False, tuple_result=False):
def record_function_inner(func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
Expand All @@ -181,10 +221,7 @@ def record_replay_wrapper(*args, **kwargs):
if recorder is None:
return func_to_record(*args, **kwargs)

if (
recorder.mode == RecorderMode.RECORD_QUERIES
and record_type.__name__ != "QueryRecord"
):
if recorder.types is not None and record_type.__name__ not in recorder.types:
return func_to_record(*args, **kwargs)

# For methods, peel off the 'self' argument before calling the
Expand Down
13 changes: 12 additions & 1 deletion docs/guides/record_replay.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,18 @@ Note also the `LoadFileRecord` class passed as a parameter to this decorator. Th

The final detail needed is to define the classes specified by `params_cls` and `result_cls`, which must be dataclasses with properties whose order and names correspond to the parameters passed to the recorded function. In this case those are the `LoadFileParams` and `LoadFileResult` classes, respectively.

With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them.
With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them.

## How to record/replay
If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions.

`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var.

example

```bash
DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run
```

## Final Thoughts

Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class TestRecord(Record):


def test_decorator_records():
prev = os.environ.get("DBT_RECORD", None)
prev = os.environ.get("DBT_RECORDER_MODE", None)
try:
os.environ["DBT_RECORD"] = "True"
os.environ["DBT_RECORDER_MODE"] = "Record"
recorder = Recorder(RecorderMode.RECORD)
set_invocation_context({})
get_invocation_context().recorder = recorder
Expand All @@ -47,15 +47,15 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str:

finally:
if prev is None:
os.environ.pop("DBT_RECORD", None)
os.environ.pop("DBT_RECORDER_MODE", None)
else:
os.environ["DBT_RECORD"] = prev
os.environ["DBT_RECORDER_MODE"] = prev


def test_decorator_replays():
prev = os.environ.get("DBT_RECORD", None)
prev = os.environ.get("DBT_RECORDER_MODE", None)
try:
os.environ["DBT_RECORD"] = "True"
os.environ["DBT_RECORDER_MODE"] = "Replay"
recorder = Recorder(RecorderMode.REPLAY)
set_invocation_context({})
get_invocation_context().recorder = recorder
Expand All @@ -76,6 +76,6 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str:

finally:
if prev is None:
os.environ.pop("DBT_RECORD", None)
os.environ.pop("DBT_RECORDER_MODE", None)
else:
os.environ["DBT_RECORD"] = prev
os.environ["DBT_RECORDER_MODE"] = prev

0 comments on commit 8a1b6a4

Please sign in to comment.