Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataflow Plan for Min & Max of Distinct Values Query #854

Merged
merged 17 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions metricflow/dag/id_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DATAFLOW_NODE_SET_MEASURE_AGGREGATION_TIME = "sma"
DATAFLOW_NODE_SEMI_ADDITIVE_JOIN_ID_PREFIX = "saj"
DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX = "jts"
DATAFLOW_NODE_MIN_MAX_ID_PREFIX = "mm"
DATAFLOW_NODE_ADD_UUID_COLUMN_PREFIX = "auid"
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"

Expand Down
7 changes: 6 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
MinMaxNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
Expand Down Expand Up @@ -630,8 +631,12 @@ def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Dat
distinct=True,
)

min_max_node: Optional[MinMaxNode] = None
if query_spec.min_max_only:
min_max_node = MinMaxNode(parent_node=distinct_values_node)

sink_node = self.build_sink_node(
parent_node=distinct_values_node,
parent_node=min_max_node or distinct_values_node,
order_by_specs=query_spec.order_by_specs,
limit=query_spec.limit,
)
Expand Down
35 changes: 35 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DATAFLOW_NODE_JOIN_SELF_OVER_TIME_RANGE_ID_PREFIX,
DATAFLOW_NODE_JOIN_TO_STANDARD_OUTPUT_ID_PREFIX,
DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX,
DATAFLOW_NODE_MIN_MAX_ID_PREFIX,
DATAFLOW_NODE_ORDER_BY_LIMIT_ID_PREFIX,
DATAFLOW_NODE_PASS_FILTER_ELEMENTS_ID_PREFIX,
DATAFLOW_NODE_READ_SQL_SOURCE_ID_PREFIX,
Expand Down Expand Up @@ -181,6 +182,10 @@ def visit_metric_time_dimension_transform_node( # noqa: D
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
def visit_min_max_node(self, node: MinMaxNode) -> VisitorOutputT: # noqa: D
pass

@abstractmethod
def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) -> VisitorOutputT: # noqa: D
pass
Expand Down Expand Up @@ -1262,6 +1267,36 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ConstrainT
)


class MinMaxNode(BaseOutput):
"""Calculate the min and max of a single instance data set."""

def __init__(self, parent_node: BaseOutput) -> None: # noqa: D
self._parent_node = parent_node
super().__init__(node_id=self.create_unique_id(), parent_nodes=[parent_node])

@classmethod
def id_prefix(cls) -> str: # noqa: D
return DATAFLOW_NODE_MIN_MAX_ID_PREFIX
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@plypaul is this still ok given your other changes to an enumerated ID prefix type?


def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D
return visitor.visit_min_max_node(self)

@property
def description(self) -> str: # noqa: D
return "Calculate min and max"

@property
def parent_node(self) -> BaseOutput: # noqa: D
return self._parent_node

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D
return isinstance(other_node, self.__class__)

def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> MinMaxNode: # noqa: D
assert len(new_parent_nodes) == 1
return MinMaxNode(parent_node=new_parent_nodes[0])


