diff --git a/.changes/unreleased/Fixes-20240129-110034.yaml b/.changes/unreleased/Fixes-20240129-110034.yaml new file mode 100644 index 00000000..705768a2 --- /dev/null +++ b/.changes/unreleased/Fixes-20240129-110034.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Referencing entities in where filter causes an error. +time: 2024-01-29T11:00:34.690723-08:00 +custom: + Author: plypaul + Issue: "256" diff --git a/dbt_semantic_interfaces/implementations/filters/where_filter.py b/dbt_semantic_interfaces/implementations/filters/where_filter.py index 0be414e6..96a74581 100644 --- a/dbt_semantic_interfaces/implementations/filters/where_filter.py +++ b/dbt_semantic_interfaces/implementations/filters/where_filter.py @@ -1,5 +1,7 @@ from __future__ import annotations +import textwrap +import traceback from typing import Callable, Generator, List, Tuple from typing_extensions import Self @@ -16,7 +18,6 @@ from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import ( WhereFilterParser, ) -from dbt_semantic_interfaces.pretty_print import pformat_big_objects class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel): @@ -126,10 +127,14 @@ def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParamete invalid_filter_expressions.append((where_filter.where_sql_template, e)) if invalid_filter_expressions: - raise ParseWhereFilterException( - f"Encountered one or more errors when parsing the set of filter expressions " - f"{pformat_big_objects(self.where_filters)}! Invalid expressions: \n " - f"{pformat_big_objects(invalid_filter_expressions)}" - ) + lines = ["Encountered error(s) while parsing:\n"] + for where_sql_template, exception in invalid_filter_expressions: + lines.append("Filter:") + lines.append(textwrap.indent(where_sql_template, prefix=" ")) + lines.append("Error Message:") + lines.append(textwrap.indent(str(exception), prefix=" ")) + lines.append("Traceback:") + lines.append(textwrap.indent("".join(traceback.format_tb(exception.__traceback__)), prefix=" ")) + raise ParseWhereFilterException("\n".join(lines)) return filter_parameter_sets diff --git a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py index af8fc085..128b0553 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py +++ b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py @@ -26,7 +26,7 @@ class ParameterSetFactory: @staticmethod def _exception_message_for_incorrect_format(element_name: str) -> str: return ( - f"Name is in an incorrect format: '{element_name}'. It should be of the form: " + f"Name is in an incorrect format: {repr(element_name)}. It should be of the form: " f"__" ) @@ -87,14 +87,17 @@ def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> Di @staticmethod def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet: """Gets called by Jinja when rendering {{ Entity(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(entity_name) - if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None: + structured_dundered_name = DunderedNameFormatter.parse_name(entity_name) + if structured_dundered_name.time_granularity is not None: raise ParseWhereFilterException( - f"Entity name is in an incorrect format: '{entity_name}'. " - f"It should not contain any dunders (double underscores, or __)." + f"Name is in an incorrect format: {repr(entity_name)}. " f"It should not contain a time grain suffix." ) + additional_entity_path_elements = tuple( + EntityReference(element_name=entity_path_item) for entity_path_item in entity_path + ) + return EntityCallParameterSet( - entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path), - entity_reference=EntityReference(element_name=entity_name), + entity_path=additional_entity_path_elements + structured_dundered_name.entity_links, + entity_reference=EntityReference(element_name=structured_dundered_name.element_name), ) diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 03ba267e..c304d500 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -13,9 +13,6 @@ PydanticWhereFilter, PydanticWhereFilterIntersection, ) -from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( - ParameterSetFactory, -) from dbt_semantic_interfaces.references import ( DimensionReference, EntityReference, @@ -145,9 +142,9 @@ def test_metric_time_in_dimension_call_error() -> None: # noqa: D def test_invalid_entity_name_error() -> None: """Test to ensure we throw an error if an entity name is invalid.""" - bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order' )}}") + bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}") - with pytest.raises(ParseWhereFilterException, match="Entity name is in an incorrect format"): + with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"): bad_entity_filter.call_parameter_sets @@ -198,7 +195,7 @@ def test_where_filter_intersection_error_collection() -> None: where_sql_template="{{ TimeDimension('order_id__order_time__month', 'week') }} > '2020-01-01'" ) valid_dimension = PydanticWhereFilter(where_sql_template=" {Dimension('customer__has_delivery_address')} ") - entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order') }}") + entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order__day') }}") filter_intersection = PydanticWhereFilterIntersection( where_filters=[metric_time_in_dimension_error, valid_dimension, entity_format_error] ) @@ -208,12 +205,13 @@ def test_where_filter_intersection_error_collection() -> None: error_string = str(exc_info.value) # These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find. - assert ParameterSetFactory._exception_message_for_incorrect_format("order_id__order_time__month") in error_string - assert "Entity name is in an incorrect format: 'order_id__is_food_order'" in error_string - # We cannot simply scan for name because the error message contains the filter list, so we assert against the error assert ( - ParameterSetFactory._exception_message_for_incorrect_format("customer__has_delivery_address") - not in error_string + "Name is in an incorrect format: 'order_id__is_food_order__day'. It should not contain a time grain " + "suffix." in error_string + ) + assert ( + "Name is in an incorrect format: 'order_id__order_time__month'. It should be of the form: " + "__" in error_string ) diff --git a/tests/parsing/test_where_filter_parsing.py b/tests/parsing/test_where_filter_parsing.py index cf741f32..553d9fa1 100644 --- a/tests/parsing/test_where_filter_parsing.py +++ b/tests/parsing/test_where_filter_parsing.py @@ -13,6 +13,7 @@ import pytest +from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet from dbt_semantic_interfaces.implementations.base import HashableBaseModel from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, @@ -21,6 +22,7 @@ from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import ( WhereFilterParser, ) +from dbt_semantic_interfaces.references import EntityReference from dbt_semantic_interfaces.type_enums.date_part import DatePart __BOOLEAN_EXPRESSION__ = "1 > 0" @@ -162,3 +164,26 @@ def test_dimension_date_part() -> None: # noqa param_sets = WhereFilterParser.parse_call_parameter_sets(where) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR + + +def test_entity_without_primary_entity_prefix() -> None: # noqa + where = "{{ Entity('non_primary_entity') }} = '1'" + param_sets = WhereFilterParser.parse_call_parameter_sets(where) + assert len(param_sets.entity_call_parameter_sets) == 1 + assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( + entity_path=(), + entity_reference=EntityReference(element_name="non_primary_entity"), + ) + + +def test_entity() -> None: # noqa + where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'" + param_sets = WhereFilterParser.parse_call_parameter_sets(where) + assert len(param_sets.entity_call_parameter_sets) == 1 + assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( + entity_path=( + EntityReference(element_name="entity_0"), + EntityReference(element_name="entity_1"), + ), + entity_reference=EntityReference(element_name="entity_2"), + )