Skip to content

Commit

Permalink
consolidate metric constraints in DataflowPlanBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Jan 10, 2023
1 parent f309c74 commit 2a22ca6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
11 changes: 3 additions & 8 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from metricflow.errors.errors import UnableToSatisfyQueryError
from metricflow.model.objects.metric import MetricType, MetricTimeWindow
from metricflow.model.semantic_model import SemanticModel
from metricflow.model.spec_converters import WhereConstraintConverter
from metricflow.object_utils import pformat_big_objects, assert_exactly_one_arg_set
from metricflow.plan_conversion.column_resolver import DefaultColumnAssociationResolver
from metricflow.plan_conversion.node_processor import PreDimensionJoinNodeProcessor
Expand Down Expand Up @@ -186,7 +185,6 @@ def _build_metrics_output_node(
logger.info(f"Generating compute metrics node for {metric_spec}")
metric_reference = metric_spec.as_reference
metric = self._metric_semantics.get_metric(metric_reference)
metric_input_measure_specs = self._metric_semantics.measures_for_metric(metric_reference)

if metric.type == MetricType.DERIVED:
metric_input_specs = self._metric_semantics.metric_input_specs_for_metric(metric_reference)
Expand Down Expand Up @@ -223,12 +221,6 @@ def _build_metrics_output_node(
f"{pformat_big_objects(metric_input_measure_specs=metric_input_measure_specs)}"
)
combined_where = where_constraint
if metric.constraint:
metric_constraint = WhereConstraintConverter.convert_to_spec_where_constraint(
self._data_source_semantics, metric.constraint
)
combined_where = combined_where.combine(metric_constraint) if combined_where else metric_constraint

if metric_spec.constraint:
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
Expand All @@ -250,6 +242,9 @@ def _build_metrics_output_node(
aggregated_measures_node=aggregated_measures_node,
)
)

assert len(output_nodes) > 0, "ComputeMetricsNode was not properly constructed"

if len(output_nodes) == 1:
return output_nodes[0]

Expand Down
52 changes: 43 additions & 9 deletions metricflow/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,36 @@ def _validate_linkable_specs(
context=suggestion_sections,
)

def _construct_metric_specs_for_query(
self, metric_references: Tuple[MetricReference, ...]
) -> Tuple[MetricSpec, ...]:
"""Populate MetricSpecs
Construct MetricSpecs in the preaggregated state to pass into the DataflowPlanBuilder.
NOTE: Currently, we populate only the metrics provided in the query, but with derived metrics,
there are nested metrics that do not get populated here, but rather in the MetricSemantic during
the builder process. This is a process that should be refined altogther.
"""
metric_specs = []
for metric_reference in metric_references:
metric = self._metric_semantics.get_metric(metric_reference)
metric_where_constraint: Optional[SpecWhereClauseConstraint] = None
if metric.constraint:
# add constraint to MetricSpec
metric_where_constraint = WhereConstraintConverter.convert_to_spec_where_constraint(
self._data_source_semantics, metric.constraint
)
# TODO: Directly initializing Spec object instead of using a factory method since
# importing WhereConstraintConverter is a problem in specs.py
metric_specs.append(
MetricSpec(
element_name=metric_reference.element_name,
constraint=metric_where_constraint,
)
)
return tuple(metric_specs)

def _parse_and_validate_query(
self,
metric_names: Sequence[str],
Expand All @@ -230,7 +260,6 @@ def _parse_and_validate_query(
# Get metric references used for validations
# In a case of derived metric, all the input metrics would be here.
metric_references = self._parse_metric_names(metric_names)

if time_constraint_start is None:
time_constraint_start = TimeRangeConstraint.ALL_TIME_BEGIN()
elif time_constraint_start < TimeRangeConstraint.ALL_TIME_BEGIN():
Expand Down Expand Up @@ -376,8 +405,11 @@ def _parse_and_validate_query(
metric_references=metric_references, time_dimension_specs=where_time_specs
)

base_metric_references = self._parse_metric_names(metric_names, traverse_metric_inputs=False)
metric_specs = self._construct_metric_specs_for_query(base_metric_references)

return MetricFlowQuerySpec(
metric_specs=tuple(MetricSpec.from_element_name(metric_name) for metric_name in metric_names),
metric_specs=metric_specs,
dimension_specs=requested_linkable_specs.dimension_specs,
identifier_specs=requested_linkable_specs.identifier_specs,
time_dimension_specs=time_dimension_specs,
Expand Down Expand Up @@ -485,7 +517,9 @@ def _find_smallest_metric_time_dimension_spec_granularity(
else:
return None

def _parse_metric_names(self, metric_names: Sequence[str]) -> Tuple[MetricReference, ...]:
def _parse_metric_names(
self, metric_names: Sequence[str], traverse_metric_inputs: bool = True
) -> Tuple[MetricReference, ...]:
"""Converts metric names into metric names. An exception is thrown if the name is invalid."""

# The config must be lower-case, so we lower case for case-insensitivity against query inputs from the user.
Expand All @@ -508,12 +542,12 @@ def _parse_metric_names(self, metric_names: Sequence[str]) -> Tuple[MetricRefere
f"Unknown metric: '{metric_name}'",
context=suggestions,
)
metric = self._metric_semantics.get_metric(metric_reference)
if metric.type == MetricType.DERIVED:
input_metrics = self._parse_metric_names([metric.name for metric in metric.input_metrics])
metric_references.extend(list(input_metrics))
else:
metric_references.append(metric_reference)
metric_references.append(metric_reference)
if traverse_metric_inputs:
metric = self._metric_semantics.get_metric(metric_reference)
if metric.type == MetricType.DERIVED:
input_metrics = self._parse_metric_names([metric.name for metric in metric.input_metrics])
metric_references.extend(list(input_metrics))
return tuple(metric_references)

def _parse_linkable_element_names(
Expand Down

0 comments on commit 2a22ca6

Please sign in to comment.