class AddGeneratedUuidColumnNode(BaseOutput):
"""Adds a UUID column."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
JoinToBaseOutputNode,
JoinToTimeSpineNode,
MetricTimeDimensionTransformNode,
MinMaxNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
Expand Down Expand Up @@ -424,3 +425,7 @@ def visit_join_conversion_events_node( # noqa: D
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerResult: # noqa: D
self._log_visit_node_type(node)
return self._default_handler(node)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
JoinToBaseOutputNode,
JoinToTimeSpineNode,
MetricTimeDimensionTransformNode,
MinMaxNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
Expand Down Expand Up @@ -341,3 +342,7 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> OptimizeBranchResult: # noqa: D
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa: D
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
4 changes: 4 additions & 0 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class MetricFlowQueryRequest:
where_constraint: Optional[str] = None
order_by_names: Optional[Sequence[str]] = None
order_by: Optional[Sequence[OrderByQueryParameter]] = None
min_max_only: bool = False
output_table: Optional[str] = None
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC
Expand All @@ -129,6 +130,7 @@ def create_with_random_request_id( # noqa: D
output_table: Optional[str] = None,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
query_type: MetricFlowQueryType = MetricFlowQueryType.METRIC,
min_max_only: bool = False,
) -> MetricFlowQueryRequest:
return MetricFlowQueryRequest(
request_id=MetricFlowRequestId(mf_rid=f"{random_id()}"),
Expand All @@ -146,6 +148,7 @@ def create_with_random_request_id( # noqa: D
output_table=output_table,
sql_optimization_level=sql_optimization_level,
query_type=query_type,
min_max_only=min_max_only,
)


Expand Down Expand Up @@ -434,6 +437,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
where_constraint_str=mf_query_request.where_constraint,
order_by_names=mf_query_request.order_by_names,
order_by=mf_query_request.order_by,
min_max_only=mf_query_request.min_max_only,
)
logger.info(f"Query spec is:\n{pformat_big_objects(query_spec)}")

Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def visit_entity_spec(self, entity_spec: EntitySpec) -> ColumnAssociation: # no

def visit_metadata_spec(self, metadata_spec: MetadataSpec) -> ColumnAssociation: # noqa: D
return ColumnAssociation(
column_name=metadata_spec.element_name,
column_name=metadata_spec.qualified_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
)

Expand Down
51 changes: 45 additions & 6 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
from dbt_semantic_interfaces.references import MetricModelReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType
from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords

Expand All @@ -27,6 +28,7 @@
JoinToBaseOutputNode,
JoinToTimeSpineNode,
MetricTimeDimensionTransformNode,
MinMaxNode,
OrderByLimitNode,
ReadSqlSourceNode,
SemiAdditiveJoinNode,
Expand All @@ -37,12 +39,7 @@
from metricflow.dataset.dataset import DataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.instances import (
InstanceSet,
MetadataInstance,
MetricInstance,
TimeDimensionInstance,
)
from metricflow.instances import InstanceSet, MetadataInstance, MetricInstance, TimeDimensionInstance
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.instance_converters import (
AddLinkToLinkableElements,
Expand All @@ -51,6 +48,7 @@
AliasAggregatedMeasures,
ChangeAssociatedColumns,
ChangeMeasureAggregationState,
ConvertToMetadata,
CreateSelectColumnForCombineOutputNode,
CreateSelectColumnsForInstances,
CreateSelectColumnsWithMeasuresAggregated,
Expand Down Expand Up @@ -1367,6 +1365,47 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
),
)

def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D
parent_data_set = node.parent_node.accept(self)
parent_table_alias = self._next_unique_table_alias()
assert (
len(parent_data_set.sql_select_node.select_columns) == 1
), "MinMaxNode supports exactly one parent select column."
parent_column_alias = parent_data_set.sql_select_node.select_columns[0].column_alias

select_columns: List[SqlSelectColumn] = []
metadata_instances: List[MetadataInstance] = []
for agg_type in (AggregationType.MIN, AggregationType.MAX):
metadata_spec = MetadataSpec.from_name(name=parent_column_alias, agg_type=agg_type)
output_column_association = self._column_association_resolver.resolve_spec(metadata_spec)
select_columns.append(
SqlSelectColumn(
expr=SqlFunctionExpression.build_expression_from_aggregation_type(
aggregation_type=agg_type,
sql_column_expression=SqlColumnReferenceExpression(
SqlColumnReference(table_alias=parent_table_alias, column_name=parent_column_alias)
),
),
column_alias=output_column_association.column_name,
)
)
metadata_instances.append(
MetadataInstance(associated_columns=(output_column_association,), spec=metadata_spec)
)

return SqlDataSet(
instance_set=parent_data_set.instance_set.transform(ConvertToMetadata(metadata_instances)),
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=tuple(select_columns),
from_source=parent_data_set.sql_select_node,
from_source_alias=parent_table_alias,
joins_descs=(),
group_bys=(),
order_bys=(),
),
)

def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) -> SqlDataSet:
"""Implements the behaviour of AddGeneratedUuidColumnNode.

Expand Down
12 changes: 12 additions & 0 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,18 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
)


