diff --git a/dbt_semantic_interfaces/implementations/saved_query.py b/dbt_semantic_interfaces/implementations/saved_query.py index 6ff709b9..53de5038 100644 --- a/dbt_semantic_interfaces/implementations/saved_query.py +++ b/dbt_semantic_interfaces/implementations/saved_query.py @@ -9,7 +9,7 @@ ModelWithMetadataParsing, ) from dbt_semantic_interfaces.implementations.filters.where_filter import ( - PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.metadata import PydanticMetadata from dbt_semantic_interfaces.protocols import ProtocolHint @@ -26,7 +26,7 @@ def _implements_protocol(self) -> SavedQuery: name: str metrics: List[str] group_bys: List[str] = [] - where: List[PydanticWhereFilter] = [] + where: Optional[PydanticWhereFilterIntersection] = None description: Optional[str] = None metadata: Optional[PydanticMetadata] = None diff --git a/dbt_semantic_interfaces/protocols/saved_query.py b/dbt_semantic_interfaces/protocols/saved_query.py index 2018b164..3bd739d9 100644 --- a/dbt_semantic_interfaces/protocols/saved_query.py +++ b/dbt_semantic_interfaces/protocols/saved_query.py @@ -2,7 +2,7 @@ from typing import Optional, Protocol, Sequence from dbt_semantic_interfaces.protocols.metadata import Metadata -from dbt_semantic_interfaces.protocols.where_filter import WhereFilter +from dbt_semantic_interfaces.protocols.where_filter import WhereFilterIntersection class SavedQuery(Protocol): @@ -35,7 +35,8 @@ def group_bys(self) -> Sequence[str]: # noqa: D @property @abstractmethod - def where(self) -> Sequence[WhereFilter]: # noqa: D + def where(self) -> Optional[WhereFilterIntersection]: + """Returns the intersection class containing any where filters specified in the saved query.""" pass @property diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index 0b2ecd4c..f9abd7f4 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -101,7 +101,9 @@ def _check_metrics(valid_metric_names: Set[str], saved_query: SavedQuery) -> Seq @validate_safely("Validate the where field in a saved query.") def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] - for where_filter in saved_query.where: + if saved_query.where is None: + return issues + for where_filter in saved_query.where.where_filters: try: where_filter.call_parameter_sets except Exception as e: diff --git a/tests/parsing/test_saved_query_parsing.py b/tests/parsing/test_saved_query_parsing.py index 596ee66a..95b0e6aa 100644 --- a/tests/parsing/test_saved_query_parsing.py +++ b/tests/parsing/test_saved_query_parsing.py @@ -131,5 +131,6 @@ def test_saved_query_where() -> None: build_result = parse_yaml_files_to_semantic_manifest(files=[file, EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE]) assert len(build_result.semantic_manifest.saved_queries) == 1 saved_query = build_result.semantic_manifest.saved_queries[0] - assert len(saved_query.where) == 1 - assert where == saved_query.where[0].where_sql_template + assert saved_query.where is not None + assert len(saved_query.where.where_filters) == 1 + assert where == saved_query.where.where_filters[0].where_sql_template diff --git a/tests/validations/test_saved_query.py b/tests/validations/test_saved_query.py index c6ae46f4..89ba8289 100644 --- a/tests/validations/test_saved_query.py +++ b/tests/validations/test_saved_query.py @@ -3,6 +3,7 @@ from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, + PydanticWhereFilterIntersection, ) from dbt_semantic_interfaces.implementations.saved_query import PydanticSavedQuery from dbt_semantic_interfaces.implementations.semantic_manifest import ( @@ -44,7 +45,9 @@ def test_invalid_metric_in_saved_query( # noqa: D description="Example description.", metrics=["invalid_metric"], group_bys=["Dimension('booking__is_instant')"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ] @@ -64,7 +67,9 @@ def test_invalid_where_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["Dimension('booking__is_instant')"], - where=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + ), ), ] @@ -85,7 +90,9 @@ def test_invalid_group_by_element_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["Dimension('booking__invalid_dimension')"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ] @@ -106,7 +113,9 @@ def test_invalid_group_by_format_in_saved_query( # noqa: D description="Example description.", metrics=["bookings"], group_bys=["invalid_format"], - where=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Dimension('booking__is_instant') }}")], + ), ), ]