Skip to content

Commit

Permalink
Handle group by metrics in instance converters
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Mar 27, 2024
1 parent 7d2c32a commit c7579e7
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
26 changes: 24 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@
from metricflow.dataset.dataset import DataSet
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.filters.time_constraint import TimeRangeConstraint
from metricflow.instances import InstanceSet, MetadataInstance, MetricInstance, TimeDimensionInstance
from metricflow.instances import (
GroupByMetricInstance,
InstanceSet,
MetadataInstance,
MetricInstance,
TimeDimensionInstance,
)
from metricflow.mf_logging.formatting import indent
from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.instance_converters import (
AddGroupByMetrics,
AddLinkToLinkableElements,
AddMetadata,
AddMetrics,
Expand Down Expand Up @@ -82,6 +89,7 @@
from metricflow.protocols.sql_client import SqlEngine
from metricflow.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver, SingleColumnCorrelationKey
from metricflow.specs.specs import (
GroupByMetricSpec,
InstanceSpecSet,
MeasureSpec,
MetadataSpec,
Expand Down Expand Up @@ -601,6 +609,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
# Add select columns that would compute the metrics to the select columns.
metric_select_columns = []
metric_instances = []
group_by_metric_instances = []
for metric_spec in node.metric_specs:
metric = self._metric_lookup.get_metric(metric_spec.reference)

Expand Down Expand Up @@ -714,7 +723,20 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
spec=metric_spec,
)
)
output_instance_set = output_instance_set.transform(AddMetrics(metric_instances))
group_by_metric_instances.append(
GroupByMetricInstance(
associated_columns=(output_column_association,),
defined_from=MetricModelReference(metric_name=metric_spec.element_name),
spec=GroupByMetricSpec(element_name=metric_spec.element_name, entity_links=()),
)
)

transform_func = (
AddGroupByMetrics(group_by_metric_instances)
if node.for_group_by_source_node
else AddMetrics(metric_instances)
)
output_instance_set = output_instance_set.transform(transform_func)

combined_select_column_set = non_metric_select_column_set.merge(
SelectColumnSet(metric_columns=metric_select_columns)
Expand Down
71 changes: 71 additions & 0 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from metricflow.instances import (
DimensionInstance,
EntityInstance,
GroupByMetricInstance,
InstanceSet,
InstanceSetTransform,
MdoInstance,
Expand All @@ -35,6 +36,7 @@
DimensionSpec,
EntityReference,
EntitySpec,
GroupByMetricSpec,
InstanceSpec,
InstanceSpecSet,
LinkableInstanceSpec,
Expand Down Expand Up @@ -102,12 +104,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
group_by_metric_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.group_by_metric_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
entity_columns=entity_cols,
group_by_metric_columns=group_by_metric_cols,
metadata_columns=metadata_cols,
)

Expand Down Expand Up @@ -254,12 +260,16 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102
metadata_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.metadata_instances])
)
group_by_metric_cols = list(
chain.from_iterable([self._make_sql_column_expression(x) for x in instance_set.group_by_metric_instances])
)
return SelectColumnSet(
metric_columns=metric_cols,
measure_columns=measure_cols,
dimension_columns=dimension_cols,
time_dimension_columns=time_dimension_cols,
entity_columns=entity_cols,
group_by_metric_columns=group_by_metric_cols,
metadata_columns=metadata_cols,
)

Expand Down Expand Up @@ -444,11 +454,28 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
)
)

# Handle group by metric instances
group_by_metric_instances_with_additional_link = []
for group_by_metric_instance in instance_set.group_by_metric_instances:
# The new group by metric spec should include the join on entity.
transformed_group_by_metric_spec_from_right = GroupByMetricSpec(
element_name=group_by_metric_instance.spec.element_name,
entity_links=self._join_on_entity.as_linkless_prefix + group_by_metric_instance.spec.entity_links,
)
group_by_metric_instances_with_additional_link.append(
GroupByMetricInstance(
associated_columns=group_by_metric_instance.associated_columns,
defined_from=group_by_metric_instance.defined_from,
spec=transformed_group_by_metric_spec_from_right,
)
)

return InstanceSet(
measure_instances=(),
dimension_instances=tuple(dimension_instances_with_additional_link),
time_dimension_instances=tuple(time_dimension_instances_with_additional_link),
entity_instances=tuple(entity_instances_with_additional_link),
group_by_metric_instances=tuple(group_by_metric_instances_with_additional_link),
metric_instances=(),
metadata_instances=(),
)
Expand Down Expand Up @@ -484,12 +511,16 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
x for x in instance_set.time_dimension_instances if self._should_pass(x.spec)
)
filtered_entity_instances = tuple(x for x in instance_set.entity_instances if self._should_pass(x.spec))
filtered_group_by_metric_instances = tuple(
x for x in instance_set.group_by_metric_instances if self._should_pass(x.spec)
)

