Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeSpineSource additions & rename SqlTableFromClauseNode -> SqlTableNode #1399

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Dict, Optional, Sequence

from dbt_semantic_interfaces.implementations.time_spine import PydanticTimeSpineCustomGranularityColumn
from dbt_semantic_interfaces.protocols import SemanticManifest
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

Expand All @@ -29,7 +30,7 @@ class TimeSpineSource:
# The time granularity of the base column.
base_granularity: TimeGranularity = DEFAULT_TIME_GRANULARITY
db_name: Optional[str] = None
custom_granularities: Sequence[str] = ()
custom_granularities: Sequence[PydanticTimeSpineCustomGranularityColumn] = ()

@property
def spine_table(self) -> SqlTable:
Expand All @@ -48,7 +49,14 @@ def build_standard_time_spine_sources(
db_name=time_spine.node_relation.database,
base_column=time_spine.primary_column.name,
base_granularity=time_spine.primary_column.time_granularity,
custom_granularities=[column.name for column in time_spine.custom_granularities],
custom_granularities=tuple(
[
PydanticTimeSpineCustomGranularityColumn(
name=custom_granularity.name, column_name=custom_granularity.column_name
)
for custom_granularity in time_spine.custom_granularities
]
),
)
for time_spine in semantic_manifest.project_configuration.time_spines
}
Expand All @@ -74,3 +82,12 @@ def build_standard_time_spine_sources(
)

return time_spine_sources

@staticmethod
def build_custom_time_spine_sources(time_spine_sources: Sequence[TimeSpineSource]) -> Dict[str, TimeSpineSource]:
"""Creates a set of time spine sources with custom granularities based on what's in the manifest."""
return {
custom_granularity.name: time_spine_source
for time_spine_source in time_spine_sources
for custom_granularity in time_spine_source.custom_granularities
}
8 changes: 3 additions & 5 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from metricflow.sql.sql_plan import (
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -491,9 +491,7 @@ def create_sql_source_data_set(self, semantic_model: SemanticModel) -> SemanticM
all_select_columns.extend(select_columns)

# Generate the "from" clause depending on whether it's an SQL query or an SQL table.
from_source = SqlTableFromClauseNode.create(
sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name)
)
from_source = SqlTableNode.create(sql_table=SqlTable.from_string(semantic_model.node_relation.relation_name))

