Skip to content

Commit

Permalink
Allow the use of saved queries in the engine / CLI.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Oct 5, 2023
1 parent 32f8ae0 commit 93b2792
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 24 deletions.
13 changes: 10 additions & 3 deletions metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import warnings
from importlib.metadata import version as pkg_version
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Sequence

import click
import jinja2
Expand Down Expand Up @@ -248,13 +248,18 @@ def tutorial(ctx: click.core.Context, cfg: CLIContext, msg: bool, clean: bool) -
default=False,
help="Shows inline descriptions of nodes in displayed SQL",
)
@click.option(
"--saved-query",
required=False,
help="Specify the name of the saved query to use for applicable parameters",
)
@pass_config
@exception_handler
@log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter)
def query(
cfg: CLIContext,
metrics: List[str],
group_by: List[str] = [],
metrics: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[str]] = None,
where: Optional[str] = None,
start_time: Optional[dt.datetime] = None,
end_time: Optional[dt.datetime] = None,
Expand All @@ -266,12 +271,14 @@ def query(
display_plans: bool = False,
decimals: int = DEFAULT_RESULT_DECIMAL_PLACES,
show_sql_descriptions: bool = False,
saved_query: Optional[str] = None,
) -> None:
"""Create a new query with MetricFlow and assembles a MetricFlowQueryResult."""
start = time.time()
spinner = Halo(text="Initiating query…", spinner="dots")
spinner.start()
mf_request = MetricFlowQueryRequest.create_with_random_request_id(
saved_query_name=saved_query,
metric_names=metrics,
group_by_names=group_by,
limit=limit,
Expand Down
3 changes: 2 additions & 1 deletion metricflow/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def query_options(function: Callable) -> Callable:
)(function)
function = click.option(
"--metrics",
type=click_custom.SequenceParamType(min_length=1),
# Validity checks for this parameter was moved to the MetricFlowEngine.
type=click_custom.SequenceParamType(min_length=0),
default="",
help="Metrics to query for: syntax is --metrics bookings or for multiple metrics --metrics bookings,messages",
)(function)
Expand Down
57 changes: 37 additions & 20 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

import pandas as pd
from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimensionTypeParams
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.pretty_print import pformat_big_objects
from dbt_semantic_interfaces.references import EntityReference, MeasureReference, MetricReference
from dbt_semantic_interfaces.type_enums import DimensionType

