Skip to content

Commit

Permalink
Use list of where filters instead of single where filter
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored and WilliamDee committed Oct 7, 2024
1 parent 811aca8 commit 89f25ad
Show file tree
Hide file tree
Showing 21 changed files with 249 additions and 192 deletions.
2 changes: 1 addition & 1 deletion dbt-metricflow/dbt_metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def query(
limit=limit,
time_constraint_start=start_time,
time_constraint_end=end_time,
where_constraint=where,
where_constraints=[where] if where else None,
order_by_names=order,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _resolve_dependencies(self, saved_query_name: str) -> SavedQueryDependencySe

parse_result = self._query_parser.parse_and_validate_saved_query(
saved_query_parameter=SavedQueryParameter(saved_query_name),
where_filter=None,
where_filters=None,
limit=None,
time_constraint_start=None,
time_constraint_end=None,
Expand Down
47 changes: 28 additions & 19 deletions metricflow-semantics/metricflow_semantics/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow_semantics.assert_one_arg import assert_at_most_one_arg_set
from metricflow_semantics.filters.merge_where import merge_to_single_where_filter
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__( # noqa: D107
def parse_and_validate_saved_query(
self,
saved_query_parameter: SavedQueryParameter,
where_filter: Optional[WhereFilter],
where_filters: Optional[Sequence[WhereFilter]],
limit: Optional[int],
time_constraint_start: Optional[datetime.datetime],
time_constraint_end: Optional[datetime.datetime],
Expand All @@ -107,19 +106,19 @@ def parse_and_validate_saved_query(
saved_query = self._get_saved_query(saved_query_parameter)

# Merge interface could streamline this.
where_filters: List[WhereFilter] = []
parsed_where_filters: List[WhereFilter] = []
if saved_query.query_params.where is not None:
where_filters.extend(saved_query.query_params.where.where_filters)
if where_filter is not None:
where_filters.append(where_filter)
parsed_where_filters.extend(saved_query.query_params.where.where_filters)
if where_filters is not None:
parsed_where_filters.extend(where_filters)

return self._parse_and_validate_query(
metric_names=saved_query.query_params.metrics,
metrics=None,
group_by_names=saved_query.query_params.group_by,
group_by=None,
where_constraint=merge_to_single_where_filter(PydanticWhereFilterIntersection(where_filters=where_filters)),
where_constraint_str=None,
where_constraints=PydanticWhereFilterIntersection(where_filters=parsed_where_filters),
where_constraint_strs=None,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
limit=limit,
Expand Down Expand Up @@ -309,8 +308,8 @@ def parse_and_validate_query(
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
where_constraints: Optional[Sequence[WhereFilter]] = None,
where_constraint_strs: Optional[Sequence[str]] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
min_max_only: bool = False,
Expand All @@ -329,8 +328,8 @@ def parse_and_validate_query(
limit=limit,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
where_constraint=where_constraint,
where_constraint_str=where_constraint_str,
where_constraints=where_constraints,
where_constraint_strs=where_constraint_strs,
order_by_names=order_by_names,
order_by=order_by,
min_max_only=min_max_only,
Expand All @@ -346,8 +345,8 @@ def _parse_and_validate_query(
limit: Optional[int],
time_constraint_start: Optional[datetime.datetime],
time_constraint_end: Optional[datetime.datetime],
where_constraint: Optional[WhereFilter],
where_constraint_str: Optional[str],
where_constraints: Optional[Sequence[WhereFilter]],
where_constraint_strs: Optional[Sequence[str]],
order_by_names: Optional[Sequence[str]],
order_by: Optional[Sequence[OrderByQueryParameter]],
min_max_only: bool,
Expand All @@ -357,7 +356,7 @@ def _parse_and_validate_query(
assert_at_most_one_arg_set(metric_names=metric_names, metrics=metrics)
assert_at_most_one_arg_set(group_by_names=group_by_names, group_by=group_by)
assert_at_most_one_arg_set(order_by_names=order_by_names, order_by=order_by)
assert_at_most_one_arg_set(where_constraint=where_constraint, where_constraint_str=where_constraint_str)
assert_at_most_one_arg_set(where_constraints=where_constraints, where_constraint_strs=where_constraint_strs)

metric_names = metric_names or ()
metrics = metrics or ()
Expand Down Expand Up @@ -455,10 +454,20 @@ def _parse_and_validate_query(

where_filters: List[PydanticWhereFilter] = []

if where_constraint is not None:
where_filters.append(PydanticWhereFilter(where_sql_template=where_constraint.where_sql_template))
if where_constraint_str is not None:
where_filters.append(PydanticWhereFilter(where_sql_template=where_constraint_str))
if where_constraints is not None:
where_filters.extend(
[
PydanticWhereFilter(where_sql_template=constraint.where_sql_template)
for constraint in where_constraints
]
)
if where_constraint_strs is not None:
where_filters.extend(
[
PydanticWhereFilter(where_sql_template=where_constraint_str)
for where_constraint_str in where_constraint_strs
]
)

resolver_input_for_filter = ResolverInputForQueryLevelWhereFilterIntersection(
where_filter_intersection=PydanticWhereFilterIntersection(where_filters=where_filters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_parse_and_validate_where_constraint_dims(
group_by_names=[MTD],
time_constraint_start=as_datetime("2020-01-15"),
time_constraint_end=as_datetime("2020-02-15"),
where_constraint_str="{{ Dimension('booking__invalid_dim') }} = '1'",
where_constraint_strs=["{{ Dimension('booking__invalid_dim') }} = '1'"],
)

with pytest.raises(InvalidQueryException, match="Error parsing where filter"):
Expand All @@ -337,15 +337,15 @@ def test_parse_and_validate_where_constraint_dims(
group_by_names=[MTD],
time_constraint_start=as_datetime("2020-01-15"),
time_constraint_end=as_datetime("2020-02-15"),
where_constraint_str="{{ Dimension('invalid_format') }} = '1'",
where_constraint_strs=["{{ Dimension('invalid_format') }} = '1'"],
)

result = bookings_query_parser.parse_and_validate_query(
metric_names=["bookings"],
group_by_names=[MTD],
time_constraint_start=as_datetime("2020-01-15"),
time_constraint_end=as_datetime("2020-02-15"),
where_constraint_str="{{ Dimension('booking__is_instant') }} = '1'",
where_constraint_strs=["{{ Dimension('booking__is_instant') }} = '1'"],
)
assert_object_snapshot_equal(request=request, mf_test_configuration=mf_test_configuration, obj=result)
assert (
Expand All @@ -366,7 +366,7 @@ def test_parse_and_validate_where_constraint_metric_time(
query_parser.parse_and_validate_query(
metric_names=["revenue"],
group_by_names=[MTD],
where_constraint_str="{{ TimeDimension('metric_time', 'day') }} > '2020-01-15'",
where_constraint_strs=["{{ TimeDimension('metric_time', 'day') }} > '2020-01-15'"],
)


Expand Down Expand Up @@ -622,5 +622,5 @@ def test_invalid_group_by_metric(bookings_query_parser: MetricFlowQueryParser) -
"""Tests that a query for an invalid group by metric gives an appropriate group by metric suggestion."""
with pytest.raises(InvalidQueryException, match="Metric\\('bookings', group_by=\\['listing'\\]\\)"):
bookings_query_parser.parse_and_validate_query(
metric_names=("bookings",), where_constraint_str="{{ Metric('listings', ['garbage']) }} > 1"
metric_names=("bookings",), where_constraint_strs=["{{ Metric('listings', ['garbage']) }} > 1"]
)
18 changes: 8 additions & 10 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class MetricFlowQueryRequest:
limit: Limit the result to this many rows.
time_constraint_start: Get data for the start of this time range.
time_constraint_end: Get data for the end of this time range.
where_constraint: A SQL string using group by names that can be used like a where clause on the output data.
where_constraints: A sequence of SQL strings that can be used like a where clause on the output data.
order_by_names: metric and group by names to order by. A "-" can be used to specify reverse order e.g. "-ds".
order_by: metric, dimension, or entity objects to order by.
output_table: If specified, output the result data to this table instead of a result data_table.
Expand All @@ -107,7 +107,7 @@ class MetricFlowQueryRequest:
limit: Optional[int] = None
time_constraint_start: Optional[datetime.datetime] = None
time_constraint_end: Optional[datetime.datetime] = None
where_constraint: Optional[str] = None
where_constraints: Optional[Sequence[str]] = None
order_by_names: Optional[Sequence[str]] = None
order_by: Optional[Sequence[OrderByQueryParameter]] = None
min_max_only: bool = False
Expand All @@ -125,7 +125,7 @@ def create_with_random_request_id( # noqa: D102
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[str] = None,
where_constraints: Optional[Sequence[str]] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
Expand All @@ -145,7 +145,7 @@ def create_with_random_request_id( # noqa: D102
limit=limit,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
where_constraint=where_constraint,
where_constraints=where_constraints,
order_by_names=order_by_names,
order_by=order_by,
sql_optimization_level=sql_optimization_level,
Expand Down Expand Up @@ -469,11 +469,9 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
raise InvalidQueryException("Group by items can't be specified with a saved query.")
query_spec = self._query_parser.parse_and_validate_saved_query(
saved_query_parameter=SavedQueryParameter(mf_query_request.saved_query_name),
where_filter=(
PydanticWhereFilter(where_sql_template=mf_query_request.where_constraint)
if mf_query_request.where_constraint is not None
else None
),
where_filters=[PydanticWhereFilter(where_sql_template=mf_query_request.where_constraints)]
if mf_query_request.where_constraints is not None
else None,
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
Expand All @@ -489,7 +487,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
where_constraint_str=mf_query_request.where_constraint,
where_constraint_strs=mf_query_request.where_constraints,
order_by_names=mf_query_request.order_by_names,
order_by=mf_query_request.order_by,
min_max_only=mf_query_request.min_max_only,
Expand Down
42 changes: 22 additions & 20 deletions tests_metricflow/dataflow/builder/test_dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_where_constrained_plan(
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("booking__is_instant",),
where_constraint_str="{{ Dimension('listing__country_latest') }} = 'us'",
where_constraint_strs=["{{ Dimension('listing__country_latest') }} = 'us'"],
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

Expand Down Expand Up @@ -386,7 +386,7 @@ def test_where_constrained_plan_time_dimension(
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("booking__is_instant",),
where_constraint_str="{{ TimeDimension('metric_time', 'day') }} >= '2020-01-01'",
where_constraint_strs=["{{ TimeDimension('metric_time', 'day') }} >= '2020-01-01'"],
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

Expand Down Expand Up @@ -416,7 +416,7 @@ def test_where_constrained_with_common_linkable_plan(
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("listing__country_latest",),
where_constraint_str="{{ Dimension('listing__country_latest') }} = 'us'",
where_constraint_strs=["{{ Dimension('listing__country_latest') }} = 'us'"],
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

Expand Down Expand Up @@ -553,7 +553,7 @@ def test_distinct_values_plan(
query_spec = query_parser.parse_and_validate_query(
metric_names=(),
group_by_names=("listing__country_latest",),
where_constraint_str="{{ Dimension('listing__country_latest') }} = 'us'",
where_constraint_strs=["{{ Dimension('listing__country_latest') }} = 'us'"],
order_by_names=("-listing__country_latest",),
limit=100,
).query_spec
Expand Down Expand Up @@ -583,7 +583,7 @@ def test_distinct_values_plan_with_join(
"""Tests a plan to get distinct values of 2 dimensions, where a join is required."""
query_spec = query_parser.parse_and_validate_query(
group_by_names=("user__home_state_latest", "listing__is_lux_latest"),
where_constraint_str="{{ Dimension('listing__country_latest') }} = 'us'",
where_constraint_strs=["{{ Dimension('listing__country_latest') }} = 'us'"],
order_by_names=("-listing__is_lux_latest",),
limit=100,
).query_spec
Expand Down Expand Up @@ -1188,9 +1188,9 @@ def test_join_to_time_spine_with_filters(
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_fill_nulls_with_0",),
group_by_names=("metric_time__day",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'")
),
where_constraints=[
PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'"))
],
time_constraint_start=datetime.datetime(2020, 1, 3),
time_constraint_end=datetime.datetime(2020, 1, 5),
).query_spec
Expand Down Expand Up @@ -1221,9 +1221,9 @@ def test_offset_window_metric_filter_and_query_have_different_granularities(
query_spec = query_parser.parse_and_validate_query(
metric_names=("booking_fees_last_week_per_booker_this_week",),
group_by_names=("metric_time__month",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'")
),
where_constraints=[
PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'"))
],
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

Expand Down Expand Up @@ -1252,9 +1252,9 @@ def test_offset_to_grain_metric_filter_and_query_have_different_granularities(
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings_at_start_of_month",),
group_by_names=("metric_time__month",),
where_constraint=PydanticWhereFilter(
where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'")
),
where_constraints=[
PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'day') }} = '2020-01-01'"))
],
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

Expand All @@ -1281,7 +1281,7 @@ def test_metric_in_query_where_filter(
) -> None:
"""Test querying a metric that has a metric in its where filter."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("listings",), where_constraint_str="{{ Metric('bookings', ['listing'])}} > 2"
metric_names=("listings",), where_constraint_strs=["{{ Metric('bookings', ['listing'])}} > 2"]
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

Expand Down Expand Up @@ -1340,11 +1340,13 @@ def test_all_available_metric_filters(
entity_spec = group_by_metric_spec.metric_subquery_entity_spec
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
where_constraint=PydanticWhereFilter(
where_sql_template=string.Template("{{ Metric('$metric_name', ['$entity_name']) }} > 2").substitute(
metric_name=linkable_metric.element_name, entity_name=entity_spec.qualified_name
),
),
where_constraints=[
PydanticWhereFilter(
where_sql_template=string.Template(
"{{ Metric('$metric_name', ['$entity_name']) }} > 2"
).substitute(metric_name=linkable_metric.element_name, entity_name=entity_spec.qualified_name),
)
],
).query_spec
dataflow_plan_builder.build_plan(query_spec)

Expand Down
Loading

0 comments on commit 89f25ad

Please sign in to comment.