Skip to content

Commit

Permalink
Replace TimeDimensionTransformNode with AliasSpecsNode
Browse files Browse the repository at this point in the history
Also fixes a small issue where we were prematurely filtering out columns in that node. This is what you see in the snapshot changes. It should not impact optimized queries.
  • Loading branch information
courtneyholcomb committed Dec 11, 2024
1 parent 1beeb58 commit da044a3
Show file tree
Hide file tree
Showing 81 changed files with 1,411 additions and 556 deletions.
1 change: 1 addition & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_ADD_UUID_COLUMN_PREFIX = "auid"
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr"
DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand Down
74 changes: 74 additions & 0 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT:
"""See Visitable."""
raise NotImplementedError()

def with_new_spec(self, new_spec: SpecT, column_association_resolver: ColumnAssociationResolver) -> MdoInstance:
"""Returns a new instance with the spec replaced."""
raise NotImplementedError()


class LinkableInstance(MdoInstance, Generic[SpecT]):
"""An MdoInstance whose spec is linkable (i.e., it can have entity links)."""
Expand Down Expand Up @@ -105,6 +109,17 @@ class MeasureInstance(MdoInstance[MeasureSpec], SemanticModelElementInstance):
def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_measure_instance(self)

def with_new_spec(
self, new_spec: MeasureSpec, column_association_resolver: ColumnAssociationResolver
) -> MeasureInstance:
"""Returns a new instance with the spec replaced."""
return MeasureInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
defined_from=self.defined_from,
spec=new_spec,
aggregation_state=self.aggregation_state,
)


@dataclass(frozen=True)
class DimensionInstance(LinkableInstance[DimensionSpec], SemanticModelElementInstance): # noqa: D101
Expand All @@ -125,6 +140,16 @@ def with_entity_prefix(
spec=transformed_spec,
)

def with_new_spec(
self, new_spec: DimensionSpec, column_association_resolver: ColumnAssociationResolver
) -> DimensionInstance:
"""Returns a new instance with the spec replaced."""
return DimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
defined_from=self.defined_from,
spec=new_spec,
)


@dataclass(frozen=True)
class TimeDimensionInstance(LinkableInstance[TimeDimensionSpec], SemanticModelElementInstance): # noqa: D101
Expand All @@ -151,6 +176,16 @@ def with_new_defined_from(self, defined_from: Sequence[SemanticModelElementRefer
associated_columns=self.associated_columns, defined_from=tuple(defined_from), spec=self.spec
)

def with_new_spec(
self, new_spec: TimeDimensionSpec, column_association_resolver: ColumnAssociationResolver
) -> TimeDimensionInstance:
"""Returns a new instance with the spec replaced."""
return TimeDimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
defined_from=self.defined_from,
spec=new_spec,
)


@dataclass(frozen=True)
class EntityInstance(LinkableInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
Expand All @@ -171,6 +206,16 @@ def with_entity_prefix(
spec=transformed_spec,
)

def with_new_spec(
self, new_spec: EntitySpec, column_association_resolver: ColumnAssociationResolver
) -> EntityInstance:
"""Returns a new instance with the spec replaced."""
return EntityInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
defined_from=self.defined_from,
spec=new_spec,
)


@dataclass(frozen=True)
class GroupByMetricInstance(LinkableInstance[GroupByMetricSpec], SerializableDataclass): # noqa: D101
Expand All @@ -192,6 +237,16 @@ def with_entity_prefix(
spec=transformed_spec,
)

def with_new_spec(
self, new_spec: GroupByMetricSpec, column_association_resolver: ColumnAssociationResolver
) -> GroupByMetricInstance:
"""Returns a new instance with the spec replaced."""
return GroupByMetricInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
defined_from=self.defined_from,
spec=new_spec,
)


@dataclass(frozen=True)
class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D101
Expand All @@ -202,6 +257,16 @@ class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D
def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_metric_instance(self)

