Skip to content

Commit

Permalink
Update NodeRelation implementation class to follow patterns used by o…
Browse files Browse the repository at this point in the history
…ther implementation classes
  • Loading branch information
courtneyholcomb committed Jun 17, 2024
1 parent 59286e0 commit 3056ea5
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
15 changes: 10 additions & 5 deletions dbt_semantic_interfaces/implementations/semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SemanticModelConfig,
SemanticModelDefaults,
)
from dbt_semantic_interfaces.protocols.semantic_model import NodeRelation
from dbt_semantic_interfaces.references import (
EntityReference,
LinkableElementReference,
Expand All @@ -28,14 +29,18 @@
from dsi_pydantic_shim import Field, validator


class NodeRelation(HashableBaseModel):
class PydanticNodeRelation(HashableBaseModel, ProtocolHint[NodeRelation]):
"""Path object to where the data should be."""

alias: str
schema_name: str
database: Optional[str] = None
relation_name: str = ""

@override
def _implements_protocol(self) -> NodeRelation: # noqa: D
return self

@validator("relation_name", always=True)
@classmethod
def __create_default_relation_name(cls, value: Any, values: Any) -> str: # type: ignore[misc]
Expand All @@ -57,12 +62,12 @@ def __create_default_relation_name(cls, value: Any, values: Any) -> str: # type
return value

@staticmethod
def from_string(sql_str: str) -> NodeRelation: # noqa: D
def from_string(sql_str: str) -> PydanticNodeRelation: # noqa: D
sql_str_split = sql_str.split(".")
if len(sql_str_split) == 2:
return NodeRelation(schema_name=sql_str_split[0], alias=sql_str_split[1])
return PydanticNodeRelation(schema_name=sql_str_split[0], alias=sql_str_split[1])
elif len(sql_str_split) == 3:
return NodeRelation(database=sql_str_split[0], schema_name=sql_str_split[1], alias=sql_str_split[2])
return PydanticNodeRelation(database=sql_str_split[0], schema_name=sql_str_split[1], alias=sql_str_split[2])
raise RuntimeError(
f"Invalid input for a SQL table, expected form '<schema>.<table>' or '<db>.<schema>.<table>' "
f"but got: {sql_str}"
Expand Down Expand Up @@ -95,7 +100,7 @@ def _implements_protocol(self) -> SemanticModel:
name: str
defaults: Optional[PydanticSemanticModelDefaults]
description: Optional[str]
node_relation: NodeRelation
node_relation: PydanticNodeRelation

primary_entity: Optional[str]
entities: Sequence[PydanticEntity] = []
Expand Down
4 changes: 1 addition & 3 deletions dbt_semantic_interfaces/implementations/time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
ModelWithMetadataParsing,
)
from dbt_semantic_interfaces.protocols import ProtocolHint
from dbt_semantic_interfaces.protocols.time_spine import (
TimeSpine,
)
from dbt_semantic_interfaces.protocols.time_spine import TimeSpine
from dbt_semantic_interfaces.type_enums import TimeGranularity


Expand Down
6 changes: 3 additions & 3 deletions dbt_semantic_interfaces/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
PydanticSemanticManifest,
)
from dbt_semantic_interfaces.implementations.semantic_model import (
NodeRelation,
PydanticNodeRelation,
PydanticSemanticModel,
)
from dbt_semantic_interfaces.parsing.objects import YamlConfigFile
Expand Down Expand Up @@ -141,7 +141,7 @@ def metric_with_guaranteed_meta(
def semantic_model_with_guaranteed_meta(
name: str,
description: Optional[str] = None,
node_relation: Optional[NodeRelation] = None,
node_relation: Optional[PydanticNodeRelation] = None,
metadata: PydanticMetadata = default_meta(),
entities: Sequence[PydanticEntity] = (),
measures: Sequence[PydanticMeasure] = (),
Expand All @@ -153,7 +153,7 @@ def semantic_model_with_guaranteed_meta(
"""
created_node_relation = node_relation
if created_node_relation is None:
created_node_relation = NodeRelation(
created_node_relation = PydanticNodeRelation(
schema_name="schema",
alias="table",
)
Expand Down
4 changes: 2 additions & 2 deletions tests/validations/test_reserved_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dbt_semantic_interfaces.implementations.semantic_manifest import (
PydanticSemanticManifest,
)
from dbt_semantic_interfaces.implementations.semantic_model import NodeRelation
from dbt_semantic_interfaces.implementations.semantic_model import PydanticNodeRelation
from dbt_semantic_interfaces.test_utils import find_semantic_model_with
from dbt_semantic_interfaces.validations.reserved_keywords import (
RESERVED_KEYWORDS,
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_reserved_keywords_in_node_relation( # noqa: D
(semantic_model_with_node_relation, _index) = find_semantic_model_with(
model=model, function=lambda semantic_model: semantic_model.node_relation is not None
)
semantic_model_with_node_relation.node_relation = NodeRelation(
semantic_model_with_node_relation.node_relation = PydanticNodeRelation(
alias=random_keyword(),
schema_name="some_schema",
)
Expand Down

0 comments on commit 3056ea5

Please sign in to comment.