Skip to content

Commit

Permalink
update callsites for PydanticWhereFilter.call_parameter_sets and Pyda…
Browse files Browse the repository at this point in the history
…nticWhereFilterIntersection.filter_expression_parameter_sets
  • Loading branch information
WilliamDee committed Nov 5, 2024
1 parent cb1fe3e commit a169c25
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 34 deletions.
14 changes: 12 additions & 2 deletions dbt_semantic_interfaces/validations/saved_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ class SavedQueryRule(SemanticManifestValidationRule[SemanticManifestT], Generic[

@staticmethod
@validate_safely("Validate the group-by field in a saved query.")
def _check_group_bys(valid_group_by_element_names: Set[str], saved_query: SavedQuery) -> Sequence[ValidationIssue]:
def _check_group_bys(
valid_group_by_element_names: Set[str], saved_query: SavedQuery, custom_granularity_names: Sequence[str]
) -> Sequence[ValidationIssue]:
issues: List[ValidationIssue] = []

for group_by_item in saved_query.query_params.group_by:
# TODO: Replace with more appropriate abstractions once available.
parameter_sets: FilterCallParameterSets
try:
parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{" + group_by_item + "}}")
parameter_sets = WhereFilterParser.parse_call_parameter_sets(
where_sql_template="{{" + group_by_item + "}}", custom_granularity_names=custom_granularity_names
)
except Exception as e:
issues.append(
generate_exception_issue(
Expand Down Expand Up @@ -245,6 +249,11 @@ def _check_limit(saved_query: SavedQuery) -> Sequence[ValidationIssue]:
@validate_safely("Validate all saved queries in a semantic manifest.")
def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D
issues: List[ValidationIssue] = []
custom_granularity_names = [
granularity.name
for time_spine in semantic_manifest.project_configuration.time_spines
for granularity in time_spine.custom_granularities
]
valid_metric_names = {metric.name for metric in semantic_manifest.metrics}
valid_group_by_element_names = valid_metric_names.union({METRIC_TIME_ELEMENT_NAME})
for semantic_model in semantic_manifest.semantic_models:
Expand All @@ -261,6 +270,7 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati
issues += SavedQueryRule._check_group_bys(
valid_group_by_element_names=valid_group_by_element_names,
saved_query=saved_query,
custom_granularity_names=custom_granularity_names,
)
issues += SavedQueryRule._check_order_by(saved_query)
issues += SavedQueryRule._check_limit(saved_query)
Expand Down
40 changes: 28 additions & 12 deletions dbt_semantic_interfaces/validations/where_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def _validate_time_granularity_names_for_saved_query(
element_type=SavedQueryElementType.WHERE,
element_value=where_filter.where_sql_template,
),
filter_call_param_sets=where_filter.call_parameter_sets,
filter_call_param_sets=where_filter.call_parameter_sets(
custom_granularity_names=valid_granularity_names
),
valid_granularity_names=valid_granularity_names,
)

Expand Down Expand Up @@ -104,7 +106,7 @@ def _validate_saved_query(saved_query: SavedQuery, valid_granularity_names: List
return issues
for where_filter in saved_query.query_params.where.where_filters:
try:
where_filter.call_parameter_sets
where_filter.call_parameter_sets(custom_granularity_names=valid_granularity_names)
except Exception as e:
issues.append(
generate_exception_issue(
Expand Down Expand Up @@ -140,7 +142,7 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ

if metric.filter is not None:
try:
metric.filter.filter_expression_parameter_sets
metric.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names)
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -155,15 +157,17 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ
else:
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets,
filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
),
valid_granularity_names=valid_granularity_names,
)

if metric.type_params:
measure = metric.type_params.measure
if measure is not None and measure.filter is not None:
try:
measure.filter.filter_expression_parameter_sets
measure.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names)
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -179,14 +183,16 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ
else:
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets,
filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
),
valid_granularity_names=valid_granularity_names,
)

numerator = metric.type_params.numerator
if numerator is not None and numerator.filter is not None:
try:
numerator.filter.filter_expression_parameter_sets
numerator.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names)
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -201,14 +207,18 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ
else:
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets,
filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
),
valid_granularity_names=valid_granularity_names,
)

denominator = metric.type_params.denominator
if denominator is not None and denominator.filter is not None:
try:
denominator.filter.filter_expression_parameter_sets
denominator.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
)
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -223,14 +233,18 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ
else:
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets,
filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
),
valid_granularity_names=valid_granularity_names,
)

