Skip to content

Commit

Permalink
fix merge conflicts from ambiguous group by changes
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Dec 16, 2023
1 parent 0ad4052 commit 07d64f8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 28 deletions.
48 changes: 28 additions & 20 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,14 @@ def _build_aggregated_conversion_node(
entity_spec: EntitySpec,
window: Optional[MetricTimeWindow],
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
constant_properties: Optional[Sequence[ConstantPropertyInput]] = None,
) -> BaseOutput:
"""Builds a node that contains aggregated values of conversions and opportunities."""
# Build measure recipes
base_required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs(
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
filter_specs=base_measure_spec.filter_specs,
)
base_measure_recipe = self._find_dataflow_recipe(
measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]),
Expand Down Expand Up @@ -261,7 +260,6 @@ def _build_aggregated_conversion_node(
aggregated_base_measure_node = self.build_aggregated_measure(
metric_input_measure_spec=base_measure_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
)

Expand Down Expand Up @@ -334,7 +332,6 @@ def _build_aggregated_conversion_node(
aggregated_conversions_node = self.build_aggregated_measure(
metric_input_measure_spec=conversion_measure_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
measure_recipe=recipe_with_join_conversion_source_node,
)
Expand All @@ -346,23 +343,19 @@ def _build_conversion_metric_output_node(
self,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> ComputeMetricsNode:
"""Builds a compute metric node for a conversion metric."""
combined_where = where_constraint
if metric_spec.constraint:
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
)

metric = self._metric_lookup.get_metric(metric_spec.reference)
metric_reference = metric_spec.reference
metric = self._metric_lookup.get_metric(metric_reference)
conversion_type_params = metric.type_params.conversion_type_params
assert conversion_type_params, "A conversion metric should have type_params.conversion_type_params defined."
base_measure, conversion_measure = self._build_input_measure_specs_for_conversion_metric(
metric_reference=metric_spec.reference,
conversion_type_params=conversion_type_params,
column_association_resolver=self._column_association_resolver,
filter_spec_factory=filter_spec_factory,
descendent_filter_specs=metric_spec.filter_specs,
)
entity_spec = EntitySpec.from_name(conversion_type_params.entity)
logger.info(
Expand All @@ -376,7 +369,6 @@ def _build_conversion_metric_output_node(
base_measure_spec=base_measure,
conversion_measure_spec=conversion_measure,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
time_range_constraint=time_range_constraint,
entity_spec=entity_spec,
window=conversion_type_params.window,
Expand Down Expand Up @@ -546,7 +538,7 @@ def _build_any_metric_output_node(
return self._build_conversion_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
filter_spec_factory=filter_spec_factory,
time_range_constraint=time_range_constraint,
)

Expand Down Expand Up @@ -957,7 +949,8 @@ def _build_input_measure_specs_for_conversion_metric(
self,
metric_reference: MetricReference,
conversion_type_params: ConversionTypeParams,
column_association_resolver: ColumnAssociationResolver,
filter_spec_factory: WhereSpecFactory,
descendent_filter_specs: Sequence[WhereFilterSpec],
) -> Tuple[MetricInputMeasureSpec, MetricInputMeasureSpec]:
"""Return [base_measure_input, conversion_measure_input] for computing a conversion metric."""
metric = self._metric_lookup.get_metric(metric_reference)
Expand All @@ -969,7 +962,7 @@ def _build_input_measure_specs_for_conversion_metric(
), f"A conversion metric should exactly 2 measures. Got{metric.input_measures}"

def _get_matching_measure(
measure_to_match: MeasureReference, input_measures: Sequence[MetricInputMeasure]
measure_to_match: MeasureReference, input_measures: Sequence[MetricInputMeasure], is_base_measure: bool
) -> MetricInputMeasureSpec:
matched_measure = next(
filter(
Expand All @@ -979,22 +972,37 @@ def _get_matching_measure(
None,
)
assert matched_measure, f"Unable to find {measure_to_match} in {input_measures}."
if is_base_measure:
filter_specs: List[WhereFilterSpec] = []
filter_specs.extend(
filter_spec_factory.create_from_where_filter_intersection(
filter_location=WhereFilterLocation.for_metric(metric_reference),
filter_intersection=metric.filter,
)
)
filter_specs.extend(descendent_filter_specs)
filter_specs.extend(
filter_spec_factory.create_from_where_filter_intersection(
filter_location=WhereFilterLocation.for_metric(metric_reference),
filter_intersection=matched_measure.filter,
)
)
return MetricInputMeasureSpec(
measure_spec=MeasureSpec.from_name(matched_measure.name),
fill_nulls_with=matched_measure.fill_nulls_with,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(matched_measure.filter),
filter_specs=tuple(filter_specs) if is_base_measure else (),
alias=matched_measure.alias,
)

base_input_measure = _get_matching_measure(
measure_to_match=conversion_type_params.base_measure.measure_reference,
input_measures=metric.input_measures,
is_base_measure=True,
)
conversion_input_measure = _get_matching_measure(
measure_to_match=conversion_type_params.conversion_measure.measure_reference,
input_measures=metric.input_measures,
is_base_measure=False,
)
return base_input_measure, conversion_input_measure

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def visit_measure_node(self, node: MeasureGroupByItemSourceNode) -> PushDownResu
metric = self._semantic_manifest_lookup.metric_lookup.get_metric(node.child_metric_reference)

patterns_to_apply: Tuple[SpecPattern, ...] = ()
if metric.type is MetricType.SIMPLE:
if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
pass
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
assert False, f"A measure should have a simple or cumulative metric as a child, but got {metric.type}"
Expand Down Expand Up @@ -293,7 +293,11 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> PushDownRe

# For metrics with offset_to_grain, don't allow date_part group-by-items
patterns_to_apply: Sequence[SpecPattern] = ()
if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
if (
metric.type is MetricType.SIMPLE
or metric.type is MetricType.CUMULATIVE
or metric.type is MetricType.CONVERSION
):
pass
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
for input_metric in metric.input_metrics:
Expand Down
19 changes: 14 additions & 5 deletions metricflow/query/group_by_item/resolution_dag/dag_builder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import logging
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilterIntersection
from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.references import MeasureReference, MetricReference
from dbt_semantic_interfaces.type_enums import MetricType

from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
Expand Down Expand Up @@ -42,9 +43,17 @@ def _build_dag_component_for_metric(

# For a base metric, the parents are measure nodes
if len(metric.input_metrics) == 0:
measure_references_for_metric = tuple(
input_measure.measure_reference for input_measure in metric.input_measures
)
measure_references_for_metric: Tuple[MeasureReference, ...]
if metric.type is MetricType.CONVERSION:
conversion_type_params = metric.type_params.conversion_type_params
assert (
conversion_type_params
), "A conversion metric should have type_params.conversion_type_params defined."
measure_references_for_metric = (conversion_type_params.base_measure.measure_reference,)
else:
measure_references_for_metric = tuple(
input_measure.measure_reference for input_measure in metric.input_measures
)

source_candidates_for_measure_nodes = tuple(
MeasureGroupByItemSourceNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def validate_metric_in_resolution_dag(
metric = self._get_metric(metric_reference)
query_includes_metric_time = self._group_by_items_include_metric_time(resolver_input_for_query)

if metric.type is MetricType.SIMPLE:
if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
return MetricFlowQueryResolutionIssueSet.empty_instance()
elif metric.type is MetricType.CUMULATIVE:
if (
Expand Down

0 comments on commit 07d64f8

Please sign in to comment.