diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 48eb6e4538..4313c8a9a7 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -43,7 +43,6 @@ from metricflow.errors.errors import UnableToSatisfyQueryError from metricflow.model.objects.metric import MetricType, MetricTimeWindow from metricflow.model.semantic_model import SemanticModel -from metricflow.model.spec_converters import WhereConstraintConverter from metricflow.object_utils import pformat_big_objects, assert_exactly_one_arg_set from metricflow.plan_conversion.column_resolver import DefaultColumnAssociationResolver from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor @@ -186,7 +185,6 @@ def _build_metrics_output_node( logger.info(f"Generating compute metrics node for {metric_spec}") metric_reference = metric_spec.as_reference metric = self._metric_semantics.get_metric(metric_reference) - metric_input_measure_specs = self._metric_semantics.measures_for_metric(metric_reference) if metric.type == MetricType.DERIVED: metric_input_specs = self._metric_semantics.metric_input_specs_for_metric(metric_reference) @@ -223,12 +221,6 @@ def _build_metrics_output_node( f"{pformat_big_objects(metric_input_measure_specs=metric_input_measure_specs)}" ) combined_where = where_constraint - if metric.constraint: - metric_constraint = WhereConstraintConverter.convert_to_spec_where_constraint( - self._data_source_semantics, metric.constraint - ) - combined_where = combined_where.combine(metric_constraint) if combined_where else metric_constraint - if metric_spec.constraint: combined_where = ( combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint @@ -250,6 +242,9 @@ def _build_metrics_output_node( aggregated_measures_node=aggregated_measures_node, ) ) + + assert len(output_nodes) > 0, "ComputeMetricsNode was not properly constructed" + if len(output_nodes) == 1: return output_nodes[0] diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index a5207e7e6e..43a01d2816 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -205,6 +205,36 @@ def _validate_linkable_specs( context=suggestion_sections, ) + def _construct_metric_specs_for_query( + self, metric_references: Tuple[MetricReference, ...] + ) -> Tuple[MetricSpec, ...]: + """Populate MetricSpecs + + Construct MetricSpecs in the preaggregated state to pass into the DataflowPlanBuilder. + + NOTE: Currently, we populate only the metrics provided in the query, but with derived metrics, + there are nested metrics that do not get populated here, but rather in the MetricSemantic during + the builder process. This is a process that should be refined altogther. + """ + metric_specs = [] + for metric_reference in metric_references: + metric = self._metric_semantics.get_metric(metric_reference) + metric_where_constraint: Optional[SpecWhereClauseConstraint] = None + if metric.constraint: + # add constraint to MetricSpec + metric_where_constraint = WhereConstraintConverter.convert_to_spec_where_constraint( + self._data_source_semantics, metric.constraint + ) + # TODO: Directly initializing Spec object instead of using a factory method since + # importing WhereConstraintConverter is a problem in specs.py + metric_specs.append( + MetricSpec( + element_name=metric_reference.element_name, + constraint=metric_where_constraint, + ) + ) + return tuple(metric_specs) + def _parse_and_validate_query( self, metric_names: Sequence[str], @@ -230,7 +260,6 @@ def _parse_and_validate_query( # Get metric references used for validations # In a case of derived metric, all the input metrics would be here. metric_references = self._parse_metric_names(metric_names) - if time_constraint_start is None: time_constraint_start = TimeRangeConstraint.ALL_TIME_BEGIN() elif time_constraint_start < TimeRangeConstraint.ALL_TIME_BEGIN(): @@ -376,8 +405,11 @@ def _parse_and_validate_query( metric_references=metric_references, time_dimension_specs=where_time_specs ) + base_metric_references = self._parse_metric_names(metric_names, traverse_metric_inputs=False) + metric_specs = self._construct_metric_specs_for_query(base_metric_references) + return MetricFlowQuerySpec( - metric_specs=tuple(MetricSpec.from_element_name(metric_name) for metric_name in metric_names), + metric_specs=metric_specs, dimension_specs=requested_linkable_specs.dimension_specs, identifier_specs=requested_linkable_specs.identifier_specs, time_dimension_specs=time_dimension_specs, @@ -485,7 +517,9 @@ def _find_smallest_metric_time_dimension_spec_granularity( else: return None - def _parse_metric_names(self, metric_names: Sequence[str]) -> Tuple[MetricReference, ...]: + def _parse_metric_names( + self, metric_names: Sequence[str], traverse_metric_inputs: bool = True + ) -> Tuple[MetricReference, ...]: """Converts metric names into metric names. An exception is thrown if the name is invalid.""" # The config must be lower-case, so we lower case for case-insensitivity against query inputs from the user. @@ -508,12 +542,12 @@ def _parse_metric_names(self, metric_names: Sequence[str]) -> Tuple[MetricRefere f"Unknown metric: '{metric_name}'", context=suggestions, ) - metric = self._metric_semantics.get_metric(metric_reference) - if metric.type == MetricType.DERIVED: - input_metrics = self._parse_metric_names([metric.name for metric in metric.input_metrics]) - metric_references.extend(list(input_metrics)) - else: - metric_references.append(metric_reference) + metric_references.append(metric_reference) + if traverse_metric_inputs: + metric = self._metric_semantics.get_metric(metric_reference) + if metric.type == MetricType.DERIVED: + input_metrics = self._parse_metric_names([metric.name for metric in metric.input_metrics]) + metric_references.extend(list(input_metrics)) return tuple(metric_references) def _parse_linkable_element_names(