Skip to content

Commit

Permalink
Allow the use of saved queries in the engine / CLI.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Sep 14, 2023
1 parent f30cb59 commit 0f95aa5
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 27 deletions.
13 changes: 10 additions & 3 deletions metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import warnings
from importlib.metadata import version as pkg_version
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Sequence

import click
import jinja2
Expand Down Expand Up @@ -248,13 +248,18 @@ def tutorial(ctx: click.core.Context, cfg: CLIContext, msg: bool, clean: bool) -
default=False,
help="Shows inline descriptions of nodes in displayed SQL",
)
@click.option(
"--saved-query",
required=False,
help="Specify the name of the saved query to use for applicable parameters",
)
@pass_config
@exception_handler
@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
def query(
cfg: CLIContext,
metrics: List[str],
group_by: List[str] = [],
metrics: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[str]] = None,
where: Optional[str] = None,
start_time: Optional[dt.datetime] = None,
end_time: Optional[dt.datetime] = None,
Expand All @@ -266,12 +271,14 @@ def query(
display_plans: bool = False,
decimals: int = DEFAULT_RESULT_DECIMAL_PLACES,
show_sql_descriptions: bool = False,
saved_query: Optional[str] = None,
) -> None:
"""Create a new query with MetricFlow and assembles a MetricFlowQueryResult."""
start = time.time()
spinner = Halo(text="Initiating query…", spinner="dots")
spinner.start()
mf_request = MetricFlowQueryRequest.create_with_random_request_id(
saved_query_name=saved_query,
metric_names=metrics,
group_by_names=group_by,
limit=limit,
Expand Down
2 changes: 1 addition & 1 deletion metricflow/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def query_options(function: Callable) -> Callable:
)(function)
function = click.option(
"--metrics",
type=click_custom.SequenceParamType(min_length=1),
type=click_custom.SequenceParamType(min_length=0),
default="",
help="Metrics to query for: syntax is --metrics bookings or for multiple metrics --metrics bookings,messages",
)(function)
Expand Down
12 changes: 4 additions & 8 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dbt_semantic_interfaces.references import EntityReference, MeasureReference, MetricReference
from dbt_semantic_interfaces.type_enums import DimensionType

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import (
DataflowPlanNodeOutputDataSetResolver,
Expand Down Expand Up @@ -97,6 +96,7 @@ class MetricFlowQueryRequest:
"""

request_id: MetricFlowRequestId
saved_query_name: Optional[str] = None
metric_names: Optional[Sequence[str]] = None
metrics: Optional[Sequence[QueryInterfaceMetric]] = None
group_by_names: Optional[Sequence[str]] = None
Expand All @@ -113,6 +113,7 @@ class MetricFlowQueryRequest:

@staticmethod
def create_with_random_request_id( # noqa: D
saved_query_name: Optional[str] = None,
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
group_by_names: Optional[Sequence[str]] = None,
Expand All @@ -127,15 +128,9 @@ def create_with_random_request_id( # noqa: D
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC,
) -> MetricFlowQueryRequest:
assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
assert not (
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"
assert not (
order_by_names and order_by
), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!"
return MetricFlowQueryRequest(
request_id=MetricFlowRequestId(mf_rid=f"{random_id()}"),
saved_query_name=saved_query_name,
metric_names=metric_names,
metrics=metrics,
group_by_names=group_by_names,
Expand Down Expand Up @@ -413,6 +408,7 @@ def all_time_constraint(self) -> TimeRangeConstraint:

def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> MetricFlowExplainResult:
query_spec = self._query_parser.parse_and_validate_query(
saved_query_name=mf_query_request.saved_query_name,
metric_names=mf_query_request.metric_names,
metrics=mf_query_request.metrics,
group_by_names=mf_query_request.group_by_names,
Expand Down
103 changes: 88 additions & 15 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dbt_semantic_interfaces.pretty_print import pformat_big_objects
from dbt_semantic_interfaces.protocols.dimension import DimensionType
from dbt_semantic_interfaces.protocols.metric import MetricType
from dbt_semantic_interfaces.protocols.saved_query import SavedQuery
from dbt_semantic_interfaces.protocols.where_filter import WhereFilter
from dbt_semantic_interfaces.references import (
DimensionReference,
Expand All @@ -20,7 +21,6 @@
)
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataset.dataset import DataSet
Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__( # noqa: D
node_output_resolver: DataflowPlanNodeOutputDataSetResolver,
) -> None:
self._column_association_resolver = column_association_resolver
# TODO: Rename model -> manifest lookup
self._model = model
self._metric_lookup = model.metric_lookup
self._semantic_model_lookup = model.semantic_model_lookup
Expand Down Expand Up @@ -168,6 +169,7 @@ def _top_fuzzy_matches(

def parse_and_validate_query(
self,
saved_query_name: Optional[str] = None,
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
group_by_names: Optional[Sequence[str]] = None,
Expand All @@ -188,6 +190,7 @@ def parse_and_validate_query(
start_time = time.time()
try:
return self._parse_and_validate_query(
saved_query_name=saved_query_name,
metric_names=metric_names,
metrics=metrics,
group_by_names=group_by_names,
Expand Down Expand Up @@ -290,11 +293,24 @@ def _construct_metric_specs_for_query(
return tuple(metric_specs)

def _get_group_by_names(
self, group_by_names: Optional[Sequence[str]], group_by: Optional[Sequence[QueryParameter]]
self,
saved_query_name: Optional[str],
group_by_names: Optional[Sequence[str]],
group_by: Optional[Sequence[QueryParameter]],
) -> Sequence[str]:
assert not (
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"

if saved_query_name is not None:
if group_by_names or group_by:
raise InvalidQueryException(
"When a saved query is specified, group-by items should not be specified at query-time."
)
return self._get_saved_query(saved_query_name).group_by_item_names
elif group_by and group_by_names:
raise InvalidQueryException("At most one of the parameters `group_by` and `group_by_names` should be set.")

return (
group_by_names
if group_by_names
Expand All @@ -303,32 +319,79 @@ def _get_group_by_names(
else []
)

def _get_saved_query(self, saved_query_name: str) -> SavedQuery: # noqa: D
matching_saved_queries = [
saved_query
for saved_query in self._model.semantic_manifest.saved_queries
if saved_query.name == saved_query_name
]

if len(matching_saved_queries) != 1:
known_saved_query_names = sorted(
saved_query.name for saved_query in self._model.semantic_manifest.saved_queries
)
raise InvalidQueryException(
f"Did not find saved query `{saved_query_name}` in known saved queries:\n"
f"{pformat_big_objects(known_saved_query_names)}"
)

return matching_saved_queries[0]

def _get_metric_names(
self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[QueryInterfaceMetric]]
self,
metric_names: Optional[Sequence[str]],
metrics: Optional[Sequence[QueryInterfaceMetric]],
saved_query_name: Optional[str],
) -> Sequence[str]:
assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
if saved_query_name is not None:
if metric_names or metrics:
raise InvalidQueryException(
"When a saved query is specified, metrics should not be specified at query-time."
)
metric_names = self._get_saved_query(saved_query_name).metric_names
elif metrics and metric_names:
raise InvalidQueryException("At most one of the parameters `metrics` and `metric_names` should be set.")

return metric_names if metric_names else [m.name for m in metrics] if metrics else []

def _get_where_filter(
self,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
saved_query_name: Optional[str],
where_constraint: Optional[WhereFilter],
where_constraint_str: Optional[str],
) -> Optional[WhereFilter]:
assert not (
where_constraint and where_constraint_str
), "Both where_constraint and where_constraint_str were set, but if a where is specified you should only use one of these!"
if saved_query_name is not None:
saved_query_where_filters = self._get_saved_query(saved_query_name).where
if where_constraint or where_constraint_str:
raise InvalidQueryException(
f"The saved query `{saved_query_name}` already defines a where filter, and additional query-time"
f"filters are not yet supported."
)
if len(saved_query_where_filters) == 1:
return saved_query_where_filters[0]
elif len(saved_query_where_filters) > 1:
raise InvalidQueryException(
f"The saved query `{saved_query_name}` defines multiple where filters, which is not yet supported."
)

return None
elif where_constraint and where_constraint_str:
raise InvalidQueryException(
"At most one of the parameters `where_constraint` and `where_constraint_str` should be set."
)

return (
PydanticWhereFilter(where_sql_template=where_constraint_str) if where_constraint_str else where_constraint
)

def _get_order(self, order: Optional[Sequence[str]], order_by: Optional[Sequence[QueryParameter]]) -> Sequence[str]:
assert not (
order and order_by
), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!"
if order and order_by:
raise InvalidQueryException("At most one of the parameters `order` and `order_by` should be set.")
return order if order else [f"{o.name}__{o.grain}" if o.grain else o.name for o in order_by] if order_by else []

def _parse_and_validate_query(
self,
saved_query_name: Optional[str] = None,
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryInterfaceMetric]] = None,
group_by_names: Optional[Sequence[str]] = None,
Expand All @@ -342,9 +405,19 @@ def _parse_and_validate_query(
order_by: Optional[Sequence[QueryParameter]] = None,
time_granularity: Optional[TimeGranularity] = None,
) -> MetricFlowQuerySpec:
metric_names = self._get_metric_names(metric_names, metrics)
group_by_names = self._get_group_by_names(group_by_names, group_by)
where_filter = self._get_where_filter(where_constraint, where_constraint_str)
metric_names = self._get_metric_names(
metric_names=metric_names,
metrics=metrics,
saved_query_name=saved_query_name,
)
group_by_names = self._get_group_by_names(
group_by_names=group_by_names, group_by=group_by, saved_query_name=saved_query_name
)
where_filter = self._get_where_filter(
saved_query_name=saved_query_name,
where_constraint=where_constraint,
where_constraint_str=where_constraint_str,
)
order = self._get_order(order, order_by)

# Get metric references used for validations
Expand Down

0 comments on commit 0f95aa5

Please sign in to comment.