Skip to content

Commit

Permalink
GroupByMetric instance converters (#1100)
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Mar 28, 2024
1 parent fcb7dfa commit bbd2901
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 5 deletions.
23 changes: 20 additions & 3 deletions metricflow/dataflow/nodes/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,30 @@
class ComputeMetricsNode(ComputedMetricsOutput):
"""A node that computes metrics from input measures. Dimensions / entities are passed through."""

def __init__(self, parent_node: BaseOutput, metric_specs: Sequence[MetricSpec]) -> None:
def __init__(
self, parent_node: BaseOutput, metric_specs: Sequence[MetricSpec], for_group_by_source_node: bool = False
) -> None:
"""Constructor.
Args:
parent_node: Node where data is coming from.
metric_specs: The specs for the metrics that this should compute.
for_group_by_source_node: Whether the node is part of a dataflow plan used for a group by source node.
"""
self._parent_node = parent_node
self._metric_specs = tuple(metric_specs)
self._for_group_by_source_node = for_group_by_source_node
super().__init__(node_id=self.create_unique_id(), parent_nodes=(self._parent_node,))

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_COMPUTE_METRICS_ID_PREFIX

@property
def for_group_by_source_node(self) -> bool:
"""Whether or not this node is part of a dataflow plan used for a group by source node."""
return self._for_group_by_source_node

@property
def metric_specs(self) -> Sequence[MetricSpec]:
"""The metric instances that this node is supposed to compute and should have in the output."""
Expand All @@ -46,9 +55,12 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + tuple(
displayed_properties = tuple(super().displayed_properties) + tuple(
DisplayedProperty("metric_spec", metric_spec) for metric_spec in self._metric_specs
)
if self.for_group_by_source_node:
displayed_properties += (DisplayedProperty("for_group_by_source_node", self.for_group_by_source_node),)
return displayed_properties

@property
def parent_node(self) -> BaseOutput: # noqa: D102
Expand All @@ -61,11 +73,16 @@ def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa:
if other_node.metric_specs != self.metric_specs:
return False

return isinstance(other_node, self.__class__) and other_node.metric_specs == self.metric_specs
return (
isinstance(other_node, self.__class__)
and other_node.metric_specs == self.metric_specs
and other_node.for_group_by_source_node == self.for_group_by_source_node
)

def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> ComputeMetricsNode: # noqa: D102
assert len(new_parent_nodes) == 1
return ComputeMetricsNode(
parent_node=new_parent_nodes[0],
metric_specs=self.metric_specs,
for_group_by_source_node=self.for_group_by_source_node,
)
26 changes: 24 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@
from metricflow.dataset.dataset import DataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.instances import InstanceSet, MetadataInstance, MetricInstance, TimeDimensionInstance
from metricflow.instances import (
GroupByMetricInstance,
InstanceSet,
MetadataInstance,
MetricInstance,
TimeDimensionInstance,
)
from metricflow.mf_logging.formatting import indent
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.instance_converters import (
AddGroupByMetrics,
AddLinkToLinkableElements,
AddMetadata,
AddMetrics,
Expand Down Expand Up @@ -82,6 +89,7 @@
from metricflow.protocols.sql_client import SqlEngine
from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey
from metricflow.specs.specs import (
GroupByMetricSpec,
InstanceSpecSet,
MeasureSpec,
MetadataSpec,
Expand Down Expand Up @@ -601,6 +609,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
# Add select columns that would compute the metrics to the select columns.
metric_select_columns = []
metric_instances = []
group_by_metric_instances = []
for metric_spec in node.metric_specs:
metric = self._metric_lookup.get_metric(metric_spec.reference)

Expand Down Expand Up @@ -714,7 +723,20 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
spec=metric_spec,
)
)
output_instance_set = output_instance_set.transform(AddMetrics(metric_instances))
group_by_metric_instances.append(
GroupByMetricInstance(
associated_columns=(output_column_association,),
defined_from=MetricModelReference(metric_name=metric_spec.element_name),
spec=GroupByMetricSpec(element_name=metric_spec.element_name, entity_links=()),
)
)

transform_func = (
AddGroupByMetrics(group_by_metric_instances)
if node.for_group_by_source_node
else AddMetrics(metric_instances)
)
output_instance_set = output_instance_set.transform(transform_func)

combined_select_column_set = non_metric_select_column_set.merge(
SelectColumnSet(metric_columns=metric_select_columns)
Expand Down
71 changes: 71 additions & 0 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from metricflow.instances import (
DimensionInstance,
EntityInstance,
GroupByMetricInstance,
InstanceSet,
InstanceSetTransform,
MdoInstance,
Expand All @@ -35,6 +36,7 @@
DimensionSpec,
EntityReference,
EntitySpec,
GroupByMetricSpec,
InstanceSpec,
InstanceSpecSet,
LinkableInstanceSpec,
Expand Down Expand Up @@ -102,12 +104,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
group_by_metric_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.group_by_metric_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
entity_columns=entity_cols,
group_by_metric_columns=group_by_metric_cols,
metadata_columns=metadata_cols,
)

Expand Down Expand Up @@ -254,12 +260,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
group_by_metric_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.group_by_metric_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
entity_columns=entity_cols,
group_by_metric_columns=group_by_metric_cols,
metadata_columns=metadata_cols,
)

Expand Down Expand Up @@ -444,11 +454,28 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
)
)

