From 53acaed06df97959e9b7aec66ac2a307af8c1e2b Mon Sep 17 00:00:00 2001 From: Devon Fulcher Date: Wed, 20 Sep 2023 11:47:45 -0500 Subject: [PATCH] Support for sort order in query interface (#775) * added support for descending * added tests * changie * fixed linting & test * moved query interface and query parameters into protocols folder * lint & format * fixed typing --- .../unreleased/Features-20230918-155524.yaml | 6 +++ metricflow/engine/metricflow_engine.py | 22 ++++----- metricflow/errors/errors.py | 7 +++ .../{specs => protocols}/query_interface.py | 38 +++++----------- metricflow/protocols/query_parameter.py | 45 +++++++++++++++++++ metricflow/query/query_parser.py | 22 ++++----- .../specs/query_param_implementations.py | 1 + metricflow/specs/where_filter_dimension.py | 11 ++++- metricflow/specs/where_filter_entity.py | 9 +++- .../specs/where_filter_time_dimension.py | 8 +++- metricflow/test/conftest.py | 3 +- metricflow/test/query/test_query_parser.py | 26 ++++++----- .../test/specs/test_where_filter_dimension.py | 11 +++++ .../test/specs/test_where_filter_entity.py | 11 +++++ 14 files changed, 154 insertions(+), 66 deletions(-) create mode 100644 .changes/unreleased/Features-20230918-155524.yaml rename metricflow/{specs => protocols}/query_interface.py (75%) create mode 100644 metricflow/protocols/query_parameter.py create mode 100644 metricflow/test/specs/test_where_filter_dimension.py create mode 100644 metricflow/test/specs/test_where_filter_entity.py diff --git a/.changes/unreleased/Features-20230918-155524.yaml b/.changes/unreleased/Features-20230918-155524.yaml new file mode 100644 index 0000000000..d6d4124074 --- /dev/null +++ b/.changes/unreleased/Features-20230918-155524.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Support for sort order in query interface +time: 2023-09-18T15:55:24.086263-05:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index af8aafc89a..21f7bc899b 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -47,12 +47,12 @@ DataflowToExecutionPlanConverter, ) from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter +from metricflow.protocols.query_parameter import QueryParameterDimension, QueryParameterMetric from metricflow.protocols.sql_client import SqlClient from metricflow.query.query_exceptions import InvalidQueryException from metricflow.query.query_parser import MetricFlowQueryParser from metricflow.random_id import random_id from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.query_interface import QueryInterfaceMetric, QueryParameter from metricflow.specs.specs import InstanceSpecSet, MetricFlowQuerySpec from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from metricflow.telemetry.models import TelemetryLevel @@ -98,15 +98,15 @@ class MetricFlowQueryRequest: request_id: MetricFlowRequestId metric_names: Optional[Sequence[str]] = None - metrics: Optional[Sequence[QueryInterfaceMetric]] = None + metrics: Optional[Sequence[QueryParameterMetric]] = None group_by_names: Optional[Sequence[str]] = None - group_by: Optional[Sequence[QueryParameter]] = None + group_by: Optional[Sequence[QueryParameterDimension]] = None limit: Optional[int] = None time_constraint_start: Optional[datetime.datetime] = None time_constraint_end: Optional[datetime.datetime] = None where_constraint: Optional[str] = None order_by_names: Optional[Sequence[str]] = None - order_by: Optional[Sequence[QueryParameter]] = None + order_by: Optional[Sequence[QueryParameterDimension]] = None output_table: Optional[str] = None sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4 query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC @@ -114,15 +114,15 @@ class MetricFlowQueryRequest: @staticmethod def create_with_random_request_id( # noqa: D metric_names: Optional[Sequence[str]] = None, - metrics: Optional[Sequence[QueryInterfaceMetric]] = None, + metrics: Optional[Sequence[QueryParameterMetric]] = None, group_by_names: Optional[Sequence[str]] = None, - group_by: Optional[Sequence[QueryParameter]] = None, + group_by: Optional[Sequence[QueryParameterDimension]] = None, limit: Optional[int] = None, time_constraint_start: Optional[datetime.datetime] = None, time_constraint_end: Optional[datetime.datetime] = None, where_constraint: Optional[str] = None, order_by_names: Optional[Sequence[str]] = None, - order_by: Optional[Sequence[QueryParameter]] = None, + order_by: Optional[Sequence[QueryParameterDimension]] = None, output_table: Optional[str] = None, sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4, query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC, @@ -286,9 +286,9 @@ def get_dimension_values( def explain_get_dimension_values( # noqa: D self, metric_names: Optional[List[str]] = None, - metrics: Optional[Sequence[QueryInterfaceMetric]] = None, + metrics: Optional[Sequence[QueryParameterMetric]] = None, get_group_by_values: Optional[str] = None, - group_by: Optional[QueryParameter] = None, + group_by: Optional[QueryParameterDimension] = None, time_constraint_start: Optional[datetime.datetime] = None, time_constraint_end: Optional[datetime.datetime] = None, ) -> MetricFlowExplainResult: @@ -682,9 +682,9 @@ def get_dimension_values( # noqa: D def explain_get_dimension_values( # noqa: D self, metric_names: Optional[List[str]] = None, - metrics: Optional[Sequence[QueryInterfaceMetric]] = None, + metrics: Optional[Sequence[QueryParameterMetric]] = None, get_group_by_values: Optional[str] = None, - group_by: Optional[QueryParameter] = None, + group_by: Optional[QueryParameterDimension] = None, time_constraint_start: Optional[datetime.datetime] = None, time_constraint_end: Optional[datetime.datetime] = None, ) -> MetricFlowExplainResult: diff --git a/metricflow/errors/errors.py b/metricflow/errors/errors.py index d27b0e01dc..bddb576e8b 100644 --- a/metricflow/errors/errors.py +++ b/metricflow/errors/errors.py @@ -75,3 +75,10 @@ class SqlBindParametersNotSupportedError(Exception): class UnknownMetricLinkingError(ValueError): """Raised during linking when a user attempts to use a metric that isn't specified.""" + + +class InvalidQuerySyntax(Exception): + """Raised when query syntax is invalid. Primarily used in the where clause.""" + + def __init__(self, msg: str) -> None: # noqa: D + super().__init__(msg) diff --git a/metricflow/specs/query_interface.py b/metricflow/protocols/query_interface.py similarity index 75% rename from metricflow/specs/query_interface.py rename to metricflow/protocols/query_interface.py index 1939c7dbe2..ec120699c9 100644 --- a/metricflow/specs/query_interface.py +++ b/metricflow/protocols/query_interface.py @@ -2,36 +2,12 @@ from typing import Optional, Protocol, Sequence -from dbt_semantic_interfaces.type_enums import TimeGranularity - -from metricflow.time.date_part import DatePart - class QueryInterfaceMetric(Protocol): - """Metric in the query interface.""" - - @property - def name(self) -> str: - """The name of the metric.""" - raise NotImplementedError + """Represents the interface for Metric in the query interface.""" - -class QueryParameter(Protocol): - """A query parameter with a grain.""" - - @property - def name(self) -> str: - """The name of the item.""" - raise NotImplementedError - - @property - def grain(self) -> Optional[TimeGranularity]: - """The time granularity.""" - raise NotImplementedError - - @property - def date_part(self) -> Optional[DatePart]: - """Date part to extract from the dimension.""" + def descending(self, _is_descending: bool) -> QueryInterfaceMetric: + """Set the sort order for order-by.""" raise NotImplementedError @@ -46,6 +22,9 @@ def alias(self, _alias: str) -> QueryInterfaceDimension: """Renaming the column.""" raise NotImplementedError + def descending(self, _is_descending: bool) -> QueryInterfaceDimension: + """Set the sort order for order-by.""" + def date_part(self, _date_part: str) -> QueryInterfaceDimension: """Date part to extract from the dimension.""" raise NotImplementedError @@ -78,6 +57,7 @@ def create( self, time_dimension_name: str, time_granularity_name: str, + descending: bool = False, date_part_name: Optional[str] = None, entity_path: Sequence[str] = (), ) -> QueryInterfaceTimeDimension: @@ -88,7 +68,9 @@ def create( class QueryInterfaceEntity(Protocol): """Represents the interface for Entity in the query interface.""" - pass + def descending(self, _is_descending: bool) -> QueryInterfaceEntity: + """Set the sort order for order-by.""" + raise NotImplementedError class QueryInterfaceEntityFactory(Protocol): diff --git a/metricflow/protocols/query_parameter.py b/metricflow/protocols/query_parameter.py new file mode 100644 index 0000000000..2b6b4e02bc --- /dev/null +++ b/metricflow/protocols/query_parameter.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Optional, Protocol + +from dbt_semantic_interfaces.type_enums import TimeGranularity + +from metricflow.time.date_part import DatePart + + +class QueryParameterDimension(Protocol): + """A query parameter with a grain.""" + + @property + def name(self) -> str: + """The name of the item.""" + raise NotImplementedError + + @property + def grain(self) -> Optional[TimeGranularity]: + """The time granularity.""" + raise NotImplementedError + + @property + def descending(self) -> bool: + """Set the sort order for order-by.""" + raise NotImplementedError + + @property + def date_part(self) -> Optional[DatePart]: + """Date part to extract from the dimension.""" + raise NotImplementedError + + +class QueryParameterMetric(Protocol): + """Metric in the query interface.""" + + @property + def name(self) -> str: + """The name of the metric.""" + raise NotImplementedError + + @property + def descending(self) -> bool: + """Set the sort order for order-by.""" + raise NotImplementedError diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index b1ec546c71..a8e7deda0b 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -28,9 +28,9 @@ from metricflow.filters.time_constraint import TimeRangeConstraint from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName +from metricflow.protocols.query_parameter import QueryParameterDimension, QueryParameterMetric from metricflow.query.query_exceptions import InvalidQueryException from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.query_interface import QueryInterfaceMetric, QueryParameter from metricflow.specs.specs import ( DimensionSpec, EntitySpec, @@ -169,16 +169,16 @@ def _top_fuzzy_matches( def parse_and_validate_query( self, metric_names: Optional[Sequence[str]] = None, - metrics: Optional[Sequence[QueryInterfaceMetric]] = None, + metrics: Optional[Sequence[QueryParameterMetric]] = None, group_by_names: Optional[Sequence[str]] = None, - group_by: Optional[Sequence[QueryParameter]] = None, + group_by: Optional[Sequence[QueryParameterDimension]] = None, limit: Optional[int] = None, time_constraint_start: Optional[datetime.datetime] = None, time_constraint_end: Optional[datetime.datetime] = None, where_constraint: Optional[WhereFilter] = None, where_constraint_str: Optional[str] = None, order: Optional[Sequence[str]] = None, - order_by: Optional[Sequence[QueryParameter]] = None, + order_by: Optional[Sequence[QueryParameterDimension]] = None, time_granularity: Optional[TimeGranularity] = None, ) -> MetricFlowQuerySpec: """Parse the query into spec objects, validating them in the process. @@ -290,7 +290,7 @@ def _construct_metric_specs_for_query( return tuple(metric_specs) def _get_metric_names( - self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[QueryInterfaceMetric]] + self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[QueryParameterMetric]] ) -> Sequence[str]: assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics) return metric_names if metric_names else [m.name for m in metrics] if metrics else [] @@ -307,7 +307,9 @@ def _get_where_filter( PydanticWhereFilter(where_sql_template=where_constraint_str) if where_constraint_str else where_constraint ) - def _get_order(self, order: Optional[Sequence[str]], order_by: Optional[Sequence[QueryParameter]]) -> Sequence[str]: + def _get_order( + self, order: Optional[Sequence[str]], order_by: Optional[Sequence[QueryParameterDimension]] + ) -> Sequence[str]: assert not ( order and order_by ), "Both order_by_names and order_by were set, but if an order by is specified you should only use one of these!" @@ -316,16 +318,16 @@ def _get_order(self, order: Optional[Sequence[str]], order_by: Optional[Sequence def _parse_and_validate_query( self, metric_names: Optional[Sequence[str]] = None, - metrics: Optional[Sequence[QueryInterfaceMetric]] = None, + metrics: Optional[Sequence[QueryParameterMetric]] = None, group_by_names: Optional[Sequence[str]] = None, - group_by: Optional[Sequence[QueryParameter]] = None, + group_by: Optional[Sequence[QueryParameterDimension]] = None, limit: Optional[int] = None, time_constraint_start: Optional[datetime.datetime] = None, time_constraint_end: Optional[datetime.datetime] = None, where_constraint: Optional[WhereFilter] = None, where_constraint_str: Optional[str] = None, order: Optional[Sequence[str]] = None, - order_by: Optional[Sequence[QueryParameter]] = None, + order_by: Optional[Sequence[QueryParameterDimension]] = None, time_granularity: Optional[TimeGranularity] = None, ) -> MetricFlowQuerySpec: metric_names = self._get_metric_names(metric_names, metrics) @@ -665,7 +667,7 @@ def _parse_linkable_elements( self, metric_references: Sequence[MetricReference], qualified_linkable_names: Optional[Sequence[str]] = None, - linkable_elements: Optional[Sequence[QueryParameter]] = None, + linkable_elements: Optional[Sequence[QueryParameterDimension]] = None, ) -> QueryTimeLinkableSpecSet: """Convert the linkable spec names into the respective specification objects.""" # TODO: refactor to only support group_by object inputs (removing group_by_names param) diff --git a/metricflow/specs/query_param_implementations.py b/metricflow/specs/query_param_implementations.py index 62c42d6937..b791c46189 100644 --- a/metricflow/specs/query_param_implementations.py +++ b/metricflow/specs/query_param_implementations.py @@ -15,6 +15,7 @@ class DimensionQueryParameter: name: str grain: Optional[TimeGranularity] = None + descending: bool = False date_part: Optional[DatePart] = None def __post_init__(self) -> None: # noqa: D diff --git a/metricflow/specs/where_filter_dimension.py b/metricflow/specs/where_filter_dimension.py index d37fec3f8b..655fac2053 100644 --- a/metricflow/specs/where_filter_dimension.py +++ b/metricflow/specs/where_filter_dimension.py @@ -14,11 +14,12 @@ ) from typing_extensions import override -from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.query_interface import ( +from metricflow.errors.errors import InvalidQuerySyntax +from metricflow.protocols.query_interface import ( QueryInterfaceDimension, QueryInterfaceDimensionFactory, ) +from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import DimensionSpec @@ -44,6 +45,12 @@ def alias(self, _alias: str) -> QueryInterfaceDimension: """Renaming the column.""" raise NotImplementedError + def descending(self, _is_descending: bool) -> QueryInterfaceDimension: + """Set the sort order for order-by.""" + raise InvalidQuerySyntax( + "Can't set descending in the where clause. Try setting descending in the order_by clause instead" + ) + def __str__(self) -> str: """Returns the column name. diff --git a/metricflow/specs/where_filter_entity.py b/metricflow/specs/where_filter_entity.py index 83c983fd39..78af16c458 100644 --- a/metricflow/specs/where_filter_entity.py +++ b/metricflow/specs/where_filter_entity.py @@ -11,8 +11,9 @@ from dbt_semantic_interfaces.references import EntityReference from typing_extensions import override +from metricflow.errors.errors import InvalidQuerySyntax +from metricflow.protocols.query_interface import QueryInterfaceEntity, QueryInterfaceEntityFactory from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.query_interface import QueryInterfaceEntity, QueryInterfaceEntityFactory from metricflow.specs.specs import EntitySpec @@ -26,6 +27,12 @@ def _implements_protocol(self) -> QueryInterfaceEntity: def __init__(self, column_name: str): # noqa self.column_name = column_name + def descending(self, _is_descending: bool) -> QueryInterfaceEntity: + """Set the sort order for order-by.""" + raise InvalidQuerySyntax( + "Can't set descending in the where clause. Try setting descending in the order_by clause instead" + ) + def __str__(self) -> str: """Returns the column name. diff --git a/metricflow/specs/where_filter_time_dimension.py b/metricflow/specs/where_filter_time_dimension.py index 3b24a9017e..6dd2bf7388 100644 --- a/metricflow/specs/where_filter_time_dimension.py +++ b/metricflow/specs/where_filter_time_dimension.py @@ -9,8 +9,9 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from typing_extensions import override +from metricflow.errors.errors import InvalidQuerySyntax +from metricflow.protocols.query_interface import QueryInterfaceTimeDimension, QueryInterfaceTimeDimensionFactory from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.query_interface import QueryInterfaceTimeDimension, QueryInterfaceTimeDimensionFactory from metricflow.specs.specs import TimeDimensionSpec @@ -55,10 +56,15 @@ def create( self, time_dimension_name: str, time_granularity_name: str, + descending: bool = False, date_part_name: Optional[str] = None, entity_path: Sequence[str] = (), ) -> WhereFilterTimeDimension: """Create a WhereFilterTimeDimension.""" + if descending: + raise InvalidQuerySyntax( + "Can't set descending in the where clause. Try setting descending in the order_by clause instead" + ) structured_name = DunderedNameFormatter.parse_name(time_dimension_name) call_parameter_set = TimeDimensionCallParameterSet( time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name), diff --git a/metricflow/test/conftest.py b/metricflow/test/conftest.py index 13318c9f16..d2a60d713c 100644 --- a/metricflow/test/conftest.py +++ b/metricflow/test/conftest.py @@ -18,9 +18,10 @@ @dataclass -class MockQueryParameter: +class MockQueryParameterDimension: """This is a mock that is just used to test the query parser.""" name: str grain: Optional[TimeGranularity] = None + descending: bool = False date_part: Optional[DatePart] = None diff --git a/metricflow/test/query/test_query_parser.py b/metricflow/test/query/test_query_parser.py index ba93699eed..e861369302 100644 --- a/metricflow/test/query/test_query_parser.py +++ b/metricflow/test/query/test_query_parser.py @@ -21,7 +21,7 @@ OrderBySpec, TimeDimensionSpec, ) -from metricflow.test.conftest import MockQueryParameter +from metricflow.test.conftest import MockQueryParameterDimension from metricflow.test.fixtures.model_fixtures import query_parser_from_yaml from metricflow.test.model.example_project_configuration import EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE from metricflow.test.time.metric_time_dimension import MTD @@ -186,14 +186,14 @@ def test_query_parser(bookings_query_parser: MetricFlowQueryParser) -> None: # def test_query_parser_with_object_params(bookings_query_parser: MetricFlowQueryParser) -> None: # noqa: D - Metric = namedtuple("Metric", ["name"]) - metric = Metric("bookings") + Metric = namedtuple("Metric", ["name", "descending"]) + metric = Metric("bookings", False) group_by = [ - MockQueryParameter("booking__is_instant"), - MockQueryParameter("listing"), - MockQueryParameter(MTD), + MockQueryParameterDimension("booking__is_instant"), + MockQueryParameterDimension("listing"), + MockQueryParameterDimension(MTD), ] - order_by = [MockQueryParameter(MTD), MockQueryParameter("-bookings")] + order_by = [MockQueryParameterDimension(MTD), MockQueryParameterDimension("-bookings")] query_spec = bookings_query_parser.parse_and_validate_query(metrics=[metric], group_by=group_by, order_by=order_by) assert query_spec.metric_specs == (MetricSpec(element_name="bookings"),) assert query_spec.dimension_specs == ( @@ -414,32 +414,34 @@ def test_date_part_parsing() -> None: with pytest.raises(RequestTimeGranularityException): query_parser.parse_and_validate_query( metric_names=["revenue"], - group_by=[MockQueryParameter(name="metric_time", date_part=DatePart.DOW)], + group_by=[MockQueryParameterDimension(name="metric_time", date_part=DatePart.DOW)], ) # Can't query date part for cumulative metrics with pytest.raises(UnableToSatisfyQueryError): query_parser.parse_and_validate_query( metric_names=["revenue_cumulative"], - group_by=[MockQueryParameter(name="metric_time", date_part=DatePart.YEAR)], + group_by=[MockQueryParameterDimension(name="metric_time", date_part=DatePart.YEAR)], ) # Can't query date part for metrics with offset to grain with pytest.raises(UnableToSatisfyQueryError): query_parser.parse_and_validate_query( metric_names=["revenue_since_start_of_year"], - group_by=[MockQueryParameter(name="metric_time", date_part=DatePart.MONTH)], + group_by=[MockQueryParameterDimension(name="metric_time", date_part=DatePart.MONTH)], ) # Requested granularity doesn't match resolved granularity with pytest.raises(RequestTimeGranularityException): query_parser.parse_and_validate_query( metric_names=["revenue"], - group_by=[MockQueryParameter(name="metric_time", grain=TimeGranularity.YEAR, date_part=DatePart.MONTH)], + group_by=[ + MockQueryParameterDimension(name="metric_time", grain=TimeGranularity.YEAR, date_part=DatePart.MONTH) + ], ) # Date part is compatible query_parser.parse_and_validate_query( metric_names=["revenue"], - group_by=[MockQueryParameter(name="metric_time", date_part=DatePart.MONTH)], + group_by=[MockQueryParameterDimension(name="metric_time", date_part=DatePart.MONTH)], ) diff --git a/metricflow/test/specs/test_where_filter_dimension.py b/metricflow/test/specs/test_where_filter_dimension.py new file mode 100644 index 0000000000..a3c673daae --- /dev/null +++ b/metricflow/test/specs/test_where_filter_dimension.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import pytest + +from metricflow.errors.errors import InvalidQuerySyntax +from metricflow.specs.where_filter_dimension import WhereFilterDimension + + +def test_descending_cannot_be_set() -> None: # noqa + with pytest.raises(InvalidQuerySyntax): + WhereFilterDimension("bookings").descending(True) diff --git a/metricflow/test/specs/test_where_filter_entity.py b/metricflow/test/specs/test_where_filter_entity.py new file mode 100644 index 0000000000..ffbc3ad0a8 --- /dev/null +++ b/metricflow/test/specs/test_where_filter_entity.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import pytest + +from metricflow.errors.errors import InvalidQuerySyntax +from metricflow.specs.where_filter_entity import WhereFilterEntity + + +def test_descending_cannot_be_set() -> None: # noqa + with pytest.raises(InvalidQuerySyntax): + WhereFilterEntity("customer").descending(True)