Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify nodes that join to time spine #1535

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class MdoInstance(ABC, Generic[SpecT]):
@property
def associated_column(self) -> ColumnAssociation:
"""Helper for getting the associated column until support for multiple associated columns is added."""
assert len(self.associated_columns) == 1
assert (
len(self.associated_columns) == 1
), f"Expected exactly one column for {self.__class__.__name__}, but got {self.associated_columns}"
return self.associated_columns[0]

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,28 @@ def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
)
)

def add_specs(
self,
dimension_specs: Tuple[DimensionSpec, ...] = (),
time_dimension_specs: Tuple[TimeDimensionSpec, ...] = (),
entity_specs: Tuple[EntitySpec, ...] = (),
group_by_metric_specs: Tuple[GroupByMetricSpec, ...] = (),
) -> LinkableSpecSet:
"""Return a new set with the new specs in addition to the existing ones."""
return LinkableSpecSet(
dimension_specs=self.dimension_specs + dimension_specs,
time_dimension_specs=self.time_dimension_specs + time_dimension_specs,
entity_specs=self.entity_specs + entity_specs,
group_by_metric_specs=self.group_by_metric_specs + group_by_metric_specs,
)

@override
def merge(self, other: LinkableSpecSet) -> LinkableSpecSet:
return LinkableSpecSet(
dimension_specs=self.dimension_specs + other.dimension_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs,
return self.add_specs(
dimension_specs=other.dimension_specs,
time_dimension_specs=other.time_dimension_specs,
entity_specs=other.entity_specs,
group_by_metric_specs=other.group_by_metric_specs,
)

@classmethod
Expand Down
46 changes: 20 additions & 26 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def _build_query_output_node(
where_filter_specs=(),
pushdown_enabled_types=frozenset({PredicateInputType.TIME_RANGE_CONSTRAINT}),
)

return self._build_metrics_output_node(
metric_specs=tuple(
MetricSpec(
Expand Down Expand Up @@ -236,6 +235,13 @@ def _optimize_plan(self, plan: DataflowPlan, optimizations: FrozenSet[DataflowPl

return plan

def _get_minimum_metric_time_spec_for_metric(self, metric_reference: MetricReference) -> TimeDimensionSpec:
"""Gets the minimum metric time spec for the given metric reference."""
min_granularity = ExpandedTimeGranularity.from_time_granularity(
self._metric_lookup.get_min_queryable_time_granularity(metric_reference)
)
return DataSet.metric_time_dimension_spec(min_granularity)

def _build_aggregated_conversion_node(
self,
metric_spec: MetricSpec,
Expand Down Expand Up @@ -307,14 +313,11 @@ def _build_aggregated_conversion_node(
# Get the time dimension used to calculate the conversion window
# Currently, both the base/conversion measure uses metric_time as it's the default agg time dimension.
# However, eventually, there can be user-specified time dimensions used for this calculation.
default_granularity = ExpandedTimeGranularity.from_time_granularity(
self._metric_lookup.get_min_queryable_time_granularity(metric_spec.reference)
)
metric_time_dimension_spec = DataSet.metric_time_dimension_spec(default_granularity)
min_metric_time_spec = self._get_minimum_metric_time_spec_for_metric(metric_spec.reference)

# Filter the source nodes with only the required specs needed for the calculation
constant_property_specs = []
required_local_specs = [base_measure_spec.measure_spec, entity_spec, metric_time_dimension_spec] + list(
required_local_specs = [base_measure_spec.measure_spec, entity_spec, min_metric_time_spec] + list(
base_measure_recipe.required_local_linkable_specs.as_tuple
)
for constant_property in constant_properties or []:
Expand Down Expand Up @@ -345,10 +348,10 @@ def _build_aggregated_conversion_node(
# adjusted in the opposite direction.
join_conversion_node = JoinConversionEventsNode.create(
base_node=unaggregated_base_measure_node,
base_time_dimension_spec=metric_time_dimension_spec,
base_time_dimension_spec=min_metric_time_spec,
conversion_node=unaggregated_conversion_measure_node,
conversion_measure_spec=conversion_measure_spec.measure_spec,
conversion_time_dimension_spec=metric_time_dimension_spec,
conversion_time_dimension_spec=min_metric_time_spec,
unique_identifier_keys=(MetadataSpec(MetricFlowReservedKeywords.MF_INTERNAL_UUID.value),),
entity_spec=entity_spec,
window=window,
Expand Down Expand Up @@ -444,21 +447,19 @@ def _build_cumulative_metric_output_node(
predicate_pushdown_state: PredicatePushdownState,
for_group_by_source_node: bool = False,
) -> DataflowPlanNode:
# TODO: [custom granularity] Figure out how to support custom granularities as defaults
default_granularity = ExpandedTimeGranularity.from_time_granularity(
self._metric_lookup.get_min_queryable_time_granularity(metric_spec.reference)
)
min_metric_time_spec = self._get_minimum_metric_time_spec_for_metric(metric_spec.reference)
min_granularity = min_metric_time_spec.time_granularity

queried_agg_time_dimensions = queried_linkable_specs.included_agg_time_dimension_specs_for_metric(
metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup
)
query_includes_agg_time_dimension_with_default_granularity = False
query_includes_agg_time_dimension_with_min_granularity = False
for time_dimension_spec in queried_agg_time_dimensions:
if time_dimension_spec.time_granularity == default_granularity:
query_includes_agg_time_dimension_with_default_granularity = True
if time_dimension_spec.time_granularity == min_granularity:
query_includes_agg_time_dimension_with_min_granularity = True
break

if query_includes_agg_time_dimension_with_default_granularity or not queried_agg_time_dimensions:
if query_includes_agg_time_dimension_with_min_granularity or len(queried_agg_time_dimensions) == 0:
return self._build_base_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
Expand All @@ -467,14 +468,11 @@ def _build_cumulative_metric_output_node(
for_group_by_source_node=for_group_by_source_node,
)

# If a cumulative metric is queried without default granularity, it will need to be aggregated twice -
# If a cumulative metric is queried without its minimum granularity, it will need to be aggregated twice:
# once as a normal metric, and again using a window function to narrow down to one row per granularity period.
# In this case, add metric time at the default granularity to the linkable specs. It will be used in the order by
# clause of the window function and later excluded from the output selections.
default_metric_time = DataSet.metric_time_dimension_spec(default_granularity)
include_linkable_specs = queried_linkable_specs.merge(
LinkableSpecSet(time_dimension_specs=(default_metric_time,))
)
include_linkable_specs = queried_linkable_specs.add_specs(time_dimension_specs=(min_metric_time_spec,))
compute_metrics_node = self._build_base_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=include_linkable_specs,
Expand All @@ -485,7 +483,7 @@ def _build_cumulative_metric_output_node(
return WindowReaggregationNode.create(
parent_node=compute_metrics_node,
metric_spec=metric_spec,
order_by_spec=default_metric_time,
order_by_spec=min_metric_time_spec,
partition_by_specs=queried_linkable_specs.as_tuple,
)

Expand Down Expand Up @@ -1613,10 +1611,6 @@ def _build_aggregated_measure_from_measure_source_node(

# If querying an offset metric, join to time spine before aggregation.
if before_aggregation_time_spine_join_description is not None:
assert queried_agg_time_dimension_specs, (
"Joining to time spine requires querying with metric time or the appropriate agg_time_dimension."
"This should have been caught by validations."
)
assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, (
f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a "
f"new use case."
Expand Down
6 changes: 3 additions & 3 deletions metricflow/dataflow/nodes/join_over_time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.protocols import MetricTimeWindow
from dbt_semantic_interfaces.type_enums import TimeGranularity
Expand All @@ -26,7 +26,7 @@ class JoinOverTimeRangeNode(DataflowPlanNode):
time_range_constraint: Time range to aggregate over.
"""

queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...]
window: Optional[MetricTimeWindow]
grain_to_date: Optional[TimeGranularity]
time_range_constraint: Optional[TimeRangeConstraint]
Expand All @@ -38,7 +38,7 @@ def __post_init__(self) -> None: # noqa: D105
@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode,
queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec],
queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...],
window: Optional[MetricTimeWindow] = None,
grain_to_date: Optional[TimeGranularity] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
Expand Down
2 changes: 2 additions & 0 deletions metricflow/dataflow/nodes/join_to_time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC):
offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine.
"""

# TODO: rename property to required_agg_time_dimension_specs
requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec]
# TODO remove this property
use_custom_agg_time_dimension: bool
join_type: SqlJoinType
time_range_constraint: Optional[TimeRangeConstraint]
Expand Down
62 changes: 44 additions & 18 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from typing import List, Optional, Sequence
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.instances import EntityInstance, InstanceSet
from metricflow_semantics.instances import EntityInstance, InstanceSet, TimeDimensionInstance
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.specs.column_assoc import ColumnAssociation
from metricflow_semantics.specs.dimension_spec import DimensionSpec
Expand Down Expand Up @@ -122,32 +123,57 @@ def column_association_for_dimension(

return column_associations_to_return[0]

def column_association_for_time_dimension(
self,
time_dimension_spec: TimeDimensionSpec,
) -> ColumnAssociation:
"""Given the name of the time dimension, return the set of columns associated with it in the data set."""
def instances_for_time_dimensions(
self, time_dimension_specs: Sequence[TimeDimensionSpec]
) -> Tuple[TimeDimensionInstance, ...]:
"""Return the instances associated with these specs in the data set."""
time_dimension_specs_set = set(time_dimension_specs)
matching_instances = 0
column_associations_to_return = None
instances_to_return: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in self.instance_set.time_dimension_instances:
if time_dimension_instance.spec == time_dimension_spec:
column_associations_to_return = time_dimension_instance.associated_columns
if time_dimension_instance.spec in time_dimension_specs_set:
instances_to_return += (time_dimension_instance,)
matching_instances += 1

if matching_instances > 1:
if matching_instances != len(time_dimension_specs_set):
raise RuntimeError(
f"More than one time dimension instance with spec {time_dimension_spec} in "
f"instance set: {self.instance_set}"
f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs_set}\n"
f"Instances: {instances_to_return}"
)

if not column_associations_to_return:
raise RuntimeError(
f"No time dimension instances with spec {time_dimension_spec} in instance set: {self.instance_set}"
)
return instances_to_return

return column_associations_to_return[0]
def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> TimeDimensionInstance:
"""Given the name of the time dimension, return the instance associated with it in the data set."""
return self.instances_for_time_dimensions((time_dimension_spec,))[0]

def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation:
"""Given the name of the time dimension, return the set of columns associated with it in the data set."""
return self.instance_for_time_dimension(time_dimension_spec).associated_column

@property
@override
def semantic_model_reference(self) -> Optional[SemanticModelReference]:
return None

def annotate(self, alias: str, metric_time_spec: TimeDimensionSpec) -> AnnotatedSqlDataSet:
"""Convert to an AnnotatedSqlDataSet with specified metadata."""
metric_time_column_name = self.column_association_for_time_dimension(metric_time_spec).column_name
return AnnotatedSqlDataSet(data_set=self, alias=alias, _metric_time_column_name=metric_time_column_name)


@dataclass(frozen=True)
class AnnotatedSqlDataSet:
"""Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan."""

data_set: SqlDataSet
alias: str
_metric_time_column_name: Optional[str] = None

@property
def metric_time_column_name(self) -> str:
"""Direct accessor for the optional metric time name, only safe to call when we know that value is set."""
assert (
self._metric_time_column_name
), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!"
return self._metric_time_column_name
Loading
Loading