# Handle group by metric instances
group_by_metric_instances_with_additional_link = []
for group_by_metric_instance in instance_set.group_by_metric_instances:
# The new group by metric spec should include the join on entity.
transformed_group_by_metric_spec_from_right = GroupByMetricSpec(
element_name=group_by_metric_instance.spec.element_name,
entity_links=self._join_on_entity.as_linkless_prefix + group_by_metric_instance.spec.entity_links,
)
group_by_metric_instances_with_additional_link.append(
GroupByMetricInstance(
associated_columns=group_by_metric_instance.associated_columns,
defined_from=group_by_metric_instance.defined_from,
spec=transformed_group_by_metric_spec_from_right,
)
)

return InstanceSet(
measure_instances=(),
dimension_instances=tuple(dimension_instances_with_additional_link),
time_dimension_instances=tuple(time_dimension_instances_with_additional_link),
entity_instances=tuple(entity_instances_with_additional_link),
group_by_metric_instances=tuple(group_by_metric_instances_with_additional_link),
metric_instances=(),
metadata_instances=(),
)
Expand Down Expand Up @@ -484,12 +511,16 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
x for x in instance_set.time_dimension_instances if self._should_pass(x.spec)
)
filtered_entity_instances = tuple(x for x in instance_set.entity_instances if self._should_pass(x.spec))
filtered_group_by_metric_instances = tuple(
x for x in instance_set.group_by_metric_instances if self._should_pass(x.spec)
)

output = InstanceSet(
measure_instances=instance_set.measure_instances,
dimension_instances=filtered_dimension_instances,
time_dimension_instances=filtered_time_dimension_instances,
entity_instances=filtered_entity_instances,
group_by_metric_instances=filtered_group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -549,6 +580,9 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
x for x in instance_set.time_dimension_instances if self._should_pass(x.spec)
),
entity_instances=tuple(x for x in instance_set.entity_instances if self._should_pass(x.spec)),
group_by_metric_instances=tuple(
x for x in instance_set.group_by_metric_instances if self._should_pass(x.spec)
),
metric_instances=tuple(x for x in instance_set.metric_instances if self._should_pass(x.spec)),
metadata_instances=tuple(x for x in instance_set.metadata_instances if self._should_pass(x.spec)),
)
Expand Down Expand Up @@ -588,6 +622,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -629,6 +664,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -683,6 +719,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand All @@ -700,11 +737,30 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances + tuple(self._metric_instances),
metadata_instances=instance_set.metadata_instances,
)


class AddGroupByMetrics(InstanceSetTransform[InstanceSet]):
"""Adds the given metric instances to the instance set."""

def __init__(self, group_by_metric_instances: List[GroupByMetricInstance]) -> None: # noqa: D107
self._group_by_metric_instances = group_by_metric_instances

def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
return InstanceSet(
measure_instances=instance_set.measure_instances,
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances + tuple(self._group_by_metric_instances),
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


class RemoveMeasures(InstanceSetTransform[InstanceSet]):
"""Remove measures from the instance set."""

Expand All @@ -714,6 +770,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand All @@ -728,6 +785,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=(),
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -938,11 +996,23 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
)
)

output_group_by_metric_instances = []
for input_group_by_metric_instance in instance_set.group_by_metric_instances:
output_group_by_metric_instances.append(
GroupByMetricInstance(
associated_columns=(
self._column_association_resolver.resolve_spec(input_group_by_metric_instance.spec),
),
spec=input_group_by_metric_instance.spec,
defined_from=input_group_by_metric_instance.defined_from,
)
)
return InstanceSet(
measure_instances=tuple(output_measure_instances),
dimension_instances=tuple(output_dimension_instances),
time_dimension_instances=tuple(output_time_dimension_instances),
entity_instances=tuple(output_entity_instances),
group_by_metric_instances=tuple(output_group_by_metric_instances),
metric_instances=tuple(output_metric_instances),
metadata_instances=tuple(output_metadata_instances),
)
Expand Down Expand Up @@ -994,6 +1064,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances + tuple(self._metadata_instances),
)
4 changes: 4 additions & 0 deletions metricflow/plan_conversion/select_column_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SelectColumnSet:
dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
time_dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
entity_columns: List[SqlSelectColumn] = field(default_factory=list)
group_by_metric_columns: List[SqlSelectColumn] = field(default_factory=list)
metadata_columns: List[SqlSelectColumn] = field(default_factory=list)

def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
Expand All @@ -28,6 +29,7 @@ def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
dimension_columns=self.dimension_columns + other_set.dimension_columns,
time_dimension_columns=self.time_dimension_columns + other_set.time_dimension_columns,
entity_columns=self.entity_columns + other_set.entity_columns,
group_by_metric_columns=self.group_by_metric_columns + other_set.group_by_metric_columns,
metadata_columns=self.metadata_columns + other_set.metadata_columns,
)

Expand All @@ -38,6 +40,7 @@ def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:
self.time_dimension_columns
+ self.entity_columns
+ self.dimension_columns
+ self.group_by_metric_columns
+ self.metric_columns
+ self.measure_columns
+ self.metadata_columns
Expand All @@ -50,5 +53,6 @@ def without_measure_columns(self) -> SelectColumnSet:
dimension_columns=self.dimension_columns,
time_dimension_columns=self.time_dimension_columns,
entity_columns=self.entity_columns,
group_by_metric_columns=self.group_by_metric_columns,
metadata_columns=self.metadata_columns,
)

0 comments on commit bbd2901

Please sign in to comment.