diff --git a/.changes/unreleased/Features-20230913-173608.yaml b/.changes/unreleased/Features-20230913-173608.yaml new file mode 100644 index 0000000000..64e281be69 --- /dev/null +++ b/.changes/unreleased/Features-20230913-173608.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support Saved Queries in MetricFlow +time: 2023-09-13T17:36:08.874392-07:00 +custom: + Author: plypaul + Issue: "765" diff --git a/metricflow/cli/main.py b/metricflow/cli/main.py index 8da115e26a..22c8462707 100644 --- a/metricflow/cli/main.py +++ b/metricflow/cli/main.py @@ -10,7 +10,7 @@ import time import warnings from importlib.metadata import version as pkg_version -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Sequence import click import jinja2 @@ -248,13 +248,18 @@ def tutorial(ctx: click.core.Context, cfg: CLIContext, msg: bool, clean: bool) - default=False, help="Shows inline descriptions of nodes in displayed SQL", ) +@click.option( + "--saved-query", + required=False, + help="Specify the name of the saved query to use for applicable parameters", +) @pass_config @exception_handler @log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter) def query( cfg: CLIContext, - metrics: List[str], - group_by: List[str] = [], + metrics: Optional[Sequence[str]] = None, + group_by: Optional[Sequence[str]] = None, where: Optional[str] = None, start_time: Optional[dt.datetime] = None, end_time: Optional[dt.datetime] = None, @@ -266,12 +271,14 @@ def query( display_plans: bool = False, decimals: int = DEFAULT_RESULT_DECIMAL_PLACES, show_sql_descriptions: bool = False, + saved_query: Optional[str] = None, ) -> None: """Create a new query with MetricFlow and assembles a MetricFlowQueryResult.""" start = time.time() spinner = Halo(text="Initiating query…", spinner="dots") spinner.start() mf_request = MetricFlowQueryRequest.create_with_random_request_id( + saved_query_name=saved_query, metric_names=metrics, group_by_names=group_by, limit=limit, diff --git a/metricflow/cli/utils.py b/metricflow/cli/utils.py index 9e304c2c80..fbfa9cb827 100644 --- a/metricflow/cli/utils.py +++ b/metricflow/cli/utils.py @@ -46,7 +46,7 @@ def query_options(function: Callable) -> Callable: )(function) function = click.option( "--metrics", - type=click_custom.SequenceParamType(min_length=1), + type=click_custom.SequenceParamType(min_length=0), default="", help="Metrics to query for: syntax is --metrics bookings or for multiple metrics --metrics bookings,messages", )(function) diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 1d459bca78..b209641307 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -13,7 +13,6 @@ from dbt_semantic_interfaces.references import EntityReference, MeasureReference, MetricReference from dbt_semantic_interfaces.type_enums import DimensionType -from metricflow.assert_one_arg import assert_exactly_one_arg_set from metricflow.dataflow.builder.dataflow_plan_builder import DataflowPlanBuilder from metricflow.dataflow.builder.node_data_set import ( DataflowPlanNodeOutputDataSetResolver, @@ -97,6 +96,7 @@ class MetricFlowQueryRequest: """ request_id: MetricFlowRequestId + saved_query_name: Optional[str] = None metric_names: Optional[Sequence[str]] = None metrics: Optional[Sequence[QueryInterfaceMetric]] = None group_by_names: Optional[Sequence[str]] = None @@ -113,6 +113,7 @@ class MetricFlowQueryRequest: @staticmethod def create_with_random_request_id( # noqa: D + saved_query_name: Optional[str] = None, metric_names: Optional[Sequence[str]] = None, metrics: Optional[Sequence[QueryInterfaceMetric]] = None, group_by_names: Optional[Sequence[str]] = None, @@ -127,15 +128,9 @@ def create_with_random_request_id( # noqa: D sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4, query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC, ) -> MetricFlowQueryRequest: - assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics) - assert not ( - group_by_names and group_by - ), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!" - assert not ( - order_by_names and order_by - ), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!" return MetricFlowQueryRequest( request_id=MetricFlowRequestId(mf_rid=f"{random_id()}"), + saved_query_name=saved_query_name, metric_names=metric_names, metrics=metrics, group_by_names=group_by_names, @@ -413,6 +408,7 @@ def all_time_constraint(self) -> TimeRangeConstraint: def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> MetricFlowExplainResult: query_spec = self._query_parser.parse_and_validate_query( + saved_query_name=mf_query_request.saved_query_name, metric_names=mf_query_request.metric_names, metrics=mf_query_request.metrics, group_by_names=mf_query_request.group_by_names, diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index 0027e68322..17c395a93d 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -11,6 +11,7 @@ from dbt_semantic_interfaces.pretty_print import pformat_big_objects from dbt_semantic_interfaces.protocols.dimension import DimensionType from dbt_semantic_interfaces.protocols.metric import MetricType +from dbt_semantic_interfaces.protocols.saved_query import SavedQuery from dbt_semantic_interfaces.protocols.where_filter import WhereFilter from dbt_semantic_interfaces.references import ( DimensionReference, @@ -20,7 +21,6 @@ ) from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity -from metricflow.assert_one_arg import assert_exactly_one_arg_set from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver from metricflow.dataflow.dataflow_plan import BaseOutput from metricflow.dataset.dataset import DataSet @@ -116,6 +116,7 @@ def __init__( # noqa: D node_output_resolver: DataflowPlanNodeOutputDataSetResolver, ) -> None: self._column_association_resolver = column_association_resolver + # TODO: Rename model -> manifest lookup self._model = model self._metric_lookup = model.metric_lookup self._semantic_model_lookup = model.semantic_model_lookup @@ -168,6 +169,7 @@ def _top_fuzzy_matches( def parse_and_validate_query( self, + saved_query_name: Optional[str] = None, metric_names: Optional[Sequence[str]] = None, metrics: Optional[Sequence[QueryInterfaceMetric]] = None, group_by_names: Optional[Sequence[str]] = None, @@ -188,6 +190,7 @@ def parse_and_validate_query( start_time = time.time() try: return self._parse_and_validate_query( + saved_query_name=saved_query_name, metric_names=metric_names, metrics=metrics, group_by_names=group_by_names, @@ -290,11 +293,24 @@ def _construct_metric_specs_for_query( return tuple(metric_specs) def _get_group_by_names( - self, group_by_names: Optional[Sequence[str]], group_by: Optional[Sequence[QueryParameter]] + self, + saved_query_name: Optional[str], + group_by_names: Optional[Sequence[str]], + group_by: Optional[Sequence[QueryParameter]], ) -> Sequence[str]: assert not ( group_by_names and group_by ), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!" + + if saved_query_name is not None: + if group_by_names or group_by: + raise InvalidQueryException( + "When a saved query is specified, group-by items should not be specified at query-time." + ) + return self._get_saved_query(saved_query_name).group_by_item_names + elif group_by and group_by_names: + raise InvalidQueryException("At most one of the parameters `group_by` and `group_by_names` should be set.") + return ( group_by_names if group_by_names @@ -303,32 +319,79 @@ def _get_group_by_names( else [] ) + def _get_saved_query(self, saved_query_name: str) -> SavedQuery: # noqa: D + matching_saved_queries = [ + saved_query + for saved_query in self._model.semantic_manifest.saved_queries + if saved_query.name == saved_query_name + ] + + if len(matching_saved_queries) != 1: + known_saved_query_names = sorted( + saved_query.name for saved_query in self._model.semantic_manifest.saved_queries + ) + raise InvalidQueryException( + f"Did not find saved query `{saved_query_name}` in known saved queries:\n" + f"{pformat_big_objects(known_saved_query_names)}" + ) + + return matching_saved_queries[0] + def _get_metric_names( - self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[QueryInterfaceMetric]] + self, + metric_names: Optional[Sequence[str]], + metrics: Optional[Sequence[QueryInterfaceMetric]], + saved_query_name: Optional[str], ) -> Sequence[str]: - assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics) + if saved_query_name is not None: + if metric_names or metrics: + raise InvalidQueryException( + "When a saved query is specified, metrics should not be specified at query-time." + ) + metric_names = self._get_saved_query(saved_query_name).metric_names + elif metrics and metric_names: + raise InvalidQueryException("At most one of the parameters `metrics` and `metric_names` should be set.") + return metric_names if metric_names else [m.name for m in metrics] if metrics else [] def _get_where_filter( self, - where_constraint: Optional[WhereFilter] = None, - where_constraint_str: Optional[str] = None, + saved_query_name: Optional[str], + where_constraint: Optional[WhereFilter], + where_constraint_str: Optional[str], ) -> Optional[WhereFilter]: - assert not ( - where_constraint and where_constraint_str - ), "Both where_constraint and where_constraint_str were set, but if a where is specified you should only use one of these!" + if saved_query_name is not None: + saved_query_where_filters = self._get_saved_query(saved_query_name).where + if where_constraint or where_constraint_str: + raise InvalidQueryException( + f"The saved query `{saved_query_name}` already defines a where filter, and additional query-time" + f"filters are not yet supported." + ) + if len(saved_query_where_filters) == 1: + return saved_query_where_filters[0] + elif len(saved_query_where_filters) > 1: + raise InvalidQueryException( + f"The saved query `{saved_query_name}` defines multiple where filters, which is not yet supported." + ) + + return None + elif where_constraint and where_constraint_str: + raise InvalidQueryException( + "At most one of the parameters `where_constraint` and `where_constraint_str` should be set." + ) + return ( PydanticWhereFilter(where_sql_template=where_constraint_str) if where_constraint_str else where_constraint ) def _get_order(self, order: Optional[Sequence[str]], order_by: Optional[Sequence[QueryParameter]]) -> Sequence[str]: - assert not ( - order and order_by - ), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!" + if order and order_by: + raise InvalidQueryException("At most one of the parameters `order` and `order_by` should be set.") return order if order else [f"{o.name}__{o.grain}" if o.grain else o.name for o in order_by] if order_by else [] def _parse_and_validate_query( self, + saved_query_name: Optional[str] = None, metric_names: Optional[Sequence[str]] = None, metrics: Optional[Sequence[QueryInterfaceMetric]] = None, group_by_names: Optional[Sequence[str]] = None, @@ -342,9 +405,19 @@ def _parse_and_validate_query( order_by: Optional[Sequence[QueryParameter]] = None, time_granularity: Optional[TimeGranularity] = None, ) -> MetricFlowQuerySpec: - metric_names = self._get_metric_names(metric_names, metrics) - group_by_names = self._get_group_by_names(group_by_names, group_by) - where_filter = self._get_where_filter(where_constraint, where_constraint_str) + metric_names = self._get_metric_names( + metric_names=metric_names, + metrics=metrics, + saved_query_name=saved_query_name, + ) + group_by_names = self._get_group_by_names( + group_by_names=group_by_names, group_by=group_by, saved_query_name=saved_query_name + ) + where_filter = self._get_where_filter( + saved_query_name=saved_query_name, + where_constraint=where_constraint, + where_constraint_str=where_constraint_str, + ) order = self._get_order(order, order_by) # Get metric references used for validations diff --git a/metricflow/test/cli/test_cli.py b/metricflow/test/cli/test_cli.py index cc7a12f3c3..a03944e64b 100644 --- a/metricflow/test/cli/test_cli.py +++ b/metricflow/test/cli/test_cli.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import shutil import textwrap from contextlib import contextmanager @@ -7,6 +8,7 @@ from typing import Iterator import pytest +from _pytest.fixtures import FixtureRequest from dbt_semantic_interfaces.parsing.dir_to_model import ( parse_yaml_files_to_validation_ready_semantic_manifest, ) @@ -25,7 +27,14 @@ ) from metricflow.protocols.sql_client import SqlEngine from metricflow.test.fixtures.cli_fixtures import MetricFlowCliRunner +from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState from metricflow.test.model.example_project_configuration import EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE +from metricflow.test.snapshot_utils import assert_object_snapshot_equal + +logger = logging.getLogger(__name__) + + +# TODO: Use snapshots to compare CLI output for all tests here. def test_query(capsys: pytest.CaptureFixture, cli_runner: MetricFlowCliRunner) -> None: # noqa: D @@ -144,3 +153,38 @@ def test_list_entities(capsys: pytest.CaptureFixture, cli_runner: MetricFlowCliR assert "guest" in resp.output assert "host" in resp.output + + +def test_saved_query( # noqa: D + request: FixtureRequest, + capsys: pytest.CaptureFixture, + mf_test_session_state: MetricFlowTestSessionState, + cli_runner: MetricFlowCliRunner, +) -> None: + # Disabling capsys to resolve error "ValueError: I/O operation on closed file". Better solution TBD. + with capsys.disabled(): + resp = cli_runner.run( + query, args=["--saved-query", "p0_booking", "--order", "metric_time__day,listing__capacity_latest"] + ) + + assert resp.exit_code == 0 + + assert_object_snapshot_equal( + request=request, mf_test_session_state=mf_test_session_state, obj_id="cli_output", obj=resp.output + ) + + +def test_saved_query_explain( # noqa: D + capsys: pytest.CaptureFixture, + mf_test_session_state: MetricFlowTestSessionState, + cli_runner: MetricFlowCliRunner, +) -> None: + # Disabling capsys to resolve error "ValueError: I/O operation on closed file". Better solution TBD. + with capsys.disabled(): + resp = cli_runner.run( + query, + args=["--explain", "--saved-query", "p0_booking", "--order", "metric_time__day,listing__capacity_latest"], + ) + + # Currently difficult to compare explain output due to randomly generated IDs. + assert resp.exit_code == 0 diff --git a/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/saved_queries.yaml b/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/saved_queries.yaml new file mode 100644 index 0000000000..77b4ebf832 --- /dev/null +++ b/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/saved_queries.yaml @@ -0,0 +1,12 @@ +--- +saved_query: + name: p0_booking + description: Booking-related metrics that are of the highest priority. + metrics: + - bookings + - instant_bookings + group_bys: + - metric_time__day + - listing__capacity_latest + where: + - "{{ Dimension('listing__capacity_latest') }} > 3" diff --git a/metricflow/test/snapshot_utils.py b/metricflow/test/snapshot_utils.py index ac4fcad18f..838420c029 100644 --- a/metricflow/test/snapshot_utils.py +++ b/metricflow/test/snapshot_utils.py @@ -172,9 +172,10 @@ def assert_snapshot_text_equal( # Create parent directory for the plan text files. os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "w") as snapshot_text_file: - snapshot_text_file.write(snapshot_text) # Add a new line at the end of the file so that PRSs don't show the "no newline" symbol on Github. - snapshot_text_file.write("\n") + if len(snapshot_text) > 1 and snapshot_text[-1] != "\n": + snapshot_text = snapshot_text + "\n" + snapshot_text_file.write(snapshot_text) # Throw an exception if the plan is not there. if not os.path.exists(file_path): @@ -198,8 +199,7 @@ def assert_snapshot_text_equal( # Read the existing plan from the file and compare with the actual plan with open(file_path, "r") as snapshot_text_file: - # Remove the newline that was added from above. - expected_snapshot_text = snapshot_text_file.read().rstrip() + expected_snapshot_text = snapshot_text_file.read() if exclude_line_regex: # Filter out lines that should be ignored. diff --git a/metricflow/test/snapshots/test_cli.py/str/test_saved_query__cli_output.txt b/metricflow/test/snapshots/test_cli.py/str/test_saved_query__cli_output.txt new file mode 100644 index 0000000000..534002ae25 --- /dev/null +++ b/metricflow/test/snapshots/test_cli.py/str/test_saved_query__cli_output.txt @@ -0,0 +1,11 @@ +| metric_time__day | listing__capacity_latest | bookings | instant_bookings | +|:--------------------|---------------------------:|-----------:|-------------------:| +| 2019-12-01 00:00:00 | 5.00 | 1.00 | 0.00 | +| 2019-12-18 00:00:00 | 4.00 | 4.00 | 2.00 | +| 2019-12-19 00:00:00 | 4.00 | 6.00 | 6.00 | +| 2019-12-19 00:00:00 | 5.00 | 2.00 | 0.00 | +| 2019-12-20 00:00:00 | 5.00 | 2.00 | 0.00 | +| 2020-01-01 00:00:00 | 4.00 | 2.00 | 1.00 | +| 2020-01-02 00:00:00 | 4.00 | 3.00 | 3.00 | +| 2020-01-02 00:00:00 | 5.00 | 1.00 | 0.00 | +| 2020-01-03 00:00:00 | 5.00 | 1.00 | 0.00 |