Skip to content

Commit

Permalink
Consolidate conditional logic for time spine joins into MetricInputMe…
Browse files Browse the repository at this point in the history
…asureSpec.
  • Loading branch information
plypaul committed Nov 18, 2023
1 parent a28a552 commit 274d987
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 76 deletions.
218 changes: 149 additions & 69 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
CumulativeMeasureDescription,
InstanceSpecSet,
JoinToTimeSpineDescription,
LinkableInstanceSpec,
LinkableSpecSet,
LinklessEntitySpec,
Expand Down Expand Up @@ -174,12 +176,19 @@ def _build_base_metric_output_node(
"""Builds a node to compute a metric that is not defined from other metrics."""
metric_reference = metric_spec.reference
metric = self._metric_lookup.get_metric(metric_reference)
metric_input_measure_specs = self._measures_for_metric(
metric_input_measure_spec = self._build_input_measure_spec_for_base_metric(
metric_reference=metric_reference,
column_association_resolver=self._column_association_resolver,
query_contains_metric_time=queried_linkable_specs.contains_metric_time,
child_metric_offset_window=metric_spec.offset_window,
child_metric_offset_to_grain=metric_spec.offset_to_grain,
culmination_description=CumulativeMeasureDescription(
cumulative_window=metric.type_params.window,
cumulative_grain_to_date=metric.type_params.grain_to_date,
)
if metric.type is MetricType.CUMULATIVE
else None,
)
assert len(metric_input_measure_specs) == 1, "Simple and cumulative metrics must have one input measure."
metric_input_measure_spec = metric_input_measure_specs[0]

logger.info(
f"For {metric_spec}, needed measure is:\n"
Expand All @@ -190,19 +199,13 @@ def _build_base_metric_output_node(
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
)

aggregated_measures_node = self.build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
time_range_constraint=time_range_constraint,
cumulative=metric.type == MetricType.CUMULATIVE,
cumulative_window=metric.type_params.window if metric.type == MetricType.CUMULATIVE else None,
cumulative_grain_to_date=(
metric.type_params.grain_to_date if metric.type == MetricType.CUMULATIVE else None
),
)

return self.build_computed_metrics_node(
metric_spec=metric_spec,
aggregated_measures_node=aggregated_measures_node,
Expand All @@ -225,13 +228,42 @@ def _build_derived_metric_output_node(
f"For {metric.type} metric: {metric_spec}, needed metrics are:\n"
f"{pformat_big_objects(metric_input_specs=metric_input_specs)}"
)

parent_nodes: List[BaseOutput] = []

for metric_input_spec in metric_input_specs:
# TODO: See: https://github.com/dbt-labs/metricflow/issues/881
if (metric_spec.offset_window is not None or metric_spec.offset_to_grain is not None) and (
metric_input_spec.offset_window is not None or metric_input_spec.offset_to_grain is not None
):
raise NotImplementedError(
f"Multiple descendent metrics in a derived metric hierarchy are not yet supported. "
f"For {metric_spec}, the parent metric input is {metric_input_spec}"
)

parent_nodes.append(
self._build_any_metric_output_node(
metric_spec=MetricSpec(
element_name=metric_input_spec.element_name,
constraint=metric_input_spec.constraint,
alias=metric_input_spec.alias,
offset_window=metric_input_spec.offset_window,
offset_to_grain=metric_input_spec.offset_to_grain,
),
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
)
)

if len(parent_nodes) == 1:
return ComputeMetricsNode(
parent_node=parent_nodes[0],
metric_specs=[metric_spec],
)

return ComputeMetricsNode(
parent_node=self._build_metrics_output_node(
metric_specs=metric_input_specs,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
),
parent_node=CombineAggregatedOutputsNode(parent_nodes=parent_nodes),
metric_specs=[metric_spec],
)

Expand Down Expand Up @@ -645,54 +677,88 @@ def build_computed_metrics_node(
metric_specs=[metric_spec],
)

def _measures_for_metric(
def _build_input_measure_spec_for_base_metric(
self,
metric_reference: MetricReference,
column_association_resolver: ColumnAssociationResolver,
) -> Sequence[MetricInputMeasureSpec]:
"""Return the measure specs required to compute the metric."""
child_metric_offset_window: Optional[MetricTimeWindow],
child_metric_offset_to_grain: Optional[TimeGranularity],
query_contains_metric_time: bool,
culmination_description: Optional[CumulativeMeasureDescription],
) -> MetricInputMeasureSpec:
"""Return the input measure spec required to compute the base metric."""
metric = self._metric_lookup.get_metric(metric_reference)
input_measure_specs: List[MetricInputMeasureSpec] = []

for input_measure in metric.input_measures:
measure_spec = MeasureSpec(
element_name=input_measure.name,
non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get(
input_measure.measure_reference
),
)
spec = MetricInputMeasureSpec(
measure_spec=measure_spec,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
join_to_timespine=input_measure.join_to_timespine,
fill_nulls_with=input_measure.fill_nulls_with,
if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
pass
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
raise ValueError("This should only be called for base metrics.")
else:
assert_values_exhausted(metric.type)

assert (
len(metric.input_measures) == 1
), f"A base metric should not have multiple measures. Got{metric.input_measures}"

input_measure = metric.input_measures[0]

measure_spec = MeasureSpec(
element_name=input_measure.name,
non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get(
input_measure.measure_reference
),
)

before_aggregation_time_spine_join_description = None
# If querying an offset metric, join to time spine.
if child_metric_offset_window is not None or child_metric_offset_to_grain is not None:
before_aggregation_time_spine_join_description = JoinToTimeSpineDescription(
join_type=SqlJoinType.INNER,
offset_window=child_metric_offset_window,
offset_to_grain=child_metric_offset_to_grain,
)
input_measure_specs.append(spec)

return tuple(input_measure_specs)
# Even if the measure is configured to join to time spine, if there's no metric_time in the query,
# there's no need to join to the time spine since all metric_time will be aggregated.
after_aggregation_time_spine_join_description = None
if input_measure.join_to_timespine and query_contains_metric_time:
after_aggregation_time_spine_join_description = JoinToTimeSpineDescription(
join_type=SqlJoinType.LEFT_OUTER,
offset_window=None,
offset_to_grain=None,
)

return MetricInputMeasureSpec(
measure_spec=measure_spec,
fill_nulls_with=input_measure.fill_nulls_with,
offset_window=child_metric_offset_window,
offset_to_grain=child_metric_offset_to_grain,
culmination_description=culmination_description,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
before_aggregation_time_spine_join_description=before_aggregation_time_spine_join_description,
after_aggregation_time_spine_join_description=after_aggregation_time_spine_join_description,
)

def build_aggregated_measure(
self,
metric_input_measure_spec: MetricInputMeasureSpec,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
cumulative: Optional[bool] = False,
cumulative_window: Optional[MetricTimeWindow] = None,
cumulative_grain_to_date: Optional[TimeGranularity] = None,
) -> BaseOutput:
"""Returns a node where the measures are aggregated by the linkable specs and constrained appropriately.
This might be a node representing a single aggregation over one semantic model, or a node representing
a composite set of aggregations originating from multiple semantic models, and joined into a single
aggregated set of measures.
"""
measure_spec = metric_input_measure_spec.measure_spec
measure_constraint = metric_input_measure_spec.constraint
logger.info(f"Building aggregated measure: {metric_input_measure_spec} with constraint: {measure_constraint}")

logger.info(f"Building aggregated measure: {measure_spec} with constraint: {measure_constraint}")
if measure_constraint is None:
node_where_constraint = where_constraint
elif where_constraint is None:
Expand All @@ -702,33 +768,30 @@ def build_aggregated_measure(

return self._build_aggregated_measure_from_measure_source_node(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=node_where_constraint,
time_range_constraint=time_range_constraint,
cumulative=cumulative,
cumulative_window=cumulative_window,
cumulative_grain_to_date=cumulative_grain_to_date,
)

def _build_aggregated_measure_from_measure_source_node(
self,
metric_input_measure_spec: MetricInputMeasureSpec,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
cumulative: Optional[bool] = False,
cumulative_window: Optional[MetricTimeWindow] = None,
cumulative_grain_to_date: Optional[TimeGranularity] = None,
) -> BaseOutput:
metric_time_dimension_specs = [
time_dimension_spec
for time_dimension_spec in queried_linkable_specs.time_dimension_specs
if time_dimension_spec.element_name == self._metric_time_dimension_reference.element_name
]
metric_time_dimension_requested = len(metric_time_dimension_specs) > 0
measure_spec = metric_input_measure_spec.measure_spec
cumulative = metric_input_measure_spec.culmination_description is not None
cumulative_window = (
metric_input_measure_spec.culmination_description.cumulative_window
if metric_input_measure_spec.culmination_description is not None
else None
)
cumulative_grain_to_date = (
metric_input_measure_spec.culmination_description.cumulative_grain_to_date
if metric_input_measure_spec.culmination_description
else None
)
measure_properties = self._build_measure_spec_properties([measure_spec])
non_additive_dimension_spec = measure_properties.non_additive_dimension_spec

Expand Down Expand Up @@ -787,7 +850,7 @@ def _build_aggregated_measure_from_measure_source_node(
# If a cumulative metric is queried with metric_time, join over time range.
# Otherwise, the measure will be aggregated over all time.
time_range_node: Optional[JoinOverTimeRangeNode] = None
if cumulative and metric_time_dimension_requested:
if cumulative and queried_linkable_specs.contains_metric_time:
time_range_node = JoinOverTimeRangeNode(
parent_node=measure_recipe.source_node,
window=cumulative_window,
Expand All @@ -797,15 +860,25 @@ def _build_aggregated_measure_from_measure_source_node(

# If querying an offset metric, join to time spine.
join_to_time_spine_node: Optional[JoinToTimeSpineNode] = None
if metric_spec.offset_window or metric_spec.offset_to_grain:
assert metric_time_dimension_specs, "Joining to time spine requires querying with metric time."

before_aggregation_time_spine_join_description = (
metric_input_measure_spec.before_aggregation_time_spine_join_description
)
if before_aggregation_time_spine_join_description is not None:
assert (
queried_linkable_specs.contains_metric_time
), "Joining to time spine requires querying with metric time."
assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, (
f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a "
f"new use case."
)
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=time_range_node or measure_recipe.source_node,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
requested_metric_time_dimension_specs=list(queried_linkable_specs.metric_time_specs),
time_range_constraint=time_range_constraint,
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
)

# Only get the required measure and the local linkable instances so that aggregations work correctly.
Expand Down Expand Up @@ -839,7 +912,7 @@ def _build_aggregated_measure_from_measure_source_node(
if (
cumulative_metric_adjusted_time_constraint is not None
and time_range_constraint is not None
and metric_time_dimension_requested
and queried_linkable_specs.contains_metric_time
):
cumulative_metric_constrained_node = ConstrainTimeRangeNode(
unaggregated_measure_node, time_range_constraint
Expand Down Expand Up @@ -890,14 +963,21 @@ def _build_aggregated_measure_from_measure_source_node(
parent_node=pre_aggregate_node,
metric_input_measure_specs=(metric_input_measure_spec,),
)

# Only join to time spine if metric time was requested in the query.
if metric_input_measure_spec.join_to_timespine and metric_time_dimension_requested:
after_aggregation_time_spine_join_description = (
metric_input_measure_spec.after_aggregation_time_spine_join_description
)
if after_aggregation_time_spine_join_description is not None:
assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, (
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
return JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
requested_metric_time_dimension_specs=list(queried_linkable_specs.metric_time_specs),
join_type=after_aggregation_time_spine_join_description.join_type,
time_range_constraint=time_range_constraint,
join_type=SqlJoinType.LEFT_OUTER,
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
)
else:
return aggregate_measures_node
6 changes: 5 additions & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@ class AggregateMeasuresNode(AggregatedMeasuresOutput):
constraints applied to the measure.
"""

def __init__(self, parent_node: BaseOutput, metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...]) -> None:
def __init__(
self,
parent_node: BaseOutput,
metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...],
) -> None:
"""Initializer for AggregateMeasuresNode.
The input measure specs are required for downstream nodes to be aware of any input measures with
Expand Down
Loading

0 comments on commit 274d987

Please sign in to comment.