output = InstanceSet(
measure_instances=instance_set.measure_instances,
dimension_instances=filtered_dimension_instances,
time_dimension_instances=filtered_time_dimension_instances,
entity_instances=filtered_entity_instances,
group_by_metric_instances=filtered_group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -549,6 +580,9 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
x for x in instance_set.time_dimension_instances if self._should_pass(x.spec)
),
entity_instances=tuple(x for x in instance_set.entity_instances if self._should_pass(x.spec)),
group_by_metric_instances=tuple(
x for x in instance_set.group_by_metric_instances if self._should_pass(x.spec)
),
metric_instances=tuple(x for x in instance_set.metric_instances if self._should_pass(x.spec)),
metadata_instances=tuple(x for x in instance_set.metadata_instances if self._should_pass(x.spec)),
)
Expand Down Expand Up @@ -588,6 +622,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -629,6 +664,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -683,6 +719,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand All @@ -700,11 +737,30 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances + tuple(self._metric_instances),
metadata_instances=instance_set.metadata_instances,
)


class AddGroupByMetrics(InstanceSetTransform[InstanceSet]):
"""Adds the given metric instances to the instance set."""

def __init__(self, group_by_metric_instances: List[GroupByMetricInstance]) -> None: # noqa: D107
self._group_by_metric_instances = group_by_metric_instances

def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
return InstanceSet(
measure_instances=instance_set.measure_instances,
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances + tuple(self._group_by_metric_instances),
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)


class RemoveMeasures(InstanceSetTransform[InstanceSet]):
"""Remove measures from the instance set."""

Expand All @@ -714,6 +770,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances,
)
Expand All @@ -728,6 +785,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=(),
metadata_instances=instance_set.metadata_instances,
)
Expand Down Expand Up @@ -938,11 +996,23 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
)
)

output_group_by_metric_instances = []
for input_group_by_metric_instance in instance_set.group_by_metric_instances:
output_group_by_metric_instances.append(
GroupByMetricInstance(
associated_columns=(
self._column_association_resolver.resolve_spec(input_group_by_metric_instance.spec),
),
spec=input_group_by_metric_instance.spec,
defined_from=input_group_by_metric_instance.defined_from,
)
)
return InstanceSet(
measure_instances=tuple(output_measure_instances),
dimension_instances=tuple(output_dimension_instances),
time_dimension_instances=tuple(output_time_dimension_instances),
entity_instances=tuple(output_entity_instances),
group_by_metric_instances=tuple(output_group_by_metric_instances),
metric_instances=tuple(output_metric_instances),
metadata_instances=tuple(output_metadata_instances),
)
Expand Down Expand Up @@ -994,6 +1064,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
dimension_instances=instance_set.dimension_instances,
time_dimension_instances=instance_set.time_dimension_instances,
entity_instances=instance_set.entity_instances,
group_by_metric_instances=instance_set.group_by_metric_instances,
metric_instances=instance_set.metric_instances,
metadata_instances=instance_set.metadata_instances + tuple(self._metadata_instances),
)
4 changes: 4 additions & 0 deletions metricflow/plan_conversion/select_column_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SelectColumnSet:
dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
time_dimension_columns: List[SqlSelectColumn] = field(default_factory=list)
entity_columns: List[SqlSelectColumn] = field(default_factory=list)
group_by_metric_columns: List[SqlSelectColumn] = field(default_factory=list)
metadata_columns: List[SqlSelectColumn] = field(default_factory=list)

def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
Expand All @@ -28,6 +29,7 @@ def merge(self, other_set: SelectColumnSet) -> SelectColumnSet:
dimension_columns=self.dimension_columns + other_set.dimension_columns,
time_dimension_columns=self.time_dimension_columns + other_set.time_dimension_columns,
entity_columns=self.entity_columns + other_set.entity_columns,
group_by_metric_columns=self.group_by_metric_columns + other_set.group_by_metric_columns,
metadata_columns=self.metadata_columns + other_set.metadata_columns,
)

Expand All @@ -38,6 +40,7 @@ def as_tuple(self) -> Tuple[SqlSelectColumn, ...]:
self.time_dimension_columns
+ self.entity_columns
+ self.dimension_columns
+ self.group_by_metric_columns
+ self.metric_columns
+ self.measure_columns
+ self.metadata_columns
Expand All @@ -50,5 +53,6 @@ def without_measure_columns(self) -> SelectColumnSet:
dimension_columns=self.dimension_columns,
time_dimension_columns=self.time_dimension_columns,
entity_columns=self.entity_columns,
group_by_metric_columns=self.group_by_metric_columns,
metadata_columns=self.metadata_columns,
)

0 comments on commit c7579e7

Please sign in to comment.