diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index d825701242..3ce778c13f 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -63,6 +63,7 @@ from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import FilterSpecResolutionLookUp +from metricflow.query.query_parser import MetricFlowQueryParser from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( ConstantPropertySpec, @@ -123,6 +124,7 @@ def __init__( # noqa: D107 semantic_manifest_lookup: SemanticManifestLookup, node_output_resolver: DataflowPlanNodeOutputDataSetResolver, column_association_resolver: ColumnAssociationResolver, + query_parser: MetricFlowQueryParser, ) -> None: self._semantic_model_lookup = semantic_manifest_lookup.semantic_model_lookup self._metric_lookup = semantic_manifest_lookup.metric_lookup @@ -130,6 +132,7 @@ def __init__( # noqa: D107 self._source_node_set = source_node_set self._column_association_resolver = column_association_resolver self._node_data_set_resolver = node_output_resolver + self._query_parser = query_parser def build_plan( self, @@ -825,7 +828,8 @@ def _find_dataflow_recipe( # MetricGroupBy source nodes could be extremely large (and potentially slow). candidate_nodes_for_right_side_of_join += [ self._build_query_output_node( - query_spec=group_by_metric_spec.query_spec_for_source_node, for_group_by_source_node=True + query_spec=self._query_parser.build_query_spec_for_group_by_metric_source_node(group_by_metric_spec), + for_group_by_source_node=True, ) for group_by_metric_spec in linkable_spec_set.group_by_metric_specs ] diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 5bf4160a9d..4f3d6d72a1 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -382,11 +382,16 @@ def __init__( ) node_output_resolver.cache_output_data_sets(source_node_set.all_nodes) + self._query_parser = query_parser or MetricFlowQueryParser( + semantic_manifest_lookup=self._semantic_manifest_lookup, + ) + self._dataflow_plan_builder = DataflowPlanBuilder( source_node_set=source_node_set, semantic_manifest_lookup=self._semantic_manifest_lookup, column_association_resolver=self._column_association_resolver, node_output_resolver=node_output_resolver, + query_parser=self._query_parser, ) self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter( column_association_resolver=self._column_association_resolver, @@ -399,10 +404,6 @@ def __init__( ) self._executor = SequentialPlanExecutor() - self._query_parser = query_parser or MetricFlowQueryParser( - semantic_manifest_lookup=self._semantic_manifest_lookup, - ) - @log_call(module_name=__name__, telemetry_reporter=_telemetry_reporter) def query(self, mf_request: MetricFlowQueryRequest) -> MetricFlowQueryResult: # noqa: D102 logger.info(f"Starting query request:\n{indent(mf_pformat(mf_request))}") diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index 75dcd34ee2..7bbf104f25 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -54,7 +54,9 @@ from metricflow.specs.patterns.base_time_grain import BaseTimeGrainPattern from metricflow.specs.patterns.metric_time_pattern import MetricTimePattern from metricflow.specs.patterns.none_date_part import NoneDatePartPattern +from metricflow.specs.query_param_implementations import DimensionOrEntityParameter, MetricParameter from metricflow.specs.specs import ( + GroupByMetricSpec, InstanceSpec, InstanceSpecSet, MetricFlowQuerySpec, @@ -511,3 +513,12 @@ def _parse_and_validate_query( return query_spec.with_time_range_constraint(time_constraint) return query_spec + + def build_query_spec_for_group_by_metric_source_node( + self, group_by_metric_spec: GroupByMetricSpec + ) -> MetricFlowQuerySpec: + """Query spec that can be used to build a source node for this spec in the DFP.""" + return self.parse_and_validate_query( + metrics=(MetricParameter(group_by_metric_spec.reference.element_name),), + group_by=(DimensionOrEntityParameter(group_by_metric_spec.entity_spec.qualified_name),), + ) diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index 41ee7e6f38..ca4728642d 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -253,6 +253,11 @@ def without_first_entity_link(self) -> GroupByMetricSpec: # noqa: D102 def without_entity_links(self) -> GroupByMetricSpec: # noqa: D102 return GroupByMetricSpec(element_name=self.element_name, entity_links=()) + @property + def last_entity_link(self) -> EntityReference: # noqa: D102 + assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}" + return self.entity_links[-1] + @staticmethod def from_name(name: str) -> GroupByMetricSpec: # noqa: D102 structured_name = StructuredLinkableSpecName.from_name(name) @@ -261,6 +266,11 @@ def from_name(name: str) -> GroupByMetricSpec: # noqa: D102 element_name=structured_name.element_name, ) + @property + def entity_spec(self) -> EntitySpec: + """Entity that the metric will be grouped by on aggregation.""" + return EntitySpec(element_name=self.last_entity_link.element_name, entity_links=self.entity_links[:-1]) + def __eq__(self, other: Any) -> bool: # type: ignore[misc] # noqa: D105 if not isinstance(other, GroupByMetricSpec): return False @@ -281,14 +291,6 @@ def as_spec_set(self) -> InstanceSpecSet: def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_group_by_metric_spec(self) - @property - def query_spec_for_source_node(self) -> MetricFlowQuerySpec: - """Query spec that can be used to build a source node for this spec in the DFP.""" - return MetricFlowQuerySpec( - metric_specs=(MetricSpec(element_name=self.element_name),), - entity_specs=tuple(EntitySpec.from_name(entity_link.element_name) for entity_link in self.entity_links), - ) - @dataclass(frozen=True) class LinklessEntitySpec(EntitySpec, SerializableDataclass): diff --git a/tests/fixtures/manifest_fixtures.py b/tests/fixtures/manifest_fixtures.py index 5e9285803d..802ac1dad1 100644 --- a/tests/fixtures/manifest_fixtures.py +++ b/tests/fixtures/manifest_fixtures.py @@ -163,6 +163,7 @@ def dataflow_plan_builder(self) -> DataflowPlanBuilder: semantic_manifest_lookup=self.semantic_manifest_lookup, node_output_resolver=self._node_output_resolver.copy(), column_association_resolver=self.column_association_resolver, + query_parser=self.query_parser, ) @staticmethod