def with_new_spec(
self, new_spec: MetricSpec, column_association_resolver: ColumnAssociationResolver
) -> MetricInstance:
"""Returns a new instance with the spec replaced."""
return MetricInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
defined_from=self.defined_from,
spec=new_spec,
)


@dataclass(frozen=True)
class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noqa: D101
Expand All @@ -211,6 +276,15 @@ class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noq
def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_metadata_instance(self)

def with_new_spec(
self, new_spec: MetadataSpec, column_association_resolver: ColumnAssociationResolver
) -> MetadataInstance:
"""Returns a new instance with the spec replaced."""
return MetadataInstance(
associated_columns=(column_association_resolver.resolve_spec(new_spec),),
spec=new_spec,
)


# Output type of transform function
TransformOutputT = TypeVar("TransformOutputT")
Expand Down
18 changes: 14 additions & 4 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
)
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.alias_specs import AliasSpecsNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
Expand All @@ -94,7 +95,6 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -1873,9 +1873,19 @@ def _build_time_spine_node(
# TODO: support multiple time spines here. Build node on the one with the smallest base grain.
# Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine.
time_spine_source = self._choose_time_spine_source(required_time_spine_specs)
time_spine_node = TransformTimeDimensionsNode.create(
parent_node=self._choose_time_spine_read_node(time_spine_source),
requested_time_dimension_specs=required_time_spine_specs,
read_node = self._choose_time_spine_read_node(time_spine_source)
time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node)

# Change the column aliases to match the specs that were requested in the query.
time_spine_node = AliasSpecsNode.create(
parent_node=read_node,
change_specs=tuple(
(
time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec,
required_spec,
)
for required_spec in required_time_spine_specs
),
)

# If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping.
Expand Down
6 changes: 3 additions & 3 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.alias_specs import AliasSpecsNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
Expand All @@ -25,7 +26,6 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -123,7 +123,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
raise NotImplementedError

@abstractmethod
def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> VisitorOutputT: # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


Expand Down Expand Up @@ -220,5 +220,5 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
return self._default_handler(node)

@override
def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> VisitorOutputT: # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)
59 changes: 59 additions & 0 deletions metricflow/dataflow/nodes/alias_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence, Tuple

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class AliasSpecsNode(DataflowPlanNode, ABC):
"""Change the columns matching the key specs to match the value specs."""

change_specs: Sequence[Tuple[InstanceSpec, InstanceSpec]]

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.change_specs) > 0, "Must have at least one value in change_specs for AliasSpecsNode."

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, change_specs: Sequence[Tuple[InstanceSpec, InstanceSpec]]
) -> AliasSpecsNode:
return AliasSpecsNode(parent_nodes=(parent_node,), change_specs=change_specs)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_alias_specs_node(self)

@property
def description(self) -> str: # noqa: D102
return """Change Column Aliases"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("change_specs", self.change_specs),)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.change_specs == self.change_specs

def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> AliasSpecsNode: # noqa: D102
assert len(new_parent_nodes) == 1, "AliasSpecsNode accepts exactly one parent node."
return AliasSpecsNode.create(
parent_node=new_parent_nodes[0],
change_specs=self.change_specs,
)
74 changes: 0 additions & 74 deletions metricflow/dataflow/nodes/transform_time_dimensions.py

This file was deleted.

6 changes: 2 additions & 4 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.alias_specs import AliasSpecsNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
Expand All @@ -33,7 +34,6 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -469,9 +469,7 @@ def visit_join_to_custom_granularity_node( # noqa: D102
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_transform_time_dimensions_node( # noqa: D102
self, node: TransformTimeDimensionsNode
) -> OptimizeBranchResult:
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.alias_specs import AliasSpecsNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
Expand All @@ -27,7 +28,6 @@
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.transform_time_dimensions import TransformTimeDimensionsNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
Expand Down Expand Up @@ -469,8 +469,6 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_transform_time_dimensions_node( # noqa: D102
self, node: TransformTimeDimensionsNode
) -> ComputeMetricsBranchCombinerResult:
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)
Loading

0 comments on commit da044a3

Please sign in to comment.