From 09049a733bc3a37d1f3e8103d5ed2e7d9ef3937e Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Wed, 10 Jul 2024 13:11:43 -0700 Subject: [PATCH] Update node initialization callsites to use `.create()`. --- .../query/query_resolver.py | 2 +- .../test_matching_item_for_querying.py | 2 +- .../dataflow/builder/dataflow_plan_builder.py | 74 ++--- metricflow/dataflow/builder/node_evaluator.py | 2 +- metricflow/dataflow/builder/source_node.py | 8 +- .../optimizer/predicate_pushdown_optimizer.py | 4 +- .../source_scan/cm_branch_combiner.py | 6 +- .../source_scan/source_scan_optimizer.py | 4 +- metricflow/dataset/convert_semantic_model.py | 20 +- metricflow/execution/dataflow_to_execution.py | 14 +- metricflow/plan_conversion/dataflow_to_sql.py | 122 ++++---- .../plan_conversion/instance_converters.py | 12 +- metricflow/plan_conversion/node_processor.py | 10 +- .../sql_expression_builders.py | 6 +- .../plan_conversion/sql_join_builder.py | 64 +++-- metricflow/sql/optimizer/column_pruner.py | 10 +- .../optimizer/rewriting_sub_query_reducer.py | 28 +- metricflow/sql/optimizer/sub_query_reducer.py | 12 +- .../sql/optimizer/table_alias_simplifier.py | 8 +- .../data_warehouse_model_validator.py | 14 +- scripts/ci_tests/metricflow_package_test.py | 2 +- .../dataflow/builder/test_node_data_set.py | 14 +- .../source_scan/test_cm_branch_combiner.py | 6 +- tests_metricflow/examples/test_node_sql.py | 6 +- tests_metricflow/execution/noop_task.py | 39 +-- .../execution/test_sequential_executor.py | 16 +- tests_metricflow/execution/test_tasks.py | 11 +- .../fixtures/manifest_fixtures.py | 2 +- .../integration/test_configured_cases.py | 22 +- .../mf_logging/test_dag_to_text.py | 6 +- .../test_metric_time_dimension_to_sql.py | 4 +- .../test_dataflow_to_sql_plan.py | 120 ++++---- .../sql/optimizer/test_column_pruner.py | 194 +++++++------ .../test_rewriting_sub_query_reducer.py | 272 ++++++++++-------- .../sql/optimizer/test_sub_query_reducer.py | 72 +++-- .../optimizer/test_table_alias_simplifier.py | 32 ++- .../sql/test_engine_specific_rendering.py | 56 ++-- tests_metricflow/sql/test_sql_expr_render.py | 115 ++++---- tests_metricflow/sql/test_sql_plan_render.py | 104 ++++--- .../sql_clients/test_date_time_operations.py | 8 +- 40 files changed, 802 insertions(+), 721 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index a1517b5462..d91ed9c17b 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -381,7 +381,7 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met # Define a resolution path for issues where the input is considered to be the whole query. query_resolution_path = MetricFlowQueryResolutionPath.from_path_item( - QueryGroupByItemResolutionNode( + QueryGroupByItemResolutionNode.create( parent_nodes=(), metrics_in_query=tuple(metric_input.spec_pattern.metric_reference for metric_input in metric_inputs), where_filter_intersection=query_level_filter_input.where_filter_intersection, diff --git a/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py b/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py index 41c0cf6afc..7b9aa5a429 100644 --- a/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py +++ b/metricflow-semantics/tests_metricflow_semantics/query/group_by_item/test_matching_item_for_querying.py @@ -131,7 +131,7 @@ def test_missing_parent_for_metric( or measures). However, in the event of a validation gap upstream, we sometimes encounter inscrutable errors caused by missing parent nodes for these input types, so we add a more informative error and test for it here. """ - metric_node = MetricGroupByItemResolutionNode( + metric_node = MetricGroupByItemResolutionNode.create( metric_reference=MetricReference(element_name="bad_metric"), metric_input_location=None, parent_nodes=tuple() ) resolution_dag = GroupByItemResolutionDag(sink_node=metric_node) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 096d63edbf..7cea91608e 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -301,7 +301,7 @@ def _build_aggregated_conversion_node( # Build unaggregated conversions source node # Generate UUID column for conversion source to uniquely identify each row - unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode( + unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode.create( parent_node=conversion_measure_recipe.source_node ) @@ -333,10 +333,10 @@ def _build_aggregated_conversion_node( # Build the unaggregated base measure node for computing conversions unaggregated_base_measure_node = base_measure_recipe.source_node if base_measure_recipe.join_targets: - unaggregated_base_measure_node = JoinOnEntitiesNode( + unaggregated_base_measure_node = JoinOnEntitiesNode.create( left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets ) - filtered_unaggregated_base_node = FilterElementsNode( + filtered_unaggregated_base_node = FilterElementsNode.create( parent_node=unaggregated_base_measure_node, include_specs=group_specs_by_type(required_local_specs) .merge(base_required_linkable_specs.as_spec_set) @@ -347,7 +347,7 @@ def _build_aggregated_conversion_node( # The conversion events are joined by the base events which are already time constrained. However, this could # be still be constrained, where we adjust the time range to the window size similar to cumulative, but # adjusted in the opposite direction. - join_conversion_node = JoinConversionEventsNode( + join_conversion_node = JoinConversionEventsNode.create( base_node=filtered_unaggregated_base_node, base_time_dimension_spec=base_time_dimension_spec, conversion_node=unaggregated_conversion_measure_node, @@ -377,7 +377,9 @@ def _build_aggregated_conversion_node( ) # Combine the aggregated opportunities and conversion data sets - return CombineAggregatedOutputsNode(parent_nodes=(aggregated_base_measure_node, aggregated_conversions_node)) + return CombineAggregatedOutputsNode.create( + parent_nodes=(aggregated_base_measure_node, aggregated_conversions_node) + ) def _build_conversion_metric_output_node( self, @@ -468,7 +470,7 @@ def _build_cumulative_metric_output_node( predicate_pushdown_state=predicate_pushdown_state, for_group_by_source_node=for_group_by_source_node, ) - return WindowReaggregationNode( + return WindowReaggregationNode.create( parent_node=compute_metrics_node, metric_spec=metric_spec, order_by_spec=default_metric_time, @@ -609,9 +611,11 @@ def _build_derived_metric_output_node( ) parent_node = ( - parent_nodes[0] if len(parent_nodes) == 1 else CombineAggregatedOutputsNode(parent_nodes=parent_nodes) + parent_nodes[0] + if len(parent_nodes) == 1 + else CombineAggregatedOutputsNode.create(parent_nodes=parent_nodes) ) - output_node: DataflowPlanNode = ComputeMetricsNode( + output_node: DataflowPlanNode = ComputeMetricsNode.create( parent_node=parent_node, metric_specs=[metric_spec], for_group_by_source_node=for_group_by_source_node, @@ -626,7 +630,7 @@ def _build_derived_metric_output_node( assert ( queried_agg_time_dimension_specs ), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension." - output_node = JoinToTimeSpineNode( + output_node = JoinToTimeSpineNode.create( parent_node=output_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, @@ -637,9 +641,9 @@ def _build_derived_metric_output_node( ) if len(metric_spec.filter_specs) > 0: - output_node = WhereConstraintNode(parent_node=output_node, where_specs=metric_spec.filter_specs) + output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=metric_spec.filter_specs) if not extraneous_linkable_specs.is_subset_of(queried_linkable_specs): - output_node = FilterElementsNode( + output_node = FilterElementsNode.create( parent_node=output_node, include_specs=InstanceSpecSet(metric_specs=(metric_spec,)).merge( queried_linkable_specs.as_spec_set @@ -734,7 +738,7 @@ def _build_metrics_output_node( if len(output_nodes) == 1: return output_nodes[0] - return CombineAggregatedOutputsNode(parent_nodes=output_nodes) + return CombineAggregatedOutputsNode.create(parent_nodes=output_nodes) def build_plan_for_distinct_values( self, query_spec: MetricFlowQuerySpec, optimizations: FrozenSet[DataflowPlanOptimization] = frozenset() @@ -779,21 +783,21 @@ def _build_plan_for_distinct_values( output_node = dataflow_recipe.source_node if dataflow_recipe.join_targets: - output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets) + output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets) if len(query_level_filter_specs) > 0: - output_node = WhereConstraintNode(parent_node=output_node, where_specs=query_level_filter_specs) + output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs) if query_spec.time_range_constraint: - output_node = ConstrainTimeRangeNode( + output_node = ConstrainTimeRangeNode.create( parent_node=output_node, time_range_constraint=query_spec.time_range_constraint ) - output_node = FilterElementsNode( + output_node = FilterElementsNode.create( parent_node=output_node, include_specs=query_spec.linkable_specs.as_spec_set, distinct=True ) if query_spec.min_max_only: - output_node = MinMaxNode(parent_node=output_node) + output_node = MinMaxNode.create(parent_node=output_node) sink_node = self.build_sink_node( parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit @@ -814,20 +818,20 @@ def build_sink_node( pre_result_node: Optional[DataflowPlanNode] = None if order_by_specs or limit: - pre_result_node = OrderByLimitNode( + pre_result_node = OrderByLimitNode.create( order_by_specs=list(order_by_specs), limit=limit, parent_node=parent_node ) if output_selection_specs: - pre_result_node = FilterElementsNode( + pre_result_node = FilterElementsNode.create( parent_node=pre_result_node or parent_node, include_specs=output_selection_specs ) write_result_node: DataflowPlanNode if not output_sql_table: - write_result_node = WriteToResultDataTableNode(parent_node=pre_result_node or parent_node) + write_result_node = WriteToResultDataTableNode.create(parent_node=pre_result_node or parent_node) else: - write_result_node = WriteToResultTableNode( + write_result_node = WriteToResultTableNode.create( parent_node=pre_result_node or parent_node, output_sql_table=output_sql_table ) @@ -1139,7 +1143,7 @@ def build_computed_metrics_node( for_group_by_source_node: bool = False, ) -> ComputeMetricsNode: """Builds a ComputeMetricsNode from aggregated measures.""" - return ComputeMetricsNode( + return ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], for_group_by_source_node=for_group_by_source_node, @@ -1449,7 +1453,7 @@ def _build_aggregated_measure_from_measure_source_node( # Otherwise, the measure will be aggregated over all time. time_range_node: Optional[JoinOverTimeRangeNode] = None if cumulative and queried_agg_time_dimension_specs: - time_range_node = JoinOverTimeRangeNode( + time_range_node = JoinOverTimeRangeNode.create( parent_node=measure_recipe.source_node, queried_agg_time_dimension_specs=tuple(queried_agg_time_dimension_specs), window=cumulative_window, @@ -1476,7 +1480,7 @@ def _build_aggregated_measure_from_measure_source_node( ) # This also uses the original time range constraint due to the application of the time window intervals # in join rendering - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=time_range_node or measure_recipe.source_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, @@ -1487,7 +1491,7 @@ def _build_aggregated_measure_from_measure_source_node( ) # Only get the required measure and the local linkable instances so that aggregations work correctly. - filtered_measure_source_node = FilterElementsNode( + filtered_measure_source_node = FilterElementsNode.create( parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge( group_specs_by_type(measure_recipe.required_local_linkable_specs), @@ -1497,7 +1501,7 @@ def _build_aggregated_measure_from_measure_source_node( join_targets = measure_recipe.join_targets unaggregated_measure_node: DataflowPlanNode if len(join_targets) > 0: - filtered_measures_with_joined_elements = JoinOnEntitiesNode( + filtered_measures_with_joined_elements = JoinOnEntitiesNode.create( left_node=filtered_measure_source_node, join_targets=join_targets, ) @@ -1506,7 +1510,7 @@ def _build_aggregated_measure_from_measure_source_node( required_linkable_specs.as_spec_set, ) - after_join_filtered_node = FilterElementsNode( + after_join_filtered_node = FilterElementsNode.create( parent_node=filtered_measures_with_joined_elements, include_specs=specs_to_keep_after_join ) unaggregated_measure_node = after_join_filtered_node @@ -1524,14 +1528,14 @@ def _build_aggregated_measure_from_measure_source_node( assert ( queried_linkable_specs.contains_metric_time ), "Using time constraints currently requires querying with metric_time." - cumulative_metric_constrained_node = ConstrainTimeRangeNode( + cumulative_metric_constrained_node = ConstrainTimeRangeNode.create( unaggregated_measure_node, predicate_pushdown_state.time_range_constraint ) pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node if len(metric_input_measure_spec.filter_specs) > 0: # Apply where constraint on the node - pre_aggregate_node = WhereConstraintNode( + pre_aggregate_node = WhereConstraintNode.create( parent_node=pre_aggregate_node, where_specs=metric_input_measure_spec.filter_specs, ) @@ -1550,7 +1554,7 @@ def _build_aggregated_measure_from_measure_source_node( window_groupings = tuple( LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings ) - pre_aggregate_node = SemiAdditiveJoinNode( + pre_aggregate_node = SemiAdditiveJoinNode.create( parent_node=pre_aggregate_node, entity_specs=window_groupings, time_dimension_spec=time_dimension_spec, @@ -1564,12 +1568,12 @@ def _build_aggregated_measure_from_measure_source_node( # show up in the final result. # # e.g. for "bookings" by "ds" where "is_instant", "is_instant" should not be in the results. - pre_aggregate_node = FilterElementsNode( + pre_aggregate_node = FilterElementsNode.create( parent_node=pre_aggregate_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(queried_linkable_specs.as_spec_set), ) - aggregate_measures_node = AggregateMeasuresNode( + aggregate_measures_node = AggregateMeasuresNode.create( parent_node=pre_aggregate_node, metric_input_measure_specs=(metric_input_measure_spec,), ) @@ -1583,7 +1587,7 @@ def _build_aggregated_measure_from_measure_source_node( f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if " f"there's a new use case." ) - output_node: DataflowPlanNode = JoinToTimeSpineNode( + output_node: DataflowPlanNode = JoinToTimeSpineNode.create( parent_node=aggregate_measures_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, @@ -1602,14 +1606,14 @@ def _build_aggregated_measure_from_measure_source_node( if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple)) ] if len(queried_filter_specs) > 0: - output_node = WhereConstraintNode( + output_node = WhereConstraintNode.create( parent_node=output_node, where_specs=queried_filter_specs, always_apply=True ) # TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time. # To fix when enabling time range constraints for agg_time_dimension. if queried_agg_time_dimension_specs and predicate_pushdown_state.time_range_constraint is not None: - output_node = ConstrainTimeRangeNode( + output_node = ConstrainTimeRangeNode.create( parent_node=output_node, time_range_constraint=predicate_pushdown_state.time_range_constraint ) return output_node diff --git a/metricflow/dataflow/builder/node_evaluator.py b/metricflow/dataflow/builder/node_evaluator.py index 262d26b854..971425eee6 100644 --- a/metricflow/dataflow/builder/node_evaluator.py +++ b/metricflow/dataflow/builder/node_evaluator.py @@ -122,7 +122,7 @@ def join_description(self) -> JoinDescription: ] ) - filtered_node_to_join = FilterElementsNode( + filtered_node_to_join = FilterElementsNode.create( parent_node=self.node_to_join, include_specs=group_specs_by_type(include_specs) ) diff --git a/metricflow/dataflow/builder/source_node.py b/metricflow/dataflow/builder/source_node.py index 7b23d58cbf..0054f57bca 100644 --- a/metricflow/dataflow/builder/source_node.py +++ b/metricflow/dataflow/builder/source_node.py @@ -59,8 +59,8 @@ def __init__( # noqa: D107 time_spine_source = TimeSpineSource.create_from_manifest(semantic_manifest_lookup.semantic_manifest) time_spine_data_set = data_set_converter.build_time_spine_source_data_set(time_spine_source) time_dim_reference = TimeDimensionReference(element_name=time_spine_source.time_column_name) - self._time_spine_source_node = MetricTimeDimensionTransformNode( - parent_node=ReadSqlSourceNode(data_set=time_spine_data_set), + self._time_spine_source_node = MetricTimeDimensionTransformNode.create( + parent_node=ReadSqlSourceNode.create(data_set=time_spine_data_set), aggregation_time_dimension_reference=time_dim_reference, ) self._query_parser = MetricFlowQueryParser(semantic_manifest_lookup) @@ -71,7 +71,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So source_nodes_for_metric_queries: List[DataflowPlanNode] = [] for data_set in data_sets: - read_node = ReadSqlSourceNode(data_set) + read_node = ReadSqlSourceNode.create(data_set) group_by_item_source_nodes.append(read_node) agg_time_dim_to_measures_grouper = ( self._semantic_manifest_lookup.semantic_model_lookup.get_aggregation_time_dimensions_with_measures( @@ -86,7 +86,7 @@ def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> So else: # Splits the measures by distinct aggregate time dimension. for time_dimension_reference in time_dimension_references: - metric_time_transform_node = MetricTimeDimensionTransformNode( + metric_time_transform_node = MetricTimeDimensionTransformNode.create( parent_node=read_node, aggregation_time_dimension_reference=time_dimension_reference, ) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index aed67dea7e..41ad2c2acb 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -312,7 +312,7 @@ def _push_down_where_filters( optimized_node = self._default_handler(node=node, pushdown_state=updated_pushdown_state) if len(filters_to_apply) > 0: return OptimizeBranchResult( - optimized_branch=WhereConstraintNode( + optimized_branch=WhereConstraintNode.create( parent_node=optimized_node.optimized_branch, where_specs=filters_to_apply ) ) @@ -397,7 +397,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran ) elif len(filter_specs_to_apply) > 0: optimized_node = OptimizeBranchResult( - optimized_branch=WhereConstraintNode( + optimized_branch=WhereConstraintNode.create( parent_node=optimized_parent.optimized_branch, where_specs=filter_specs_to_apply ) ) diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 1124f00498..732b41452c 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -259,7 +259,7 @@ def visit_aggregate_measures_node( # noqa: D102 ) return ComputeMetricsBranchCombinerResult() - combined_node = AggregateMeasuresNode( + combined_node = AggregateMeasuresNode.create( parent_node=combined_parent_node, metric_input_measure_specs=combined_metric_input_measure_specs, ) @@ -305,7 +305,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ComputeMetrics if metric_spec not in unique_metric_specs: unique_metric_specs.append(metric_spec) - combined_node = ComputeMetricsNode( + combined_node = ComputeMetricsNode.create( parent_node=combined_parent_node, metric_specs=unique_metric_specs, aggregated_to_elements=current_right_node.aggregated_to_elements, @@ -389,7 +389,7 @@ def visit_filter_elements_node(self, node: FilterElementsNode) -> ComputeMetrics # De-dupe so that we don't see the same spec twice in include specs. For example, this can happen with dimension # specs since any branch that is merged together needs to output the same set of dimensions. - combined_node = FilterElementsNode( + combined_node = FilterElementsNode.create( parent_node=combined_parent_node, include_specs=self._current_left_node.include_specs.merge(current_right_node.include_specs).dedupe(), ) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index e7f87aa992..5fa9fc4602 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -148,7 +148,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranch optimized_parent_result: OptimizeBranchResult = node.parent_node.accept(self) if optimized_parent_result.optimized_branch is not None: return OptimizeBranchResult( - optimized_branch=ComputeMetricsNode( + optimized_branch=ComputeMetricsNode.create( parent_node=optimized_parent_result.optimized_branch, metric_specs=node.metric_specs, for_group_by_source_node=node.for_group_by_source_node, @@ -264,7 +264,7 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 return OptimizeBranchResult(optimized_branch=combined_parent_branches[0]) return OptimizeBranchResult( - optimized_branch=CombineAggregatedOutputsNode(parent_nodes=combined_parent_branches) + optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches) ) def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D102 diff --git a/metricflow/dataset/convert_semantic_model.py b/metricflow/dataset/convert_semantic_model.py index 670e768827..2c76394129 100644 --- a/metricflow/dataset/convert_semantic_model.py +++ b/metricflow/dataset/convert_semantic_model.py @@ -169,15 +169,15 @@ def _make_element_sql_expr( "FALSE", "NULL", ): - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias, column_name=element_expr, ) ) - return SqlStringExpression(sql_expr=element_expr) + return SqlStringExpression.create(sql_expr=element_expr) - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias, column_name=element_name, @@ -368,7 +368,7 @@ def _build_time_dimension_instances_and_columns( select_columns.append( SqlSelectColumn( - expr=SqlExtractExpression(date_part=date_part, arg=dimension_select_expr), + expr=SqlExtractExpression.create(date_part=date_part, arg=dimension_select_expr), column_alias=time_dimension_instance.associated_column.column_name, ) ) @@ -379,7 +379,7 @@ def _build_column_for_time_granularity( self, time_granularity: TimeGranularity, expr: SqlExpressionNode, column_alias: str ) -> SqlSelectColumn: return SqlSelectColumn( - expr=SqlDateTruncExpression(time_granularity=time_granularity, arg=expr), column_alias=column_alias + expr=SqlDateTruncExpression.create(time_granularity=time_granularity, arg=expr), column_alias=column_alias ) def _create_entity_instances( @@ -493,9 +493,11 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM all_select_columns.extend(select_columns) # Generate the "from" clause depending on whether it's an SQL query or an SQL table. - from_source = SqlTableFromClauseNode(sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name)) + from_source = SqlTableFromClauseNode.create( + sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name) + ) - select_statement_node = SqlSelectStatementNode( + select_statement_node = SqlSelectStatementNode.create( description=f"Read Elements From Semantic Model '{semantic_model.name}'", select_columns=tuple(all_select_columns), from_source=from_source, @@ -549,10 +551,10 @@ def build_time_spine_source_data_set(self, time_spine_source: TimeSpineSource) - return SqlDataSet( instance_set=InstanceSet(time_dimension_instances=tuple(time_dimension_instances)), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=TIME_SPINE_DATA_SET_DESCRIPTION, select_columns=tuple(select_columns), - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), + from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table), from_source_alias=from_source_alias, ), ) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index b553590465..4df6cc0f21 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -33,6 +33,7 @@ ExecutionPlan, SelectSqlQueryToDataTableTask, SelectSqlQueryToTableTask, + SqlQuery, ) from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter @@ -80,10 +81,9 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode render_sql_result = self._render_sql(convert_to_sql_plan_result) execution_plan = ExecutionPlan( leaf_tasks=( - SelectSqlQueryToDataTableTask( + SelectSqlQueryToDataTableTask.create( sql_client=self._sql_client, - sql_query=render_sql_result.sql, - bind_parameters=render_sql_result.bind_parameters, + sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameters), ), ) ) @@ -99,10 +99,12 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv render_sql_result = self._render_sql(convert_to_sql_plan_result) execution_plan = ExecutionPlan( leaf_tasks=( - SelectSqlQueryToTableTask( + SelectSqlQueryToTableTask.create( sql_client=self._sql_client, - sql_query=render_sql_result.sql, - bind_parameters=render_sql_result.bind_parameters, + sql_query=SqlQuery( + sql_query=render_sql_result.sql, + bind_parameters=render_sql_result.bind_parameters, + ), output_table=node.output_sql_table, ), ), diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index c3ad1d542f..ad2d989989 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -150,17 +150,17 @@ def _make_time_range_comparison_expr( ) -> SqlExpressionNode: """Build an expression like "ds BETWEEN CAST('2020-01-01' AS TIMESTAMP) AND CAST('2020-01-02' AS TIMESTAMP).""" # TODO: Update when adding < day granularity support. - return SqlBetweenExpression( - column_arg=SqlColumnReferenceExpression( + return SqlBetweenExpression.create( + column_arg=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias, column_name=column_alias, ) ), - start_expr=SqlStringLiteralExpression( + start_expr=SqlStringLiteralExpression.create( literal_value=time_range_constraint.start_time.strftime(ISO8601_PYTHON_FORMAT), ), - end_expr=SqlStringLiteralExpression( + end_expr=SqlStringLiteralExpression.create( literal_value=time_range_constraint.end_time.strftime(ISO8601_PYTHON_FORMAT), ), ) @@ -254,7 +254,7 @@ def _make_time_spine_data_set( else: select_columns += ( SqlSelectColumn( - expr=SqlDateTruncExpression( + expr=SqlDateTruncExpression.create( time_granularity=agg_time_dimension_instance.spec.time_granularity, arg=column_expr ), column_alias=column_alias, @@ -264,10 +264,10 @@ def _make_time_spine_data_set( return SqlDataSet( instance_set=time_spine_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=TIME_SPINE_DATA_SET_DESCRIPTION, select_columns=select_columns, - from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table), + from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table), from_source_alias=time_spine_table_alias, group_bys=select_columns if apply_group_by else (), where=( @@ -353,14 +353,14 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=create_select_columns_for_instance_sets( self._column_association_resolver, table_alias_to_instance_set ), from_source=time_spine_data_set.checked_sql_select_node, from_source_alias=time_spine_data_set_alias, - joins_descs=(join_desc,), + join_descs=(join_desc,), ), ) @@ -443,14 +443,14 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet: # clauses. return SqlDataSet( instance_set=InstanceSet.merge(list(table_alias_to_instance_set.values())), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=create_select_columns_for_instance_sets( self._column_association_resolver, table_alias_to_instance_set ), from_source=from_data_set.checked_sql_select_node, from_source_alias=from_data_set_alias, - joins_descs=tuple(sql_join_descs), + join_descs=tuple(sql_join_descs), ), ) @@ -518,7 +518,7 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> SqlDataS return SqlDataSet( instance_set=aggregated_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This will generate expressions with the appropriate aggregation functions e.g. SUM() select_columns=select_column_set.as_tuple(), @@ -576,14 +576,14 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: MetricSpec.from_reference(denominator.post_aggregation_reference) ).column_name - metric_expr = SqlRatioComputationExpression( - numerator=SqlColumnReferenceExpression( + metric_expr = SqlRatioComputationExpression.create( + numerator=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=numerator_column_name, ) ), - denominator=SqlColumnReferenceExpression( + denominator=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=denominator_column_name, @@ -619,7 +619,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: assert ( metric.type_params.expr ), "Derived metrics are required to have an `expr` in their YAML definition." - metric_expr = SqlStringExpression(sql_expr=metric.type_params.expr) + metric_expr = SqlStringExpression.create(sql_expr=metric.type_params.expr) elif metric.type == MetricType.CONVERSION: conversion_type_params = metric.type_params.conversion_type_params assert ( @@ -635,20 +635,20 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: ).column_name calculation_type = conversion_type_params.calculation - conversion_column_reference = SqlColumnReferenceExpression( + conversion_column_reference = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=conversion_measure_column, ) ) - base_column_reference = SqlColumnReferenceExpression( + base_column_reference = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=from_data_set_alias, column_name=base_measure_column, ) ) if calculation_type == ConversionCalculationType.CONVERSION_RATE: - metric_expr = SqlRatioComputationExpression( + metric_expr = SqlRatioComputationExpression.create( numerator=conversion_column_reference, denominator=base_column_reference, ) @@ -699,7 +699,7 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=combined_select_column_set.as_tuple(), from_source=from_data_set.checked_sql_select_node, @@ -711,14 +711,14 @@ def __make_col_reference_or_coalesce_expr( self, column_name: str, input_measure: Optional[MetricInputMeasure], from_data_set_alias: str ) -> SqlExpressionNode: # Use a column reference to improve query optimization. - metric_expr: SqlExpressionNode = SqlColumnReferenceExpression( + metric_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=from_data_set_alias, column_name=column_name) ) # Coalesce nulls to requested integer value, if requested. if input_measure and input_measure.fill_nulls_with is not None: - metric_expr = SqlAggregateFunctionExpression( + metric_expr = SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COALESCE, - sql_function_args=[metric_expr, SqlStringExpression(str(input_measure.fill_nulls_with))], + sql_function_args=[metric_expr, SqlStringExpression.create(str(input_measure.fill_nulls_with))], ) return metric_expr @@ -734,7 +734,7 @@ def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: # no for order_by_spec in node.order_by_specs: order_by_descriptions.append( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=from_data_set_alias, column_name=self._column_association_resolver.resolve_spec( @@ -748,7 +748,7 @@ def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: # no return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=output_instance_set.transform( @@ -770,7 +770,7 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlD input_instance_set: InstanceSet = input_data_set.instance_set return SqlDataSet( instance_set=input_instance_set, - sql_node=SqlCreateTableAsNode( + sql_node=SqlCreateTableAsNode.create( sql_table=node.output_sql_table, parent_node=input_data_set.checked_sql_select_node, ), @@ -794,7 +794,7 @@ def visit_filter_elements_node(self, node: FilterElementsNode) -> SqlDataSet: group_bys = select_columns if node.distinct else () return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=select_columns, from_source=from_data_set.checked_sql_select_node, @@ -819,7 +819,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=output_instance_set.transform( @@ -827,7 +827,7 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: ).as_tuple(), from_source=parent_data_set.checked_sql_select_node, from_source_alias=from_data_set_alias, - where=SqlStringExpression( + where=SqlStringExpression.create( sql_expr=node.where.where_sql, used_columns=tuple( column_association.column_name for column_association in column_associations_in_where_sql @@ -940,12 +940,12 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=combined_select_column_set.as_tuple(), from_source=from_data_set.data_set.checked_sql_select_node, from_source_alias=from_data_set.alias, - joins_descs=tuple(joins_descriptions), + join_descs=tuple(joins_descriptions), group_bys=linkable_select_column_set.as_tuple(), ), ) @@ -987,7 +987,7 @@ def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> SqlDa return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=output_instance_set.transform( @@ -1073,7 +1073,7 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, # This creates select expressions for all columns referenced in the instance set. select_columns=CreateSelectColumnsForInstances( @@ -1116,7 +1116,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe time_dimension_select_column = SqlSelectColumn( expr=SqlFunctionExpression.build_expression_from_aggregation_type( aggregation_type=node.agg_by_function, - sql_column_expression=SqlColumnReferenceExpression( + sql_column_expression=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=inner_join_data_set_alias, column_name=time_dimension_column_name, @@ -1138,7 +1138,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe entity_column_name = self.column_association_resolver.resolve_spec(entity_spec).column_name entity_select_columns.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=inner_join_data_set_alias, column_name=entity_column_name, @@ -1161,7 +1161,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe node.queried_time_dimension_spec ).column_name queried_time_dimension_select_column = SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=inner_join_data_set_alias, column_name=query_time_dimension_column_name, @@ -1174,7 +1174,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe if queried_time_dimension_select_column: row_filter_group_bys += (queried_time_dimension_select_column,) # Construct SelectNode for Row filtering - row_filter_sql_select_node = SqlSelectStatementNode( + row_filter_sql_select_node = SqlSelectStatementNode.create( description=f"Filter row on {node.agg_by_function.name}({time_dimension_column_name})", select_columns=row_filter_group_bys + (time_dimension_select_column,), from_source=from_data_set.checked_sql_select_node, @@ -1192,14 +1192,14 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=output_instance_set.transform( CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver) ).as_tuple(), from_source=from_data_set.checked_sql_select_node, from_source_alias=from_data_set_alias, - joins_descs=(sql_join_desc,), + join_descs=(sql_join_desc,), ), ) @@ -1292,7 +1292,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet original_time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0] time_spine_column_select_expr: Union[ SqlColumnReferenceExpression, SqlDateTruncExpression - ] = SqlColumnReferenceExpression( + ] = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name ) @@ -1329,25 +1329,25 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet select_expr: SqlExpressionNode = ( time_spine_column_select_expr if time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity - else SqlDateTruncExpression( + else SqlDateTruncExpression.create( time_granularity=time_dimension_spec.time_granularity, arg=time_spine_column_select_expr ) ) # Filter down to one row per granularity period requested in the group by. Any other granularities # included here will be filtered out in later nodes so should not be included in where filter. if need_where_filter and time_dimension_spec in node.requested_agg_time_dimension_specs: - new_where_filter = SqlComparisonExpression( + new_where_filter = SqlComparisonExpression.create( left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr ) where_filter = ( - SqlLogicalExpression(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)) + SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)) if where_filter else new_where_filter ) # Apply date_part to time spine column select expression. if time_dimension_spec.date_part: - select_expr = SqlExtractExpression(date_part=time_dimension_spec.date_part, arg=select_expr) + select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr) time_dim_spec = TimeDimensionSpec( element_name=original_time_spine_dim_instance.spec.element_name, entity_links=original_time_spine_dim_instance.spec.entity_links, @@ -1368,12 +1368,12 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet return SqlDataSet( instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=tuple(time_spine_select_columns) + parent_select_columns, from_source=time_spine_dataset.checked_sql_select_node, from_source_alias=time_spine_alias, - joins_descs=(join_description,), + join_descs=(join_description,), where=where_filter, ), ) @@ -1395,7 +1395,7 @@ def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 SqlSelectColumn( expr=SqlFunctionExpression.build_expression_from_aggregation_type( aggregation_type=agg_type, - sql_column_expression=SqlColumnReferenceExpression( + sql_column_expression=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=parent_table_alias, column_name=parent_column_alias) ), ), @@ -1408,7 +1408,7 @@ def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 return SqlDataSet( instance_set=parent_data_set.instance_set.transform(ConvertToMetadata(metadata_instances)), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=tuple(select_columns), from_source=parent_data_set.checked_sql_select_node, @@ -1438,12 +1438,12 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) ) ) gen_uuid_sql_select_column = SqlSelectColumn( - expr=SqlGenerateUuidExpression(), column_alias=output_column_association.column_name + expr=SqlGenerateUuidExpression.create(), column_alias=output_column_association.column_name ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description="Add column with generated UUID", select_columns=input_data_set.instance_set.transform( CreateSelectColumnsForInstances(input_data_set_alias, self._column_association_resolver) @@ -1531,10 +1531,10 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ) base_sql_select_columns = tuple( SqlSelectColumn( - expr=SqlWindowFunctionExpression( + expr=SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=base_data_set_alias, column_name=base_sql_column_reference.col_ref.column_name, @@ -1542,7 +1542,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ) ], partition_by_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=conversion_data_set_alias, column_name=column, @@ -1552,7 +1552,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ], order_by_args=[ SqlWindowOrderByArgument( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=base_data_set_alias, column_name=base_time_dimension_column_name, @@ -1574,7 +1574,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S # Deduplicate the fanout results conversion_unique_key_select_columns = tuple( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=conversion_data_set_alias, column_name=column_name, @@ -1587,14 +1587,14 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S additional_conversion_select_columns = conversion_data_set_output_instance_set.transform( CreateSelectColumnsForInstances(conversion_data_set_alias, self._column_association_resolver) ).as_tuple() - deduped_sql_select_node = SqlSelectStatementNode( + deduped_sql_select_node = SqlSelectStatementNode.create( description=f"Dedupe the fanout with {','.join(spec.qualified_name for spec in node.unique_identifier_keys)} in the conversion data set", select_columns=base_sql_select_columns + conversion_unique_key_select_columns + additional_conversion_select_columns, from_source=base_data_set.checked_sql_select_node, from_source_alias=base_data_set_alias, - joins_descs=(sql_join_description,), + join_descs=(sql_join_description,), distinct=True, ) @@ -1605,7 +1605,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S ) return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=output_instance_set.transform( CreateSelectColumnsForInstances(output_data_set_alias, self._column_association_resolver) @@ -1655,7 +1655,7 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlD ) ) metric_select_column = SqlSelectColumn( - expr=SqlWindowFunctionExpression( + expr=SqlWindowFunctionExpression.create( sql_function=sql_window_function, sql_function_args=[ SqlColumnReferenceExpression.from_table_and_column_names( @@ -1687,7 +1687,7 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlD ).as_tuple() + ( metric_select_column, ) - subquery = SqlSelectStatementNode( + subquery = SqlSelectStatementNode.create( description="Window Function for Metric Re-aggregation", select_columns=subquery_select_columns, from_source=from_data_set.checked_sql_select_node, @@ -1700,7 +1700,7 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlD ).as_tuple() return SqlDataSet( instance_set=output_instance_set, - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description="Re-aggregate Metric via Group By", select_columns=outer_query_select_columns, from_source=subquery, diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 0d816209ed..c33fc5bbb3 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -162,7 +162,7 @@ def _make_sql_column_expression( input_column_name = self._output_to_input_column_mapping[output_column_name] select_columns.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference(self._table_alias, input_column_name)), + expr=SqlColumnReferenceExpression.create(SqlColumnReference(self._table_alias, input_column_name)), column_alias=output_column_name, ) ) @@ -223,7 +223,7 @@ def _make_sql_column_expression_to_aggregate_measure( measure = self._semantic_model_lookup.get_measure(measure_instance.spec.reference) aggregation_type = measure.agg - expression_to_get_measure = SqlColumnReferenceExpression( + expression_to_get_measure = SqlColumnReferenceExpression.create( SqlColumnReference(self._table_alias, column_name_in_table) ) @@ -824,7 +824,7 @@ def transform(self, instance_set: InstanceSet) -> Tuple[SqlColumnReferenceExpres self._column_resolver.resolve_spec(spec).column_name for spec in instance_set.spec_set.all_specs ] return tuple( - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=self._table_alias, column_name=column_name, @@ -854,7 +854,7 @@ def __init__( # noqa: D107 def _create_select_column(self, spec: InstanceSpec, fill_nulls_with: Optional[int] = None) -> SqlSelectColumn: """Creates the select column for the given spec and the fill value.""" column_name = self._column_resolver.resolve_spec(spec).column_name - column_reference_expression = SqlColumnReferenceExpression( + column_reference_expression = SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=self._table_alias, column_name=column_name, @@ -864,11 +864,11 @@ def _create_select_column(self, spec: InstanceSpec, fill_nulls_with: Optional[in aggregation_type=AggregationType.MAX, sql_column_expression=column_reference_expression ) if fill_nulls_with is not None: - select_expression = SqlAggregateFunctionExpression( + select_expression = SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COALESCE, sql_function_args=[ select_expression, - SqlStringExpression(str(fill_nulls_with)), + SqlStringExpression.create(str(fill_nulls_with)), ], ) return SqlSelectColumn( diff --git a/metricflow/plan_conversion/node_processor.py b/metricflow/plan_conversion/node_processor.py index 8b68e22705..57f4f1d8ea 100644 --- a/metricflow/plan_conversion/node_processor.py +++ b/metricflow/plan_conversion/node_processor.py @@ -379,7 +379,9 @@ def _add_time_range_constraint( break if constrain_time: processed_nodes.append( - ConstrainTimeRangeNode(parent_node=source_node, time_range_constraint=time_range_constraint) + ConstrainTimeRangeNode.create( + parent_node=source_node, time_range_constraint=time_range_constraint + ) ) else: processed_nodes.append(source_node) @@ -421,7 +423,7 @@ def _add_where_constraint( filtered_nodes.append(source_node) else: filtered_nodes.append( - WhereConstraintNode(parent_node=source_node, where_specs=matching_filter_specs) + WhereConstraintNode.create(parent_node=source_node, where_specs=matching_filter_specs) ) else: filtered_nodes.append(source_node) @@ -531,7 +533,7 @@ def _get_candidates_nodes_for_multi_hop( # filter measures out of joinable_node specs = data_set_of_second_node_that_can_be_joined.instance_set.spec_set - filtered_joinable_node = FilterElementsNode( + filtered_joinable_node = FilterElementsNode.create( parent_node=second_node_that_could_be_joined, include_specs=group_specs_by_type( specs.dimension_specs @@ -552,7 +554,7 @@ def _get_candidates_nodes_for_multi_hop( multi_hop_join_candidates.append( MultiHopJoinCandidate( - node_with_multi_hop_elements=JoinOnEntitiesNode( + node_with_multi_hop_elements=JoinOnEntitiesNode.create( left_node=first_node_that_could_be_joined, join_targets=[ JoinDescription( diff --git a/metricflow/plan_conversion/sql_expression_builders.py b/metricflow/plan_conversion/sql_expression_builders.py index c3567ab163..e5ed18d463 100644 --- a/metricflow/plan_conversion/sql_expression_builders.py +++ b/metricflow/plan_conversion/sql_expression_builders.py @@ -25,7 +25,7 @@ def make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> SqlE COALESCE(a.is_instant, b.is_instant) """ if len(table_aliases) == 1: - return SqlColumnReferenceExpression( + return SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=table_aliases[0], column_name=column_alias, @@ -35,14 +35,14 @@ def make_coalesced_expr(table_aliases: Sequence[str], column_alias: str) -> SqlE columns_to_coalesce: List[SqlExpressionNode] = [] for table_alias in table_aliases: columns_to_coalesce.append( - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=table_alias, column_name=column_alias, ) ) ) - return SqlAggregateFunctionExpression( + return SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COALESCE, sql_function_args=columns_to_coalesce, ) diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 7a58689b67..7bc24b7f11 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -97,30 +97,30 @@ def make_column_equality_sql_join_description( and_conditions: List[SqlExpressionNode] = [] for column_equality_description in column_equality_descriptions: - left_column = SqlColumnReferenceExpression( + left_column = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=left_source_alias, column_name=column_equality_description.left_column_alias, ) ) - right_column = SqlColumnReferenceExpression( + right_column = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=right_source_alias, column_name=column_equality_description.right_column_alias, ) ) - column_equality_expression = SqlComparisonExpression( + column_equality_expression = SqlComparisonExpression.create( left_expr=left_column, comparison=SqlComparison.EQUALS, right_expr=right_column, ) if column_equality_description.treat_nulls_as_equal: - null_comparison_expression = SqlLogicalExpression( + null_comparison_expression = SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, - args=(SqlIsNullExpression(arg=left_column), SqlIsNullExpression(arg=right_column)), + args=(SqlIsNullExpression.create(arg=left_column), SqlIsNullExpression.create(arg=right_column)), ) and_conditions.append( - SqlLogicalExpression( + SqlLogicalExpression.create( operator=SqlLogicalOperator.OR, args=(column_equality_expression, null_comparison_expression) ) ) @@ -135,7 +135,7 @@ def make_column_equality_sql_join_description( elif len(and_conditions) == 1: on_condition = and_conditions[0] else: - on_condition = SqlLogicalExpression(operator=SqlLogicalOperator.AND, args=tuple(and_conditions)) + on_condition = SqlLogicalExpression.create(operator=SqlLogicalOperator.AND, args=tuple(and_conditions)) return SqlJoinDescription( right_source=right_source_node, @@ -287,32 +287,32 @@ def _make_time_window_join_condition( {start_dimension_name} >= metric_time AND ({end_dimension_name} < metric_time OR {end_dimension_name} IS NULL) """ - left_time_column_expr = SqlColumnReferenceExpression( + left_time_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=left_source_alias, column_name=left_source_time_dimension_name) ) - window_start_column_expr = SqlColumnReferenceExpression( + window_start_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=right_source_alias, column_name=window_start_dimension_name) ) - window_end_column_expr = SqlColumnReferenceExpression( + window_end_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference(table_alias=right_source_alias, column_name=window_end_dimension_name) ) - window_start_condition = SqlComparisonExpression( + window_start_condition = SqlComparisonExpression.create( left_expr=left_time_column_expr, comparison=SqlComparison.GREATER_THAN_OR_EQUALS, right_expr=window_start_column_expr, ) - window_end_by_time = SqlComparisonExpression( + window_end_by_time = SqlComparisonExpression.create( left_expr=left_time_column_expr, comparison=SqlComparison.LESS_THAN, right_expr=window_end_column_expr, ) - window_end_is_null = SqlIsNullExpression(window_end_column_expr) - window_end_condition = SqlLogicalExpression( + window_end_is_null = SqlIsNullExpression.create(window_end_column_expr) + window_end_condition = SqlLogicalExpression.create( operator=SqlLogicalOperator.OR, args=(window_end_by_time, window_end_is_null) ) - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=(window_start_condition, window_end_condition) ) @@ -346,7 +346,7 @@ def make_join_description_for_combining_datasets( for colname in column_names ] on_condition = ( - SqlLogicalExpression(operator=SqlLogicalOperator.AND, args=tuple(equality_exprs)) + SqlLogicalExpression.create(operator=SqlLogicalOperator.AND, args=tuple(equality_exprs)) if len(equality_exprs) > 1 else equality_exprs[0] ) @@ -403,10 +403,10 @@ def _make_equality_expression_for_full_outer_join( The latter scenario consolidates the rows keyed by 'c' into a single entry. """ - return SqlComparisonExpression( + return SqlComparisonExpression.create( left_expr=make_coalesced_expr(table_aliases_in_coalesce, column_alias), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias=right_table_alias, column_name=column_alias, @@ -430,13 +430,13 @@ def _make_time_range_window_join_condition( """ if window or grain_to_date: assert_exactly_one_arg_set(window=window, grain_to_date=grain_to_date) - base_column_expr = SqlColumnReferenceExpression( + base_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=base_data_set.alias, column_name=base_data_set.metric_time_column_name, ) ) - time_comparison_column_expr = SqlColumnReferenceExpression( + time_comparison_column_expr = SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=time_comparison_dataset.alias, column_name=time_comparison_dataset.metric_time_column_name, @@ -445,7 +445,7 @@ def _make_time_range_window_join_condition( # Comparison expression against the endpoint of the cumulative time range, # meaning the base metrc time must always be BEFORE the comparison metric time - end_of_range_comparison_expression = SqlComparisonExpression( + end_of_range_comparison_expression = SqlComparisonExpression.create( left_expr=base_column_expr, comparison=SqlComparison.LESS_THAN_OR_EQUALS, right_expr=time_comparison_column_expr, @@ -453,10 +453,10 @@ def _make_time_range_window_join_condition( comparison_expressions: List[SqlComparisonExpression] = [end_of_range_comparison_expression] if window: - start_of_range_comparison_expr = SqlComparisonExpression( + start_of_range_comparison_expr = SqlComparisonExpression.create( left_expr=base_column_expr, comparison=SqlComparison.GREATER_THAN, - right_expr=SqlSubtractTimeIntervalExpression( + right_expr=SqlSubtractTimeIntervalExpression.create( arg=time_comparison_column_expr, count=window.count, granularity=window.granularity, @@ -464,14 +464,16 @@ def _make_time_range_window_join_condition( ) comparison_expressions.append(start_of_range_comparison_expr) elif grain_to_date: - start_of_range_comparison_expr = SqlComparisonExpression( + start_of_range_comparison_expr = SqlComparisonExpression.create( left_expr=base_column_expr, comparison=SqlComparison.GREATER_THAN_OR_EQUALS, - right_expr=SqlDateTruncExpression(arg=time_comparison_column_expr, time_granularity=grain_to_date), + right_expr=SqlDateTruncExpression.create( + arg=time_comparison_column_expr, time_granularity=grain_to_date + ), ) comparison_expressions.append(start_of_range_comparison_expr) - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=tuple(comparison_expressions), ) @@ -537,23 +539,23 @@ def make_join_to_time_spine_join_description( parent_alias: str, ) -> SqlJoinDescription: """Build join expression used to join a metric to a time spine dataset.""" - left_expr: SqlExpressionNode = SqlColumnReferenceExpression( + left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name) ) if node.offset_window: - left_expr = SqlSubtractTimeIntervalExpression( + left_expr = SqlSubtractTimeIntervalExpression.create( arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity ) elif node.offset_to_grain: - left_expr = SqlDateTruncExpression(time_granularity=node.offset_to_grain, arg=left_expr) + left_expr = SqlDateTruncExpression.create(time_granularity=node.offset_to_grain, arg=left_expr) return SqlJoinDescription( right_source=parent_sql_select_node, right_source_alias=parent_alias, - on_condition=SqlComparisonExpression( + on_condition=SqlComparisonExpression.create( left_expr=left_expr, comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias=parent_alias, column_name=agg_time_dimension_column_name) ), ), diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 75c62ca7b8..61bd283bf1 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -98,12 +98,12 @@ def _prune_columns_from_grandparents( else: pruned_join_descriptions.append(join_description) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=pruned_select_columns, from_source=pruned_from_source, from_source_alias=node.from_source_alias, - joins_descs=tuple(pruned_join_descriptions), + join_descs=tuple(pruned_join_descriptions), group_bys=node.group_bys, order_bys=node.order_bys, where=node.where, @@ -178,12 +178,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ) ) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=tuple(pruned_select_columns), from_source=pruned_from_source, from_source_alias=node.from_source_alias, - joins_descs=tuple(pruned_join_descriptions), + join_descs=tuple(pruned_join_descriptions), group_bys=node.group_bys, order_bys=node.order_bys, where=node.where, @@ -200,7 +200,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 18ffb95a99..e587b6510f 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -65,7 +65,7 @@ def combine_wheres(self, additional_where_clauses: List[SqlExpressionNode]) -> O if len(all_where_clauses) == 1: return all_where_clauses[0] elif len(all_where_clauses) > 1: - return SqlLogicalExpression( + return SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=tuple(all_where_clauses), ) @@ -93,12 +93,12 @@ def _reduce_parents( node: SqlSelectStatementNode, ) -> SqlSelectStatementNode: """Apply the reducing operation to the parent select statements.""" - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -401,7 +401,7 @@ def _rewrite_where( # For type checking. The above conditionals should ensure the below. assert node_where assert parent_node_where - return SqlLogicalExpression(operator=SqlLogicalOperator.AND, args=(node_where, parent_node_where)) + return SqlLogicalExpression.create(operator=SqlLogicalOperator.AND, args=(node_where, parent_node_where)) @staticmethod def _find_matching_select_column( @@ -568,12 +568,12 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN for x in new_join_descs ] - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=tuple(clauses_to_rewrite.select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(new_join_descs), + join_descs=tuple(new_join_descs), group_bys=tuple(clauses_to_rewrite.group_bys), order_bys=tuple(clauses_to_rewrite.order_bys), where=clauses_to_rewrite.combine_wheres(additional_where_clauses), @@ -656,7 +656,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP new_order_bys.append( SqlOrderByDescription( - expr=SqlColumnAliasReferenceExpression(column_alias=matching_select_column.column_alias), + expr=SqlColumnAliasReferenceExpression.create(column_alias=matching_select_column.column_alias), desc=order_by_item.desc, ) ) @@ -681,14 +681,14 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP elif parent_select_node.group_bys: new_group_bys = parent_select_node.group_bys - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="\n".join([parent_select_node.description, node_with_reduced_parents.description]), select_columns=SqlRewritingSubQueryReducerVisitor._rewrite_select_columns( old_select_columns=node.select_columns, column_replacements=column_replacements ), from_source=parent_select_node.from_source, from_source_alias=parent_select_node.from_source_alias, - joins_descs=parent_select_node.join_descs, + join_descs=parent_select_node.join_descs, group_bys=new_group_bys, order_bys=tuple(new_order_bys), where=SqlRewritingSubQueryReducerVisitor._rewrite_where( @@ -707,7 +707,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) @@ -735,7 +735,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP if matching_select_column: new_group_bys.append( SqlSelectColumn( - expr=SqlColumnAliasReferenceExpression(column_alias=matching_select_column.column_alias), + expr=SqlColumnAliasReferenceExpression.create(column_alias=matching_select_column.column_alias), column_alias=matching_select_column.column_alias, ) ) @@ -743,12 +743,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP logger.info(f"Did not find matching select for {group_by} in:\n{indent(node.structure_text())}") new_group_bys.append(group_by) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -771,7 +771,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index be649142ff..a3b440cc67 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -27,12 +27,12 @@ def _reduce_parents( node: SqlSelectStatementNode, ) -> SqlSelectStatementNode: """Apply the reducing operation to the parent select statements.""" - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -158,7 +158,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP return node_with_reduced_parents new_order_by.append( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias=table_alias_in_parent, column_name=order_by_item_expr.col_ref.column_name, @@ -175,12 +175,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP elif parent_select_node.limit is not None: new_limit = min(new_limit, parent_select_node.limit) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="\n".join([parent_select_node.description, node_with_reduced_parents.description]), select_columns=parent_select_node.select_columns, from_source=parent_select_node.from_source, from_source_alias=parent_select_node.from_source_alias, - joins_descs=parent_select_node.join_descs, + join_descs=parent_select_node.join_descs, group_bys=parent_select_node.group_bys, order_bys=tuple(new_order_by), where=parent_select_node.where, @@ -195,7 +195,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index baebd8eefd..9f32cbefa1 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -26,7 +26,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP should_simplify_table_aliases = len(node.parent_nodes) <= 1 if should_simplify_table_aliases: - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=tuple( SqlSelectColumn(expr=x.expr.rewrite(should_render_table_alias=False), column_alias=x.column_alias) @@ -47,12 +47,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP distinct=node.distinct, ) - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description=node.description, select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - joins_descs=tuple( + join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), right_source_alias=x.right_source_alias, @@ -75,7 +75,7 @@ def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> Sq return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102 - return SqlCreateTableAsNode( + return SqlCreateTableAsNode.create( sql_table=node.sql_table, parent_node=node.parent_node.accept(self), ) diff --git a/metricflow/validation/data_warehouse_model_validator.py b/metricflow/validation/data_warehouse_model_validator.py index 8ffd7a2a0a..a9d076278f 100644 --- a/metricflow/validation/data_warehouse_model_validator.py +++ b/metricflow/validation/data_warehouse_model_validator.py @@ -201,7 +201,7 @@ def gen_dimension_tasks( spec_filter_tuples.append( ( spec, - FilterElementsNode( + FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(dimension_specs=(spec,)) ), ) @@ -214,7 +214,7 @@ def gen_dimension_tasks( spec_filter_tuples.append( ( spec, - FilterElementsNode( + FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(time_dimension_specs=(spec,)) ), ) @@ -241,7 +241,7 @@ def gen_dimension_tasks( ) ) - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet( dimension_specs=dimension_specs, @@ -299,7 +299,7 @@ def gen_entity_tasks( dataset.instance_set.spec_set.entity_specs ) for spec in semantic_model_specs: - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(entity_specs=(spec,)) ) semantic_model_sub_tasks.append( @@ -322,7 +322,7 @@ def gen_entity_tasks( ) ) - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet( entity_specs=tuple(semantic_model_specs), @@ -392,7 +392,7 @@ def gen_measure_tasks( obtained_source_node = source_node_by_measure_spec.get(spec) assert obtained_source_node, f"Unable to find generated source node for measure: {spec.element_name}" - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=obtained_source_node, include_specs=InstanceSpecSet( measure_specs=(spec,), @@ -419,7 +419,7 @@ def gen_measure_tasks( ) for measure_specs, source_node in measure_specs_source_node_pair: - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=measure_specs) ) tasks.append( diff --git a/scripts/ci_tests/metricflow_package_test.py b/scripts/ci_tests/metricflow_package_test.py index 5ca4ef645f..ab8bfb0f49 100644 --- a/scripts/ci_tests/metricflow_package_test.py +++ b/scripts/ci_tests/metricflow_package_test.py @@ -36,7 +36,7 @@ def _data_set_to_read_nodes(data_sets: OrderedDict[str, SemanticModelDataSet]) - # Moved from model_fixtures.py. return_dict: OrderedDict[str, ReadSqlSourceNode] = OrderedDict() for semantic_model_name, data_set in data_sets.items(): - return_dict[semantic_model_name] = ReadSqlSourceNode(data_set) + return_dict[semantic_model_name] = ReadSqlSourceNode.create(data_set) return return_dict diff --git a/tests_metricflow/dataflow/builder/test_node_data_set.py b/tests_metricflow/dataflow/builder/test_node_data_set.py index ee668769f8..8c05459111 100644 --- a/tests_metricflow/dataflow/builder/test_node_data_set.py +++ b/tests_metricflow/dataflow/builder/test_node_data_set.py @@ -68,20 +68,24 @@ def test_no_parent_node_data_set( time_dimension_instances=(), entity_instances=(), ), - sql_select_node=SqlSelectStatementNode( + sql_select_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference(table_alias="src", column_name="bookings")), + expr=SqlColumnReferenceExpression.create( + SqlColumnReference(table_alias="src", column_name="bookings") + ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="src", ), ) - node = ReadSqlSourceNode(data_set=data_set) + node = ReadSqlSourceNode.create(data_set=data_set) assert resolver.get_output_data_set(node).instance_set == data_set.instance_set @@ -102,7 +106,7 @@ def test_joined_node_data_set( # Join "revenue" with "users_latest" to get "user__home_state_latest" revenue_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["revenue"] users_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["users_latest"] - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=revenue_node, join_targets=[ JoinDescription( diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py b/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py index eaceb6465f..096dd30fa9 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_cm_branch_combiner.py @@ -27,7 +27,7 @@ def make_dataflow_plan(node: DataflowPlanNode) -> DataflowPlan: # noqa: D103 return DataflowPlan( - sink_nodes=[WriteToResultDataTableNode(node)], + sink_nodes=[WriteToResultDataTableNode.create(node)], plan_id=DagId.from_id_prefix(StaticIdPrefix.OPTIMIZED_DATAFLOW_PLAN_PREFIX), ) @@ -69,11 +69,11 @@ def test_filter_combination( ) -> None: """Tests combining a single node.""" source0 = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["bookings_source"] - filter0 = FilterElementsNode( + filter0 = FilterElementsNode.create( parent_node=source0, include_specs=InstanceSpecSet(measure_specs=(MeasureSpec(element_name="bookings"),)) ) source1 = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping["bookings_source"] - filter1 = FilterElementsNode( + filter1 = FilterElementsNode.create( parent_node=source1, include_specs=InstanceSpecSet( measure_specs=(MeasureSpec(element_name="booking_value"),), diff --git a/tests_metricflow/examples/test_node_sql.py b/tests_metricflow/examples/test_node_sql.py index d3c96fffc6..aa88d93868 100644 --- a/tests_metricflow/examples/test_node_sql.py +++ b/tests_metricflow/examples/test_node_sql.py @@ -51,7 +51,7 @@ def test_view_sql_generated_at_a_node( # Show SQL and spec set at a source node. bookings_source_data_set = to_data_set_converter.create_sql_source_data_set(bookings_semantic_model) - read_source_node = ReadSqlSourceNode(bookings_source_data_set) + read_source_node = ReadSqlSourceNode.create(bookings_source_data_set) conversion_result = to_sql_plan_converter.convert_to_sql_query_plan( sql_engine_type=sql_client.sql_engine_type, dataflow_plan_node=read_source_node, @@ -63,13 +63,13 @@ def test_view_sql_generated_at_a_node( logger.info(f"SQL generated at {read_source_node} is:\n\n{sql_at_read_node}") logger.info(f"Spec set at {read_source_node} is:\n\n{mf_pformat(spec_set_at_read_node)}") - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=read_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) # Show SQL and spec set at a filter node. - filter_elements_node = FilterElementsNode( + filter_elements_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( time_dimension_specs=( diff --git a/tests_metricflow/execution/noop_task.py b/tests_metricflow/execution/noop_task.py index 563d16feaa..8f65ec049a 100644 --- a/tests_metricflow/execution/noop_task.py +++ b/tests_metricflow/execution/noop_task.py @@ -2,13 +2,13 @@ import logging import time -from typing import Optional, Sequence +from dataclasses import dataclass +from typing import ClassVar, Sequence from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow.execution.execution_plan import ( ExecutionPlanTask, - SqlQuery, TaskExecutionError, TaskExecutionResult, ) @@ -16,21 +16,28 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) class NoOpExecutionPlanTask(ExecutionPlanTask): - """A no-op task for testing executors.""" + """A no-op task for testing executors. - # Error to return if should_error is set. - EXAMPLE_ERROR = TaskExecutionError("Expected Error") + Attributes: + should_error: If true, test the error flow by intentionally returning an error in the results. + """ - def __init__(self, parent_tasks: Sequence[ExecutionPlanTask] = (), should_error: bool = False) -> None: - """Constructor. + EXAMPLE_ERROR: ClassVar[TaskExecutionError] = TaskExecutionError("Expected Error") - Args: - parent_tasks: Self-explanatory. - should_error: if true, return an error in the results. - """ - self._should_error = should_error - super().__init__(task_id=self.create_unique_id(), parent_nodes=list(parent_tasks)) + should_error: bool = False + + @staticmethod + def create( # noqa: D102 + parent_tasks: Sequence[ExecutionPlanTask] = (), + should_error: bool = False, + ) -> NoOpExecutionPlanTask: + return NoOpExecutionPlanTask( + parent_nodes=tuple(parent_tasks), + sql_query=None, + should_error=should_error, + ) @property def description(self) -> str: # noqa: D102 @@ -45,9 +52,5 @@ def execute(self) -> TaskExecutionResult: # noqa: D102 time.sleep(0.01) end_time = time.time() return TaskExecutionResult( - start_time=start_time, end_time=end_time, errors=(self.EXAMPLE_ERROR,) if self._should_error else () + start_time=start_time, end_time=end_time, errors=(self.EXAMPLE_ERROR,) if self.should_error else () ) - - @property - def sql_query(self) -> Optional[SqlQuery]: # noqa: D102 - return None diff --git a/tests_metricflow/execution/test_sequential_executor.py b/tests_metricflow/execution/test_sequential_executor.py index 84f29b6461..998f842da8 100644 --- a/tests_metricflow/execution/test_sequential_executor.py +++ b/tests_metricflow/execution/test_sequential_executor.py @@ -9,7 +9,7 @@ def test_single_task() -> None: """Tests running an execution plan with a single task.""" - task = NoOpExecutionPlanTask() + task = NoOpExecutionPlanTask.create() execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) assert results.get_result(task.task_id) @@ -17,7 +17,7 @@ def test_single_task() -> None: def test_single_task_error() -> None: """Check that an error is properly returned in the results if a task errors out.""" - task = NoOpExecutionPlanTask(should_error=True) + task = NoOpExecutionPlanTask.create(should_error=True) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) executor = SequentialPlanExecutor() results = executor.execute_plan(execution_plan) @@ -27,9 +27,9 @@ def test_single_task_error() -> None: def test_task_with_parents() -> None: """Tests a plan with a task that has 2 direct parents.""" - parent_task1 = NoOpExecutionPlanTask() - parent_task2 = NoOpExecutionPlanTask() - leaf_task = NoOpExecutionPlanTask(parent_tasks=[parent_task1, parent_task2]) + parent_task1 = NoOpExecutionPlanTask.create() + parent_task2 = NoOpExecutionPlanTask.create() + leaf_task = NoOpExecutionPlanTask.create(parent_tasks=[parent_task1, parent_task2]) execution_plan = ExecutionPlan(leaf_tasks=[leaf_task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) @@ -47,9 +47,9 @@ def test_task_with_parents() -> None: def test_parent_task_error() -> None: """Check that a child task is not run if a parent task fails.""" - parent_task1 = NoOpExecutionPlanTask(should_error=True) - parent_task2 = NoOpExecutionPlanTask() - leaf_task = NoOpExecutionPlanTask(parent_tasks=[parent_task1, parent_task2]) + parent_task1 = NoOpExecutionPlanTask.create(should_error=True) + parent_task2 = NoOpExecutionPlanTask.create() + leaf_task = NoOpExecutionPlanTask.create(parent_tasks=[parent_task1, parent_task2]) execution_plan = ExecutionPlan(leaf_tasks=[leaf_task], dag_id=DagId.from_str("plan0")) executor = SequentialPlanExecutor() diff --git a/tests_metricflow/execution/test_tasks.py b/tests_metricflow/execution/test_tasks.py index 17fe99f2ef..6e24ec85de 100644 --- a/tests_metricflow/execution/test_tasks.py +++ b/tests_metricflow/execution/test_tasks.py @@ -10,6 +10,7 @@ ExecutionPlan, SelectSqlQueryToDataTableTask, SelectSqlQueryToTableTask, + SqlQuery, ) from metricflow.execution.executor import SequentialPlanExecutor from metricflow.protocols.sql_client import SqlClient, SqlEngine @@ -18,7 +19,7 @@ def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D103 - task = SelectSqlQueryToDataTableTask(sql_client, "SELECT 1 AS foo", SqlBindParameters()) + task = SelectSqlQueryToDataTableTask.create(sql_client, SqlQuery("SELECT 1 AS foo", SqlBindParameters())) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) results = SequentialPlanExecutor().execute_plan(execution_plan) @@ -41,10 +42,12 @@ def test_write_table_task( # noqa: D103 mf_test_configuration: MetricFlowTestConfiguration, sql_client: SqlClient ) -> None: # noqa: D103 output_table = SqlTable(schema_name=mf_test_configuration.mf_system_schema, table_name=f"test_table_{random_id()}") - task = SelectSqlQueryToTableTask( + task = SelectSqlQueryToTableTask.create( sql_client=sql_client, - sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", - bind_parameters=SqlBindParameters(), + sql_query=SqlQuery( + sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo", + bind_parameters=SqlBindParameters(), + ), output_table=output_table, ) execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0")) diff --git a/tests_metricflow/fixtures/manifest_fixtures.py b/tests_metricflow/fixtures/manifest_fixtures.py index 6ab1061a42..f34be3fc83 100644 --- a/tests_metricflow/fixtures/manifest_fixtures.py +++ b/tests_metricflow/fixtures/manifest_fixtures.py @@ -221,7 +221,7 @@ def _data_set_to_read_nodes( # Moved from model_fixtures.py. return_dict: OrderedDict[str, ReadSqlSourceNode] = OrderedDict() for semantic_model_name, data_set in data_sets.items(): - return_dict[semantic_model_name] = ReadSqlSourceNode(data_set) + return_dict[semantic_model_name] = ReadSqlSourceNode.create(data_set) logger.debug( f"For semantic model {semantic_model_name}, creating node_id {return_dict[semantic_model_name].node_id}" ) diff --git a/tests_metricflow/integration/test_configured_cases.py b/tests_metricflow/integration/test_configured_cases.py index f381fde20b..4f9c563d6b 100644 --- a/tests_metricflow/integration/test_configured_cases.py +++ b/tests_metricflow/integration/test_configured_cases.py @@ -97,8 +97,8 @@ def render_date_sub( granularity: TimeGranularity, ) -> str: """Renders a date subtract expression.""" - expr = SqlSubtractTimeIntervalExpression( - arg=SqlColumnReferenceExpression(SqlColumnReference(table_alias, column_alias)), + expr = SqlSubtractTimeIntervalExpression.create( + arg=SqlColumnReferenceExpression.create(SqlColumnReference(table_alias, column_alias)), count=count, granularity=granularity, ) @@ -106,10 +106,10 @@ def render_date_sub( def render_date_trunc(self, expr: str, granularity: TimeGranularity) -> str: """Return the DATE_TRUNC() call that can be used for converting the given expr to the granularity.""" - renderable_expr = SqlDateTruncExpression( + renderable_expr = SqlDateTruncExpression.create( time_granularity=granularity, - arg=SqlCastToTimestampExpression( - arg=SqlStringExpression( + arg=SqlCastToTimestampExpression.create( + arg=SqlStringExpression.create( sql_expr=expr, requires_parenthesis=False, ) @@ -119,10 +119,10 @@ def render_date_trunc(self, expr: str, granularity: TimeGranularity) -> str: def render_extract(self, expr: str, date_part: DatePart) -> str: """Return the EXTRACT call that can be used for converting the given expr to the date_part.""" - renderable_expr = SqlExtractExpression( + renderable_expr = SqlExtractExpression.create( date_part=date_part, - arg=SqlCastToTimestampExpression( - arg=SqlStringExpression( + arg=SqlCastToTimestampExpression.create( + arg=SqlStringExpression.create( sql_expr=expr, requires_parenthesis=False, ) @@ -142,8 +142,8 @@ def render_percentile_expr( ) ) - renderable_expr = SqlPercentileExpression( - order_by_arg=SqlStringExpression( + renderable_expr = SqlPercentileExpression.create( + order_by_arg=SqlStringExpression.create( sql_expr=expr, requires_parenthesis=False, ), @@ -191,7 +191,7 @@ def render_time_dimension_template( def generate_random_uuid(self) -> str: """Returns the generate random UUID SQL function.""" - expr = SqlGenerateUuidExpression() + expr = SqlGenerateUuidExpression.create() return self._sql_client.sql_query_plan_renderer.expr_renderer.render_sql_expr(expr).sql diff --git a/tests_metricflow/mf_logging/test_dag_to_text.py b/tests_metricflow/mf_logging/test_dag_to_text.py index 06086744c5..079d31bd90 100644 --- a/tests_metricflow/mf_logging/test_dag_to_text.py +++ b/tests_metricflow/mf_logging/test_dag_to_text.py @@ -33,15 +33,15 @@ def test_multithread_dag_to_text() -> None: dag_to_text_formatter = MetricFlowDagTextFormatter(max_width=1) dag = SqlQueryPlan( plan_id=DagId("plan"), - render_node=SqlSelectStatementNode( + render_node=SqlSelectStatementNode.create( description="test", select_columns=( SqlSelectColumn( - expr=SqlStringExpression("'foo'"), + expr=SqlStringExpression.create("'foo'"), column_alias="bar", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="schema", table_name="table")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="schema", table_name="table")), from_source_alias="src", ), ) diff --git a/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py b/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py index 15e8da4a9f..3a4b460772 100644 --- a/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py +++ b/tests_metricflow/plan_conversion/dataflow_to_sql/test_metric_time_dimension_to_sql.py @@ -30,7 +30,7 @@ def test_metric_time_dimension_transform_node_using_primary_time( source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_dimension_transform_node = MetricTimeDimensionTransformNode( + metric_time_dimension_transform_node = MetricTimeDimensionTransformNode.create( parent_node=source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds") ) convert_and_check( @@ -54,7 +54,7 @@ def test_metric_time_dimension_transform_node_using_non_primary_time( source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_dimension_transform_node = MetricTimeDimensionTransformNode( + metric_time_dimension_transform_node = MetricTimeDimensionTransformNode.create( parent_node=source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="paid_at"), ) diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py index 15ba3620d5..88fd000fda 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_sql_plan.py @@ -153,7 +153,7 @@ def test_filter_node( source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filter_node = FilterElementsNode( + filter_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,)) ) @@ -184,11 +184,11 @@ def test_filter_with_where_constraint_node( ] ds_spec = TimeDimensionSpec(element_name="ds", entity_links=(), time_granularity=TimeGranularity.DAY) - filter_node = FilterElementsNode( + filter_node = FilterElementsNode.create( parent_node=source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), time_dimension_specs=(ds_spec,)), ) # need to include ds_spec because where constraint operates on ds - where_constraint_node = WhereConstraintNode( + where_constraint_node = WhereConstraintNode.create( parent_node=filter_node, where_specs=( WhereFilterSpec( @@ -258,12 +258,12 @@ def test_measure_aggregation_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs)), ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) @@ -292,7 +292,7 @@ def test_single_join_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -307,7 +307,7 @@ def test_single_join_node( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -315,7 +315,7 @@ def test_single_join_node( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -353,7 +353,7 @@ def test_multi_join_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), entity_specs=(entity_spec,)), ) @@ -365,7 +365,7 @@ def test_multi_join_node( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -373,7 +373,7 @@ def test_multi_join_node( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -419,7 +419,7 @@ def test_compute_metrics_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -434,7 +434,7 @@ def test_compute_metrics_node( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -442,7 +442,7 @@ def test_compute_metrics_node( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -455,12 +455,12 @@ def test_compute_metrics_node( ], ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measure_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, dimension_spec}, @@ -492,7 +492,7 @@ def test_compute_metrics_node_simple_expr( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(measure_spec,), entity_specs=(entity_spec,)), ) @@ -504,7 +504,7 @@ def test_compute_metrics_node_simple_expr( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -512,7 +512,7 @@ def test_compute_metrics_node_simple_expr( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measure_node, join_targets=[ JoinDescription( @@ -525,17 +525,17 @@ def test_compute_metrics_node_simple_expr( ], ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, dimension_spec}, ) - sink_node = WriteToResultDataTableNode(compute_metrics_node) + sink_node = WriteToResultDataTableNode.create(compute_metrics_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -578,27 +578,27 @@ def test_join_to_time_spine_node_without_offset( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec}, ) - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, requested_agg_time_dimension_specs=[MTD_SPEC_DAY], use_custom_agg_time_dimension=False, @@ -608,7 +608,7 @@ def test_join_to_time_spine_node_without_offset( join_type=SqlJoinType.INNER, ) - sink_node = WriteToResultDataTableNode(join_to_time_spine_node) + sink_node = WriteToResultDataTableNode.create(join_to_time_spine_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -651,26 +651,26 @@ def test_join_to_time_spine_node_with_offset_window( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, metric_time_spec}, ) - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, requested_agg_time_dimension_specs=[MTD_SPEC_DAY], use_custom_agg_time_dimension=False, @@ -681,7 +681,7 @@ def test_join_to_time_spine_node_with_offset_window( join_type=SqlJoinType.INNER, ) - sink_node = WriteToResultDataTableNode(join_to_time_spine_node) + sink_node = WriteToResultDataTableNode.create(join_to_time_spine_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -724,26 +724,26 @@ def test_join_to_time_spine_node_with_offset_to_grain( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=measure_source_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=metric_time_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), entity_specs=(entity_spec,), dimension_specs=(metric_time_spec,) ), ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="booking_fees") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, metric_time_spec}, ) - join_to_time_spine_node = JoinToTimeSpineNode( + join_to_time_spine_node = JoinToTimeSpineNode.create( parent_node=compute_metrics_node, requested_agg_time_dimension_specs=[MTD_SPEC_DAY], use_custom_agg_time_dimension=False, @@ -755,7 +755,7 @@ def test_join_to_time_spine_node_with_offset_to_grain( join_type=SqlJoinType.INNER, ) - sink_node = WriteToResultDataTableNode(join_to_time_spine_node) + sink_node = WriteToResultDataTableNode.create(join_to_time_spine_node) dataflow_plan = DataflowPlan(sink_nodes=[sink_node], plan_id=DagId.from_str("plan0")) assert_plan_snapshot_text_equal( @@ -803,7 +803,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measures_node = FilterElementsNode( + filtered_measures_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=(numerator_spec, denominator_spec), entity_specs=(entity_spec,)), ) @@ -815,7 +815,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( dimension_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "listings_latest" ] - filtered_dimension_node = FilterElementsNode( + filtered_dimension_node = FilterElementsNode.create( parent_node=dimension_source_node, include_specs=InstanceSpecSet( entity_specs=(entity_spec,), @@ -823,7 +823,7 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( ), ) - join_node = JoinOnEntitiesNode( + join_node = JoinOnEntitiesNode.create( left_node=filtered_measures_node, join_targets=[ JoinDescription( @@ -836,11 +836,11 @@ def test_compute_metrics_node_ratio_from_single_semantic_model( ], ) - aggregated_measures_node = AggregateMeasuresNode( + aggregated_measures_node = AggregateMeasuresNode.create( parent_node=join_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings_per_booker") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measures_node, metric_specs=[metric_spec], aggregated_to_elements={entity_spec, dimension_spec}, @@ -882,7 +882,7 @@ def test_order_by_node( "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=(measure_spec,), @@ -891,18 +891,18 @@ def test_order_by_node( ), ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=metric_input_measure_specs ) metric_spec = MetricSpec(element_name="bookings") - compute_metrics_node = ComputeMetricsNode( + compute_metrics_node = ComputeMetricsNode.create( parent_node=aggregated_measure_node, metric_specs=[metric_spec], aggregated_to_elements={dimension_spec, time_dimension_spec}, ) - order_by_node = OrderByLimitNode( + order_by_node = OrderByLimitNode.create( order_by_specs=[ OrderBySpec( instance_spec=time_dimension_spec, @@ -940,7 +940,7 @@ def test_semi_additive_join_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "accounts_source" ] - semi_additive_join_node = SemiAdditiveJoinNode( + semi_additive_join_node = SemiAdditiveJoinNode.create( parent_node=measure_source_node, entity_specs=tuple(), time_dimension_spec=time_dimension_spec, @@ -974,7 +974,7 @@ def test_semi_additive_join_node_with_queried_group_by( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "accounts_source" ] - semi_additive_join_node = SemiAdditiveJoinNode( + semi_additive_join_node = SemiAdditiveJoinNode.create( parent_node=measure_source_node, entity_specs=tuple(), time_dimension_spec=time_dimension_spec, @@ -1010,7 +1010,7 @@ def test_semi_additive_join_node_with_grouping( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "accounts_source" ] - semi_additive_join_node = SemiAdditiveJoinNode( + semi_additive_join_node = SemiAdditiveJoinNode.create( parent_node=measure_source_node, entity_specs=(entity_spec,), time_dimension_spec=time_dimension_spec, @@ -1037,7 +1037,7 @@ def test_constrain_time_range_node( measure_source_node = mf_engine_test_fixture_mapping[SemanticManifestSetup.SIMPLE_MANIFEST].read_node_mapping[ "bookings_source" ] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet( measure_specs=( @@ -1050,12 +1050,12 @@ def test_constrain_time_range_node( ), ), ) - metric_time_node = MetricTimeDimensionTransformNode( + metric_time_node = MetricTimeDimensionTransformNode.create( parent_node=filtered_measure_node, aggregation_time_dimension_reference=TimeDimensionReference(element_name="ds"), ) - constrain_time_node = ConstrainTimeRangeNode( + constrain_time_node = ConstrainTimeRangeNode.create( parent_node=metric_time_node, time_range_constraint=TimeRangeConstraint( start_time=as_datetime("2020-01-01"), @@ -1136,29 +1136,29 @@ def test_combine_output_node( # Build compute measures node measure_specs: List[MeasureSpec] = [sum_spec] - filtered_measure_node = FilterElementsNode( + filtered_measure_node = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs), dimension_specs=(dimension_spec,)), ) - aggregated_measure_node = AggregateMeasuresNode( + aggregated_measure_node = AggregateMeasuresNode.create( parent_node=filtered_measure_node, metric_input_measure_specs=tuple(MetricInputMeasureSpec(measure_spec=x) for x in measure_specs), ) # Build agg measures node measure_specs_2 = [sum_boolean_spec, count_distinct_spec] - filtered_measure_node_2 = FilterElementsNode( + filtered_measure_node_2 = FilterElementsNode.create( parent_node=measure_source_node, include_specs=InstanceSpecSet(measure_specs=tuple(measure_specs_2), dimension_specs=(dimension_spec,)), ) - aggregated_measure_node_2 = AggregateMeasuresNode( + aggregated_measure_node_2 = AggregateMeasuresNode.create( parent_node=filtered_measure_node_2, metric_input_measure_specs=tuple( MetricInputMeasureSpec(measure_spec=x, fill_nulls_with=1) for x in measure_specs_2 ), ) - combine_output_node = CombineAggregatedOutputsNode([aggregated_measure_node, aggregated_measure_node_2]) + combine_output_node = CombineAggregatedOutputsNode.create([aggregated_measure_node, aggregated_measure_node_2]) convert_and_check( request=request, mf_test_configuration=mf_test_configuration, diff --git a/tests_metricflow/sql/optimizer/test_column_pruner.py b/tests_metricflow/sql/optimizer/test_column_pruner.py index 443c58fc80..a73feb22a5 100644 --- a/tests_metricflow/sql/optimizer/test_column_pruner.py +++ b/tests_metricflow/sql/optimizer/test_column_pruner.py @@ -70,51 +70,51 @@ def base_select_statement() -> SqlSelectStatementNode: ON from_source.join_col = joined_source.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col1") ), column_alias="from_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), column_alias="from_source_join_col", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col1") ), column_alias="joined_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), column_alias="joined_source_join_col", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="from_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col0", @@ -123,7 +123,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col1", @@ -132,7 +132,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="join_col", @@ -141,17 +141,19 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="from_source_table", ), from_source_alias="from_source", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="joined_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col0", @@ -160,7 +162,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col1", @@ -169,7 +171,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="join_col", @@ -178,18 +180,18 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="joined_source_table") ), from_source_alias="joined_source_table", ), right_source_alias="joined_source", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), ), @@ -228,29 +230,29 @@ def test_prune_from_source( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests a case where columns should be pruned from the FROM clause.""" - select_statement_with_some_from_source_column_removed = SqlSelectStatementNode( + select_statement_with_some_from_source_column_removed = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col1") ), column_alias="joined_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), column_alias="joined_source_join_col", @@ -258,7 +260,7 @@ def test_prune_from_source( ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, where=base_select_statement.where, @@ -287,29 +289,29 @@ def test_prune_joined_source( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests a case where columns should be pruned from the JOIN clause.""" - select_statement_with_some_joined_source_column_removed = SqlSelectStatementNode( + select_statement_with_some_joined_source_column_removed = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col1") ), column_alias="from_source_col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), column_alias="from_source_join_col", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", @@ -317,7 +319,7 @@ def test_prune_joined_source( ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, where=base_select_statement.where, @@ -346,11 +348,11 @@ def test_dont_prune_if_in_where( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests that columns aren't pruned from parent sources if columns are used in a where.""" - select_statement_with_other_exprs = SqlSelectStatementNode( + select_statement_with_other_exprs = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", @@ -358,9 +360,11 @@ def test_dont_prune_if_in_where( ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, - where=SqlIsNullExpression( - SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="from_source", column_name="col1")) + join_descs=base_select_statement.join_descs, + where=SqlIsNullExpression.create( + SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="from_source", column_name="col1") + ) ), group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, @@ -389,17 +393,17 @@ def test_dont_prune_with_str_expr( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests that columns aren't pruned from parent sources if there's a string expression in the select.""" - select_statement_with_other_exprs = SqlSelectStatementNode( + select_statement_with_other_exprs = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlStringExpression("from_source.col0", requires_parenthesis=False), + expr=SqlStringExpression.create("from_source.col0", requires_parenthesis=False), column_alias="some_string_expr", ), ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, where=base_select_statement.where, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, @@ -428,17 +432,17 @@ def test_prune_with_str_expr( base_select_statement: SqlSelectStatementNode, ) -> None: """Tests that columns are from parent sources if there's a string expression in the select with known cols.""" - select_statement_with_other_exprs = SqlSelectStatementNode( + select_statement_with_other_exprs = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlStringExpression("from_source.col0", requires_parenthesis=False, used_columns=("col0",)), + expr=SqlStringExpression.create("from_source.col0", requires_parenthesis=False, used_columns=("col0",)), column_alias="some_string_expr", ), ), from_source=base_select_statement.from_source, from_source_alias=base_select_statement.from_source_alias, - joins_descs=base_select_statement.join_descs, + join_descs=base_select_statement.join_descs, where=base_select_statement.where, group_bys=base_select_statement.group_bys, order_bys=base_select_statement.order_bys, @@ -486,19 +490,19 @@ def string_select_statement() -> SqlSelectStatementNode: ON from_source.join_col = joined_source.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="col0", used_columns=("col0",)), + expr=SqlStringExpression.create(sql_expr="col0", used_columns=("col0",)), column_alias="from_source_col0", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="from_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col0", @@ -507,7 +511,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col1", @@ -516,7 +520,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="join_col", @@ -525,17 +529,19 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="from_source_table", ), from_source_alias="from_source", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="joined_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col2", @@ -544,7 +550,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col3", @@ -553,7 +559,7 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="join_col", @@ -562,18 +568,18 @@ def string_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="joined_source_table") ), from_source_alias="joined_source_table", ), right_source_alias="joined_source", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), ), @@ -631,19 +637,19 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: ) src1 ) src2 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="col0"), + expr=SqlStringExpression.create(sql_expr="col0"), column_alias="col0", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="col0", @@ -652,7 +658,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="col1", @@ -661,11 +667,11 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col0", @@ -674,7 +680,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col1", @@ -683,7 +689,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col2", @@ -692,7 +698,7 @@ def grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col2", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src0")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="src0")), from_source_alias="src0", ), from_source_alias="src1", @@ -751,23 +757,25 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: ON src3.join_col = src4.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="4", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="col0"), + expr=SqlStringExpression.create(sql_expr="col0"), column_alias="col0", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="src3", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="col0", @@ -776,7 +784,7 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src1", column_name="join_col", @@ -785,11 +793,11 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col0", @@ -798,7 +806,7 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col1", @@ -807,7 +815,7 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="join_col", @@ -816,18 +824,20 @@ def join_grandparent_pruning_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src0")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="src0") + ), from_source_alias="src0", ), from_source_alias="src1", ), right_source_alias="src4", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src3", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src4", column_name="join_col") ), ), @@ -866,33 +876,35 @@ def test_prune_distinct_select( column_pruner: SqlColumnPrunerOptimizer, ) -> None: """Test that distinct select node shouldn't be pruned.""" - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="test1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", distinct=True, ), diff --git a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py index 87f5fd10a8..1536f30e66 100644 --- a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -54,14 +54,14 @@ def base_select_statement() -> SqlSelectStatementNode: GROUP BY src2.ds ORDER BY src2.ds """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src3", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src2", column_name="bookings") ) ], @@ -69,29 +69,33 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="ds")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="ds") + ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="bookings") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src1", column_name="ds")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src1", column_name="ds") + ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="bookings", @@ -100,7 +104,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="ds", @@ -109,27 +113,29 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="src0", limit=2, ), from_source_alias="src1", - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src1", column_name="ds", ) ), comparison=SqlComparison.GREATER_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-01"), + right_expr=SqlStringLiteralExpression.create("2020-01-01"), ), limit=1, ), from_source_alias="src2", group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="ds", @@ -138,19 +144,19 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="ds", ) ), comparison=SqlComparison.LESS_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-05"), + right_expr=SqlStringLiteralExpression.create("2020-01-05"), ), order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="ds", @@ -210,14 +216,14 @@ def join_select_statement() -> SqlSelectStatementNode: ON bookings_src.listing = listings_src.listing GROUP BY bookings_src.ds """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="bookings") ) ], @@ -225,72 +231,74 @@ def join_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src", column_name="country_latest") ), column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="ds") ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="bookings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="booking") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlStringExpression(sql_expr="1", requires_parenthesis=False, used_columns=()), + expr=SqlStringExpression.create(sql_expr="1", requires_parenthesis=False, used_columns=()), column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="fct_bookings_src", ), from_source_alias="bookings_src", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src", column_name="country") ), column_alias="country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="dim_listings_src", ), right_source_alias="listings_src", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="bookings_src", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src", column_name="listing"), ), ), @@ -299,7 +307,7 @@ def join_select_statement() -> SqlSelectStatementNode: ), group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", @@ -308,15 +316,15 @@ def join_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", ) ), comparison=SqlComparison.LESS_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-05"), + right_expr=SqlStringLiteralExpression.create("2020-01-05"), ), ) @@ -369,14 +377,14 @@ def colliding_select_statement() -> SqlSelectStatementNode: ON bookings_src.listing = listings_src.listing GROUP BY bookings_src.ds """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="bookings") ) ], @@ -384,74 +392,76 @@ def colliding_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src", column_name="listing__country_latest") ), column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="ds") ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="bookings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="booking") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="ds") ), column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="colliding_alias", ), from_source_alias="bookings_src", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="country") ), column_alias="country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="colliding_alias", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="colliding_alias", ), right_source_alias="listings_src", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="bookings_src", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src", column_name="listing"), ), ), @@ -460,7 +470,7 @@ def colliding_select_statement() -> SqlSelectStatementNode: ), group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", @@ -469,15 +479,15 @@ def colliding_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), ), - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", ) ), comparison=SqlComparison.LESS_THAN_OR_EQUALS, - right_expr=SqlStringLiteralExpression("2020-01-05"), + right_expr=SqlStringLiteralExpression.create("2020-01-05"), ), ) @@ -538,14 +548,14 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: ON listing_src1.listing = listings_src2.listing GROUP BY bookings_src.ds, listings_src1.country_latest, listings_src2.capacity_latest """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="bookings") ) ], @@ -553,114 +563,116 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src1", column_name="country_latest") ), column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="listings_src2", column_name="capacity_latest") ), column_alias="listing__capacity_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="bookings_src", column_name="ds") ), column_alias="ds", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="bookings_src", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="booking") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="ds") ), column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="fct_bookings_src", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="fct_bookings_src", ), from_source_alias="bookings_src", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src1", column_name="country") ), column_alias="country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src1", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="dim_listings_src1", ), right_source_alias="listings_src1", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="bookings_src", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src1", column_name="listing"), ), ), join_type=SqlJoinType.LEFT_OUTER, ), SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="listings_src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src2", column_name="capacity") ), column_alias="capacity_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="dim_listings_src2", column_name="listing_id") ), column_alias="listing", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="dim_listings") ), from_source_alias="dim_listings_src2", ), right_source_alias="listings_src2", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src1", column_name="listing"), ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( SqlColumnReference(table_alias="listings_src2", column_name="listing"), ), ), @@ -669,7 +681,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: ), group_bys=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="bookings_src", column_name="ds", @@ -678,7 +690,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: column_alias="ds", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="listings_src1", column_name="country_latest", @@ -687,7 +699,7 @@ def reduce_all_join_select_statement() -> SqlSelectStatementNode: column_alias="listing__country_latest", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="listings_src2", column_name="capacity_latest", @@ -748,30 +760,30 @@ def reducing_join_statement() -> SqlSelectStatementNode: FROM demo.fct_listings src4 ) src3 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src2", column_name="bookings") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src3", column_name="listings") ), column_alias="listings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="bookings") ) ], @@ -779,30 +791,32 @@ def reducing_join_statement() -> SqlSelectStatementNode: column_alias="bookings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="1", requires_parenthesis=False, used_columns=()), + expr=SqlStringExpression.create(sql_expr="1", requires_parenthesis=False, used_columns=()), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="src0", ), from_source_alias="src1", ), from_source_alias="src2", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="src4", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src4", column_name="listings") ) ], @@ -810,7 +824,7 @@ def reducing_join_statement() -> SqlSelectStatementNode: column_alias="listings", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="fct_listings") ), from_source_alias="src4", @@ -872,30 +886,30 @@ def reducing_join_left_node_statement() -> SqlSelectStatementNode: FROM demo.fct_listings src4 ) src3 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="query", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src2", column_name="bookings") ), column_alias="bookings", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src3", column_name="listings") ), column_alias="listings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src4", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src4", column_name="listings") ) ], @@ -903,20 +917,22 @@ def reducing_join_left_node_statement() -> SqlSelectStatementNode: column_alias="listings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_listings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_listings") + ), from_source_alias="src4", ), from_source_alias="src2", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlAggregateFunctionExpression( + expr=SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression( + SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="bookings") ) ], @@ -924,15 +940,17 @@ def reducing_join_left_node_statement() -> SqlSelectStatementNode: column_alias="bookings", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlStringExpression(sql_expr="1", requires_parenthesis=False, used_columns=()), + expr=SqlStringExpression.create( + sql_expr="1", requires_parenthesis=False, used_columns=() + ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") ), from_source_alias="src0", @@ -975,33 +993,35 @@ def test_rewriting_distinct_select_node_is_not_reduced( mf_test_configuration: MetricFlowTestConfiguration, ) -> None: """Tests to ensure distinct select node doesn't get overwritten.""" - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="test1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", distinct=True, ), diff --git a/tests_metricflow/sql/optimizer/test_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_sub_query_reducer.py index ce0255beac..dc2bc157ee 100644 --- a/tests_metricflow/sql/optimizer/test_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_sub_query_reducer.py @@ -46,39 +46,43 @@ def base_select_statement() -> SqlSelectStatementNode: ) src2 ORDER BY src2.col0 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src3", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col0")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col0") + ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col1")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col1") + ), column_alias="col1", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="col0") ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="col1") ), column_alias="col1", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="src1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col0", @@ -87,7 +91,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="src0", column_name="col1", @@ -96,7 +100,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col1", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="from_source_table") ), from_source_alias="src0", @@ -108,7 +112,7 @@ def base_select_statement() -> SqlSelectStatementNode: from_source_alias="src2", order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="col0", @@ -163,47 +167,53 @@ def rewrite_order_by_statement() -> SqlSelectStatementNode: ORDER BY src2.col1 """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="src3", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col0")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col0") + ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="src2", column_name="col1")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="src2", column_name="col1") + ), column_alias="col1", ), ), from_source=( - SqlSelectStatementNode( + SqlSelectStatementNode.create( description="src2", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src0", column_name="col0") ), column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="col1") ), column_alias="col1", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src0")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="src0")), from_source_alias="src0", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="src1")), + right_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="src1") + ), right_source_alias="src1", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src0", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="src1", column_name="join_col") ), ), @@ -215,7 +225,7 @@ def rewrite_order_by_statement() -> SqlSelectStatementNode: from_source_alias="src2", order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( SqlColumnReference( table_alias="src2", column_name="col1", @@ -255,33 +265,35 @@ def test_distinct_select_node_is_not_reduced( mf_test_configuration: MetricFlowTestConfiguration, ) -> None: """Tests to ensure distinct select node doesn't get overwritten.""" - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="test1", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", distinct=True, ), diff --git a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py index 6dff3fcd3d..d603054f38 100644 --- a/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py +++ b/tests_metricflow/sql/optimizer/test_table_alias_simplifier.py @@ -53,27 +53,27 @@ def base_select_statement() -> SqlSelectStatementNode: ON from_source.join_col = joined_source.join_col """ - return SqlSelectStatementNode( + return SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="col0") ), column_alias="from_source_col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="col0") ), column_alias="joined_source_col0", ), ), - from_source=SqlSelectStatementNode( + from_source=SqlSelectStatementNode.create( description="from_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="col0", @@ -82,7 +82,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="from_source_table", column_name="join_col", @@ -91,17 +91,19 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="from_source_table")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="from_source_table") + ), from_source_alias="from_source_table", ), from_source_alias="from_source", - joins_descs=( + join_descs=( SqlJoinDescription( - right_source=SqlSelectStatementNode( + right_source=SqlSelectStatementNode.create( description="joined_source", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="col0", @@ -110,7 +112,7 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="col0", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference( table_alias="joined_source_table", column_name="join_col", @@ -119,18 +121,18 @@ def base_select_statement() -> SqlSelectStatementNode: column_alias="join_col", ), ), - from_source=SqlTableFromClauseNode( + from_source=SqlTableFromClauseNode.create( sql_table=SqlTable(schema_name="demo", table_name="joined_source_table") ), from_source_alias="joined_source_table", ), right_source_alias="joined_source", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="from_source", column_name="join_col") ), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression( + right_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="joined_source", column_name="join_col") ), ), diff --git a/tests_metricflow/sql/test_engine_specific_rendering.py b/tests_metricflow/sql/test_engine_specific_rendering.py index 71b0383620..6ba8f8fc15 100644 --- a/tests_metricflow/sql/test_engine_specific_rendering.py +++ b/tests_metricflow/sql/test_engine_specific_rendering.py @@ -37,8 +37,8 @@ def test_cast_to_timestamp( """Tests rendering of the cast to timestamp expression in a query.""" select_columns = [ SqlSelectColumn( - expr=SqlCastToTimestampExpression( - arg=SqlStringLiteralExpression( + expr=SqlCastToTimestampExpression.create( + arg=SqlStringLiteralExpression.create( literal_value="2020-01-01", ) ), @@ -46,7 +46,7 @@ def test_cast_to_timestamp( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -56,12 +56,12 @@ def test_cast_to_timestamp( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Cast to Timestamp Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -80,17 +80,17 @@ def test_generate_uuid( """Tests rendering of the generate uuid expression in a query.""" select_columns = [ SqlSelectColumn( - expr=SqlGenerateUuidExpression(), + expr=SqlGenerateUuidExpression.create(), column_alias="uuid", ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Generate UUID Expression", select_columns=tuple(select_columns), from_source=from_source, @@ -115,8 +115,8 @@ def test_continuous_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.CONTINUOUS ), @@ -125,7 +125,7 @@ def test_continuous_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -135,12 +135,12 @@ def test_continuous_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Continuous Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -164,8 +164,8 @@ def test_discrete_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.DISCRETE ), @@ -174,7 +174,7 @@ def test_discrete_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -184,12 +184,12 @@ def test_discrete_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Discrete Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -213,8 +213,8 @@ def test_approximate_continuous_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS ), @@ -223,7 +223,7 @@ def test_approximate_continuous_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -233,12 +233,12 @@ def test_approximate_continuous_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Approximate Continuous Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -262,8 +262,8 @@ def test_approximate_discrete_percentile_expr( select_columns = [ SqlSelectColumn( - expr=SqlPercentileExpression( - order_by_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlPercentileExpression.create( + order_by_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), percentile_args=SqlPercentileExpressionArgument( percentile=0.5, function_type=SqlPercentileFunctionType.APPROXIMATE_DISCRETE ), @@ -272,7 +272,7 @@ def test_approximate_discrete_percentile_expr( ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) from_source_alias = "a" joins_descs: List[SqlJoinDescription] = [] where = None @@ -282,12 +282,12 @@ def test_approximate_discrete_percentile_expr( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="Test Approximate Discrete Percentile Expression", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), diff --git a/tests_metricflow/sql/test_sql_expr_render.py b/tests_metricflow/sql/test_sql_expr_render.py index 7e6ba4cb6e..05b4103ea2 100644 --- a/tests_metricflow/sql/test_sql_expr_render.py +++ b/tests_metricflow/sql/test_sql_expr_render.py @@ -45,14 +45,14 @@ def default_expr_renderer() -> DefaultSqlExpressionRenderer: # noqa: D103 def test_str_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - actual = default_expr_renderer.render_sql_expr(SqlStringExpression("a + b")).sql + actual = default_expr_renderer.render_sql_expr(SqlStringExpression.create("a + b")).sql expected = "a + b" assert actual == expected def test_col_ref_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlColumnReferenceExpression(SqlColumnReference("my_table", "my_col")) + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "my_col")) ).sql expected = "my_table.my_col" assert actual == expected @@ -60,10 +60,10 @@ def test_col_ref_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> No def test_comparison_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("my_table", "my_col")), + SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "my_col")), comparison=SqlComparison.EQUALS, - right_expr=SqlStringExpression("a + b"), + right_expr=SqlStringExpression.create("a + b"), ) ).sql assert actual == "my_table.my_col = (a + b)" @@ -71,10 +71,10 @@ def test_comparison_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> def test_require_parenthesis(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("a", "booking_value")), + SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "booking_value")), comparison=SqlComparison.GREATER_THAN, - right_expr=SqlStringExpression("100", requires_parenthesis=False), + right_expr=SqlStringExpression.create("100", requires_parenthesis=False), ) ).sql @@ -83,11 +83,11 @@ def test_require_parenthesis(default_expr_renderer: DefaultSqlExpressionRenderer def test_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlAggregateFunctionExpression( + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.SUM, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "a")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "b")), ], ) ).sql @@ -97,11 +97,11 @@ def test_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> N def test_distinct_agg_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: """Distinct aggregation functions require the insertion of the DISTINCT keyword in the rendered function expr.""" actual = default_expr_renderer.render_sql_expr( - SqlAggregateFunctionExpression( + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.COUNT_DISTINCT, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "a")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "b")), ], ) ).sql @@ -111,15 +111,15 @@ def test_distinct_agg_expr(default_expr_renderer: DefaultSqlExpressionRenderer) def test_nested_function_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlAggregateFunctionExpression( + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.CONCAT, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "a")), - SqlAggregateFunctionExpression( + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "a")), + SqlAggregateFunctionExpression.create( sql_function=SqlFunction.CONCAT, sql_function_args=[ - SqlColumnReferenceExpression(SqlColumnReference("my_table", "b")), - SqlColumnReferenceExpression(SqlColumnReference("my_table", "c")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "b")), + SqlColumnReferenceExpression.create(SqlColumnReference("my_table", "c")), ], ), ], @@ -129,17 +129,17 @@ def test_nested_function_expr(default_expr_renderer: DefaultSqlExpressionRendere def test_null_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - actual = default_expr_renderer.render_sql_expr(SqlNullExpression()).sql + actual = default_expr_renderer.render_sql_expr(SqlNullExpression.create()).sql assert actual == "NULL" def test_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlLogicalExpression( + SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=( - SqlStringExpression("1 < 2", requires_parenthesis=True), - SqlStringExpression("foo", requires_parenthesis=False), + SqlStringExpression.create("1 < 2", requires_parenthesis=True), + SqlStringExpression.create("foo", requires_parenthesis=False), ), ) ).sql @@ -155,12 +155,12 @@ def test_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: def test_long_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlLogicalExpression( + SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=( - SqlStringExpression("some_long_expression1"), - SqlStringExpression("some_long_expression2"), - SqlStringExpression("some_long_expression3"), + SqlStringExpression.create("some_long_expression1"), + SqlStringExpression.create("some_long_expression2"), + SqlStringExpression.create("some_long_expression3"), ), ) ).sql @@ -181,56 +181,59 @@ def test_long_and_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> N def test_string_literal_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - actual = default_expr_renderer.render_sql_expr(SqlStringLiteralExpression("foo")).sql + actual = default_expr_renderer.render_sql_expr(SqlStringLiteralExpression.create("foo")).sql assert actual == "'foo'" def test_is_null_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlIsNullExpression(SqlStringExpression("foo", requires_parenthesis=False)) + SqlIsNullExpression.create(SqlStringExpression.create("foo", requires_parenthesis=False)) ).sql assert actual == "foo IS NULL" def test_date_trunc_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlDateTruncExpression(time_granularity=TimeGranularity.MONTH, arg=SqlStringExpression("ds")) + SqlDateTruncExpression.create(time_granularity=TimeGranularity.MONTH, arg=SqlStringExpression.create("ds")) ).sql assert actual == "DATE_TRUNC('month', ds)" def test_extract_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlExtractExpression(date_part=DatePart.DOY, arg=SqlStringExpression("ds")) + SqlExtractExpression.create(date_part=DatePart.DOY, arg=SqlStringExpression.create("ds")) ).sql assert actual == "EXTRACT(doy FROM ds)" def test_ratio_computation_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlRatioComputationExpression( - numerator=SqlAggregateFunctionExpression( - SqlFunction.SUM, sql_function_args=[SqlStringExpression(sql_expr="1", requires_parenthesis=False)] + SqlRatioComputationExpression.create( + numerator=SqlAggregateFunctionExpression.create( + SqlFunction.SUM, + sql_function_args=[SqlStringExpression.create(sql_expr="1", requires_parenthesis=False)], + ), + denominator=SqlColumnReferenceExpression.create( + SqlColumnReference(column_name="divide_by_me", table_alias="a") ), - denominator=SqlColumnReferenceExpression(SqlColumnReference(column_name="divide_by_me", table_alias="a")), ), ).sql assert actual == "CAST(SUM(1) AS DOUBLE) / CAST(NULLIF(a.divide_by_me, 0) AS DOUBLE)" def test_expr_rewrite(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 - expr = SqlLogicalExpression( + expr = SqlLogicalExpression.create( operator=SqlLogicalOperator.AND, args=( - SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), - SqlColumnReferenceExpression(SqlColumnReference("a", "col1")), + SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), + SqlColumnReferenceExpression.create(SqlColumnReference("a", "col1")), ), ) column_replacements = SqlColumnReplacements( { - SqlColumnReference("a", "col0"): SqlStringExpression("foo", requires_parenthesis=False), - SqlColumnReference("a", "col1"): SqlStringExpression("bar", requires_parenthesis=False), + SqlColumnReference("a", "col0"): SqlStringExpression.create("foo", requires_parenthesis=False), + SqlColumnReference("a", "col1"): SqlStringExpression.create("bar", requires_parenthesis=False), } ) expr_rewritten = expr.rewrite(column_replacements) @@ -239,15 +242,15 @@ def test_expr_rewrite(default_expr_renderer: DefaultSqlExpressionRenderer) -> No def test_between_expr(default_expr_renderer: DefaultSqlExpressionRenderer) -> None: # noqa: D103 actual = default_expr_renderer.render_sql_expr( - SqlBetweenExpression( - column_arg=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), - start_expr=SqlCastToTimestampExpression( - arg=SqlStringLiteralExpression( + SqlBetweenExpression.create( + column_arg=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), + start_expr=SqlCastToTimestampExpression.create( + arg=SqlStringLiteralExpression.create( literal_value="2020-01-01", ) ), - end_expr=SqlCastToTimestampExpression( - arg=SqlStringLiteralExpression( + end_expr=SqlCastToTimestampExpression.create( + arg=SqlStringLiteralExpression.create( literal_value="2020-01-10", ) ), @@ -262,17 +265,17 @@ def test_window_function_expr( # noqa: D103 default_expr_renderer: DefaultSqlExpressionRenderer, ) -> None: partition_by_args = ( - SqlColumnReferenceExpression(SqlColumnReference("b", "col0")), - SqlColumnReferenceExpression(SqlColumnReference("b", "col1")), + SqlColumnReferenceExpression.create(SqlColumnReference("b", "col0")), + SqlColumnReferenceExpression.create(SqlColumnReference("b", "col1")), ) order_by_args = ( SqlWindowOrderByArgument( - expr=SqlColumnReferenceExpression(SqlColumnReference("a", "col0")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0")), descending=True, nulls_last=False, ), SqlWindowOrderByArgument( - expr=SqlColumnReferenceExpression(SqlColumnReference("b", "col0")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "col0")), descending=False, nulls_last=True, ), @@ -284,9 +287,9 @@ def test_window_function_expr( # noqa: D103 rendered_sql_lines.append(f"-- Window function with {num_partition_by_args} PARTITION BY items(s)") rendered_sql_lines.append( default_expr_renderer.render_sql_expr( - SqlWindowFunctionExpression( + SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, - sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + sql_function_args=[SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0"))], partition_by_args=partition_by_args[:num_partition_by_args], order_by_args=(), ) @@ -298,9 +301,9 @@ def test_window_function_expr( # noqa: D103 rendered_sql_lines.append(f"-- Window function with {num_order_by_args} ORDER BY items(s)") rendered_sql_lines.append( default_expr_renderer.render_sql_expr( - SqlWindowFunctionExpression( + SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, - sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + sql_function_args=[SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0"))], partition_by_args=(), order_by_args=order_by_args[:num_order_by_args], ) @@ -311,9 +314,9 @@ def test_window_function_expr( # noqa: D103 rendered_sql_lines.append("-- Window function with PARTITION BY and ORDER BY items") rendered_sql_lines.append( default_expr_renderer.render_sql_expr( - SqlWindowFunctionExpression( + SqlWindowFunctionExpression.create( sql_function=SqlWindowFunction.FIRST_VALUE, - sql_function_args=[SqlColumnReferenceExpression(SqlColumnReference("a", "col0"))], + sql_function_args=[SqlColumnReferenceExpression.create(SqlColumnReference("a", "col0"))], partition_by_args=partition_by_args, order_by_args=order_by_args, ) diff --git a/tests_metricflow/sql/test_sql_plan_render.py b/tests_metricflow/sql/test_sql_plan_render.py index f844cee00d..14dd164a89 100644 --- a/tests_metricflow/sql/test_sql_plan_render.py +++ b/tests_metricflow/sql/test_sql_plan_render.py @@ -42,14 +42,14 @@ def test_component_rendering( # Test single SELECT column select_columns = [ SqlSelectColumn( - expr=SqlAggregateFunctionExpression( - sql_function=SqlFunction.SUM, sql_function_args=[SqlStringExpression("1")] + expr=SqlAggregateFunctionExpression.create( + sql_function=SqlFunction.SUM, sql_function_args=[SqlStringExpression.create("1")] ), column_alias="bookings", ), ] - from_source = SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")) + from_source = SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")) from_source = from_source from_source_alias = "a" @@ -61,12 +61,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -79,11 +79,11 @@ def test_component_rendering( select_columns.extend( [ SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("b", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "country")), column_alias="user__country", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("c", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("c", "country")), column_alias="listing__country", ), ] @@ -92,12 +92,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test1", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -109,12 +109,12 @@ def test_component_rendering( # Test single join joins_descs.append( SqlJoinDescription( - right_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="dim_users")), + right_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="dim_users")), right_source_alias="b", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("a", "user_id")), + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "user_id")), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression(SqlColumnReference("b", "user_id")), + right_expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "user_id")), ), join_type=SqlJoinType.LEFT_OUTER, ) @@ -123,12 +123,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test2", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -140,12 +140,14 @@ def test_component_rendering( # Test multiple join joins_descs.append( SqlJoinDescription( - right_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="dim_listings")), + right_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="dim_listings") + ), right_source_alias="c", - on_condition=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression(SqlColumnReference("a", "user_id")), + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create(SqlColumnReference("a", "user_id")), comparison=SqlComparison.EQUALS, - right_expr=SqlColumnReferenceExpression(SqlColumnReference("c", "user_id")), + right_expr=SqlColumnReferenceExpression.create(SqlColumnReference("c", "user_id")), ), join_type=SqlJoinType.LEFT_OUTER, ) @@ -154,12 +156,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test3", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -171,7 +173,7 @@ def test_component_rendering( # Test single group by group_bys.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("b", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("b", "country")), column_alias="user__country", ), ) @@ -179,12 +181,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test4", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -196,7 +198,7 @@ def test_component_rendering( # Test multiple group bys group_bys.append( SqlSelectColumn( - expr=SqlColumnReferenceExpression(SqlColumnReference("c", "country")), + expr=SqlColumnReferenceExpression.create(SqlColumnReference("c", "country")), column_alias="listing__country", ), ) @@ -204,12 +206,12 @@ def test_component_rendering( assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test5", select_columns=tuple(select_columns), from_source=from_source, from_source_alias=from_source_alias, - joins_descs=tuple(joins_descs), + join_descs=tuple(joins_descs), where=where, group_bys=tuple(group_bys), order_bys=tuple(order_bys), @@ -228,24 +230,26 @@ def test_render_where( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", - where=SqlComparisonExpression( - left_expr=SqlColumnReferenceExpression( + where=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), comparison=SqlComparison.GREATER_THAN, - right_expr=SqlStringExpression( + right_expr=SqlStringExpression.create( sql_expr="100", requires_parenthesis=False, used_columns=(), @@ -266,33 +270,35 @@ def test_render_order_by( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), column_alias="booking_value", ), SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", order_bys=( SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="booking_value") ), desc=False, ), SqlOrderByDescription( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), desc=True, @@ -313,17 +319,19 @@ def test_render_limit( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlSelectStatementNode( + sql_plan_node=SqlSelectStatementNode.create( description="test0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression( + expr=SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias="a", column_name="bookings") ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create( + sql_table=SqlTable(schema_name="demo", table_name="fct_bookings") + ), from_source_alias="a", limit=1, ), @@ -338,22 +346,24 @@ def test_render_create_table_as( # noqa: D103 mf_test_configuration: MetricFlowTestConfiguration, sql_client: SqlClient, ) -> None: - select_node = SqlSelectStatementNode( + select_node = SqlSelectStatementNode.create( description="select_0", select_columns=( SqlSelectColumn( - expr=SqlColumnReferenceExpression(col_ref=SqlColumnReference(table_alias="a", column_name="bookings")), + expr=SqlColumnReferenceExpression.create( + col_ref=SqlColumnReference(table_alias="a", column_name="bookings") + ), column_alias="bookings", ), ), - from_source=SqlTableFromClauseNode(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), + from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")), from_source_alias="a", limit=1, ) assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlCreateTableAsNode( + sql_plan_node=SqlCreateTableAsNode.create( sql_table=SqlTable( schema_name="schema_name", table_name="table_name", @@ -367,7 +377,7 @@ def test_render_create_table_as( # noqa: D103 assert_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration, - sql_plan_node=SqlCreateTableAsNode( + sql_plan_node=SqlCreateTableAsNode.create( sql_table=SqlTable( schema_name="schema_name", table_name="table_name", diff --git a/tests_metricflow/sql_clients/test_date_time_operations.py b/tests_metricflow/sql_clients/test_date_time_operations.py index a99b4e58bd..0370ade60e 100644 --- a/tests_metricflow/sql_clients/test_date_time_operations.py +++ b/tests_metricflow/sql_clients/test_date_time_operations.py @@ -44,8 +44,8 @@ def _extract_data_table_value(df: MetricFlowDataTable) -> Any: # type: ignore[m def _build_date_trunc_expression(date_string: str, time_granularity: TimeGranularity) -> SqlDateTruncExpression: - cast_expr = SqlCastToTimestampExpression(SqlStringLiteralExpression(literal_value=date_string)) - return SqlDateTruncExpression(time_granularity=time_granularity, arg=cast_expr) + cast_expr = SqlCastToTimestampExpression.create(SqlStringLiteralExpression.create(literal_value=date_string)) + return SqlDateTruncExpression.create(time_granularity=time_granularity, arg=cast_expr) def test_date_trunc_to_year(sql_client: SqlClient) -> None: @@ -118,8 +118,8 @@ def test_date_trunc_to_week(sql_client: SqlClient, input: str, expected: datetim def _build_extract_expression(date_string: str, date_part: DatePart) -> SqlExtractExpression: - cast_expr = SqlCastToTimestampExpression(SqlStringLiteralExpression(literal_value=date_string)) - return SqlExtractExpression(date_part=date_part, arg=cast_expr) + cast_expr = SqlCastToTimestampExpression.create(SqlStringLiteralExpression.create(literal_value=date_string)) + return SqlExtractExpression.create(date_part=date_part, arg=cast_expr) def test_date_part_year(sql_client: SqlClient) -> None: