diff --git a/metricflow/dataflow/nodes/compute_metrics.py b/metricflow/dataflow/nodes/compute_metrics.py index 10b6146345..ab36ad9e1c 100644 --- a/metricflow/dataflow/nodes/compute_metrics.py +++ b/metricflow/dataflow/nodes/compute_metrics.py @@ -17,21 +17,30 @@ class ComputeMetricsNode(ComputedMetricsOutput): """A node that computes metrics from input measures. Dimensions / entities are passed through.""" - def __init__(self, parent_node: BaseOutput, metric_specs: Sequence[MetricSpec]) -> None: + def __init__( + self, parent_node: BaseOutput, metric_specs: Sequence[MetricSpec], for_group_by_source_node: bool = False + ) -> None: """Constructor. Args: parent_node: Node where data is coming from. metric_specs: The specs for the metrics that this should compute. + for_group_by_source_node: Whether the node is part of a dataflow plan used for a group by source node. """ self._parent_node = parent_node self._metric_specs = tuple(metric_specs) + self._for_group_by_source_node = for_group_by_source_node super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,)) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_COMPUTE_METRICS_ID_PREFIX + @property + def for_group_by_source_node(self) -> bool: + """Whether or not this node is part of a dataflow plan used for a group by source node.""" + return self._for_group_by_source_node + @property def metric_specs(self) -> Sequence[MetricSpec]: """The metric instances that this node is supposed to compute and should have in the output.""" @@ -46,9 +55,12 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - return tuple(super().displayed_properties) + tuple( + displayed_properties = tuple(super().displayed_properties) + tuple( DisplayedProperty("metric_spec", metric_spec) for metric_spec in self._metric_specs ) + if self.for_group_by_source_node: + displayed_properties += (DisplayedProperty("for_group_by_source_node", self.for_group_by_source_node),) + return displayed_properties @property def parent_node(self) -> BaseOutput: # noqa: D102 @@ -61,11 +73,16 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: if other_node.metric_specs != self.metric_specs: return False - return isinstance(other_node, self.__class__) and other_node.metric_specs == self.metric_specs + return ( + isinstance(other_node, self.__class__) + and other_node.metric_specs == self.metric_specs + and other_node.for_group_by_source_node == self.for_group_by_source_node + ) def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ComputeMetricsNode: # noqa: D102 assert len(new_parent_nodes) == 1 return ComputeMetricsNode( parent_node=new_parent_nodes[0], metric_specs=self.metric_specs, + for_group_by_source_node=self.for_group_by_source_node, ) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 22fc363dfb..a477e332d0 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -42,11 +42,18 @@ from metricflow.dataset.dataset import DataSet from metricflow.dataset.sql_dataset import SqlDataSet from metricflow.filters.time_constraint import TimeRangeConstraint -from metricflow.instances import InstanceSet, MetadataInstance, MetricInstance, TimeDimensionInstance +from metricflow.instances import ( + GroupByMetricInstance, + InstanceSet, + MetadataInstance, + MetricInstance, + TimeDimensionInstance, +) from metricflow.mf_logging.formatting import indent from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult from metricflow.plan_conversion.instance_converters import ( + AddGroupByMetrics, AddLinkToLinkableElements, AddMetadata, AddMetrics, @@ -82,6 +89,7 @@ from metricflow.protocols.sql_client import SqlEngine from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey from metricflow.specs.specs import ( + GroupByMetricSpec, InstanceSpecSet, MeasureSpec, MetadataSpec, @@ -601,6 +609,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: # Add select columns that would compute the metrics to the select columns. metric_select_columns = [] metric_instances = [] + group_by_metric_instances = [] for metric_spec in node.metric_specs: metric = self._metric_lookup.get_metric(metric_spec.reference) @@ -714,7 +723,20 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: spec=metric_spec, ) ) - output_instance_set = output_instance_set.transform(AddMetrics(metric_instances)) + group_by_metric_instances.append( + GroupByMetricInstance( + associated_columns=(output_column_association,), + defined_from=MetricModelReference(metric_name=metric_spec.element_name), + spec=GroupByMetricSpec(element_name=metric_spec.element_name, entity_links=()), + ) + ) + + transform_func = ( + AddGroupByMetrics(group_by_metric_instances) + if node.for_group_by_source_node + else AddMetrics(metric_instances) + ) + output_instance_set = output_instance_set.transform(transform_func) combined_select_column_set = non_metric_select_column_set.merge( SelectColumnSet(metric_columns=metric_select_columns) diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 689d709c2d..56d677c947 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -20,6 +20,7 @@ from metricflow.instances import ( DimensionInstance, EntityInstance, + GroupByMetricInstance, InstanceSet, InstanceSetTransform, MdoInstance, @@ -35,6 +36,7 @@ DimensionSpec, EntityReference, EntitySpec, + GroupByMetricSpec, InstanceSpec, InstanceSpecSet, LinkableInstanceSpec, @@ -102,12 +104,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102 metadata_cols = list( chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances]) ) + group_by_metric_cols = list( + chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.group_by_metric_instances]) + ) return SelectColumnSet( metric_columns=metric_cols, measure_columns=measure_cols, dimension_columns=dimension_cols, time_dimension_columns=time_dimension_cols, entity_columns=entity_cols, + group_by_metric_columns=group_by_metric_cols, metadata_columns=metadata_cols, ) @@ -254,12 +260,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102 metadata_cols = list( chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances]) ) + group_by_metric_cols = list( + chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.group_by_metric_instances]) + ) return SelectColumnSet( metric_columns=metric_cols, measure_columns=measure_cols, dimension_columns=dimension_cols, time_dimension_columns=time_dimension_cols, entity_columns=entity_cols, + group_by_metric_columns=group_by_metric_cols, metadata_columns=metadata_cols, ) @@ -444,11 +454,28 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 ) ) + # Handle group by metric instances + group_by_metric_instances_with_additional_link = [] + for group_by_metric_instance in instance_set.group_by_metric_instances: + # The new group by metric spec should include the join on entity. + transformed_group_by_metric_spec_from_right = GroupByMetricSpec( + element_name=group_by_metric_instance.spec.element_name, + entity_links=self._join_on_entity.as_linkless_prefix + group_by_metric_instance.spec.entity_links, + ) + group_by_metric_instances_with_additional_link.append( + GroupByMetricInstance( + associated_columns=group_by_metric_instance.associated_columns, + defined_from=group_by_metric_instance.defined_from, + spec=transformed_group_by_metric_spec_from_right, + ) + ) + return InstanceSet( measure_instances=(), dimension_instances=tuple(dimension_instances_with_additional_link), time_dimension_instances=tuple(time_dimension_instances_with_additional_link), entity_instances=tuple(entity_instances_with_additional_link), + group_by_metric_instances=tuple(group_by_metric_instances_with_additional_link), metric_instances=(), metadata_instances=(), ) @@ -484,12 +511,16 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 x for x in instance_set.time_dimension_instances if self._should_pass(x.spec) ) filtered_entity_instances = tuple(x for x in instance_set.entity_instances if self._should_pass(x.spec)) + filtered_group_by_metric_instances = tuple( + x for x in instance_set.group_by_metric_instances if self._should_pass(x.spec) + ) output = InstanceSet( measure_instances=instance_set.measure_instances, dimension_instances=filtered_dimension_instances, time_dimension_instances=filtered_time_dimension_instances, entity_instances=filtered_entity_instances, + group_by_metric_instances=filtered_group_by_metric_instances, metric_instances=instance_set.metric_instances, metadata_instances=instance_set.metadata_instances, ) @@ -549,6 +580,9 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 x for x in instance_set.time_dimension_instances if self._should_pass(x.spec) ), entity_instances=tuple(x for x in instance_set.entity_instances if self._should_pass(x.spec)), + group_by_metric_instances=tuple( + x for x in instance_set.group_by_metric_instances if self._should_pass(x.spec) + ), metric_instances=tuple(x for x in instance_set.metric_instances if self._should_pass(x.spec)), metadata_instances=tuple(x for x in instance_set.metadata_instances if self._should_pass(x.spec)), ) @@ -588,6 +622,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=instance_set.metric_instances, metadata_instances=instance_set.metadata_instances, ) @@ -629,6 +664,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=instance_set.metric_instances, metadata_instances=instance_set.metadata_instances, ) @@ -683,6 +719,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=instance_set.metric_instances, metadata_instances=instance_set.metadata_instances, ) @@ -700,11 +737,30 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=instance_set.metric_instances + tuple(self._metric_instances), metadata_instances=instance_set.metadata_instances, ) +class AddGroupByMetrics(InstanceSetTransform[InstanceSet]): + """Adds the given metric instances to the instance set.""" + + def __init__(self, group_by_metric_instances: List[GroupByMetricInstance]) -> None: # noqa: D107 + self._group_by_metric_instances = group_by_metric_instances + + def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 + return InstanceSet( + measure_instances=instance_set.measure_instances, + dimension_instances=instance_set.dimension_instances, + time_dimension_instances=instance_set.time_dimension_instances, + entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances + tuple(self._group_by_metric_instances), + metric_instances=instance_set.metric_instances, + metadata_instances=instance_set.metadata_instances, + ) + + class RemoveMeasures(InstanceSetTransform[InstanceSet]): """Remove measures from the instance set.""" @@ -714,6 +770,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=instance_set.metric_instances, metadata_instances=instance_set.metadata_instances, ) @@ -728,6 +785,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=(), metadata_instances=instance_set.metadata_instances, ) @@ -938,11 +996,23 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 ) ) + output_group_by_metric_instances = [] + for input_group_by_metric_instance in instance_set.group_by_metric_instances: + output_group_by_metric_instances.append( + GroupByMetricInstance( + associated_columns=( + self._column_association_resolver.resolve_spec(input_group_by_metric_instance.spec), + ), + spec=input_group_by_metric_instance.spec, + defined_from=input_group_by_metric_instance.defined_from, + ) + ) return InstanceSet( measure_instances=tuple(output_measure_instances), dimension_instances=tuple(output_dimension_instances), time_dimension_instances=tuple(output_time_dimension_instances), entity_instances=tuple(output_entity_instances), + group_by_metric_instances=tuple(output_group_by_metric_instances), metric_instances=tuple(output_metric_instances), metadata_instances=tuple(output_metadata_instances), ) @@ -994,6 +1064,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 dimension_instances=instance_set.dimension_instances, time_dimension_instances=instance_set.time_dimension_instances, entity_instances=instance_set.entity_instances, + group_by_metric_instances=instance_set.group_by_metric_instances, metric_instances=instance_set.metric_instances, metadata_instances=instance_set.metadata_instances + tuple(self._metadata_instances), ) diff --git a/metricflow/plan_conversion/select_column_gen.py b/metricflow/plan_conversion/select_column_gen.py index eee5fa2087..3d7be3fd68 100644 --- a/metricflow/plan_conversion/select_column_gen.py +++ b/metricflow/plan_conversion/select_column_gen.py @@ -18,6 +18,7 @@ class SelectColumnSet: dimension_columns: List[SqlSelectColumn] = field(default_factory=list) time_dimension_columns: List[SqlSelectColumn] = field(default_factory=list) entity_columns: List[SqlSelectColumn] = field(default_factory=list) + group_by_metric_columns: List[SqlSelectColumn] = field(default_factory=list) metadata_columns: List[SqlSelectColumn] = field(default_factory=list) def merge(self, other_set: SelectColumnSet) -> SelectColumnSet: @@ -28,6 +29,7 @@ def merge(self, other_set: SelectColumnSet) -> SelectColumnSet: dimension_columns=self.dimension_columns + other_set.dimension_columns, time_dimension_columns=self.time_dimension_columns + other_set.time_dimension_columns, entity_columns=self.entity_columns + other_set.entity_columns, + group_by_metric_columns=self.group_by_metric_columns + other_set.group_by_metric_columns, metadata_columns=self.metadata_columns + other_set.metadata_columns, ) @@ -38,6 +40,7 @@ def as_tuple(self) -> Tuple[SqlSelectColumn, ...]: self.time_dimension_columns + self.entity_columns + self.dimension_columns + + self.group_by_metric_columns + self.metric_columns + self.measure_columns + self.metadata_columns @@ -50,5 +53,6 @@ def without_measure_columns(self) -> SelectColumnSet: dimension_columns=self.dimension_columns, time_dimension_columns=self.time_dimension_columns, entity_columns=self.entity_columns, + group_by_metric_columns=self.group_by_metric_columns, metadata_columns=self.metadata_columns, )