From 86d1568955f0c89fbfe63c898e43eca652ba9fa0 Mon Sep 17 00:00:00 2001 From: tlento Date: Mon, 9 Oct 2023 18:26:58 -0700 Subject: [PATCH] Add accessor for collected filter call parameter sets to WhereFilterIntersection The call_parameter_sets for each of the WhereFilters contained in a WhereFilterIntersection currently have to be accessed one at a time in a list. In addition to making it harder to run sensible validations against an implementation of the WhereFilterIntersection, this also complicates runtime processing for any implementation (e.g., MetricFlow) that needs to access these parameter sets as a collection. This adds a property to the protocol spec for getting a sequence of pairs between the filter expression sql and the call parameter sets it contains, which allows for downstream flexibility for managing the WhereFilter components of a WhereFilterIntersection. --- .../implementations/filters/where_filter.py | 28 +++++++- .../protocols/where_filter.py | 12 +++- .../where_filter/test_parse_calls.py | 70 +++++++++++++++++++ 3 files changed, 107 insertions(+), 3 deletions(-) diff --git a/dbt_semantic_interfaces/implementations/filters/where_filter.py b/dbt_semantic_interfaces/implementations/filters/where_filter.py index ae4fe1d9..0be414e6 100644 --- a/dbt_semantic_interfaces/implementations/filters/where_filter.py +++ b/dbt_semantic_interfaces/implementations/filters/where_filter.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import Callable, Generator, List +from typing import Callable, Generator, List, Tuple from typing_extensions import Self -from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets +from dbt_semantic_interfaces.call_parameter_sets import ( + FilterCallParameterSets, + ParseWhereFilterException, +) from dbt_semantic_interfaces.implementations.base import ( HashableBaseModel, PydanticCustomInputParser, @@ -13,6 +16,7 @@ from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import ( WhereFilterParser, ) +from dbt_semantic_interfaces.pretty_print import pformat_big_objects class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel): @@ -109,3 +113,23 @@ def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Se f"Expected input to be of type string, list, PydanticWhereFilter, PydanticWhereFilterIntersection, " f"or dict but got {type(input)} with value {input}" ) + + @property + def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParameterSets]]: + """Gets the call parameter sets for each filter expression.""" + filter_parameter_sets: List[Tuple[str, FilterCallParameterSets]] = [] + invalid_filter_expressions: List[Tuple[str, Exception]] = [] + for where_filter in self.where_filters: + try: + filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets)) + except Exception as e: + invalid_filter_expressions.append((where_filter.where_sql_template, e)) + + if invalid_filter_expressions: + raise ParseWhereFilterException( + f"Encountered one or more errors when parsing the set of filter expressions " + f"{pformat_big_objects(self.where_filters)}! Invalid expressions: \n " + f"{pformat_big_objects(invalid_filter_expressions)}" + ) + + return filter_parameter_sets diff --git a/dbt_semantic_interfaces/protocols/where_filter.py b/dbt_semantic_interfaces/protocols/where_filter.py index 927b87e9..7792e006 100644 --- a/dbt_semantic_interfaces/protocols/where_filter.py +++ b/dbt_semantic_interfaces/protocols/where_filter.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Protocol, Sequence +from typing import Protocol, Sequence, Tuple from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets @@ -40,3 +40,13 @@ class WhereFilterIntersection(Protocol): def where_filters(self) -> Sequence[WhereFilter]: """The collection of WhereFilters to be applied to the input data set.""" pass + + @property + @abstractmethod + def filter_expression_parameter_sets(self) -> Sequence[Tuple[str, FilterCallParameterSets]]: + """Mapping from distinct filter expressions to the call parameter sets associated with them. + + We use a tuple, rather than a Mapping, in case the call parameter sets may vary between + filter expression specifications. + """ + pass diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 4ad77c5b..f3e069b8 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -11,6 +11,10 @@ ) from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, + PydanticWhereFilterIntersection, +) +from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( + ParameterSetFactory, ) from dbt_semantic_interfaces.references import ( DimensionReference, @@ -145,3 +149,69 @@ def test_invalid_entity_name_error() -> None: with pytest.raises(ParseWhereFilterException, match="Entity name is in an incorrect format"): bad_entity_filter.call_parameter_sets + + +def test_where_filter_interesection_extract_call_parameter_sets() -> None: + """Tests the collection of call parameter sets for a set of where filters.""" + time_filter = PydanticWhereFilter( + where_sql_template=("""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'""") + ) + entity_filter = PydanticWhereFilter( + where_sql_template=( + """{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'""" + ) + ) + filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter]) + + parse_result = dict(filter_intersection.filter_expression_parameter_sets) + + assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets( + time_dimension_call_parameter_sets=( + TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference(element_name="metric_time"), + entity_path=(), + time_granularity=TimeGranularity.MONTH, + ), + ) + ) + assert parse_result.get(entity_filter.where_sql_template) == FilterCallParameterSets( + dimension_call_parameter_sets=(), + entity_call_parameter_sets=( + EntityCallParameterSet( + entity_path=(), + entity_reference=EntityReference("listing"), + ), + EntityCallParameterSet( + entity_path=(EntityReference("listing"),), + entity_reference=EntityReference("user"), + ), + ), + ) + + +def test_where_filter_intersection_error_collection() -> None: + """Tests the error behaviors when parsing where filters and collecting the call parameter sets for each. + + This should result in a single exception with all broken filters represented. + """ + metric_time_in_dimension_error = PydanticWhereFilter( + where_sql_template="{{ TimeDimension('order_id__order_time__month', 'week') }} > '2020-01-01'" + ) + valid_dimension = PydanticWhereFilter(where_sql_template=" {Dimension('customer__has_delivery_address')} ") + entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order') }}") + filter_intersection = PydanticWhereFilterIntersection( + where_filters=[metric_time_in_dimension_error, valid_dimension, entity_format_error] + ) + + with pytest.raises(ParseWhereFilterException) as exc_info: + filter_intersection.filter_expression_parameter_sets + + error_string = str(exc_info.value) + # These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find. + assert ParameterSetFactory._exception_message_for_incorrect_format("order_id__order_time__month") in error_string + assert "Entity name is in an incorrect format: 'order_id__is_food_order'" in error_string + # We cannot simply scan for name because the error message contains the filter list, so we assert against the error + assert ( + ParameterSetFactory._exception_message_for_incorrect_format("customer__has_delivery_address") + not in error_string + )