diff --git a/dbt_semantic_interfaces/implementations/semantic_model.py b/dbt_semantic_interfaces/implementations/semantic_model.py index 90b4a89e..f119bdef 100644 --- a/dbt_semantic_interfaces/implementations/semantic_model.py +++ b/dbt_semantic_interfaces/implementations/semantic_model.py @@ -18,6 +18,7 @@ SemanticModelConfig, SemanticModelDefaults, ) +from dbt_semantic_interfaces.protocols.semantic_model import NodeRelation from dbt_semantic_interfaces.references import ( EntityReference, LinkableElementReference, @@ -28,7 +29,7 @@ from dsi_pydantic_shim import Field, validator -class NodeRelation(HashableBaseModel): +class PydanticNodeRelation(HashableBaseModel, ProtocolHint[NodeRelation]): """Path object to where the data should be.""" alias: str @@ -36,6 +37,10 @@ class NodeRelation(HashableBaseModel): database: Optional[str] = None relation_name: str = "" + @override + def _implements_protocol(self) -> NodeRelation: # noqa: D + return self + @validator("relation_name", always=True) @classmethod def __create_default_relation_name(cls, value: Any, values: Any) -> str: # type: ignore[misc] @@ -57,12 +62,12 @@ def __create_default_relation_name(cls, value: Any, values: Any) -> str: # type return value @staticmethod - def from_string(sql_str: str) -> NodeRelation: # noqa: D + def from_string(sql_str: str) -> PydanticNodeRelation: # noqa: D sql_str_split = sql_str.split(".") if len(sql_str_split) == 2: - return NodeRelation(schema_name=sql_str_split[0], alias=sql_str_split[1]) + return PydanticNodeRelation(schema_name=sql_str_split[0], alias=sql_str_split[1]) elif len(sql_str_split) == 3: - return NodeRelation(database=sql_str_split[0], schema_name=sql_str_split[1], alias=sql_str_split[2]) + return PydanticNodeRelation(database=sql_str_split[0], schema_name=sql_str_split[1], alias=sql_str_split[2]) raise RuntimeError( f"Invalid input for a SQL table, expected form '.' or '..
' " f"but got: {sql_str}" @@ -95,7 +100,7 @@ def _implements_protocol(self) -> SemanticModel: name: str defaults: Optional[PydanticSemanticModelDefaults] description: Optional[str] - node_relation: NodeRelation + node_relation: PydanticNodeRelation primary_entity: Optional[str] entities: Sequence[PydanticEntity] = [] diff --git a/dbt_semantic_interfaces/implementations/time_spine.py b/dbt_semantic_interfaces/implementations/time_spine.py index 8a8bafee..02a787a2 100644 --- a/dbt_semantic_interfaces/implementations/time_spine.py +++ b/dbt_semantic_interfaces/implementations/time_spine.py @@ -7,9 +7,7 @@ ModelWithMetadataParsing, ) from dbt_semantic_interfaces.protocols import ProtocolHint -from dbt_semantic_interfaces.protocols.time_spine import ( - TimeSpine, -) +from dbt_semantic_interfaces.protocols.time_spine import TimeSpine from dbt_semantic_interfaces.type_enums import TimeGranularity diff --git a/dbt_semantic_interfaces/test_utils.py b/dbt_semantic_interfaces/test_utils.py index addd1b93..0b2fefbe 100644 --- a/dbt_semantic_interfaces/test_utils.py +++ b/dbt_semantic_interfaces/test_utils.py @@ -20,7 +20,7 @@ PydanticSemanticManifest, ) from dbt_semantic_interfaces.implementations.semantic_model import ( - NodeRelation, + PydanticNodeRelation, PydanticSemanticModel, ) from dbt_semantic_interfaces.parsing.objects import YamlConfigFile @@ -141,7 +141,7 @@ def metric_with_guaranteed_meta( def semantic_model_with_guaranteed_meta( name: str, description: Optional[str] = None, - node_relation: Optional[NodeRelation] = None, + node_relation: Optional[PydanticNodeRelation] = None, metadata: PydanticMetadata = default_meta(), entities: Sequence[PydanticEntity] = (), measures: Sequence[PydanticMeasure] = (), @@ -153,7 +153,7 @@ def semantic_model_with_guaranteed_meta( """ created_node_relation = node_relation if created_node_relation is None: - created_node_relation = NodeRelation( + created_node_relation = PydanticNodeRelation( schema_name="schema", alias="table", ) diff --git a/tests/validations/test_reserved_keywords.py b/tests/validations/test_reserved_keywords.py index e1d477d5..85cd8b02 100644 --- a/tests/validations/test_reserved_keywords.py +++ b/tests/validations/test_reserved_keywords.py @@ -4,7 +4,7 @@ from dbt_semantic_interfaces.implementations.semantic_manifest import ( PydanticSemanticManifest, ) -from dbt_semantic_interfaces.implementations.semantic_model import NodeRelation +from dbt_semantic_interfaces.implementations.semantic_model import PydanticNodeRelation from dbt_semantic_interfaces.test_utils import find_semantic_model_with from dbt_semantic_interfaces.validations.reserved_keywords import ( RESERVED_KEYWORDS, @@ -76,7 +76,7 @@ def test_reserved_keywords_in_node_relation( # noqa: D (semantic_model_with_node_relation, _index) = find_semantic_model_with( model=model, function=lambda semantic_model: semantic_model.node_relation is not None ) - semantic_model_with_node_relation.node_relation = NodeRelation( + semantic_model_with_node_relation.node_relation = PydanticNodeRelation( alias=random_keyword(), schema_name="some_schema", )