Skip to content

Commit

Permalink
Specific implementations for query parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 21, 2023
1 parent 47042ca commit 68edcfb
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 100 deletions.
33 changes: 19 additions & 14 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple, Union

import pandas as pd
from dbt_semantic_interfaces.implementations.elements.dimension import PydanticDimensionTypeParams
Expand Down Expand Up @@ -47,7 +47,12 @@
DataflowToExecutionPlanConverter,
)
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.query_parameter import QueryParameterDimension, QueryParameterMetric
from metricflow.protocols.query_parameter import (
GroupByQueryParameter,
MetricQueryParameter,
OrderByQueryParameter,
TimeDimensionQueryParameter,
)
from metricflow.protocols.sql_client import SqlClient
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.query.query_parser import MetricFlowQueryParser
Expand Down Expand Up @@ -98,31 +103,31 @@ class MetricFlowQueryRequest:

request_id: MetricFlowRequestId
metric_names: Optional[Sequence[str]] = None
metrics: Optional[Sequence[QueryParameterMetric]] = None
metrics: Optional[Sequence[MetricQueryParameter]] = None
group_by_names: Optional[Sequence[str]] = None
group_by: Optional[Sequence[QueryParameterDimension]] = None
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None
limit: Optional[int] = None
time_constraint_start: Optional[datetime.datetime] = None
time_constraint_end: Optional[datetime.datetime] = None
where_constraint: Optional[str] = None
order_by_names: Optional[Sequence[str]] = None
order_by: Optional[Sequence[QueryParameterDimension]] = None
order_by: Optional[Sequence[OrderByQueryParameter]] = None
output_table: Optional[str] = None
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC

