Skip to content

Commit

Permalink
Consolidate query interface params (#717)
Browse files Browse the repository at this point in the history
This PR makes the Dimension, TimeDimension, and Entity objects that are passed into the where filter Jinja template implement protocols. These protocols allow for different implementations of these objects depending on the context (group_by, where, order_by parameters) while also constraining these implementations to all have the same method signatures. This will create consistency across these contexts.

The Dimension protocol is also implemented for parameters in certain methods (i.e. GroupByOrderByDimension). This will allow for more complex Dimension objects in the future that aren't feasible to serialize into a string, such as Dimension('demographic').grain('month').alias('monthly_demographics'). These parameters are added as optional parameters to these methods, so everything should be backward compatible.

I added some unit tests. This change is mostly refactoring existing behavior to implement protocols and adding additional optional parameters. These optional parameters just transform into the existing parameters. So, I relied on the existing tests to ensure no breaking behavior.
  • Loading branch information
DevonFulcher authored Aug 25, 2023
1 parent 1c64a15 commit 3e2cea0
Show file tree
Hide file tree
Showing 10 changed files with 544 additions and 140 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20230817-100659.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Adds the option for users to specify group by parameters with object syntax
matching the where/filter expressions.
time: 2023-08-17T10:06:59.615022-05:00
custom:
Author: DevonFulcher
Issue: None
48 changes: 39 additions & 9 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dbt_semantic_interfaces.references import EntityReference, MetricReference
from dbt_semantic_interfaces.type_enums import DimensionType

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import (
DataflowPlanNodeOutputDataSetResolver,
Expand Down Expand Up @@ -50,6 +51,7 @@
from metricflow.query.query_parser import MetricFlowQueryParser
from metricflow.random_id import random_id
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.query_interface import QueryInterfaceMetric, QueryParameter
from metricflow.specs.specs import InstanceSpecSet, MetricFlowQuerySpec
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.telemetry.models import TelemetryLevel
Expand Down Expand Up @@ -94,39 +96,55 @@ class MetricFlowQueryRequest:
"""

request_id: MetricFlowRequestId
metric_names: Sequence[str]
group_by_names: Sequence[str]
metric_names: Optional[Sequence[str]] = None
metrics: Optional[Sequence[QueryInterfaceMetric]] = None
group_by_names: Optional[Sequence[str]] = None
group_by: Optional[Sequence[QueryParameter]] = None
limit: Optional[int] = None
time_constraint_start: Optional[datetime.datetime] = None
time_constraint_end: Optional[datetime.datetime] = None
where_constraint: Optional[str] = None
order_by_names: Optional[Sequence[str]] = None
order_by: Optional[Sequence[QueryParameter]] = None
output_table: Optional[str] = None
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC

@staticmethod
def create_with_random_request_id( # noqa: D
metric_names: Sequence[str],
group_by_names: Sequence[str],
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[QueryParameter]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[str] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[QueryParameter]] = None,
output_table: Optional[str] = None,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC,
) -> MetricFlowQueryRequest:
assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
assert not (
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"
assert not (
order_by_names and order_by
), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!"
return MetricFlowQueryRequest(
request_id=MetricFlowRequestId(mf_rid=f"{random_id()}"),
metric_names=metric_names,
metrics=metrics,
group_by_names=group_by_names,
group_by=group_by,
limit=limit,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
where_constraint=where_constraint,
order_by_names=order_by_names,
order_by=order_by,
output_table=output_table,
sql_optimization_level=sql_optimization_level,
query_type=query_type,
Expand Down Expand Up @@ -263,8 +281,10 @@ def get_dimension_values(
@abstractmethod
def explain_get_dimension_values( # noqa: D
self,
metric_names: List[str],
get_group_by_values: str,
metric_names: Optional[List[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
get_group_by_values: Optional[str] = None,
group_by: Optional[QueryParameter] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
) -> MetricFlowExplainResult:
Expand Down Expand Up @@ -381,12 +401,15 @@ def query(self, mf_request: MetricFlowQueryRequest) -> MetricFlowQueryResult: #
def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> MetricFlowExplainResult:
query_spec = self._query_parser.parse_and_validate_query(
metric_names=mf_query_request.metric_names,
metrics=mf_query_request.metrics,
group_by_names=mf_query_request.group_by_names,
group_by=mf_query_request.group_by,
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
where_constraint_str=mf_query_request.where_constraint,
order=mf_query_request.order_by_names,
order_by=mf_query_request.order_by,
)
logger.info(f"Query spec is:\n{pformat_big_objects(query_spec)}")

Expand Down Expand Up @@ -616,15 +639,22 @@ def get_dimension_values( # noqa: D
@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
def explain_get_dimension_values( # noqa: D
self,
metric_names: List[str],
get_group_by_values: str,
metric_names: Optional[List[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
get_group_by_values: Optional[str] = None,
group_by: Optional[QueryParameter] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
) -> MetricFlowExplainResult:
assert not (
get_group_by_values and group_by
), "Both get_group_by_values and group_by were set, but if a group by is specified you should only use one of these!"
return self._create_execution_plan(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=metric_names,
group_by_names=[get_group_by_values],
metrics=metrics,
group_by_names=[get_group_by_values] if get_group_by_values else None,
group_by=[group_by] if group_by else None,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
query_type=MetricFlowQueryType.DIMENSION_VALUES,
Expand Down
85 changes: 68 additions & 17 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataset.dataset import DataSet
Expand All @@ -30,6 +31,7 @@
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.query_interface import QueryInterfaceMetric, QueryParameter
from metricflow.specs.specs import (
DimensionSpec,
EntitySpec,
Expand Down Expand Up @@ -167,14 +169,17 @@ def _top_fuzzy_matches(

def parse_and_validate_query(
self,
metric_names: Sequence[str],
group_by_names: Sequence[str],
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[QueryParameter]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
order: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[QueryParameter]] = None,
time_granularity: Optional[TimeGranularity] = None,
) -> MetricFlowQuerySpec:
"""Parse the query into spec objects, validating them in the process.
Expand All @@ -185,13 +190,16 @@ def parse_and_validate_query(
try:
return self._parse_and_validate_query(
metric_names=metric_names,
metrics=metrics,
group_by_names=group_by_names,
group_by=group_by,
limit=limit,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
where_constraint=where_constraint,
where_constraint_str=where_constraint_str,
order=order,
order_by=order_by,
time_granularity=time_granularity,
)
finally:
Expand Down Expand Up @@ -238,7 +246,8 @@ def _validate_linkable_specs(
suggestion_sections = {}
for invalid_group_by in invalid_group_bys:
suggestions = MetricFlowQueryParser._top_fuzzy_matches(
item=invalid_group_by.qualified_name, candidate_items=valid_group_by_names_for_metrics
item=invalid_group_by.qualified_name,
candidate_items=valid_group_by_names_for_metrics,
)
section_key = f"Suggestions for invalid dimension '{invalid_group_by.qualified_name}'"
section_value = pformat_big_objects(suggestions)
Expand Down Expand Up @@ -281,27 +290,63 @@ def _construct_metric_specs_for_query(
)
return tuple(metric_specs)

def _get_group_by_names(
self, group_by_names: Optional[Sequence[str]], group_by: Optional[Sequence[QueryParameter]]
) -> Sequence[str]:
assert not (
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"
return (
group_by_names
if group_by_names
else [f"{g.name}__{g.grain}" if g.grain else g.name for g in group_by]
if group_by
else []
)

def _get_metric_names(
self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[QueryInterfaceMetric]]
) -> Sequence[str]:
assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
return metric_names if metric_names else [m.name for m in metrics] if metrics else []

def _get_where_filter(
self,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
) -> Optional[WhereFilter]:
assert not (
where_constraint and where_constraint_str
), "Both where_constraint and where_constraint_str were set, but if a where is specified you should only use one of these!"
return (
PydanticWhereFilter(where_sql_template=where_constraint_str) if where_constraint_str else where_constraint
)

def _get_order(self, order: Optional[Sequence[str]], order_by: Optional[Sequence[QueryParameter]]) -> Sequence[str]:
assert not (
order and order_by
), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!"
return order if order else [f"{o.name}__{o.grain}" if o.grain else o.name for o in order_by] if order_by else []

def _parse_and_validate_query(
self,
metric_names: Sequence[str],
group_by_names: Sequence[str],
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[QueryParameter]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
order: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[QueryParameter]] = None,
time_granularity: Optional[TimeGranularity] = None,
) -> MetricFlowQuerySpec:
assert not (
where_constraint and where_constraint_str
), "Both where_constraint and where_constraint_str should not be set"

where_filter: Optional[WhereFilter]
if where_constraint_str:
where_filter = PydanticWhereFilter(where_sql_template=where_constraint_str)
else:
where_filter = where_constraint
metric_names = self._get_metric_names(metric_names, metrics)
group_by_names = self._get_group_by_names(group_by_names, group_by)
where_filter = self._get_where_filter(where_constraint, where_constraint_str)
order = self._get_order(order, order_by)

# Get metric references used for validations
# In a case of derived metric, all the input metrics would be here.
Expand Down Expand Up @@ -507,7 +552,8 @@ def _adjust_time_range_constraint(
)
partial_time_dimension_spec_to_time_dimension_spec = (
self._time_granularity_solver.resolve_granularity_for_partial_time_dimension_specs(
metric_references=metric_references, partial_time_dimension_specs=(partial_metric_time_spec,)
metric_references=metric_references,
partial_time_dimension_specs=(partial_metric_time_spec,),
)
)
adjust_to_granularity = partial_time_dimension_spec_to_time_dimension_spec[
Expand All @@ -527,7 +573,10 @@ def _find_replacement_for_metric_time_dimension(
== self._metric_time_dimension_reference.element_name
and partial_time_dimension_spec_to_replace.entity_links == ()
):
return partial_time_dimension_spec_to_replace, replace_with_time_dimension_spec
return (
partial_time_dimension_spec_to_replace,
replace_with_time_dimension_spec,
)

raise RuntimeError(f"Replacement for metric time dimension '{self._metric_time_dimension_reference}' not found")

Expand Down Expand Up @@ -596,7 +645,9 @@ def _parse_metric_names(
return tuple(metric_references)

def _parse_linkable_element_names(
self, qualified_linkable_names: Sequence[str], metric_references: Sequence[MetricReference]
self,
qualified_linkable_names: Sequence[str],
metric_references: Sequence[MetricReference],
) -> QueryTimeLinkableSpecSet:
"""Convert the linkable spec names into the respective specification objects."""
qualified_linkable_names = [x.lower() for x in qualified_linkable_names]
Expand Down
Loading

0 comments on commit 3e2cea0

Please sign in to comment.