diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 48e2212461..301d02e539 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -1438,23 +1438,32 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod if instance.spec == node.time_dimension_spec.with_base_grain(): parent_time_dimension_instance = instance break + parent_column: Optional[SqlSelectColumn] = None assert parent_time_dimension_instance, ( "JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. " - "This indicates internal misconfiguration." + f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain}; " + f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}" + ) + for select_column in parent_data_set.checked_sql_select_node.select_columns: + if select_column.column_alias == parent_time_dimension_instance.associated_column.column_name: + parent_column = select_column + break + assert parent_column, ( + "JoinToCustomGranularityNode's expected time_dimension_spec not found in parent columns. " + f"This indicates internal misconfiguration. Expected: " + f"{parent_time_dimension_instance.associated_column.column_name}; Got: " + f"{[column.column_alias for column in parent_data_set.checked_sql_select_node.select_columns]}" ) # Build join expression. time_spine_alias = self._next_unique_table_alias() custom_granularity_name = node.time_dimension_spec.time_granularity.name time_spine_source = self._get_time_spine_for_custom_granularity(custom_granularity_name) - left_expr_for_join: SqlExpressionNode = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=parent_alias, column_name=parent_time_dimension_instance.associated_column.column_name - ) join_description = SqlJoinDescription( right_source=SqlTableNode.create(sql_table=time_spine_source.spine_table), right_source_alias=time_spine_alias, on_condition=SqlComparisonExpression.create( - left_expr=left_expr_for_join, + left_expr=parent_column.expr, comparison=SqlComparison.EQUALS, right_expr=SqlColumnReferenceExpression.from_table_and_column_names( table_alias=time_spine_alias, column_name=time_spine_source.base_column diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index d8a4cf5340..2f44a1277f 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -122,6 +122,8 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP or select_column in node.group_bys or node.distinct ) + # TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any + # SqlExpressionNode. if len(pruned_select_columns) == 0: raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.")