diff --git a/metricflow/model/semantics/linkable_spec_resolver.py b/metricflow/model/semantics/linkable_spec_resolver.py index fbe2513915..3a5f6e9655 100644 --- a/metricflow/model/semantics/linkable_spec_resolver.py +++ b/metricflow/model/semantics/linkable_spec_resolver.py @@ -356,6 +356,7 @@ def as_spec_set(self) -> LinkableSpecSet: # noqa: D102 GroupByMetricSpec( element_name=linkable_metric.element_name, entity_links=linkable_metric.join_path.entity_links, + metric_subquery_entity_links=linkable_metric.metric_subquery_entity_links, ) for linkable_metrics in self.path_key_to_linkable_metrics.values() for linkable_metric in linkable_metrics diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 05ff5212e4..fb4d065180 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -727,7 +727,11 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: GroupByMetricInstance( associated_columns=(output_column_association,), defined_from=MetricModelReference(metric_name=metric_spec.element_name), - spec=GroupByMetricSpec(element_name=metric_spec.element_name, entity_links=()), + spec=GroupByMetricSpec( + element_name=metric_spec.element_name, + entity_links=(), + metric_subquery_entity_links=(), # TODO + ), ) ) diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index c60fba87be..e7c4b60a4f 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -462,6 +462,7 @@ def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 transformed_group_by_metric_spec_from_right = GroupByMetricSpec( element_name=group_by_metric_instance.spec.element_name, entity_links=self._join_on_entity.as_linkless_prefix + group_by_metric_instance.spec.entity_links, + metric_subquery_entity_links=group_by_metric_instance.spec.metric_subquery_entity_links, ) group_by_metric_instances_with_additional_link.append( GroupByMetricInstance( diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index 7bbf104f25..6b8083c70c 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -518,7 +518,8 @@ def build_query_spec_for_group_by_metric_source_node( self, group_by_metric_spec: GroupByMetricSpec ) -> MetricFlowQuerySpec: """Query spec that can be used to build a source node for this spec in the DFP.""" + group_by_metric_spec.metric_subquery_entity_links return self.parse_and_validate_query( metrics=(MetricParameter(group_by_metric_spec.reference.element_name),), - group_by=(DimensionOrEntityParameter(group_by_metric_spec.entity_spec.qualified_name),), + group_by=(DimensionOrEntityParameter(group_by_metric_spec.metric_subquery_entity_spec.qualified_name),), ) diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index ca4728642d..8e86a247be 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -242,16 +242,35 @@ def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT @dataclass(frozen=True) class GroupByMetricSpec(LinkableInstanceSpec, SerializableDataclass): - """Metric used in group by or where filter.""" + """Metric used in group by or where filter. + + Args: + element_name: Name of the metric being joined. + entity_links: Sequence of entities joined to join the metric subquery to the outer query. Last entity is the one + joining the subquery to the outer query. + metric_subquery_entity_links: Sequence of entities used in the metric subquery to join the metric to the entity. + Does not include the top-level entity (to avoid duplicating the last element of `entity_links`). + + """ + + metric_subquery_entity_links: Tuple[EntityReference, ...] @property def without_first_entity_link(self) -> GroupByMetricSpec: # noqa: D102 assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}" - return GroupByMetricSpec(element_name=self.element_name, entity_links=self.entity_links[1:]) + return GroupByMetricSpec( + element_name=self.element_name, + entity_links=self.entity_links[1:], + metric_subquery_entity_links=self.metric_subquery_entity_links, + ) @property def without_entity_links(self) -> GroupByMetricSpec: # noqa: D102 - return GroupByMetricSpec(element_name=self.element_name, entity_links=()) + return GroupByMetricSpec( + element_name=self.element_name, + entity_links=(), + metric_subquery_entity_links=self.metric_subquery_entity_links, + ) @property def last_entity_link(self) -> EntityReference: # noqa: D102 @@ -264,12 +283,16 @@ def from_name(name: str) -> GroupByMetricSpec: # noqa: D102 return GroupByMetricSpec( entity_links=tuple(EntityReference(idl) for idl in structured_name.entity_link_names), element_name=structured_name.element_name, + metric_subquery_entity_links=(), ) @property - def entity_spec(self) -> EntitySpec: - """Entity that the metric will be grouped by on aggregation.""" - return EntitySpec(element_name=self.last_entity_link.element_name, entity_links=self.entity_links[:-1]) + def metric_subquery_entity_spec(self) -> EntitySpec: + """Spec for the entity that the metric will be grouped by it the metric subquery.""" + return EntitySpec( + element_name=self.last_entity_link.element_name, + entity_links=self.metric_subquery_entity_links, + ) def __eq__(self, other: Any) -> bool: # type: ignore[misc] # noqa: D105 if not isinstance(other, GroupByMetricSpec): diff --git a/tests/model/test_where_filter_spec.py b/tests/model/test_where_filter_spec.py index a4df031e72..f7f1d41a0f 100644 --- a/tests/model/test_where_filter_spec.py +++ b/tests/model/test_where_filter_spec.py @@ -370,7 +370,9 @@ def test_metric_in_filter( # noqa: D103 ) -> None: where_filter = PydanticWhereFilter(where_sql_template="{{ Metric('bookings', group_by=['listing']) }} > 2") - group_by_metric_spec = GroupByMetricSpec(element_name="bookings", entity_links=(EntityReference("listing"),)) + group_by_metric_spec = GroupByMetricSpec( + element_name="bookings", entity_links=(EntityReference("listing"),), metric_subquery_entity_links=() + ) where_filter_spec = WhereSpecFactory( column_association_resolver=column_association_resolver, spec_resolution_lookup=create_spec_lookup( diff --git a/tests/naming/conftest.py b/tests/naming/conftest.py index 5bd2701b22..07f8fb07d1 100644 --- a/tests/naming/conftest.py +++ b/tests/naming/conftest.py @@ -50,5 +50,9 @@ def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D103 entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), ), # GroupByMetrics - GroupByMetricSpec(element_name="bookings", entity_links=(EntityReference(element_name="listing"),)), + GroupByMetricSpec( + element_name="bookings", + entity_links=(EntityReference(element_name="listing"),), + metric_subquery_entity_links=(), + ), ) diff --git a/tests/naming/test_object_builder_naming_scheme.py b/tests/naming/test_object_builder_naming_scheme.py index 2bebc6c5b5..e003503b3d 100644 --- a/tests/naming/test_object_builder_naming_scheme.py +++ b/tests/naming/test_object_builder_naming_scheme.py @@ -52,6 +52,7 @@ def test_input_str(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> N GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference(element_name="listing"),), + metric_subquery_entity_links=(), ) ) == "Metric('bookings', group_by=['listing'])" @@ -121,5 +122,9 @@ def test_spec_pattern( # noqa: D103 assert tuple( object_builder_naming_scheme.spec_pattern("Metric('bookings', group_by=['listing'])").match(specs) ) == ( - GroupByMetricSpec(element_name="bookings", entity_links=(EntityReference(element_name="listing"),)), + GroupByMetricSpec( + element_name="bookings", + entity_links=(EntityReference(element_name="listing"),), + metric_subquery_entity_links=(), + ), ) diff --git a/tests/specs/patterns/test_entity_link_pattern.py b/tests/specs/patterns/test_entity_link_pattern.py index aea60bc2c0..73e5fc2562 100644 --- a/tests/specs/patterns/test_entity_link_pattern.py +++ b/tests/specs/patterns/test_entity_link_pattern.py @@ -56,6 +56,7 @@ def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D103 GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference(element_name="listing"),), + metric_subquery_entity_links=(), ), ) @@ -127,7 +128,11 @@ def test_group_by_metric_match(specs: Sequence[LinkableInstanceSpec]) -> None: ) assert tuple(pattern.match(specs)) == ( - GroupByMetricSpec(element_name="bookings", entity_links=(EntityReference(element_name="listing"),)), + GroupByMetricSpec( + element_name="bookings", + entity_links=(EntityReference(element_name="listing"),), + metric_subquery_entity_links=(), + ), ) diff --git a/tests/specs/patterns/test_typed_patterns.py b/tests/specs/patterns/test_typed_patterns.py index 754a06d660..d0625cd741 100644 --- a/tests/specs/patterns/test_typed_patterns.py +++ b/tests/specs/patterns/test_typed_patterns.py @@ -60,6 +60,7 @@ def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D103 GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference(element_name="listing"),), + metric_subquery_entity_links=(), ), ) @@ -157,5 +158,6 @@ def test_group_by_metric_pattern(specs: Sequence[LinkableInstanceSpec]) -> None: GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference("listing"),), + metric_subquery_entity_links=(), ), ) diff --git a/tests/test_specs.py b/tests/test_specs.py index da922e61c9..27c3a26be6 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -145,6 +145,7 @@ def spec_set() -> InstanceSpecSet: # noqa: D103 GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference(element_name="listing_id"),), + metric_subquery_entity_links=(), ), ), ) @@ -165,6 +166,7 @@ def test_spec_set_linkable_specs(spec_set: InstanceSpecSet) -> None: # noqa: D1 GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference(element_name="listing_id"),), + metric_subquery_entity_links=(), ), } @@ -188,6 +190,7 @@ def test_spec_set_all_specs(spec_set: InstanceSpecSet) -> None: # noqa: D103 GroupByMetricSpec( element_name="bookings", entity_links=(EntityReference(element_name="listing_id"),), + metric_subquery_entity_links=(), ), }