Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assorted Fixes for the sql Module #1081

Merged
merged 7 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions metricflow/sql/sql_column.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from dbt_semantic_interfaces.implementations.base import FrozenBaseModel

from metricflow.sql.sql_table import SqlTable


class SqlColumn(FrozenBaseModel):
@dataclass(frozen=True, order=True)
class SqlColumn:
"""Represents a reference to a SQL column."""

table: SqlTable
Expand Down
27 changes: 4 additions & 23 deletions metricflow/sql/sql_table.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,16 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple, Union

from dbt_semantic_interfaces.implementations.base import (
FrozenBaseModel,
PydanticCustomInputParser,
PydanticParseableValueType,
)


class SqlTable(PydanticCustomInputParser, FrozenBaseModel):
@dataclass(frozen=True, order=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, thanks! There was some reason we couldn't do this originally and I guess we forgot to clean it up later.

class SqlTable:
"""Represents a reference to a SQL table."""

db_name: Optional[str] = None
schema_name: str
table_name: str

@classmethod
def _from_yaml_value(cls, input: PydanticParseableValueType) -> SqlTable:
"""Parses a SqlTable from string input found in a user-provided model specification.

Raises a ValueError on any non-string input, as all user-provided specifications of table entities
should be strings conforming to the expectations defined in the from_string method.
"""
if isinstance(input, str):
return SqlTable.from_string(input)
else:
raise ValueError(
f"SqlTable inputs from model configs are expected to always be of type string, but got type "
f"{type(input)} with value: {input}"
)
db_name: Optional[str] = None

@staticmethod
def from_string(sql_str: str) -> SqlTable: # noqa: D
Expand Down
5 changes: 0 additions & 5 deletions metricflow/test/dataflow/test_sql_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,3 @@ def test_sql_column() -> None: # noqa: D
assert sql_column.db_name == sql_column.table.db_name
assert sql_column.schema_name == sql_column.table.schema_name
assert sql_column.table_name == sql_column.table.table_name

json_serialized_column = sql_column.json()
deserialized_column = SqlColumn.parse_raw(json_serialized_column)

assert sql_column == deserialized_column
10 changes: 0 additions & 10 deletions metricflow/test/dataflow/test_sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,13 @@ def test_sql_table() -> None: # noqa: D
assert sql_table.sql == "foo.bar"
assert SqlTable.from_string("foo.bar") == sql_table

json_serialized_table = sql_table.json()
deserialized_table = SqlTable.parse_raw(json_serialized_table)

assert sql_table == deserialized_table


def test_sql_table_with_db() -> None: # noqa: D
sql_table = SqlTable(db_name="db", schema_name="foo", table_name="bar")

assert sql_table.sql == "db.foo.bar"
assert SqlTable.from_string("db.foo.bar") == sql_table

json_serialized_table = sql_table.json()
deserialized_table = SqlTable.parse_raw(json_serialized_table)

assert sql_table == deserialized_table


def test_invalid_sql_table() -> None: # noqa: D
with pytest.raises(RuntimeError):
Expand Down