select_statement_node = SqlSelectStatementNode.create(
description=f"Read Elements From Semantic Model '{semantic_model.name}'",
Expand Down Expand Up @@ -552,7 +550,7 @@ def build_time_spine_source_data_set(self, time_spine_source: TimeSpineSource) -
sql_select_node=SqlSelectStatementNode.create(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=tuple(select_columns),
from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table),
from_source=SqlTableNode.create(sql_table=time_spine_source.spine_table),
from_source_alias=from_source_alias,
),
)
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
SqlQueryPlanNode,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -312,7 +312,7 @@ def _make_time_spine_data_set(
sql_select_node=SqlSelectStatementNode.create(
description=TIME_SPINE_DATA_SET_DESCRIPTION,
select_columns=select_columns,
from_source=SqlTableFromClauseNode.create(sql_table=time_spine_source.spine_table),
from_source=SqlTableNode.create(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
group_bys=select_columns if apply_group_by else (),
where=(
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -191,8 +191,8 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode:
"""This node is effectively a FROM statement inside a SELECT statement node, so pruning cannot apply."""
def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode:
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -700,7 +700,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
Expand Down Expand Up @@ -764,7 +764,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
Expand Down
4 changes: 2 additions & 2 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SqlQueryPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,7 +188,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
Expand Down
4 changes: 2 additions & 2 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,7 +68,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode: # noqa: D102
Expand Down
4 changes: 2 additions & 2 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -301,7 +301,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanRe
bind_parameters=combined_params,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlPlanRenderResult: # noqa: D102
def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa: D102
return SqlPlanRenderResult(
sql=node.sql_table.sql,
bind_parameters=SqlBindParameters(),
Expand Down
14 changes: 7 additions & 7 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SqlQueryPlanNode(DagNode["SqlQueryPlanNode"], ABC):
* Statements like ALTER TABLE don't fit well, but they could be modeled as just a single sink node.
* SQL queries in where conditions could be modeled as another SqlQueryPlan.
* SqlRenderableNode() indicates nodes where plan generation can begin. Generally, this will be all nodes except
the SqlTableFromClauseNode() since my_table.my_column wouldn't be a valid SQL query.
the SqlTableNode() since my_table.my_column wouldn't be a valid SQL query.

Is there an existing library that can do this?
"""
Expand Down Expand Up @@ -63,7 +63,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> VisitorOu
raise NotImplementedError

@abstractmethod
def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> VisitorOutputT: # noqa: D102
def visit_table_node(self, node: SqlTableNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -191,14 +191,14 @@ def description(self) -> str:


@dataclass(frozen=True)
class SqlTableFromClauseNode(SqlQueryPlanNode):
"""An SQL table that can go in the FROM clause."""
class SqlTableNode(SqlQueryPlanNode):
"""An SQL table that can go in the FROM clause or the JOIN clause."""

sql_table: SqlTable

@staticmethod
def create(sql_table: SqlTable) -> SqlTableFromClauseNode: # noqa: D102
return SqlTableFromClauseNode(
def create(sql_table: SqlTable) -> SqlTableNode: # noqa: D102
return SqlTableNode(
parent_nodes=(),
sql_table=sql_table,
)
Expand All @@ -216,7 +216,7 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("table_id", self.sql_table.sql),)

def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_table_from_clause_node(self)
return visitor.visit_table_node(self)

@property
def is_table(self) -> bool: # noqa: D102
Expand Down
6 changes: 2 additions & 4 deletions tests_metricflow/dataflow/builder/test_node_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from metricflow.sql.sql_plan import (
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)
from tests_metricflow.fixtures.manifest_fixtures import MetricFlowEngineTestFixture, SemanticManifestSetup

Expand Down Expand Up @@ -77,9 +77,7 @@ def test_no_parent_node_data_set(
column_alias="bookings",
),
),
from_source=SqlTableFromClauseNode.create(
sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name="demo", table_name="fct_bookings")),
from_source_alias="src",
),
)
Expand Down
12 changes: 6 additions & 6 deletions tests_metricflow/mf_logging/test_dag_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SqlQueryPlan,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableFromClauseNode,
SqlTableNode,
)

logger = logging.getLogger(__name__)
Expand All @@ -41,7 +41,7 @@ def test_multithread_dag_to_text() -> None:
column_alias="bar",
),
),
from_source=SqlTableFromClauseNode.create(sql_table=SqlTable(schema_name="schema", table_name="table")),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name="schema", table_name="table")),
from_source_alias="src",
),
)
Expand Down Expand Up @@ -80,13 +80,13 @@ def _run_mf_pformat() -> None:
<!-- expr=SqlStringExpression(node_id=str_0 sql_expr='foo'), -->
<!-- column_alias='bar', -->
<!-- ) -->
<!-- from_source = -->
<!-- SqlTableFromClauseNode(node_id=tfc_0) -->
<!-- from_source = -->
<!-- SqlTableNode(node_id=tfc_0) -->
<!-- where = -->
<!-- None -->
<!-- distinct = -->
<!-- False -->
<SqlTableFromClauseNode>
<SqlTableNode>
<!-- description = -->
<!-- ('Read ' -->
<!-- 'from ' -->
Expand All @@ -97,7 +97,7 @@ def _run_mf_pformat() -> None:
<!-- ) -->
<!-- table_id = -->
<!-- 'schema.table' -->
</SqlTableFromClauseNode>
</SqlTableNode>
</SqlSelectStatementNode>
</SqlQueryPlan>
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,14 @@
<!-- expr=SqlColumnReferenceExpression(node_id=cr_28121), -->
<!-- column_alias='visit__session', -->
<!-- ) -->
<!-- from_source = SqlTableFromClauseNode(node_id=tfc_28011) -->
<!-- from_source = SqlTableNode(node_id=tfc_28011) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlTableFromClauseNode>
<SqlTableNode>
<!-- description = 'Read from ***************************.fct_visits' -->
<!-- node_id = NodeId(id_str='tfc_28011') -->
<!-- table_id = '***************************.fct_visits' -->
</SqlTableFromClauseNode>
</SqlTableNode>
</SqlSelectStatementNode>
</SqlSelectStatementNode>
</SqlSelectStatementNode>
Expand Down Expand Up @@ -842,14 +842,14 @@
<!-- expr=SqlColumnReferenceExpression(node_id=cr_28121), -->
<!-- column_alias='visit__session', -->
<!-- ) -->
<!-- from_source = SqlTableFromClauseNode(node_id=tfc_28011) -->
<!-- from_source = SqlTableNode(node_id=tfc_28011) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlTableFromClauseNode>
<SqlTableNode>
<!-- description = 'Read from ***************************.fct_visits' -->
<!-- node_id = NodeId(id_str='tfc_28011') -->
<!-- table_id = '***************************.fct_visits' -->
</SqlTableFromClauseNode>
</SqlTableNode>
</SqlSelectStatementNode>
</SqlSelectStatementNode>
</SqlSelectStatementNode>
Expand Down Expand Up @@ -1403,14 +1403,14 @@
<!-- expr=SqlColumnReferenceExpression(node_id=cr_28043), -->
<!-- column_alias='buy__session_id', -->
<!-- ) -->
<!-- from_source = SqlTableFromClauseNode(node_id=tfc_28002) -->
<!-- from_source = SqlTableNode(node_id=tfc_28002) -->
<!-- where = None -->
<!-- distinct = False -->
<SqlTableFromClauseNode>
<SqlTableNode>
<!-- description = 'Read from ***************************.fct_buys' -->
<!-- node_id = NodeId(id_str='tfc_28002') -->
<!-- table_id = '***************************.fct_buys' -->
</SqlTableFromClauseNode>
</SqlTableNode>
</SqlSelectStatementNode>
</SqlSelectStatementNode>
</SqlSelectStatementNode>
Expand Down
Loading
Loading