for input_metric in metric.type_params.metrics or []:
if input_metric.filter is not None:
try:
input_metric.filter.filter_expression_parameter_sets
input_metric.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
)
except Exception as e:
issues.append(
generate_exception_issue(
Expand All @@ -246,7 +260,9 @@ def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequ
else:
issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric(
context=context,
filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets,
filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets(
custom_granularity_names=valid_granularity_names
),
valid_granularity_names=valid_granularity_names,
)
return issues
Expand Down
46 changes: 33 additions & 13 deletions tests/implementations/where_filter/test_parse_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_extract_dimension_call_parameter_sets() -> None: # noqa: D
AND {{ Dimension('user__country', entity_path=['listing']) }} == 'US'\
"""
)
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(
Expand All @@ -61,7 +61,7 @@ def test_extract_dimension_with_grain_call_parameter_sets() -> None: # noqa: D
{{ Dimension('metric_time').grain('WEEK') }} > 2023-09-18
"""
)
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
Expand All @@ -81,7 +81,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D
where_sql_template=(
"""{{ TimeDimension('user__created_at', 'month', entity_path=['listing']) }} = '2020-01-01'"""
)
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
time_dimension_call_parameter_sets=(
Expand All @@ -100,7 +100,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D
where_sql_template=(
"""{{ TimeDimension('user__created_at__month', entity_path=['listing']) }} = '2020-01-01'"""
)
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
time_dimension_call_parameter_sets=(
Expand All @@ -119,7 +119,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D
def test_extract_metric_time_dimension_call_parameter_sets() -> None: # noqa: D
parse_result = PydanticWhereFilter(
where_sql_template="""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'"""
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
time_dimension_call_parameter_sets=(
Expand All @@ -137,7 +137,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D
where_sql_template=(
"""{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'"""
)
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
Expand All @@ -157,7 +157,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D
def test_extract_metric_call_parameter_sets() -> None: # noqa: D
parse_result = PydanticWhereFilter(
where_sql_template=("{{ Metric('bookings', group_by=['listing']) }} > 2")
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
Expand All @@ -172,7 +172,7 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D

parse_result = PydanticWhereFilter(
where_sql_template=("{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 2")
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
Expand All @@ -186,15 +186,17 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D
)

with pytest.raises(ParseWhereFilterException):
PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets
PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets(
custom_granularity_names=()
)


def test_invalid_entity_name_error() -> None:
"""Test to ensure we throw an error if an entity name is invalid."""
bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}")

with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"):
bad_entity_filter.call_parameter_sets
bad_entity_filter.call_parameter_sets(custom_granularity_names=())


def test_where_filter_interesection_extract_call_parameter_sets() -> None:
Expand All @@ -209,7 +211,7 @@ def test_where_filter_interesection_extract_call_parameter_sets() -> None:
)
filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter])

parse_result = dict(filter_intersection.filter_expression_parameter_sets)
parse_result = dict(filter_intersection.filter_expression_parameter_sets(custom_granularity_names=()))

assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets(
time_dimension_call_parameter_sets=(
Expand Down Expand Up @@ -250,7 +252,7 @@ def test_where_filter_intersection_error_collection() -> None:
)

with pytest.raises(ParseWhereFilterException) as exc_info:
filter_intersection.filter_expression_parameter_sets
filter_intersection.filter_expression_parameter_sets(custom_granularity_names=())

error_string = str(exc_info.value)
# These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find.
Expand All @@ -261,7 +263,7 @@ def test_where_filter_intersection_error_collection() -> None:
def test_time_dimension_without_granularity() -> None: # noqa: D
parse_result = PydanticWhereFilter(
where_sql_template="{{ TimeDimension('booking__created_at') }} > 2023-09-18"
).call_parameter_sets
).call_parameter_sets(custom_granularity_names=())

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
Expand All @@ -274,3 +276,21 @@ def test_time_dimension_without_granularity() -> None: # noqa: D
),
entity_call_parameter_sets=(),
)


def test_time_dimension_with_custom_granularity() -> None: # noqa: D
parse_result = PydanticWhereFilter(
where_sql_template="{{ TimeDimension('booking__created_at', 'martian_week') }} > 2023-09-18"
).call_parameter_sets(custom_granularity_names=("martian_week",))

assert parse_result == FilterCallParameterSets(
dimension_call_parameter_sets=(),
time_dimension_call_parameter_sets=(
TimeDimensionCallParameterSet(
entity_path=(EntityReference("booking"),),
time_dimension_reference=TimeDimensionReference(element_name="created_at"),
time_granularity_name="martian_week",
),
),
entity_call_parameter_sets=(),
)
14 changes: 7 additions & 7 deletions tests/parsing/test_where_filter_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def test_where_filter_intersection_from_partially_deserialized_list_of_strings()
],
)
def test_time_dimension_date_part(where: str) -> None: # noqa
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=())
assert len(param_sets.time_dimension_call_parameter_sets) == 1
assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR


def test_dimension_date_part() -> None: # noqa
where = "{{ Dimension('metric_time').grain('DAY').date_part('YEAR') }} > '2023-01-01'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=())
assert len(param_sets.time_dimension_call_parameter_sets) == 1
assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR

Expand Down Expand Up @@ -218,14 +218,14 @@ def test_time_dimension_grain( # noqa
where_and_expected_call_params: Tuple[str, Union[TimeDimensionCallParameterSet, DimensionCallParameterSet]]
) -> None:
where, expected_call_params = where_and_expected_call_params
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=("martian_week",))
assert len(param_sets.time_dimension_call_parameter_sets) == 1
assert param_sets.time_dimension_call_parameter_sets[0] == expected_call_params


def test_entity_without_primary_entity_prefix() -> None: # noqa
where = "{{ Entity('non_primary_entity') }} = '1'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=())
assert len(param_sets.entity_call_parameter_sets) == 1
assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet(
entity_path=(),
Expand All @@ -235,7 +235,7 @@ def test_entity_without_primary_entity_prefix() -> None: # noqa

def test_entity() -> None: # noqa
where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=())
assert len(param_sets.entity_call_parameter_sets) == 1
assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet(
entity_path=(
Expand All @@ -248,7 +248,7 @@ def test_entity() -> None: # noqa

def test_metric() -> None: # noqa
where = "{{ Metric('metric', group_by=['dimension']) }} = 10"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=())
assert len(param_sets.metric_call_parameter_sets) == 1
assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet(
group_by=(LinkableElementReference(element_name="dimension"),),
Expand All @@ -257,7 +257,7 @@ def test_metric() -> None: # noqa

# Without kwarg syntax
where = "{{ Metric('metric', ['dimension']) }} = 10"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=())
assert len(param_sets.metric_call_parameter_sets) == 1
assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet(
group_by=(LinkableElementReference(element_name="dimension"),),
Expand Down

0 comments on commit a169c25

Please sign in to comment.