diff --git a/core/dbt/clients/jinja_static.py b/core/dbt/clients/jinja_static.py index 6082f03f80c..d8746a7607d 100644 --- a/core/dbt/clients/jinja_static.py +++ b/core/dbt/clients/jinja_static.py @@ -157,68 +157,37 @@ def statically_parse_adapter_dispatch(func_call, ctx, db_wrapper): return possible_macro_calls -def statically_parse_ref(input: str) -> RefArgs: +def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]: """ - Returns a RefArgs object corresponding to an input jinja expression. + Returns a RefArgs or List[str] object, corresponding to ref or source respectively, given an input jinja expression. input: str representing how input node is referenced in tested model sql * examples: - "ref('my_model_a')" - "ref('my_model_a', version=3)" - "ref('package', 'my_model_a', version=3)" - - If input is not a well-formed jinja ref expression, a ParsingError is raised. - """ - try: - statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") - except ExtractionError: - raise ParsingError(f"Invalid jinja expression: {input}") - - if not statically_parsed.get("refs"): - raise ParsingError(f"Invalid ref expression: {input}") - - ref = list(statically_parsed["refs"])[0] - return RefArgs(package=ref.get("package"), name=ref.get("name"), version=ref.get("version")) - - -def statically_parse_source(input: str) -> List[str]: - """ - Returns a RefArgs object corresponding to an input jinja expression. - - input: str representing how input node is referenced in tested model sql - * examples: - "source('my_source_schema', 'my_source_name')" - If input is not a well-formed jinja source expression, ParsingError is raised. + If input is not a well-formed jinja ref or source expression, a ParsingError is raised. """ - try: - statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}") - except ExtractionError: - raise ParsingError(f"Invalid jinja expression: {input}") - - if not statically_parsed.get("sources"): - raise ParsingError(f"Invalid source expression: {input}") - - source = list(statically_parsed["sources"])[0] - source_name, source_table_name = source - return [source_name, source_table_name] - - -def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]: ref_or_source: Union[RefArgs, List[str]] - valid_ref = True - valid_source = True try: - ref_or_source = statically_parse_ref(expression) - except ParsingError: - valid_ref = False - try: - ref_or_source = statically_parse_source(expression) - except ParsingError: - valid_source = False - - if not valid_ref and not valid_source: - raise ParsingError(f"Invalid ref or source syntax: {expression}.") + statically_parsed = py_extract_from_source(f"{{{{ {expression} }}}}") + except ExtractionError: + raise ParsingError(f"Invalid jinja expression: {expression}") + + if statically_parsed.get("refs"): + raw_ref = list(statically_parsed["refs"])[0] + ref_or_source = RefArgs( + package=raw_ref.get("package"), + name=raw_ref.get("name"), + version=raw_ref.get("version"), + ) + elif statically_parsed.get("sources"): + source_name, source_table_name = list(statically_parsed["sources"])[0] + ref_or_source = [source_name, source_table_name] + else: + raise ParsingError(f"Invalid ref or source expression: {expression}") return ref_or_source diff --git a/tests/unit/clients/test_jinja_static.py b/tests/unit/clients/test_jinja_static.py index c714624300c..171976a6b50 100644 --- a/tests/unit/clients/test_jinja_static.py +++ b/tests/unit/clients/test_jinja_static.py @@ -3,9 +3,7 @@ from dbt.artifacts.resources import RefArgs from dbt.clients.jinja_static import ( statically_extract_macro_calls, - statically_parse_ref, statically_parse_ref_or_source, - statically_parse_source, ) from dbt.context.base import generate_base_context from dbt.exceptions import ParsingError @@ -61,37 +59,6 @@ def test_extract_macro_calls(macro_string, expected_possible_macro_calls): assert possible_macro_calls == expected_possible_macro_calls -class TestStaticallyParseRef: - @pytest.mark.parametrize("invalid_expression", ["invalid", "source('schema', 'table')"]) - def test_invalid_expression(self, invalid_expression): - with pytest.raises(ParsingError): - statically_parse_ref(invalid_expression) - - @pytest.mark.parametrize( - "ref_expression,expected_ref_args", - [ - ("ref('model')", RefArgs(name="model")), - ("ref('package','model')", RefArgs(name="model", package="package")), - ("ref('model',v=3)", RefArgs(name="model", version=3)), - ("ref('package','model',v=3)", RefArgs(name="model", package="package", version=3)), - ], - ) - def test_valid_ref_expression(self, ref_expression, expected_ref_args): - ref_args = statically_parse_ref(ref_expression) - assert ref_args == expected_ref_args - - -class TestStaticallyParseSource: - @pytest.mark.parametrize("invalid_expression", ["invalid", "ref('package', 'model')"]) - def test_invalid_expression(self, invalid_expression): - with pytest.raises(ParsingError): - statically_parse_source(invalid_expression) - - def test_valid_ref_expression(self): - parsed_source = statically_parse_source("source('schema', 'table')") - assert parsed_source == ["schema", "table"] - - class TestStaticallyParseRefOrSource: def test_invalid_expression(self): with pytest.raises(ParsingError):