diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index db4ccac9b6..fd46f4af39 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -190,7 +190,7 @@ def _render_from_section(self, from_source: SqlQueryPlanNode, from_source_alias: from_render_result = self._render_node(from_source) from_section_lines = [] - if from_source.is_table: + if from_source.as_sql_table_node is not None: from_section_lines.append(f"FROM {from_render_result.sql} {from_source_alias}") else: from_section_lines.append("FROM (") @@ -228,7 +228,7 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription]) on_condition_rendered = self.EXPR_RENDERER.render_sql_expr(join_description.on_condition) params = params.merge(on_condition_rendered.bind_parameter_set) - if join_description.right_source.is_table: + if join_description.right_source.as_sql_table_node is not None: join_section_lines.append(join_description.join_type.value) join_section_lines.append( textwrap.indent( diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 18f2e2839f..eb2e4137f7 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -41,14 +41,14 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut @property @abstractmethod - def is_table(self) -> bool: - """Returns whether this node resolves to a table (vs. a query).""" + def as_select_node(self) -> Optional[SqlSelectStatementNode]: + """If possible, return this as a select statement node.""" raise NotImplementedError @property @abstractmethod - def as_select_node(self) -> Optional[SqlSelectStatementNode]: - """If possible, return this as a select statement node.""" + def as_sql_table_node(self) -> Optional[SqlTableNode]: + """If possible, return this as SQL table node.""" raise NotImplementedError @abstractmethod @@ -208,14 +208,15 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_select_statement_node(self) - @property - def is_table(self) -> bool: # noqa: D102 - return False - @property def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return self + @property + @override + def as_sql_table_node(self) -> Optional[SqlTableNode]: + return None + @property @override def description(self) -> str: @@ -271,10 +272,6 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_table_node(self) - @property - def is_table(self) -> bool: # noqa: D102 - return True - @property def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None @@ -289,6 +286,11 @@ def nearest_select_columns( return cte_node.nearest_select_columns(cte_source_mapping) return None + @property + @override + def as_sql_table_node(self) -> Optional[SqlTableNode]: + return self + @dataclass(frozen=True, eq=False) class SqlSelectQueryFromClauseNode(SqlQueryPlanNode): @@ -318,10 +320,6 @@ def description(self) -> str: # noqa: D102 def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 return visitor.visit_query_from_clause_node(self) - @property - def is_table(self) -> bool: # noqa: D102 - return False - @property def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None @@ -332,6 +330,11 @@ def nearest_select_columns( ) -> Optional[Sequence[SqlSelectColumn]]: return None + @property + @override + def as_sql_table_node(self) -> Optional[SqlTableNode]: + return None + @dataclass(frozen=True, eq=False) class SqlCreateTableAsNode(SqlQueryPlanNode): @@ -339,7 +342,6 @@ class SqlCreateTableAsNode(SqlQueryPlanNode): Attributes: sql_table: The SQL table to create. - parent_node: The parent query plan node. """ sql_table: SqlTable @@ -361,12 +363,12 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut @property @override - def is_table(self) -> bool: - return False + def as_select_node(self) -> Optional[SqlSelectStatementNode]: + return None @property @override - def as_select_node(self) -> Optional[SqlSelectStatementNode]: + def as_sql_table_node(self) -> Optional[SqlTableNode]: return None @property @@ -415,7 +417,7 @@ def render_node(self) -> SqlQueryPlanNode: # noqa: D102 class SqlCteNode(SqlQueryPlanNode): """Represents a single common table expression.""" - select_statement: SqlSelectStatementNode + select_statement: SqlQueryPlanNode cte_alias: str def __post_init__(self) -> None: # noqa: D105 @@ -423,7 +425,7 @@ def __post_init__(self) -> None: # noqa: D105 assert len(self.parent_nodes) == 1 @staticmethod - def create(select_statement: SqlSelectStatementNode, cte_alias: str) -> SqlCteNode: # noqa: D102 + def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: # noqa: D102 return SqlCteNode( parent_nodes=(select_statement,), select_statement=select_statement, @@ -436,12 +438,12 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut @property @override - def is_table(self) -> bool: - return False + def as_select_node(self) -> Optional[SqlSelectStatementNode]: + return None @property @override - def as_select_node(self) -> Optional[SqlSelectStatementNode]: + def as_sql_table_node(self) -> Optional[SqlTableNode]: return None @property