Skip to content

Commit

Permalink
fixup! Add new dataflow plan nodes for custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Dec 18, 2024
1 parent aeb12d0 commit c19072d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
24 changes: 19 additions & 5 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from dbt_semantic_interfaces.type_enums import DatePart
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.instances import EntityInstance, InstanceSet, MdoInstance, TimeDimensionInstance
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
Expand All @@ -12,6 +13,7 @@
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from typing_extensions import override

from metricflow.dataset.dataset_classes import DataSet
Expand Down Expand Up @@ -165,18 +167,30 @@ def instance_for_spec(self, spec: InstanceSpec) -> MdoInstance:
)

def instance_from_time_dimension_grain_and_date_part(
self, time_dimension_spec: TimeDimensionSpec
self, time_granularity_name: str, date_part: Optional[DatePart]
) -> TimeDimensionInstance:
"""Find instance in dataset that matches the grain and date part of the given time dimension spec."""
"""Find instance in dataset that matches the given grain and date part."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity
and time_dimension_instance.spec.date_part == time_dimension_spec.date_part
time_dimension_instance.spec.time_granularity.name == time_granularity_name
and time_dimension_instance.spec.date_part == date_part
and time_dimension_instance.spec.window_function is None
):
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n"
f"Did not find a time dimension instance with grain '{time_granularity_name}' and date part {date_part}\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

def instance_from_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionInstance:
"""Find instance in dataset that matches the given window function."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if time_dimension_instance.spec.window_function is window_function:
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with window function {window_function}.\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

Expand Down
6 changes: 3 additions & 3 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,9 +1276,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down
12 changes: 11 additions & 1 deletion metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag
from metricflow_semantics.sql.sql_exprs import SqlExpressionNode
from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT
Expand Down Expand Up @@ -102,6 +102,16 @@ class SqlSelectColumn:
# Always require a column alias for simplicity.
column_alias: str

@staticmethod
def from_table_and_column_names(table_alias: str, column_name: str) -> SqlSelectColumn:
"""Create a column that selects a column from a table by name."""
return SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
column_name=column_name, table_alias=table_alias
),
column_alias=column_name,
)


@dataclass(frozen=True)
class SqlJoinDescription:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -32,6 +33,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -114,6 +116,12 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> int: # noqa: D102
return self._sum_parents(node)

def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> int: # noqa: D102
return self._sum_parents(node)

def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> int: # noqa: D102
return self._sum_parents(node)

def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102
return dataflow_plan.sink_node.accept(self)

Expand Down

0 comments on commit c19072d

Please sign in to comment.