Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate Measure Time Spine Join Conditional Behavior Into MetricInputMeasureSpec #879

Merged
merged 1 commit into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 149 additions & 69 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
from metricflow.plan_conversion.node_processor import PreJoinNodeProcessor
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
CumulativeMeasureDescription,
InstanceSpecSet,
JoinToTimeSpineDescription,
LinkableInstanceSpec,
LinkableSpecSet,
LinklessEntitySpec,
Expand Down Expand Up @@ -174,12 +176,19 @@ def _build_base_metric_output_node(
"""Builds a node to compute a metric that is not defined from other metrics."""
metric_reference = metric_spec.reference
metric = self._metric_lookup.get_metric(metric_reference)
metric_input_measure_specs = self._measures_for_metric(
metric_input_measure_spec = self._build_input_measure_spec_for_base_metric(
metric_reference=metric_reference,
column_association_resolver=self._column_association_resolver,
query_contains_metric_time=queried_linkable_specs.contains_metric_time,
child_metric_offset_window=metric_spec.offset_window,
child_metric_offset_to_grain=metric_spec.offset_to_grain,
culmination_description=CumulativeMeasureDescription(
cumulative_window=metric.type_params.window,
cumulative_grain_to_date=metric.type_params.grain_to_date,
)
if metric.type is MetricType.CUMULATIVE
else None,
)
assert len(metric_input_measure_specs) == 1, "Simple and cumulative metrics must have one input measure."
metric_input_measure_spec = metric_input_measure_specs[0]

logger.info(
f"For {metric_spec}, needed measure is:\n"
Expand All @@ -190,19 +199,13 @@ def _build_base_metric_output_node(
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
)

aggregated_measures_node = self.build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
time_range_constraint=time_range_constraint,
cumulative=metric.type == MetricType.CUMULATIVE,
cumulative_window=metric.type_params.window if metric.type == MetricType.CUMULATIVE else None,
cumulative_grain_to_date=(
metric.type_params.grain_to_date if metric.type == MetricType.CUMULATIVE else None
),
)

return self.build_computed_metrics_node(
metric_spec=metric_spec,
aggregated_measures_node=aggregated_measures_node,
Expand All @@ -225,13 +228,42 @@ def _build_derived_metric_output_node(
f"For {metric.type} metric: {metric_spec}, needed metrics are:\n"
f"{pformat_big_objects(metric_input_specs=metric_input_specs)}"
)

parent_nodes: List[BaseOutput] = []

for metric_input_spec in metric_input_specs:
# TODO: See: https://github.com/dbt-labs/metricflow/issues/881
if (metric_spec.offset_window is not None or metric_spec.offset_to_grain is not None) and (
metric_input_spec.offset_window is not None or metric_input_spec.offset_to_grain is not None
):
raise NotImplementedError(
f"Multiple descendent metrics in a derived metric hierarchy are not yet supported. "
f"For {metric_spec}, the parent metric input is {metric_input_spec}"
)

parent_nodes.append(
self._build_any_metric_output_node(
metric_spec=MetricSpec(
element_name=metric_input_spec.element_name,
constraint=metric_input_spec.constraint,
alias=metric_input_spec.alias,
offset_window=metric_input_spec.offset_window,
offset_to_grain=metric_input_spec.offset_to_grain,
),
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
)
)

if len(parent_nodes) == 1:
return ComputeMetricsNode(
parent_node=parent_nodes[0],
metric_specs=[metric_spec],
)

return ComputeMetricsNode(
parent_node=self._build_metrics_output_node(
metric_specs=metric_input_specs,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
),
parent_node=CombineAggregatedOutputsNode(parent_nodes=parent_nodes),
metric_specs=[metric_spec],
)

Expand Down Expand Up @@ -645,54 +677,88 @@ def build_computed_metrics_node(
metric_specs=[metric_spec],
)

def _measures_for_metric(
def _build_input_measure_spec_for_base_metric(
self,
metric_reference: MetricReference,
column_association_resolver: ColumnAssociationResolver,
) -> Sequence[MetricInputMeasureSpec]:
"""Return the measure specs required to compute the metric."""
child_metric_offset_window: Optional[MetricTimeWindow],
child_metric_offset_to_grain: Optional[TimeGranularity],
query_contains_metric_time: bool,
culmination_description: Optional[CumulativeMeasureDescription],
) -> MetricInputMeasureSpec:
"""Return the input measure spec required to compute the base metric."""
metric = self._metric_lookup.get_metric(metric_reference)
input_measure_specs: List[MetricInputMeasureSpec] = []

for input_measure in metric.input_measures:
measure_spec = MeasureSpec(
element_name=input_measure.name,
non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get(
input_measure.measure_reference
),
)
spec = MetricInputMeasureSpec(
measure_spec=measure_spec,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
join_to_timespine=input_measure.join_to_timespine,
fill_nulls_with=input_measure.fill_nulls_with,
if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
pass
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
raise ValueError("This should only be called for base metrics.")
else:
assert_values_exhausted(metric.type)

assert (
len(metric.input_measures) == 1
), f"A base metric should not have multiple measures. Got{metric.input_measures}"

input_measure = metric.input_measures[0]

measure_spec = MeasureSpec(
element_name=input_measure.name,
non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get(
input_measure.measure_reference
),
)

before_aggregation_time_spine_join_description = None
# If querying an offset metric, join to time spine.
if child_metric_offset_window is not None or child_metric_offset_to_grain is not None:
before_aggregation_time_spine_join_description = JoinToTimeSpineDescription(
join_type=SqlJoinType.INNER,
offset_window=child_metric_offset_window,
offset_to_grain=child_metric_offset_to_grain,
)
input_measure_specs.append(spec)

return tuple(input_measure_specs)
# Even if the measure is configured to join to time spine, if there's no metric_time in the query,
# there's no need to join to the time spine since all metric_time will be aggregated.
after_aggregation_time_spine_join_description = None
if input_measure.join_to_timespine and query_contains_metric_time:
after_aggregation_time_spine_join_description = JoinToTimeSpineDescription(
join_type=SqlJoinType.LEFT_OUTER,
offset_window=None,
offset_to_grain=None,
)

return MetricInputMeasureSpec(
measure_spec=measure_spec,
fill_nulls_with=input_measure.fill_nulls_with,
offset_window=child_metric_offset_window,
offset_to_grain=child_metric_offset_to_grain,
culmination_description=culmination_description,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
before_aggregation_time_spine_join_description=before_aggregation_time_spine_join_description,
after_aggregation_time_spine_join_description=after_aggregation_time_spine_join_description,
)

def build_aggregated_measure(
self,
metric_input_measure_spec: MetricInputMeasureSpec,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
cumulative: Optional[bool] = False,
cumulative_window: Optional[MetricTimeWindow] = None,
cumulative_grain_to_date: Optional[TimeGranularity] = None,
) -> BaseOutput:
"""Returns a node where the measures are aggregated by the linkable specs and constrained appropriately.

This might be a node representing a single aggregation over one semantic model, or a node representing
a composite set of aggregations originating from multiple semantic models, and joined into a single
aggregated set of measures.
"""
measure_spec = metric_input_measure_spec.measure_spec
measure_constraint = metric_input_measure_spec.constraint
logger.info(f"Building aggregated measure: {metric_input_measure_spec} with constraint: {measure_constraint}")

logger.info(f"Building aggregated measure: {measure_spec} with constraint: {measure_constraint}")
if measure_constraint is None:
node_where_constraint = where_constraint
elif where_constraint is None:
Expand All @@ -702,33 +768,30 @@ def build_aggregated_measure(

return self._build_aggregated_measure_from_measure_source_node(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=node_where_constraint,
time_range_constraint=time_range_constraint,
cumulative=cumulative,
cumulative_window=cumulative_window,
cumulative_grain_to_date=cumulative_grain_to_date,
)

def _build_aggregated_measure_from_measure_source_node(
self,
metric_input_measure_spec: MetricInputMeasureSpec,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
cumulative: Optional[bool] = False,
cumulative_window: Optional[MetricTimeWindow] = None,
cumulative_grain_to_date: Optional[TimeGranularity] = None,
) -> BaseOutput:
metric_time_dimension_specs = [
time_dimension_spec
for time_dimension_spec in queried_linkable_specs.time_dimension_specs
if time_dimension_spec.element_name == self._metric_time_dimension_reference.element_name
]
metric_time_dimension_requested = len(metric_time_dimension_specs) > 0
measure_spec = metric_input_measure_spec.measure_spec
cumulative = metric_input_measure_spec.culmination_description is not None
cumulative_window = (
metric_input_measure_spec.culmination_description.cumulative_window
if metric_input_measure_spec.culmination_description is not None
else None
)
cumulative_grain_to_date = (
metric_input_measure_spec.culmination_description.cumulative_grain_to_date
if metric_input_measure_spec.culmination_description
else None
)
measure_properties = self._build_measure_spec_properties([measure_spec])
non_additive_dimension_spec = measure_properties.non_additive_dimension_spec

Expand Down Expand Up @@ -787,7 +850,7 @@ def _build_aggregated_measure_from_measure_source_node(
# If a cumulative metric is queried with metric_time, join over time range.
# Otherwise, the measure will be aggregated over all time.
time_range_node: Optional[JoinOverTimeRangeNode] = None
if cumulative and metric_time_dimension_requested:
if cumulative and queried_linkable_specs.contains_metric_time:
time_range_node = JoinOverTimeRangeNode(
parent_node=measure_recipe.source_node,
window=cumulative_window,
Expand All @@ -797,15 +860,25 @@ def _build_aggregated_measure_from_measure_source_node(

# If querying an offset metric, join to time spine.
join_to_time_spine_node: Optional[JoinToTimeSpineNode] = None
if metric_spec.offset_window or metric_spec.offset_to_grain:
assert metric_time_dimension_specs, "Joining to time spine requires querying with metric time."

before_aggregation_time_spine_join_description = (
metric_input_measure_spec.before_aggregation_time_spine_join_description
)
if before_aggregation_time_spine_join_description is not None:
assert (
queried_linkable_specs.contains_metric_time
), "Joining to time spine requires querying with metric time."
assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, (
f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a "
f"new use case."
)
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=time_range_node or measure_recipe.source_node,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
requested_metric_time_dimension_specs=list(queried_linkable_specs.metric_time_specs),
time_range_constraint=time_range_constraint,
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
join_type=SqlJoinType.INNER,
offset_window=before_aggregation_time_spine_join_description.offset_window,
offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain,
join_type=before_aggregation_time_spine_join_description.join_type,
)

# Only get the required measure and the local linkable instances so that aggregations work correctly.
Expand Down Expand Up @@ -839,7 +912,7 @@ def _build_aggregated_measure_from_measure_source_node(
if (
cumulative_metric_adjusted_time_constraint is not None
and time_range_constraint is not None
and metric_time_dimension_requested
and queried_linkable_specs.contains_metric_time
):
cumulative_metric_constrained_node = ConstrainTimeRangeNode(
unaggregated_measure_node, time_range_constraint
Expand Down Expand Up @@ -890,14 +963,21 @@ def _build_aggregated_measure_from_measure_source_node(
parent_node=pre_aggregate_node,
metric_input_measure_specs=(metric_input_measure_spec,),
)

# Only join to time spine if metric time was requested in the query.
if metric_input_measure_spec.join_to_timespine and metric_time_dimension_requested:
after_aggregation_time_spine_join_description = (
metric_input_measure_spec.after_aggregation_time_spine_join_description
)
if after_aggregation_time_spine_join_description is not None:
assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, (
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
return JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_metric_time_dimension_specs=metric_time_dimension_specs,
requested_metric_time_dimension_specs=list(queried_linkable_specs.metric_time_specs),
join_type=after_aggregation_time_spine_join_description.join_type,
time_range_constraint=time_range_constraint,
join_type=SqlJoinType.LEFT_OUTER,
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
)
else:
return aggregate_measures_node
6 changes: 5 additions & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@ class AggregateMeasuresNode(AggregatedMeasuresOutput):
constraints applied to the measure.
"""

def __init__(self, parent_node: BaseOutput, metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...]) -> None:
def __init__(
self,
parent_node: BaseOutput,
metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...],
) -> None:
"""Initializer for AggregateMeasuresNode.

The input measure specs are required for downstream nodes to be aware of any input measures with
Expand Down
Loading