From 622188c86b246c4e15ff191409801793638b4e33 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Fri, 17 May 2024 14:17:19 -0500 Subject: [PATCH 1/8] add query recording options --- dbt_common/record.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dbt_common/record.py b/dbt_common/record.py index b2b5ba48..930dc2e0 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -59,6 +59,7 @@ class Diff: class RecorderMode(Enum): RECORD = 1 REPLAY = 2 + RECORD_QUERIES = 3 class Recorder: @@ -155,6 +156,10 @@ def get_record_mode_from_env() -> Optional[RecorderMode]: if record_val is not None and record_val != "0" and record_val.lower() != "false": return RecorderMode.RECORD + record_val = os.environ.get("DBT_RECORD_QUERIES") + if record_val is not None and record_val != "0" and record_val.lower() != "false": + return RecorderMode.RECORD_QUERIES + return None @@ -176,6 +181,12 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) + if ( + recorder.mode == RecorderMode.RECORD_QUERIES + and record_type.__name__ != "QueryRecord" + ): + return func_to_record(*args, **kwargs) + # For methods, peel off the 'self' argument before calling the # params constructor. param_args = args[1:] if method else args From 8a1b6a45f131131b05352d3a56527f122a008972 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Thu, 23 May 2024 11:20:50 -0500 Subject: [PATCH 2/8] make record mode single var, add better types --- dbt_common/record.py | 65 ++++++++++++++++++++++++++++-------- docs/guides/record_replay.md | 13 +++++++- tests/unit/test_record.py | 16 ++++----- 3 files changed, 71 insertions(+), 23 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 930dc2e0..5e225c0d 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -66,8 +66,11 @@ class Recorder: _record_cls_by_name: Dict[str, Type] = {} _record_name_by_params_name: Dict[str, str] = {} - def __init__(self, mode: RecorderMode, recording_path: Optional[str] = None) -> None: + def __init__( + self, mode: RecorderMode, types: Optional[List], recording_path: Optional[str] = None + ) -> None: self.mode = mode + self.types = types self._records_by_type: Dict[str, List[Record]] = {} self._replay_diffs: List["Diff"] = [] @@ -148,21 +151,58 @@ def print_diffs(self) -> None: def get_record_mode_from_env() -> Optional[RecorderMode]: - replay_val = os.environ.get("DBT_REPLAY") - if replay_val is not None and replay_val != "0" and replay_val.lower() != "false": - return RecorderMode.REPLAY + """ + Get the record mode from the environment variables. - record_val = os.environ.get("DBT_RECORD") - if record_val is not None and record_val != "0" and record_val.lower() != "false": - return RecorderMode.RECORD + If the mode is not set to 'RECORD' or 'REPLAY', return None. + Expected format: 'DBT_RECORDER_MODE=RECORD' + """ + record_mode = os.environ.get("DBT_RECORDER_MODE") - record_val = os.environ.get("DBT_RECORD_QUERIES") - if record_val is not None and record_val != "0" and record_val.lower() != "false": - return RecorderMode.RECORD_QUERIES + if record_mode is None: + return None + if record_mode.lower() == "record": + return RecorderMode.RECORD + elif record_mode.lower() == "replay": + return RecorderMode.REPLAY + + # if you don't specify record/replay it's a noop return None +def get_record_types_from_env() -> Optional[List]: + """ + Get the record subset from the environment variables. + + If no types are provided, there will be no filtering. + Invalid types will be ignored. + Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord' + """ + record_types_str = os.environ.get("DBT_RECORDER_TYPES") + + # if all is specified we don't want any type filtering + if record_types_str is None or record_types_str.lower == "all": + return None + + record_types = record_types_str.split(",") + + for type in record_types: + # Types not defined in common are not in the record_types list yet + # TODO: is there a better way to do this without hardcoding? We can't just + # wait for later because if it's QueryRecord (not defined in common) we don't + # want to remove it to ensure everything else is filtered out.... + if type not in Recorder._record_cls_by_name and type != "QueryRecord": + print(f"Invalid record type: {type}") # TODO: remove after testing + record_types.remove(type) + + # if everything is invalid we don't want any type filtering + if len(record_types) == 0: + return None + + return record_types + + def record_function(record_type, method=False, tuple_result=False): def record_function_inner(func_to_record): # To avoid runtime overhead and other unpleasantness, we only apply the @@ -181,10 +221,7 @@ def record_replay_wrapper(*args, **kwargs): if recorder is None: return func_to_record(*args, **kwargs) - if ( - recorder.mode == RecorderMode.RECORD_QUERIES - and record_type.__name__ != "QueryRecord" - ): + if recorder.types is not None and record_type.__name__ not in recorder.types: return func_to_record(*args, **kwargs) # For methods, peel off the 'self' argument before calling the diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index 9d9d87f2..c182da0a 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -28,7 +28,18 @@ Note also the `LoadFileRecord` class passed as a parameter to this decorator. Th The final detail needed is to define the classes specified by `params_cls` and `result_cls`, which must be dataclasses with properties whose order and names correspond to the parameters passed to the recorded function. In this case those are the `LoadFileParams` and `LoadFileResult` classes, respectively. -With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. +With these decorators applied and classes defined, dbt is able to record all file access during a run, and mock out the accesses during replay, isolating dbt from actually loading files. At least it would if dbt only used this function for all file access, which is only mostly true. We hope to continue improving the usefulness of this mechanism by adding more recorded functions and routing more operations through them. + +## How to record/replay +If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a no-op. Invalid values are ignored and do not throw exceptions. + +`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. + +example + +```bash +DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run +``` ## Final Thoughts diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index aa7af69b..7829762c 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -25,9 +25,9 @@ class TestRecord(Record): def test_decorator_records(): - prev = os.environ.get("DBT_RECORD", None) + prev = os.environ.get("DBT_RECORDER_MODE", None) try: - os.environ["DBT_RECORD"] = "True" + os.environ["DBT_RECORDER_MODE"] = "Record" recorder = Recorder(RecorderMode.RECORD) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -47,15 +47,15 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: finally: if prev is None: - os.environ.pop("DBT_RECORD", None) + os.environ.pop("DBT_RECORDER_MODE", None) else: - os.environ["DBT_RECORD"] = prev + os.environ["DBT_RECORDER_MODE"] = prev def test_decorator_replays(): - prev = os.environ.get("DBT_RECORD", None) + prev = os.environ.get("DBT_RECORDER_MODE", None) try: - os.environ["DBT_RECORD"] = "True" + os.environ["DBT_RECORDER_MODE"] = "Replay" recorder = Recorder(RecorderMode.REPLAY) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -76,6 +76,6 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: finally: if prev is None: - os.environ.pop("DBT_RECORD", None) + os.environ.pop("DBT_RECORDER_MODE", None) else: - os.environ["DBT_RECORD"] = prev + os.environ["DBT_RECORDER_MODE"] = prev From 7406809e592c61764760764920e01b5c2f8d4c62 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Thu, 23 May 2024 11:57:23 -0500 Subject: [PATCH 3/8] fix tests, tweak logic --- dbt_common/record.py | 9 +++++++-- docs/guides/record_replay.md | 6 +++++- tests/unit/test_record.py | 4 ++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 5e225c0d..187719ed 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -122,6 +122,9 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]: records_by_type: Dict[str, List[Record]] = {} for record_type_name in loaded_dct: + # TODO: this break with QueryRecord on replay since it's + # not in common so isn't part of cls._record_cls_by_name yet + record_cls = cls._record_cls_by_name[record_type_name] rec_list = [] for record_dct in loaded_dct[record_type_name]: @@ -164,7 +167,8 @@ def get_record_mode_from_env() -> Optional[RecorderMode]: if record_mode.lower() == "record": return RecorderMode.RECORD - elif record_mode.lower() == "replay": + # replaying requires a file path, otherwise treat as noop + elif record_mode.lower() == "replay" and os.environ["DBT_RECORDER_REPLAY_PATH"] is not None: return RecorderMode.REPLAY # if you don't specify record/replay it's a noop @@ -191,7 +195,8 @@ def get_record_types_from_env() -> Optional[List]: # Types not defined in common are not in the record_types list yet # TODO: is there a better way to do this without hardcoding? We can't just # wait for later because if it's QueryRecord (not defined in common) we don't - # want to remove it to ensure everything else is filtered out.... + # want to remove it to ensure everything else is filtered out.... This is also + # a problem with replaying QueryRecords generally if type not in Recorder._record_cls_by_name and type != "QueryRecord": print(f"Invalid record type: {type}") # TODO: remove after testing record_types.remove(type) diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index c182da0a..2103fd1b 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -35,12 +35,16 @@ If `DBT_RECORDER_MODE` is not `replay` or `record`, case insensitive, this is a `DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses. Any invalid types will be ignored. `all` is a valid type and behaves the same as not populating the env var. -example ```bash DBT_RECORDER_MODE=record DBT_RECODER_TYPES=QueryRecord,GetEnvRecord dbt run ``` +replay need the file to replay +```bash +DBT_RECORDER_MODE=replay DBT_RECORDER_REPLAY_PATH=recording.json dbt run +``` + ## Final Thoughts We are aware of the potential limitations of this mechanism, since it makes several strong assumptions, not least of which are: diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index 7829762c..bce824bc 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -28,7 +28,7 @@ def test_decorator_records(): prev = os.environ.get("DBT_RECORDER_MODE", None) try: os.environ["DBT_RECORDER_MODE"] = "Record" - recorder = Recorder(RecorderMode.RECORD) + recorder = Recorder(RecorderMode.RECORD, None) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -56,7 +56,7 @@ def test_decorator_replays(): prev = os.environ.get("DBT_RECORDER_MODE", None) try: os.environ["DBT_RECORDER_MODE"] = "Replay" - recorder = Recorder(RecorderMode.REPLAY) + recorder = Recorder(RecorderMode.REPLAY, None) set_invocation_context({}) get_invocation_context().recorder = recorder From 84cca75b89df118bd00e24722192c8e9c4688054 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Fri, 24 May 2024 09:34:55 -0500 Subject: [PATCH 4/8] fix comment --- dbt_common/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 187719ed..0f4750d8 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -122,7 +122,7 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]: records_by_type: Dict[str, List[Record]] = {} for record_type_name in loaded_dct: - # TODO: this break with QueryRecord on replay since it's + # TODO: this breaks with QueryRecord on replay since it's # not in common so isn't part of cls._record_cls_by_name yet record_cls = cls._record_cls_by_name[record_type_name] From 306576ec94af4bb7afdfeb8d4ca628091fe05a67 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 28 May 2024 09:30:14 -0500 Subject: [PATCH 5/8] clean up and fix test --- dbt_common/record.py | 13 ++++++------- tests/unit/test_record.py | 6 ++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index 0f4750d8..ab5b16da 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -124,14 +124,12 @@ def load(cls, file_name: str) -> Dict[str, List[Record]]: for record_type_name in loaded_dct: # TODO: this breaks with QueryRecord on replay since it's # not in common so isn't part of cls._record_cls_by_name yet - record_cls = cls._record_cls_by_name[record_type_name] rec_list = [] for record_dct in loaded_dct[record_type_name]: rec = record_cls.from_dict(record_dct) rec_list.append(rec) # type: ignore records_by_type[record_type_name] = rec_list - return records_by_type def expect_record(self, params: Any) -> Any: @@ -168,7 +166,9 @@ def get_record_mode_from_env() -> Optional[RecorderMode]: if record_mode.lower() == "record": return RecorderMode.RECORD # replaying requires a file path, otherwise treat as noop - elif record_mode.lower() == "replay" and os.environ["DBT_RECORDER_REPLAY_PATH"] is not None: + elif ( + record_mode.lower() == "replay" and os.environ.get("DBT_RECORDER_REPLAY_PATH") is not None + ): return RecorderMode.REPLAY # if you don't specify record/replay it's a noop @@ -193,10 +193,9 @@ def get_record_types_from_env() -> Optional[List]: for type in record_types: # Types not defined in common are not in the record_types list yet - # TODO: is there a better way to do this without hardcoding? We can't just - # wait for later because if it's QueryRecord (not defined in common) we don't - # want to remove it to ensure everything else is filtered out.... This is also - # a problem with replaying QueryRecords generally + # TODO: This is related to a problem with replay noted above. Will solve + # at a future date. Leaving it hardcoded for now to unblock. Will remove + # after resolving MNTL-308. if type not in Recorder._record_cls_by_name and type != "QueryRecord": print(f"Invalid record type: {type}") # TODO: remove after testing record_types.remove(type) diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index bce824bc..af9fc0fe 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -54,8 +54,10 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: def test_decorator_replays(): prev = os.environ.get("DBT_RECORDER_MODE", None) + prev_path = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) try: os.environ["DBT_RECORDER_MODE"] = "Replay" + os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json" recorder = Recorder(RecorderMode.REPLAY, None) set_invocation_context({}) get_invocation_context().recorder = recorder @@ -79,3 +81,7 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: os.environ.pop("DBT_RECORDER_MODE", None) else: os.environ["DBT_RECORDER_MODE"] = prev + if prev_path is None: + os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + else: + os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_path From fb4dff6b3930aa15f5086319556bdab609f5d068 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 28 May 2024 11:05:38 -0500 Subject: [PATCH 6/8] changelog --- .changes/unreleased/Under the Hood-20240528-110518.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Under the Hood-20240528-110518.yaml diff --git a/.changes/unreleased/Under the Hood-20240528-110518.yaml b/.changes/unreleased/Under the Hood-20240528-110518.yaml new file mode 100644 index 00000000..58b243c7 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240528-110518.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Allow dynamic selection of record types when recording. +time: 2024-05-28T11:05:18.290107-05:00 +custom: + Author: emmyoop + Issue: "140" From 08fa4ebdc8c521546b62c0d4249d07ee6995b522 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Tue, 28 May 2024 13:52:27 -0500 Subject: [PATCH 7/8] add test --- dbt_common/record.py | 17 +----------- tests/unit/test_record.py | 58 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/dbt_common/record.py b/dbt_common/record.py index ab5b16da..5edc9c7f 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -189,22 +189,7 @@ def get_record_types_from_env() -> Optional[List]: if record_types_str is None or record_types_str.lower == "all": return None - record_types = record_types_str.split(",") - - for type in record_types: - # Types not defined in common are not in the record_types list yet - # TODO: This is related to a problem with replay noted above. Will solve - # at a future date. Leaving it hardcoded for now to unblock. Will remove - # after resolving MNTL-308. - if type not in Recorder._record_cls_by_name and type != "QueryRecord": - print(f"Invalid record type: {type}") # TODO: remove after testing - record_types.remove(type) - - # if everything is invalid we don't want any type filtering - if len(record_types) == 0: - return None - - return record_types + return record_types_str.split(",") def record_function(record_type, method=False, tuple_result=False): diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index af9fc0fe..92a17b8e 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -24,6 +24,24 @@ class TestRecord(Record): result_cls = TestRecordResult +@dataclasses.dataclass +class NotTestRecordParams: + a: int + b: str + c: Optional[str] = None + + +@dataclasses.dataclass +class NotTestRecordResult: + return_val: str + + +@Recorder.register_record_type +class NotTestRecord(Record): + params_cls = NotTestRecordParams + result_cls = NotTestRecordResult + + def test_decorator_records(): prev = os.environ.get("DBT_RECORDER_MODE", None) try: @@ -52,6 +70,46 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: os.environ["DBT_RECORDER_MODE"] = prev +def test_record_types(): + prev_mode = os.environ.get("DBT_RECORDER_MODE", None) + prev_types = os.environ.get("DBT_RECORDER_TYPES", None) + try: + os.environ["DBT_RECORDER_MODE"] = "Record" + os.environ["DBT_RECORDER_TYPES"] = "TestRecord" + recorder = Recorder(RecorderMode.RECORD, ["TestRecord"]) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + @record_function(NotTestRecord) + def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + not_test_func(456, "def") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + assert NotTestRecord not in recorder._records_by_type + + finally: + if prev_mode is None: + os.environ.pop("DBT_RECORDER_MODE", None) + else: + os.environ["DBT_RECORDER_MODE"] = prev_mode + if prev_types is None: + os.environ.pop("DBT_RECORDER_TYPES", None) + else: + os.environ["DBT_RECORDER_TYPES"] = prev_types + + def test_decorator_replays(): prev = os.environ.get("DBT_RECORDER_MODE", None) prev_path = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) From eb0c3fffd2cd8657c8b48ac6b5a286d018673751 Mon Sep 17 00:00:00 2001 From: Emily Rockman Date: Wed, 29 May 2024 09:38:36 -0500 Subject: [PATCH 8/8] use test fixture --- tests/unit/test_record.py | 193 ++++++++++++++++++-------------------- 1 file changed, 93 insertions(+), 100 deletions(-) diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index 92a17b8e..8dca1d19 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -1,5 +1,6 @@ import dataclasses import os +import pytest from typing import Optional from dbt_common.context import set_invocation_context, get_invocation_context @@ -42,104 +43,96 @@ class NotTestRecord(Record): result_cls = NotTestRecordResult -def test_decorator_records(): - prev = os.environ.get("DBT_RECORDER_MODE", None) - try: - os.environ["DBT_RECORDER_MODE"] = "Record" - recorder = Recorder(RecorderMode.RECORD, None) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - return str(a) + b + (c if c else "") - - test_func(123, "abc") - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params - assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result - - finally: - if prev is None: - os.environ.pop("DBT_RECORDER_MODE", None) - else: - os.environ["DBT_RECORDER_MODE"] = prev - - -def test_record_types(): +@pytest.fixture(scope="function", autouse=True) +def setup(): + # capture the previous state of the environment variables prev_mode = os.environ.get("DBT_RECORDER_MODE", None) - prev_types = os.environ.get("DBT_RECORDER_TYPES", None) - try: - os.environ["DBT_RECORDER_MODE"] = "Record" - os.environ["DBT_RECORDER_TYPES"] = "TestRecord" - recorder = Recorder(RecorderMode.RECORD, ["TestRecord"]) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - return str(a) + b + (c if c else "") - - @record_function(NotTestRecord) - def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: - return str(a) + b + (c if c else "") - - test_func(123, "abc") - not_test_func(456, "def") - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params - assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result - assert NotTestRecord not in recorder._records_by_type - - finally: - if prev_mode is None: - os.environ.pop("DBT_RECORDER_MODE", None) - else: - os.environ["DBT_RECORDER_MODE"] = prev_mode - if prev_types is None: - os.environ.pop("DBT_RECORDER_TYPES", None) - else: - os.environ["DBT_RECORDER_TYPES"] = prev_types - - -def test_decorator_replays(): - prev = os.environ.get("DBT_RECORDER_MODE", None) - prev_path = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) - try: - os.environ["DBT_RECORDER_MODE"] = "Replay" - os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json" - recorder = Recorder(RecorderMode.REPLAY, None) - set_invocation_context({}) - get_invocation_context().recorder = recorder - - expected_record = TestRecord( - params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") - ) - - recorder._records_by_type["TestRecord"] = [expected_record] - - @record_function(TestRecord) - def test_func(a: int, b: str, c: Optional[str] = None) -> str: - raise Exception("This should not actually be called") - - res = test_func(123, "abc") - - assert res == "123abc" - - finally: - if prev is None: - os.environ.pop("DBT_RECORDER_MODE", None) - else: - os.environ["DBT_RECORDER_MODE"] = prev - if prev_path is None: - os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) - else: - os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_path + prev_type = os.environ.get("DBT_RECORDER_TYPES", None) + prev_fp = os.environ.get("DBT_RECORDER_REPLAY_PATH", None) + # clear the environment variables + os.environ.pop("DBT_RECORDER_MODE", None) + os.environ.pop("DBT_RECORDER_TYPES", None) + os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + yield + # reset the environment variables to their previous state + if prev_mode is None: + os.environ.pop("DBT_RECORDER_MODE", None) + else: + os.environ["DBT_RECORDER_MODE"] = prev_mode + if prev_type is None: + os.environ.pop("DBT_RECORDER_TYPES", None) + else: + os.environ["DBT_RECORDER_TYPES"] = prev_type + if prev_fp is None: + os.environ.pop("DBT_RECORDER_REPLAY_PATH", None) + else: + os.environ["DBT_RECORDER_REPLAY_PATH"] = prev_fp + + +def test_decorator_records(setup): + os.environ["DBT_RECORDER_MODE"] = "Record" + recorder = Recorder(RecorderMode.RECORD, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + + +def test_record_types(setup): + os.environ["DBT_RECORDER_MODE"] = "Record" + os.environ["DBT_RECORDER_TYPES"] = "TestRecord" + recorder = Recorder(RecorderMode.RECORD, ["TestRecord"]) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + @record_function(NotTestRecord) + def not_test_func(a: int, b: str, c: Optional[str] = None) -> str: + return str(a) + b + (c if c else "") + + test_func(123, "abc") + not_test_func(456, "def") + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + assert recorder._records_by_type["TestRecord"][-1].params == expected_record.params + assert recorder._records_by_type["TestRecord"][-1].result == expected_record.result + assert NotTestRecord not in recorder._records_by_type + + +def test_decorator_replays(setup): + os.environ["DBT_RECORDER_MODE"] = "Replay" + os.environ["DBT_RECORDER_REPLAY_PATH"] = "record.json" + recorder = Recorder(RecorderMode.REPLAY, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + expected_record = TestRecord( + params=TestRecordParams(123, "abc"), result=TestRecordResult("123abc") + ) + + recorder._records_by_type["TestRecord"] = [expected_record] + + @record_function(TestRecord) + def test_func(a: int, b: str, c: Optional[str] = None) -> str: + raise Exception("This should not actually be called") + + res = test_func(123, "abc") + + assert res == "123abc"