class ConvertToMetadata(InstanceSetTransform[InstanceSet]):
"""Removes all instances from old instance set and replaces them with a set of metadata instances."""

def __init__(self, metadata_instances: Sequence[MetadataInstance]) -> None: # noqa: D
self._metadata_instances = metadata_instances

def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D
return InstanceSet(
metadata_instances=tuple(self._metadata_instances),
)


def create_select_columns_for_instance_sets(
column_resolver: ColumnAssociationResolver,
table_alias_to_instance_set: OrderedDict[str, InstanceSet],
Expand Down
43 changes: 43 additions & 0 deletions metricflow/query/issues/parsing/invalid_min_max_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from dataclasses import dataclass

from typing_extensions import override

from metricflow.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow.query.issues.issues_base import (
MetricFlowQueryIssueType,
MetricFlowQueryResolutionIssue,
)
from metricflow.query.resolver_inputs.base_resolver_inputs import MetricFlowQueryResolverInput


@dataclass(frozen=True)
class InvalidMinMaxOnlyIssue(MetricFlowQueryResolutionIssue):
"""Describes an issue with the query where the limit is invalid."""

min_max_only: bool

@staticmethod
def from_parameters( # noqa: D
min_max_only: bool, query_resolution_path: MetricFlowQueryResolutionPath
) -> InvalidMinMaxOnlyIssue:
return InvalidMinMaxOnlyIssue(
issue_type=MetricFlowQueryIssueType.ERROR,
parent_issues=(),
query_resolution_path=query_resolution_path,
min_max_only=min_max_only,
)

@override
def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str:
return "`min_max_only` must be used with exactly one `group_by`, and cannot be used with `metrics`, `order_by`, or `limit`."

@override
def with_path_prefix(self, path_prefix_node: MetricFlowQueryResolutionPath) -> InvalidMinMaxOnlyIssue:
return InvalidMinMaxOnlyIssue(
issue_type=self.issue_type,
parent_issues=tuple(issue.with_path_prefix(path_prefix_node) for issue in self.parent_issues),
query_resolution_path=self.query_resolution_path.with_path_prefix(path_prefix_node),
min_max_only=self.min_max_only,
)
5 changes: 5 additions & 0 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ResolverInputForGroupByItem,
ResolverInputForLimit,
ResolverInputForMetric,
ResolverInputForMinMaxOnly,
ResolverInputForOrderByItem,
ResolverInputForQuery,
ResolverInputForQueryLevelWhereFilterIntersection,
Expand Down Expand Up @@ -307,11 +308,13 @@ def parse_and_validate_query(
where_constraint_str: Optional[str] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
min_max_only: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@courtneyholcomb please coordinate with @plypaul , this file changed dramatically in his pending stack of changes.

) -> MetricFlowQuerySpec:
"""Parse the query into spec objects, validating them in the process.

e.g. make sure that the given metric is a valid metric name.
"""
# TODO: validate min_max_only - can only be called for non-metric queries
assert_at_most_one_arg_set(metric_names=metric_names, metrics=metrics)
assert_at_most_one_arg_set(group_by_names=group_by_names, group_by=group_by)
assert_at_most_one_arg_set(order_by_names=order_by_names, order_by=order_by)
Expand Down Expand Up @@ -423,13 +426,15 @@ def parse_and_validate_query(
resolver_inputs_for_order_by.extend(MetricFlowQueryParser._parse_order_by(order_by=order_by))

resolver_input_for_limit = ResolverInputForLimit(limit=limit)
resolver_input_for_min_max_only = ResolverInputForMinMaxOnly(min_max_only=min_max_only)

resolver_input_for_query = ResolverInputForQuery(
metric_inputs=tuple(resolver_inputs_for_metrics),
group_by_item_inputs=tuple(resolver_inputs_for_group_by_items),
order_by_item_inputs=tuple(resolver_inputs_for_order_by),
limit_input=resolver_input_for_limit,
filter_input=resolver_input_for_filter,
min_max_only=resolver_input_for_min_max_only,
)

logger.info("Resolver input for query is:\n" + indent_log_line(mf_pformat(resolver_input_for_query)))
Expand Down
Loading