@staticmethod
def create_with_random_request_id( # noqa: D
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryParameterMetric]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[QueryParameterDimension]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[str] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[QueryParameterDimension]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
output_table: Optional[str] = None,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC,
Expand Down Expand Up @@ -286,9 +291,9 @@ def get_dimension_values(
def explain_get_dimension_values( # noqa: D
self,
metric_names: Optional[List[str]] = None,
metrics: Optional[Sequence[QueryParameterMetric]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
get_group_by_values: Optional[str] = None,
group_by: Optional[QueryParameterDimension] = None,
group_by: Optional[Union[GroupByQueryParameter, TimeDimensionQueryParameter]] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
) -> MetricFlowExplainResult:
Expand Down Expand Up @@ -682,9 +687,9 @@ def get_dimension_values( # noqa: D
def explain_get_dimension_values( # noqa: D
self,
metric_names: Optional[List[str]] = None,
metrics: Optional[Sequence[QueryParameterMetric]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
get_group_by_values: Optional[str] = None,
group_by: Optional[QueryParameterDimension] = None,
group_by: Optional[Union[GroupByQueryParameter, TimeDimensionQueryParameter]] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
) -> MetricFlowExplainResult:
Expand All @@ -695,8 +700,8 @@ def explain_get_dimension_values( # noqa: D
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=metric_names,
metrics=metrics,
group_by_names=[get_group_by_values] if get_group_by_values else None,
group_by=[group_by] if group_by else None,
group_by_names=(get_group_by_values,) if get_group_by_values else None,
group_by=(group_by,) if group_by else None,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
query_type=MetricFlowQueryType.DIMENSION_VALUES,
Expand Down
40 changes: 27 additions & 13 deletions metricflow/protocols/query_parameter.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
from __future__ import annotations

from typing import Optional, Protocol
from typing import Optional, Protocol, Union, runtime_checkable

from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow.time.date_part import DatePart


class QueryParameterDimension(Protocol):
"""A query parameter with a grain."""
@runtime_checkable
class MetricQueryParameter(Protocol):
"""Metric requested in a query."""

@property
def name(self) -> str:
"""The name of the item."""
"""The name of the metric."""
raise NotImplementedError


@runtime_checkable
class GroupByQueryParameter(Protocol):
"""Generic group by parameter for queries. Might be an entity or a dimension."""

@property
def grain(self) -> Optional[TimeGranularity]:
"""The time granularity."""
def name(self) -> str:
"""The name of the metric."""
raise NotImplementedError


@runtime_checkable
class TimeDimensionQueryParameter(Protocol): # noqa: D
@property
def descending(self) -> bool:
"""Set the sort order for order-by."""
def name(self) -> str:
"""The name of the item."""
raise NotImplementedError

@property
def grain(self) -> Optional[TimeGranularity]:
"""The time granularity."""
raise NotImplementedError

@property
Expand All @@ -31,15 +45,15 @@ def date_part(self) -> Optional[DatePart]:
raise NotImplementedError


class QueryParameterMetric(Protocol):
"""Metric in the query interface."""
class OrderByQueryParameter(Protocol):
"""Parameter to order by, specifying ascending or descending."""

@property
def name(self) -> str:
"""The name of the metric."""
def order_by(self) -> Union[MetricQueryParameter, GroupByQueryParameter, TimeDimensionQueryParameter]:
"""Parameter to order results by."""
raise NotImplementedError

@property
def descending(self) -> bool:
"""Set the sort order for order-by."""
"""Indicates if the order should be ascending or descending."""
raise NotImplementedError
70 changes: 37 additions & 33 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
Expand All @@ -28,7 +28,12 @@
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow.protocols.query_parameter import QueryParameterDimension, QueryParameterMetric
from metricflow.protocols.query_parameter import (
GroupByQueryParameter,
MetricQueryParameter,
OrderByQueryParameter,
TimeDimensionQueryParameter,
)
from metricflow.query.query_exceptions import InvalidQueryException
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
Expand Down Expand Up @@ -169,16 +174,16 @@ def _top_fuzzy_matches(
def parse_and_validate_query(
self,
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryParameterMetric]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[QueryParameterDimension]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
order: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[QueryParameterDimension]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
time_granularity: Optional[TimeGranularity] = None,
) -> MetricFlowQuerySpec:
"""Parse the query into spec objects, validating them in the process.
Expand Down Expand Up @@ -290,7 +295,7 @@ def _construct_metric_specs_for_query(
return tuple(metric_specs)

def _get_metric_names(
self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[QueryParameterMetric]]
self, metric_names: Optional[Sequence[str]], metrics: Optional[Sequence[MetricQueryParameter]]
) -> Sequence[str]:
assert_exactly_one_arg_set(metric_names=metric_names, metrics=metrics)
return metric_names if metric_names else [m.name for m in metrics] if metrics else []
Expand All @@ -308,7 +313,7 @@ def _get_where_filter(
)

def _get_order(
self, order: Optional[Sequence[str]], order_by: Optional[Sequence[QueryParameterDimension]]
self, order: Optional[Sequence[str]], order_by: Optional[Sequence[OrderByQueryParameter]]
) -> Sequence[str]:
assert not (
order and order_by
Expand All @@ -318,16 +323,16 @@ def _get_order(
def _parse_and_validate_query(
self,
metric_names: Optional[Sequence[str]] = None,
metrics: Optional[Sequence[QueryParameterMetric]] = None,
metrics: Optional[Sequence[MetricQueryParameter]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Sequence[QueryParameterDimension]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
order: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[QueryParameterDimension]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
time_granularity: Optional[TimeGranularity] = None,
) -> MetricFlowQuerySpec:
metric_names = self._get_metric_names(metric_names, metrics)
Expand Down Expand Up @@ -380,8 +385,8 @@ def _parse_and_validate_query(
# If the time constraint is all time, just ignore and not render
time_constraint = None

requested_linkable_specs = self._parse_linkable_elements(
qualified_linkable_names=group_by_names, linkable_elements=group_by, metric_references=metric_references
requested_linkable_specs = self._parse_group_by(
group_by_names=group_by_names, group_by=group_by, metric_references=metric_references
)
where_filter_spec: Optional[WhereFilterSpec] = None
if where_filter is not None:
Expand Down Expand Up @@ -426,9 +431,9 @@ def _parse_and_validate_query(
for metric_reference in metric_references:
metric = self._metric_lookup.get_metric(metric_reference)
if metric.filter is not None:
group_by_specs_for_one_metric = self._parse_linkable_elements(
qualified_linkable_names=group_by_names,
linkable_elements=group_by,
group_by_specs_for_one_metric = self._parse_group_by(
group_by_names=group_by_names,
group_by=group_by,
metric_references=(metric_reference,),
)

Expand Down Expand Up @@ -663,30 +668,32 @@ def _parse_metric_names(
metric_references.extend(list(input_metrics))
return tuple(metric_references)

def _parse_linkable_elements(
def _parse_group_by(
self,
metric_references: Sequence[MetricReference],
qualified_linkable_names: Optional[Sequence[str]] = None,
linkable_elements: Optional[Sequence[QueryParameterDimension]] = None,
group_by_names: Optional[Sequence[str]] = None,
group_by: Optional[Tuple[Union[GroupByQueryParameter, TimeDimensionQueryParameter], ...]] = None,
) -> QueryTimeLinkableSpecSet:
"""Convert the linkable spec names into the respective specification objects."""
# TODO: refactor to only support group_by object inputs (removing group_by_names param)
assert not (
qualified_linkable_names and linkable_elements
group_by_names and group_by
), "Both group_by_names and group_by were set, but if a group by is specified you should only use one of these!"

structured_names: List[StructuredLinkableSpecName] = []
if qualified_linkable_names:
qualified_linkable_names = [x.lower() for x in qualified_linkable_names]
structured_names = [StructuredLinkableSpecName.from_name(name) for name in qualified_linkable_names]
elif linkable_elements:
for linkable_element in linkable_elements:
parsed_name = StructuredLinkableSpecName.from_name(linkable_element.name)
if group_by_names:
group_by_names = [x.lower() for x in group_by_names]
structured_names = [StructuredLinkableSpecName.from_name(name) for name in group_by_names]
elif group_by:
for group_by_obj in group_by:
parsed_name = StructuredLinkableSpecName.from_name(group_by_obj.name)
structured_name = StructuredLinkableSpecName(
entity_link_names=parsed_name.entity_link_names,
element_name=parsed_name.element_name,
time_granularity=linkable_element.grain,
date_part=linkable_element.date_part,
time_granularity=group_by_obj.grain
if isinstance(group_by_obj, TimeDimensionQueryParameter)
else None,
date_part=group_by_obj.date_part if isinstance(group_by_obj, TimeDimensionQueryParameter) else None,
)
structured_names.append(structured_name)

Expand Down Expand Up @@ -729,15 +736,12 @@ def _parse_linkable_elements(
valid_group_bys_for_metrics = self._metric_lookup.element_specs_for_metrics(list(metric_references))
valid_group_by_names_for_metrics = sorted(
list(
set(
x.qualified_name if qualified_linkable_names else x.element_name
for x in valid_group_bys_for_metrics
)
set(x.qualified_name if group_by_names else x.element_name for x in valid_group_bys_for_metrics)
)
)

# If requested by name, show qualified name. If requested as object, show element name.
display_name = structured_name.qualified_name if qualified_linkable_names else element_name
display_name = structured_name.qualified_name if group_by_names else element_name
suggestions = {
f"Suggestions for '{display_name}'": pformat_big_objects(
MetricFlowQueryParser._top_fuzzy_matches(
Expand All @@ -748,7 +752,7 @@ def _parse_linkable_elements(
}
raise UnableToSatisfyQueryError(
f"Unknown element name '{element_name}' in dimension name '{display_name}'"
if qualified_linkable_names
if group_by_names
else f"Unknown dimension {element_name}",
context=suggestions,
)
Expand Down
Loading

0 comments on commit 68edcfb

Please sign in to comment.