from metricflow.assert_one_arg import assert_exactly_one_arg_set
from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder
from metricflow.dataflow.builder.node_data_set import (
DataflowPlanNodeOutputDataSetResolver,
Expand Down Expand Up @@ -53,6 +53,7 @@
from metricflow.query.query_parser import MetricFlowQueryParser
from metricflow.random_id import random_id
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.query_param_implementations import SavedQueryParameter
from metricflow.specs.specs import InstanceSpecSet, MetricFlowQuerySpec
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.telemetry.models import TelemetryLevel
Expand Down Expand Up @@ -84,6 +85,8 @@ class MetricFlowQueryType(Enum):
class MetricFlowQueryRequest:
"""Encapsulates the parameters for a metric query.
TODO: This has turned into a bag of parameters that make it difficult to use without a bunch of conditionals.
metric_names: Names of the metrics to query.
metrics: Metric objects to query.
group_by_names: Names of the dimensions and entities to query.
Expand All @@ -100,6 +103,7 @@ class MetricFlowQueryRequest:
"""

request_id: MetricFlowRequestId
saved_query_name: Optional[str] = None
metric_names: Optional[Sequence[str]] = None
metrics: Optional[Sequence[MetricQueryParameter]] = None
group_by_names: Optional[Sequence[str]] = None
Expand All @@ -116,6 +120,7 @@ class MetricFlowQueryRequest:

@staticmethod
def create_with_random_request_id( # noqa: D
saved_query_name: Optional[str] = None,
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
Expand All @@ -130,15 +135,9 @@ def create_with_random_request_id( # noqa: D
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC,
) -> MetricFlowQueryRequest:
assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
assert not (
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"
assert not (
order_by_names and order_by
), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!"
return MetricFlowQueryRequest(
request_id=MetricFlowRequestId(mf_rid=f"{random_id()}"),
saved_query_name=saved_query_name,
metric_names=metric_names,
metrics=metrics,
group_by_names=group_by_names,
Expand Down Expand Up @@ -415,18 +414,36 @@ def all_time_constraint(self) -> TimeRangeConstraint:
return TimeRangeConstraint.all_time()

def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> MetricFlowExplainResult:
query_spec = self._query_parser.parse_and_validate_query(
metric_names=mf_query_request.metric_names,
metrics=mf_query_request.metrics,
group_by_names=mf_query_request.group_by_names,
group_by=mf_query_request.group_by,
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
where_constraint_str=mf_query_request.where_constraint,
order_by_names=mf_query_request.order_by_names,
order_by=mf_query_request.order_by,
)
if mf_query_request.saved_query_name is not None:
if mf_query_request.metrics or mf_query_request.metric_names:
raise InvalidQueryException("Metrics can't be specified with a saved query.")
if mf_query_request.group_by or mf_query_request.group_by_names:
raise InvalidQueryException("Group by items can't be specified with a saved query.")
query_spec = self._query_parser.parse_and_validate_saved_query(
saved_query_parameter=SavedQueryParameter(mf_query_request.saved_query_name),
where_filter=(
PydanticWhereFilter(where_sql_template=mf_query_request.where_constraint)
if mf_query_request.where_constraint is not None
else None
),
limit=mf_query_request.limit,
order_by_parameters=mf_query_request.order_by,
)
else:
if not (mf_query_request.metrics or mf_query_request.metric_names):
raise InvalidQueryException("Metrics must be specified with queries.")
query_spec = self._query_parser.parse_and_validate_query(
metric_names=mf_query_request.metric_names,
metrics=mf_query_request.metrics,
group_by_names=mf_query_request.group_by_names,
group_by=mf_query_request.group_by,
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
where_constraint_str=mf_query_request.where_constraint,
order_by_names=mf_query_request.order_by_names,
order_by=mf_query_request.order_by,
)
logger.info(f"Query spec is:\n{pformat_big_objects(query_spec)}")

if self._semantic_manifest_lookup.metric_lookup.contains_cumulative_or_time_offset_metric(
Expand Down
8 changes: 8 additions & 0 deletions metricflow/protocols/query_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ def order_by(self) -> InputOrderByParameter:
def descending(self) -> bool:
"""Indicates if the order should be ascending or descending."""
raise NotImplementedError


class SavedQueryParameter(Protocol):
"""Name of the saved query to execute."""

@property
def name(self) -> str: # noqa: D
raise NotImplementedError
63 changes: 63 additions & 0 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.pretty_print import pformat_big_objects
from dbt_semantic_interfaces.protocols import SavedQuery
from dbt_semantic_interfaces.protocols.dimension import DimensionType
from dbt_semantic_interfaces.protocols.metric import MetricType
from dbt_semantic_interfaces.protocols.where_filter import WhereFilter
Expand All @@ -32,10 +33,13 @@
GroupByParameter,
MetricQueryParameter,
OrderByQueryParameter,
SavedQueryParameter,
TimeDimensionQueryParameter,
)
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.python_object import parse_object_builder_naming_scheme
from metricflow.specs.query_param_implementations import MetricParameter
from metricflow.specs.specs import (
DimensionSpec,
EntitySpec,
Expand Down Expand Up @@ -172,6 +176,65 @@ def _top_fuzzy_matches(
top_ranked_suggestions = [x for x in top_ranked_suggestions if x[1] > min_score]
return [x[0] for x in top_ranked_suggestions]

def parse_and_validate_saved_query(
self,
saved_query_parameter: SavedQueryParameter,
where_filter: Optional[WhereFilter],
limit: Optional[int],
order_by_parameters: Optional[Sequence[OrderByQueryParameter]],
) -> MetricFlowQuerySpec:
"""Parse and validate a query using parameters from a pre-defined / saved query.
Additional parameters act in conjunction with the parameters in the saved query.
"""
saved_query = self._get_saved_query(saved_query_parameter)

# This logic could be encapsulated in the WhereFilter through a merge interface.
where_conditions: List[str] = []
if saved_query.where is not None:
where_conditions.extend(
tuple(saevd_query_where_filter.where_sql_template for saevd_query_where_filter in saved_query.where)
)
if where_filter is not None:
where_conditions.append(where_filter.where_sql_template)

if len(where_conditions) == 0:
combined_where_filter = None
if len(where_conditions) == 1:
combined_where_filter = PydanticWhereFilter(where_sql_template=where_conditions[0])
else:
where_conditions_with_parenthesis = tuple(f"({where_condition})" for where_condition in where_conditions)
combined_where_filter = PydanticWhereFilter(
where_sql_template=" AND ".join(where_conditions_with_parenthesis)
)
return self.parse_and_validate_query(
metrics=tuple(MetricParameter(name=metric_name) for metric_name in saved_query.metrics),
group_by=tuple(
parse_object_builder_naming_scheme(group_by_item_name) for group_by_item_name in saved_query.group_bys
),
where_constraint=combined_where_filter,
limit=limit,
order_by=order_by_parameters,
)

def _get_saved_query(self, saved_query_parameter: SavedQueryParameter) -> SavedQuery:
matching_saved_queries = [
saved_query
for saved_query in self._model.semantic_manifest.saved_queries
if saved_query.name == saved_query_parameter.name
]

if len(matching_saved_queries) != 1:
known_saved_query_names = sorted(
saved_query.name for saved_query in self._model.semantic_manifest.saved_queries
)
raise InvalidQueryException(
f"Did not find saved query `{saved_query_parameter.name}` in known saved queries:\n"
f"{pformat_big_objects(known_saved_query_names)}"
)

return matching_saved_queries[0]

def parse_and_validate_query(
self,
metric_names: Optional[Sequence[str]] = None,
Expand Down
90 changes: 90 additions & 0 deletions metricflow/specs/python_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from typing import List

from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter

from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.protocols.query_parameter import GroupByParameter
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.specs.query_param_implementations import DimensionOrEntityParameter, TimeDimensionParameter


def parse_object_builder_naming_scheme(group_by_item_name: str) -> GroupByParameter:
"""Parses a string following the object-builder naming scheme into the corresponding GroupByParameter.
The implementation of the query parameter classes seems incomplete and there needs to be follow up with the author
of the query interface classes for the best approach. Right now, it seems like using the where filter is the only
way to handle this conversion. However, it seems like this functionality should be abstracted into a module that
handles operations related to the object-builder naming scheme.
Additional issues:
* The call parameter sets in DSI does not support date part.
* Conversion from the element name / entity path to the name field in the query parameter objects requires going
through StructuredLinkableSpecName.
TODO: Replace this method once the aforementioned issues are resolved.
"""
try:
call_parameter_sets = PydanticWhereFilter(
where_sql_template="{{ " + group_by_item_name + " }}"
).call_parameter_sets
except ParseWhereFilterException as e:
raise InvalidQueryException(f"Error parsing `{group_by_item_name}`") from e

group_by_parameters: List[GroupByParameter] = []

for dimension_call_parameter_set in call_parameter_sets.dimension_call_parameter_sets:
if len(dimension_call_parameter_set.entity_path) != 1:
raise NotImplementedError(
f"DimensionOrEntityParameter only supports a single item in the entity path. Got "
f"{dimension_call_parameter_set} while handling `{group_by_item_name}`"
)
group_by_parameters.append(
DimensionOrEntityParameter(
name=StructuredLinkableSpecName(
element_name=dimension_call_parameter_set.dimension_reference.element_name,
entity_link_names=tuple(
entity_reference.element_name for entity_reference in dimension_call_parameter_set.entity_path
),
).qualified_name
)
)

for entity_call_parameter_set in call_parameter_sets.entity_call_parameter_sets:
if len(entity_call_parameter_set.entity_path) != 1:
raise NotImplementedError(
f"DimensionOrEntityParameter only supports a single item in the entity path. Got "
f"{entity_call_parameter_set} while handling `{group_by_item_name}`"
)
group_by_parameters.append(
DimensionOrEntityParameter(
name=StructuredLinkableSpecName(
element_name=entity_call_parameter_set.entity_reference.element_name,
entity_link_names=tuple(
entity_reference.element_name for entity_reference in entity_call_parameter_set.entity_path
),
).qualified_name
)
)

for time_dimension_parameter_set in call_parameter_sets.time_dimension_call_parameter_sets:
group_by_parameters.append(
TimeDimensionParameter(
name=StructuredLinkableSpecName(
element_name=time_dimension_parameter_set.time_dimension_reference.element_name,
entity_link_names=tuple(
entity_reference.element_name for entity_reference in time_dimension_parameter_set.entity_path
),
).qualified_name,
grain=time_dimension_parameter_set.time_granularity,
)
)

if len(group_by_parameters) != 1:
raise InvalidQueryException(
f"Did not get exactly 1 parameter while parsing `{group_by_item_name}`. Got: {group_by_parameters}"
)

return group_by_parameters[0]
14 changes: 14 additions & 0 deletions metricflow/specs/query_param_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from dataclasses import dataclass
from typing import Optional

from dbt_semantic_interfaces.protocols import ProtocolHint
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.protocols.query_parameter import InputOrderByParameter
from metricflow.protocols.query_parameter import SavedQueryParameter as SavedQueryParameterProtocol
from metricflow.time.date_part import DatePart


Expand Down Expand Up @@ -47,3 +50,14 @@ class OrderByParameter:

order_by: InputOrderByParameter
descending: bool = False


@dataclass(frozen=True)
class SavedQueryParameter(ProtocolHint[SavedQueryParameterProtocol]):
"""Dataclass implementation of SavedQueryParameterProtocol."""

name: str

@override
def _implements_protocol(self) -> SavedQueryParameterProtocol:
return self

0 comments on commit 93b2792

Please sign in to comment.