diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index c805aa8f..cad7562c 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -49,14 +49,18 @@ class SavedQueryRule(SemanticManifestValidationRule[SemanticManifestT], Generic[ @staticmethod @validate_safely("Validate the group-by field in a saved query.") - def _check_group_bys(valid_group_by_element_names: Set[str], saved_query: SavedQuery) -> Sequence[ValidationIssue]: + def _check_group_bys( + valid_group_by_element_names: Set[str], saved_query: SavedQuery, custom_granularity_names: Sequence[str] + ) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] for group_by_item in saved_query.query_params.group_by: # TODO: Replace with more appropriate abstractions once available. parameter_sets: FilterCallParameterSets try: - parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{" + group_by_item + "}}") + parameter_sets = WhereFilterParser.parse_call_parameter_sets( + where_sql_template="{{" + group_by_item + "}}", custom_granularity_names=custom_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -245,6 +249,11 @@ def _check_limit(saved_query: SavedQuery) -> Sequence[ValidationIssue]: @validate_safely("Validate all saved queries in a semantic manifest.") def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D issues: List[ValidationIssue] = [] + custom_granularity_names = [ + granularity.name + for time_spine in semantic_manifest.project_configuration.time_spines + for granularity in time_spine.custom_granularities + ] valid_metric_names = {metric.name for metric in semantic_manifest.metrics} valid_group_by_element_names = valid_metric_names.union({METRIC_TIME_ELEMENT_NAME}) for semantic_model in semantic_manifest.semantic_models: @@ -261,6 +270,7 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati issues += SavedQueryRule._check_group_bys( valid_group_by_element_names=valid_group_by_element_names, saved_query=saved_query, + custom_granularity_names=custom_granularity_names, ) issues += SavedQueryRule._check_order_by(saved_query) issues += SavedQueryRule._check_limit(saved_query) diff --git a/dbt_semantic_interfaces/validations/where_filters.py b/dbt_semantic_interfaces/validations/where_filters.py index 57ebaa30..d01dde39 100644 --- a/dbt_semantic_interfaces/validations/where_filters.py +++ b/dbt_semantic_interfaces/validations/where_filters.py @@ -73,7 +73,9 @@ def _validate_time_granularity_names_for_saved_query( element_type=SavedQueryElementType.WHERE, element_value=where_filter.where_sql_template, ), - filter_call_param_sets=where_filter.call_parameter_sets, + filter_call_param_sets=where_filter.call_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) @@ -104,7 +106,7 @@ def _validate_saved_query(saved_query: SavedQuery, valid_granularity_names: List return issues for where_filter in saved_query.query_params.where.where_filters: try: - where_filter.call_parameter_sets + where_filter.call_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -140,7 +142,7 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ if metric.filter is not None: try: - metric.filter.filter_expression_parameter_sets + metric.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -155,7 +157,9 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) @@ -163,7 +167,7 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ measure = metric.type_params.measure if measure is not None and measure.filter is not None: try: - measure.filter.filter_expression_parameter_sets + measure.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -179,14 +183,16 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) numerator = metric.type_params.numerator if numerator is not None and numerator.filter is not None: try: - numerator.filter.filter_expression_parameter_sets + numerator.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) except Exception as e: issues.append( generate_exception_issue( @@ -201,14 +207,18 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) denominator = metric.type_params.denominator if denominator is not None and denominator.filter is not None: try: - denominator.filter.filter_expression_parameter_sets + denominator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -223,14 +233,18 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) for input_metric in metric.type_params.metrics or []: if input_metric.filter is not None: try: - input_metric.filter.filter_expression_parameter_sets + input_metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -246,7 +260,9 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ else: issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( context=context, - filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets, + filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), valid_granularity_names=valid_granularity_names, ) return issues diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 2a7f9e89..02f284f5 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -34,7 +34,7 @@ def test_extract_dimension_call_parameter_sets() -> None: # noqa: D AND {{ Dimension('user__country', entity_path=['listing']) }} == 'US'\ """ ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=( @@ -61,7 +61,7 @@ def test_extract_dimension_with_grain_call_parameter_sets() -> None: # noqa: D {{ Dimension('metric_time').grain('WEEK') }} > 2023-09-18 """ ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -81,7 +81,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ TimeDimension('user__created_at', 'month', entity_path=['listing']) }} = '2020-01-01'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -100,7 +100,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ TimeDimension('user__created_at__month', entity_path=['listing']) }} = '2020-01-01'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -119,7 +119,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D def test_extract_metric_time_dimension_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template="""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'""" - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -137,7 +137,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -157,7 +157,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D def test_extract_metric_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=("{{ Metric('bookings', group_by=['listing']) }} > 2") - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -172,7 +172,7 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=("{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 2") - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -186,7 +186,9 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D ) with pytest.raises(ParseWhereFilterException): - PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets + PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets( + custom_granularity_names=() + ) def test_invalid_entity_name_error() -> None: @@ -194,7 +196,7 @@ def test_invalid_entity_name_error() -> None: bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}") with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"): - bad_entity_filter.call_parameter_sets + bad_entity_filter.call_parameter_sets(custom_granularity_names=()) def test_where_filter_interesection_extract_call_parameter_sets() -> None: @@ -209,7 +211,7 @@ def test_where_filter_interesection_extract_call_parameter_sets() -> None: ) filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter]) - parse_result = dict(filter_intersection.filter_expression_parameter_sets) + parse_result = dict(filter_intersection.filter_expression_parameter_sets(custom_granularity_names=())) assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -250,7 +252,7 @@ def test_where_filter_intersection_error_collection() -> None: ) with pytest.raises(ParseWhereFilterException) as exc_info: - filter_intersection.filter_expression_parameter_sets + filter_intersection.filter_expression_parameter_sets(custom_granularity_names=()) error_string = str(exc_info.value) # These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find. @@ -261,7 +263,7 @@ def test_where_filter_intersection_error_collection() -> None: def test_time_dimension_without_granularity() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template="{{ TimeDimension('booking__created_at') }} > 2023-09-18" - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -274,3 +276,21 @@ def test_time_dimension_without_granularity() -> None: # noqa: D ), entity_call_parameter_sets=(), ) + + +def test_time_dimension_with_custom_granularity() -> None: # noqa: D + parse_result = PydanticWhereFilter( + where_sql_template="{{ TimeDimension('booking__created_at', 'martian_week') }} > 2023-09-18" + ).call_parameter_sets(custom_granularity_names=("martian_week",)) + + assert parse_result == FilterCallParameterSets( + dimension_call_parameter_sets=(), + time_dimension_call_parameter_sets=( + TimeDimensionCallParameterSet( + entity_path=(EntityReference("booking"),), + time_dimension_reference=TimeDimensionReference(element_name="created_at"), + time_granularity_name="martian_week", + ), + ), + entity_call_parameter_sets=(), + ) diff --git a/tests/parsing/test_where_filter_parsing.py b/tests/parsing/test_where_filter_parsing.py index 0886ea72..8764a97a 100644 --- a/tests/parsing/test_where_filter_parsing.py +++ b/tests/parsing/test_where_filter_parsing.py @@ -165,14 +165,14 @@ def test_where_filter_intersection_from_partially_deserialized_list_of_strings() ], ) def test_time_dimension_date_part(where: str) -> None: # noqa - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR def test_dimension_date_part() -> None: # noqa where = "{{ Dimension('metric_time').grain('DAY').date_part('YEAR') }} > '2023-01-01'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR @@ -218,14 +218,14 @@ def test_time_dimension_grain( # noqa where_and_expected_call_params: Tuple[str, Union[TimeDimensionCallParameterSet, DimensionCallParameterSet]] ) -> None: where, expected_call_params = where_and_expected_call_params - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=("martian_week",)) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0] == expected_call_params def test_entity_without_primary_entity_prefix() -> None: # noqa where = "{{ Entity('non_primary_entity') }} = '1'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.entity_call_parameter_sets) == 1 assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( entity_path=(), @@ -235,7 +235,7 @@ def test_entity_without_primary_entity_prefix() -> None: # noqa def test_entity() -> None: # noqa where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.entity_call_parameter_sets) == 1 assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( entity_path=( @@ -248,7 +248,7 @@ def test_entity() -> None: # noqa def test_metric() -> None: # noqa where = "{{ Metric('metric', group_by=['dimension']) }} = 10" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.metric_call_parameter_sets) == 1 assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet( group_by=(LinkableElementReference(element_name="dimension"),), @@ -257,7 +257,7 @@ def test_metric() -> None: # noqa # Without kwarg syntax where = "{{ Metric('metric', ['dimension']) }} = 10" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.metric_call_parameter_sets) == 1 assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet( group_by=(LinkableElementReference(element_name="dimension"),),