Skip to content

Commit

Permalink
Update node initialization callsites to use .create().
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jul 10, 2024
1 parent a153235 commit b68a0fd
Show file tree
Hide file tree
Showing 40 changed files with 802 additions and 721 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met

# Define a resolution path for issues where the input is considered to be the whole query.
query_resolution_path = MetricFlowQueryResolutionPath.from_path_item(
QueryGroupByItemResolutionNode(
QueryGroupByItemResolutionNode.create(
parent_nodes=(),
metrics_in_query=tuple(metric_input.spec_pattern.metric_reference for metric_input in metric_inputs),
where_filter_intersection=query_level_filter_input.where_filter_intersection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_missing_parent_for_metric(
or measures). However, in the event of a validation gap upstream, we sometimes encounter inscrutable errors
caused by missing parent nodes for these input types, so we add a more informative error and test for it here.
"""
metric_node = MetricGroupByItemResolutionNode(
metric_node = MetricGroupByItemResolutionNode.create(
metric_reference=MetricReference(element_name="bad_metric"), metric_input_location=None, parent_nodes=tuple()
)
resolution_dag = GroupByItemResolutionDag(sink_node=metric_node)
Expand Down
74 changes: 39 additions & 35 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _build_aggregated_conversion_node(

# Build unaggregated conversions source node
# Generate UUID column for conversion source to uniquely identify each row
unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode(
unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode.create(
parent_node=conversion_measure_recipe.source_node
)

Expand Down Expand Up @@ -333,10 +333,10 @@ def _build_aggregated_conversion_node(
# Build the unaggregated base measure node for computing conversions
unaggregated_base_measure_node = base_measure_recipe.source_node
if base_measure_recipe.join_targets:
unaggregated_base_measure_node = JoinOnEntitiesNode(
unaggregated_base_measure_node = JoinOnEntitiesNode.create(
left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets
)
filtered_unaggregated_base_node = FilterElementsNode(
filtered_unaggregated_base_node = FilterElementsNode.create(
parent_node=unaggregated_base_measure_node,
include_specs=group_specs_by_type(required_local_specs)
.merge(base_required_linkable_specs.as_spec_set)
Expand All @@ -347,7 +347,7 @@ def _build_aggregated_conversion_node(
# The conversion events are joined by the base events which are already time constrained. However, this could
# be still be constrained, where we adjust the time range to the window size similar to cumulative, but
# adjusted in the opposite direction.
join_conversion_node = JoinConversionEventsNode(
join_conversion_node = JoinConversionEventsNode.create(
base_node=filtered_unaggregated_base_node,
base_time_dimension_spec=base_time_dimension_spec,
conversion_node=unaggregated_conversion_measure_node,
Expand Down Expand Up @@ -377,7 +377,9 @@ def _build_aggregated_conversion_node(
)

# Combine the aggregated opportunities and conversion data sets
return CombineAggregatedOutputsNode(parent_nodes=(aggregated_base_measure_node, aggregated_conversions_node))
return CombineAggregatedOutputsNode.create(
parent_nodes=(aggregated_base_measure_node, aggregated_conversions_node)
)

def _build_conversion_metric_output_node(
self,
Expand Down Expand Up @@ -468,7 +470,7 @@ def _build_cumulative_metric_output_node(
predicate_pushdown_state=predicate_pushdown_state,
for_group_by_source_node=for_group_by_source_node,
)
return WindowReaggregationNode(
return WindowReaggregationNode.create(
parent_node=compute_metrics_node,
metric_spec=metric_spec,
order_by_spec=default_metric_time,
Expand Down Expand Up @@ -609,9 +611,11 @@ def _build_derived_metric_output_node(
)

parent_node = (
parent_nodes[0] if len(parent_nodes) == 1 else CombineAggregatedOutputsNode(parent_nodes=parent_nodes)
parent_nodes[0]
if len(parent_nodes) == 1
else CombineAggregatedOutputsNode.create(parent_nodes=parent_nodes)
)
output_node: DataflowPlanNode = ComputeMetricsNode(
output_node: DataflowPlanNode = ComputeMetricsNode.create(
parent_node=parent_node,
metric_specs=[metric_spec],
for_group_by_source_node=for_group_by_source_node,
Expand All @@ -626,7 +630,7 @@ def _build_derived_metric_output_node(
assert (
queried_agg_time_dimension_specs
), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension."
output_node = JoinToTimeSpineNode(
output_node = JoinToTimeSpineNode.create(
parent_node=output_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
Expand All @@ -637,9 +641,9 @@ def _build_derived_metric_output_node(
)

if len(metric_spec.filter_specs) > 0:
output_node = WhereConstraintNode(parent_node=output_node, where_specs=metric_spec.filter_specs)
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=metric_spec.filter_specs)
if not extraneous_linkable_specs.is_subset_of(queried_linkable_specs):
output_node = FilterElementsNode(
output_node = FilterElementsNode.create(
parent_node=output_node,
include_specs=InstanceSpecSet(metric_specs=(metric_spec,)).merge(
queried_linkable_specs.as_spec_set
Expand Down Expand Up @@ -734,7 +738,7 @@ def _build_metrics_output_node(
if len(output_nodes) == 1:
return output_nodes[0]

return CombineAggregatedOutputsNode(parent_nodes=output_nodes)
return CombineAggregatedOutputsNode.create(parent_nodes=output_nodes)

def build_plan_for_distinct_values(
self, query_spec: MetricFlowQuerySpec, optimizations: FrozenSet[DataflowPlanOptimization] = frozenset()
Expand Down Expand Up @@ -779,21 +783,21 @@ def _build_plan_for_distinct_values(

output_node = dataflow_recipe.source_node
if dataflow_recipe.join_targets:
output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets)
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode(parent_node=output_node, where_specs=query_level_filter_specs)
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs)
if query_spec.time_range_constraint:
output_node = ConstrainTimeRangeNode(
output_node = ConstrainTimeRangeNode.create(
parent_node=output_node, time_range_constraint=query_spec.time_range_constraint
)

output_node = FilterElementsNode(
output_node = FilterElementsNode.create(
parent_node=output_node, include_specs=query_spec.linkable_specs.as_spec_set, distinct=True
)

if query_spec.min_max_only:
output_node = MinMaxNode(parent_node=output_node)
output_node = MinMaxNode.create(parent_node=output_node)

sink_node = self.build_sink_node(
parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit
Expand All @@ -814,20 +818,20 @@ def build_sink_node(
pre_result_node: Optional[DataflowPlanNode] = None

if order_by_specs or limit:
pre_result_node = OrderByLimitNode(
pre_result_node = OrderByLimitNode.create(
order_by_specs=list(order_by_specs), limit=limit, parent_node=parent_node
)

if output_selection_specs:
pre_result_node = FilterElementsNode(
pre_result_node = FilterElementsNode.create(
parent_node=pre_result_node or parent_node, include_specs=output_selection_specs
)

write_result_node: DataflowPlanNode
if not output_sql_table:
write_result_node = WriteToResultDataTableNode(parent_node=pre_result_node or parent_node)
write_result_node = WriteToResultDataTableNode.create(parent_node=pre_result_node or parent_node)
else:
write_result_node = WriteToResultTableNode(
write_result_node = WriteToResultTableNode.create(
parent_node=pre_result_node or parent_node, output_sql_table=output_sql_table
)

Expand Down Expand Up @@ -1139,7 +1143,7 @@ def build_computed_metrics_node(
for_group_by_source_node: bool = False,
) -> ComputeMetricsNode:
"""Builds a ComputeMetricsNode from aggregated measures."""
return ComputeMetricsNode(
return ComputeMetricsNode.create(
parent_node=aggregated_measures_node,
metric_specs=[metric_spec],
for_group_by_source_node=for_group_by_source_node,
Expand Down Expand Up @@ -1449,7 +1453,7 @@ def _build_aggregated_measure_from_measure_source_node(
# Otherwise, the measure will be aggregated over all time.
time_range_node: Optional[JoinOverTimeRangeNode] = None
if cumulative and queried_agg_time_dimension_specs:
time_range_node = JoinOverTimeRangeNode(
time_range_node = JoinOverTimeRangeNode.create(
parent_node=measure_recipe.source_node,
queried_agg_time_dimension_specs=tuple(queried_agg_time_dimension_specs),
window=cumulative_window,
Expand All @@ -1476,7 +1480,7 @@ def _build_aggregated_measure_from_measure_source_node(
)
# This also uses the original time range constraint due to the application of the time window intervals
# in join rendering
join_to_time_spine_node = JoinToTimeSpineNode(
join_to_time_spine_node = JoinToTimeSpineNode.create(
parent_node=time_range_node or measure_recipe.source_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
Expand All @@ -1487,7 +1491,7 @@ def _build_aggregated_measure_from_measure_source_node(
)

# Only get the required measure and the local linkable instances so that aggregations work correctly.
filtered_measure_source_node = FilterElementsNode(
filtered_measure_source_node = FilterElementsNode.create(
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node,
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
group_specs_by_type(measure_recipe.required_local_linkable_specs),
Expand All @@ -1497,7 +1501,7 @@ def _build_aggregated_measure_from_measure_source_node(
join_targets = measure_recipe.join_targets
unaggregated_measure_node: DataflowPlanNode
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinOnEntitiesNode(
filtered_measures_with_joined_elements = JoinOnEntitiesNode.create(
left_node=filtered_measure_source_node,
join_targets=join_targets,
)
Expand All @@ -1506,7 +1510,7 @@ def _build_aggregated_measure_from_measure_source_node(
required_linkable_specs.as_spec_set,
)

after_join_filtered_node = FilterElementsNode(
after_join_filtered_node = FilterElementsNode.create(
parent_node=filtered_measures_with_joined_elements, include_specs=specs_to_keep_after_join
)
unaggregated_measure_node = after_join_filtered_node
Expand All @@ -1524,14 +1528,14 @@ def _build_aggregated_measure_from_measure_source_node(
assert (
queried_linkable_specs.contains_metric_time
), "Using time constraints currently requires querying with metric_time."
cumulative_metric_constrained_node = ConstrainTimeRangeNode(
cumulative_metric_constrained_node = ConstrainTimeRangeNode.create(
unaggregated_measure_node, predicate_pushdown_state.time_range_constraint
)

pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node
if len(metric_input_measure_spec.filter_specs) > 0:
# Apply where constraint on the node
pre_aggregate_node = WhereConstraintNode(
pre_aggregate_node = WhereConstraintNode.create(
parent_node=pre_aggregate_node,
where_specs=metric_input_measure_spec.filter_specs,
)
Expand All @@ -1550,7 +1554,7 @@ def _build_aggregated_measure_from_measure_source_node(
window_groupings = tuple(
LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings
)
pre_aggregate_node = SemiAdditiveJoinNode(
pre_aggregate_node = SemiAdditiveJoinNode.create(
parent_node=pre_aggregate_node,
entity_specs=window_groupings,
time_dimension_spec=time_dimension_spec,
Expand All @@ -1564,12 +1568,12 @@ def _build_aggregated_measure_from_measure_source_node(
# show up in the final result.
#
# e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results.
pre_aggregate_node = FilterElementsNode(
pre_aggregate_node = FilterElementsNode.create(
parent_node=pre_aggregate_node,
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set),
)

aggregate_measures_node = AggregateMeasuresNode(
aggregate_measures_node = AggregateMeasuresNode.create(
parent_node=pre_aggregate_node,
metric_input_measure_specs=(metric_input_measure_spec,),
)
Expand All @@ -1583,7 +1587,7 @@ def _build_aggregated_measure_from_measure_source_node(
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
output_node: DataflowPlanNode = JoinToTimeSpineNode(
output_node: DataflowPlanNode = JoinToTimeSpineNode.create(
parent_node=aggregate_measures_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
Expand All @@ -1602,14 +1606,14 @@ def _build_aggregated_measure_from_measure_source_node(
if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple))
]
if len(queried_filter_specs) > 0:
output_node = WhereConstraintNode(
output_node = WhereConstraintNode.create(
parent_node=output_node, where_specs=queried_filter_specs, always_apply=True
)

# TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time.
# To fix when enabling time range constraints for agg_time_dimension.
if queried_agg_time_dimension_specs and predicate_pushdown_state.time_range_constraint is not None:
output_node = ConstrainTimeRangeNode(
output_node = ConstrainTimeRangeNode.create(
parent_node=output_node, time_range_constraint=predicate_pushdown_state.time_range_constraint
)
return output_node
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def join_description(self) -> JoinDescription:
]
)

filtered_node_to_join = FilterElementsNode(
filtered_node_to_join = FilterElementsNode.create(
parent_node=self.node_to_join, include_specs=group_specs_by_type(include_specs)
)

Expand Down
8 changes: 4 additions & 4 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__( # noqa: D107
time_spine_source = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest)
time_spine_data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source)
time_dim_reference = TimeDimensionReference(element_name=time_spine_source.time_column_name)
self._time_spine_source_node = MetricTimeDimensionTransformNode(
parent_node=ReadSqlSourceNode(data_set=time_spine_data_set),
self._time_spine_source_node = MetricTimeDimensionTransformNode.create(
parent_node=ReadSqlSourceNode.create(data_set=time_spine_data_set),
aggregation_time_dimension_reference=time_dim_reference,
)
self._query_parser = MetricFlowQueryParser(semantic_manifest_lookup)
Expand All @@ -71,7 +71,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So
source_nodes_for_metric_queries: List[DataflowPlanNode] = []

for data_set in data_sets:
read_node = ReadSqlSourceNode(data_set)
read_node = ReadSqlSourceNode.create(data_set)
group_by_item_source_nodes.append(read_node)
agg_time_dim_to_measures_grouper = (
self._semantic_manifest_lookup.semantic_model_lookup.get_aggregation_time_dimensions_with_measures(
Expand All @@ -86,7 +86,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So
else:
# Splits the measures by distinct aggregate time dimension.
for time_dimension_reference in time_dimension_references:
metric_time_transform_node = MetricTimeDimensionTransformNode(
metric_time_transform_node = MetricTimeDimensionTransformNode.create(
parent_node=read_node,
aggregation_time_dimension_reference=time_dimension_reference,
)
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _push_down_where_filters(
optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state)
if len(filters_to_apply) > 0:
return OptimizeBranchResult(
optimized_branch=WhereConstraintNode(
optimized_branch=WhereConstraintNode.create(
parent_node=optimized_node.optimized_branch, where_specs=filters_to_apply
)
)
Expand Down Expand Up @@ -397,7 +397,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
)
elif len(filter_specs_to_apply) > 0:
optimized_node = OptimizeBranchResult(
optimized_branch=WhereConstraintNode(
optimized_branch=WhereConstraintNode.create(
parent_node=optimized_parent.optimized_branch, where_specs=filter_specs_to_apply
)
)
Expand Down
Loading

0 comments on commit b68a0fd

Please sign in to comment.