From a6533efd0009cd9231de0c68b9bae18521c0f554 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Wed, 16 Oct 2024 11:31:24 -0700 Subject: [PATCH] Make `SqlTable.schema_name` optional (#1458) This PR makes `SqlTable.schema_name` optional as `SqlTable` can later be used to reference a CTE, which does not have a schema name. --- .../metricflow_semantics/sql/sql_table.py | 28 ++++++++++--------- .../time/time_spine_source.py | 2 +- .../sql/test_sql_table.py | 13 +++++++++ metricflow/sql/sql_column.py | 2 +- 4 files changed, 30 insertions(+), 15 deletions(-) create mode 100644 metricflow-semantics/tests_metricflow_semantics/sql/test_sql_table.py diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_table.py b/metricflow-semantics/metricflow_semantics/sql/sql_table.py index 78e2a66d1a..bedd9afe83 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_table.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_table.py @@ -2,23 +2,28 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Union +from typing import Optional class SqlTableType(Enum): # noqa: D101 TABLE = "table" VIEW = "view" + # CTE type may be added later. @dataclass(frozen=True, order=True) class SqlTable: """Represents a reference to a SQL table.""" - schema_name: str + schema_name: Optional[str] table_name: str db_name: Optional[str] = None table_type: SqlTableType = SqlTableType.TABLE + def __post_init__(self) -> None: # noqa: D105 + if self.db_name is not None and self.schema_name is None: + raise ValueError(f"{self.db_name=} when it should be specified with {self.schema_name=}") + @staticmethod def from_string(sql_str: str) -> SqlTable: # noqa: D102 sql_str_split = sql_str.split(".") @@ -34,14 +39,11 @@ def from_string(sql_str: str) -> SqlTable: # noqa: D102 @property def sql(self) -> str: """Return the snippet that can be used for use in SQL queries.""" - if self.db_name: - return f"{self.db_name}.{self.schema_name}.{self.table_name}" - return f"{self.schema_name}.{self.table_name}" - - @property - def parts_tuple(self) -> Union[Tuple[str, str], Tuple[str, str, str]]: - """Return a tuple of the sql table parts.""" - if self.db_name: - return (self.db_name, self.schema_name, self.table_name) - else: - return (self.schema_name, self.table_name) + items = [] + if self.db_name is not None: + items.append(self.db_name) + if self.schema_name is not None: + items.append(self.schema_name) + items.append(self.table_name) + + return ".".join(items) diff --git a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py index ddc6e17534..f55967d3eb 100644 --- a/metricflow-semantics/metricflow_semantics/time/time_spine_source.py +++ b/metricflow-semantics/metricflow_semantics/time/time_spine_source.py @@ -24,7 +24,7 @@ class TimeSpineSource: Dates should be contiguous. May also contain custom granularity columns. """ - schema_name: str + schema_name: Optional[str] table_name: str = "mf_time_spine" # Name of the column in the table that contains date/time values that map to a standard granularity. base_column: str = "ds" diff --git a/metricflow-semantics/tests_metricflow_semantics/sql/test_sql_table.py b/metricflow-semantics/tests_metricflow_semantics/sql/test_sql_table.py new file mode 100644 index 0000000000..4735f7e3a9 --- /dev/null +++ b/metricflow-semantics/tests_metricflow_semantics/sql/test_sql_table.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import pytest +from metricflow_semantics.sql.sql_table import SqlTable + + +def test_sql_table() -> None: # noqa: D103 + assert SqlTable(schema_name=None, table_name="table", db_name=None).sql == "table" + assert SqlTable(schema_name="schema", table_name="table", db_name=None).sql == "schema.table" + assert SqlTable(schema_name="schema", table_name="table", db_name="db").sql == "db.schema.table" + + with pytest.raises(ValueError): + SqlTable(schema_name=None, table_name="table", db_name="db") diff --git a/metricflow/sql/sql_column.py b/metricflow/sql/sql_column.py index 2765d36a7d..ff515bbcb5 100644 --- a/metricflow/sql/sql_column.py +++ b/metricflow/sql/sql_column.py @@ -35,7 +35,7 @@ def db_name(self) -> Optional[str]: # noqa: D102 return self.table.db_name @property - def schema_name(self) -> str: # noqa: D102 + def schema_name(self) -> Optional[str]: # noqa: D102 return self.table.schema_name @property