Skip to content

Commit

Permalink
Finish
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 25, 2024
1 parent cf88ff6 commit 8ea4578
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 249 deletions.
77 changes: 32 additions & 45 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,9 +796,7 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
f"semantic models: {semantic_models}. This suggests the measure_specs were not correctly filtered."
)

agg_time_dimension = agg_time_dimension = self._semantic_model_lookup.get_agg_time_dimension_for_measure(
measure_specs[0].reference
)
agg_time_dimension = self._semantic_model_lookup.get_agg_time_dimension_for_measure(measure_specs[0].reference)
non_additive_dimension_spec = measure_specs[0].non_additive_dimension_spec
for measure_spec in measure_specs:
if non_additive_dimension_spec != measure_spec.non_additive_dimension_spec:
Expand Down Expand Up @@ -1322,47 +1320,38 @@ def _build_aggregated_measure_from_measure_source_node(
f"Recipe not found for measure spec: {measure_spec} and linkable specs: {required_linkable_specs}"
)

queried_agg_time_dimension_specs = list(queried_linkable_specs.metric_time_specs)
if not queried_agg_time_dimension_specs:
valid_agg_time_dimensions = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure(
measure_spec.reference
)
queried_agg_time_dimension_specs = list(
set(queried_linkable_specs.time_dimension_specs).intersection(set(valid_agg_time_dimensions))
)

# If a cumulative metric is queried with agg_time_dimension, join over time range.
# Otherwise, the measure will be aggregated over all time.
time_range_node: Optional[JoinOverTimeRangeNode] = None
if cumulative:
queried_metric_time_spec = queried_linkable_specs.metric_time_spec_with_smallest_granularity
if not queried_metric_time_spec:
valid_agg_time_dimensions = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure(
measure_spec.reference
)
# TODO: will it be a problem if we get one with date part or diff granularity? Write test case
queried_agg_time_dims = sorted(
set(queried_linkable_specs.time_dimension_specs).intersection(set(valid_agg_time_dimensions)),
key=lambda x: x.time_granularity.to_int(),
)
if queried_agg_time_dims:
queried_metric_time_spec = queried_agg_time_dims[0]

if queried_metric_time_spec:
time_range_node = JoinOverTimeRangeNode(
parent_node=measure_recipe.source_node,
metric_time_dimension_spec=queried_metric_time_spec,
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
time_range_constraint=time_range_constraint
if not before_aggregation_time_spine_join_description
else None,
)
if cumulative and queried_agg_time_dimension_specs:
# TODO: will it be a problem if we get one with date part or diff granularity? Write test case to confirm
# Use the time dimension spec with the smallest granularity.
agg_time_dimension_spec_for_join = sorted(
queried_agg_time_dimension_specs, key=lambda spec: spec.time_granularity.to_int()
)[0]
time_range_node = JoinOverTimeRangeNode(
parent_node=measure_recipe.source_node,
# TODO: rename param
metric_time_dimension_spec=agg_time_dimension_spec_for_join,
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
time_range_constraint=time_range_constraint
if not before_aggregation_time_spine_join_description
else None,
)

# If querying an offset metric, join to time spine before aggregation.
join_to_time_spine_node: Optional[JoinToTimeSpineNode] = None
if before_aggregation_time_spine_join_description is not None:
# TODO: below logic is somewhat duplicated
queried_agg_time_dimension_specs = list(queried_linkable_specs.metric_time_specs)
if not queried_agg_time_dimension_specs:
valid_agg_time_dimensions = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure(
measure_spec.reference
)
queried_agg_time_dimension_specs = list(
set(queried_linkable_specs.time_dimension_specs).intersection(set(valid_agg_time_dimensions))
)

assert queried_agg_time_dimension_specs, (
"Joining to time spine requires querying with metric time or the appropriate agg_time_dimension."
"This should have been caught by validations."
Expand Down Expand Up @@ -1408,19 +1397,17 @@ def _build_aggregated_measure_from_measure_source_node(
else:
unaggregated_measure_node = filtered_measure_source_node

query_contains_metric_time_or_agg_time_dimension = queried_linkable_specs.contains_metric_time
if not query_contains_metric_time_or_agg_time_dimension:
pass # check for agg_time_dimension and update accordingly
# Write a test case for this scenario

# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
# here. Can skip if metric is being aggregated over all time.
cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None
if (
cumulative_metric_adjusted_time_constraint is not None
and time_range_constraint is not None
and query_contains_metric_time_or_agg_time_dimension
and queried_agg_time_dimension_specs
):
cumulative_metric_constrained_node = ConstrainTimeRangeNode(
unaggregated_measure_node, time_range_constraint
parent_node=unaggregated_measure_node,
time_range_constraint=time_range_constraint,
)

pre_aggregate_node: BaseOutput = cumulative_metric_constrained_node or unaggregated_measure_node
Expand All @@ -1439,7 +1426,7 @@ def _build_aggregated_measure_from_measure_source_node(
queried_time_dimension_spec: Optional[
TimeDimensionSpec
] = self._find_non_additive_dimension_in_linkable_specs(
agg_time_dimension=agg_time_dimension,
agg_time_dimension=TimeDimensionReference(agg_time_dimension.element_name),
linkable_specs=queried_linkable_specs.as_tuple,
non_additive_dimension_spec=non_additive_dimension_spec,
)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> SqlDa
instead of this: DATE_TRUNC('month', ds) >= '2020-01-01' AND DATE_TRUNC('month', ds <= '2020-02-01')
"""
from_data_set: SqlDataSet = node.parent_node.accept(self)
from_data_set = node.parent_node.accept(self)
from_data_set_alias = self._next_unique_table_alias()

time_dimension_instances_for_metric_time = sorted(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ integration_test:
metrics: ["trailing_2_months_revenue"]
group_bys: ["revenue_instance__ds__day"]
order_bys: ["revenue_instance__ds__day"]
time_constraint: ["2020-03-05", "2021-01-04"]
where_filter: '{{ render_time_constraint("revenue_instance__ds__day", "2020-03-05", "2021-01-04") }}'
check_query: |
SELECT
SUM(b.txn_revenue) as trailing_2_months_revenue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,6 @@ def test_cumulative_metric_with_agg_time_dimension(
time_dimension_specs=(
TimeDimensionSpec(element_name="ds", entity_links=(EntityReference("revenue_instance"),)),
),
time_range_constraint=TimeRangeConstraint(
start_time=as_datetime("2020-03-05"), end_time=as_datetime("2021-01-04")
),
)
)

Expand Down
Loading

0 comments on commit 8ea4578

Please sign in to comment.