From 828752c1da35f538c0c9f94348349ec761d2c91d Mon Sep 17 00:00:00 2001 From: Devon Fulcher Date: Mon, 2 Oct 2023 17:11:38 -0500 Subject: [PATCH] query interface updates to add Dimension(...).grain(...) support --- .../unreleased/Features-20230929-123932.yaml | 6 + .../call_parameter_sets.py | 6 +- dbt_semantic_interfaces/errors.py | 7 + .../where_filter/parameter_set_factory.py | 95 ++++++++++++ .../where_filter/where_filter_dimension.py | 62 ++++++++ .../where_filter/where_filter_entity.py | 43 ++++++ .../where_filter_time_dimension.py | 57 +++++++ .../parsing/where_filter_parser.py | 143 +++++------------- .../protocols/query_interface.py | 87 +++++++++++ .../where_filter/test_parse_calls.py | 22 +++ 10 files changed, 420 insertions(+), 108 deletions(-) create mode 100644 .changes/unreleased/Features-20230929-123932.yaml create mode 100644 dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py create mode 100644 dbt_semantic_interfaces/parsing/where_filter/where_filter_dimension.py create mode 100644 dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py create mode 100644 dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py create mode 100644 dbt_semantic_interfaces/protocols/query_interface.py diff --git a/.changes/unreleased/Features-20230929-123932.yaml b/.changes/unreleased/Features-20230929-123932.yaml new file mode 100644 index 00000000..83d896c1 --- /dev/null +++ b/.changes/unreleased/Features-20230929-123932.yaml @@ -0,0 +1,6 @@ +kind: Features +body: 'Backport: Add support for Dimension(...).grain(...) syntax in where parameter' +time: 2023-09-29T12:39:32.834352-05:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/dbt_semantic_interfaces/call_parameter_sets.py b/dbt_semantic_interfaces/call_parameter_sets.py index 193bc892..e9cae4f6 100644 --- a/dbt_semantic_interfaces/call_parameter_sets.py +++ b/dbt_semantic_interfaces/call_parameter_sets.py @@ -13,7 +13,7 @@ @dataclass(frozen=True) class DimensionCallParameterSet: - """When 'dimension(...)' is used in the Jinja template of the where filter, the parameters to that call.""" + """When 'Dimension(...)' is used in the Jinja template of the where filter, the parameters to that call.""" entity_path: Tuple[EntityReference, ...] dimension_reference: DimensionReference @@ -21,7 +21,7 @@ class DimensionCallParameterSet: @dataclass(frozen=True) class TimeDimensionCallParameterSet: - """When 'time_dimension(...)' is used in the Jinja template of the where filter, the parameters to that call.""" + """When 'TimeDimension(...)' is used in the Jinja template of the where filter, the parameters to that call.""" entity_path: Tuple[EntityReference, ...] time_dimension_reference: TimeDimensionReference @@ -30,7 +30,7 @@ class TimeDimensionCallParameterSet: @dataclass(frozen=True) class EntityCallParameterSet: - """When 'entity(...)' is used in the Jinja template of the where filter, the parameters to that call.""" + """When 'Entity(...)' is used in the Jinja template of the where filter, the parameters to that call.""" entity_path: Tuple[EntityReference, ...] entity_reference: EntityReference diff --git a/dbt_semantic_interfaces/errors.py b/dbt_semantic_interfaces/errors.py index 322e60ab..5cc9e631 100644 --- a/dbt_semantic_interfaces/errors.py +++ b/dbt_semantic_interfaces/errors.py @@ -22,3 +22,10 @@ class ModelTransformError(Exception): """Exception to represent errors related to model transformations.""" pass + + +class InvalidQuerySyntax(Exception): + """Raised when query syntax is invalid.""" + + def __init__(self, msg: str) -> None: # noqa: D + super().__init__(msg) diff --git a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py new file mode 100644 index 00000000..bbcfc722 --- /dev/null +++ b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py @@ -0,0 +1,95 @@ +from typing import Sequence + +from dbt_semantic_interfaces.call_parameter_sets import ( + DimensionCallParameterSet, + EntityCallParameterSet, + ParseWhereFilterException, + TimeDimensionCallParameterSet, +) +from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter +from dbt_semantic_interfaces.naming.keywords import ( + METRIC_TIME_ELEMENT_NAME, + is_metric_time_name, +) +from dbt_semantic_interfaces.references import ( + DimensionReference, + EntityReference, + TimeDimensionReference, +) +from dbt_semantic_interfaces.type_enums import TimeGranularity + + +class ParameterSetFactory: + """Creates parameter sets for use in the Jinja sandbox.""" + + @staticmethod + def _exception_message_for_incorrect_format(element_name: str) -> str: + return ( + f"Name is in an incorrect format: '{element_name}'. It should be of the form: " + f"__" + ) + + @staticmethod + def create_time_dimension( + time_dimension_name: str, time_granularity_name: str, entity_path: Sequence[str] = () + ) -> TimeDimensionCallParameterSet: + """Gets called by Jinja when rendering {{ TimeDimension(...) }}.""" + group_by_item_name = DunderedNameFormatter.parse_name(time_dimension_name) + + # metric_time is the only time dimension that does not have an associated primary entity, so the + # GroupByItemName would not have any entity links. + if is_metric_time_name(group_by_item_name.element_name): + if len(group_by_item_name.entity_links) != 0 or group_by_item_name.time_granularity is not None: + raise ParseWhereFilterException( + f"Name is in an incorrect format: {time_dimension_name} " + f"When referencing {METRIC_TIME_ELEMENT_NAME}," + "the name should not have any dunders (double underscores, or __)." + ) + else: + if len(group_by_item_name.entity_links) != 1 or group_by_item_name.time_granularity is not None: + raise ParseWhereFilterException( + ParameterSetFactory._exception_message_for_incorrect_format(time_dimension_name) + ) + + return TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference(element_name=group_by_item_name.element_name), + entity_path=( + tuple(EntityReference(element_name=arg) for arg in entity_path) + group_by_item_name.entity_links + ), + time_granularity=TimeGranularity(time_granularity_name), + ) + + @staticmethod + def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> DimensionCallParameterSet: + """Gets called by Jinja when rendering {{ Dimension(...) }}.""" + group_by_item_name = DunderedNameFormatter.parse_name(dimension_name) + if is_metric_time_name(group_by_item_name.element_name): + raise ParseWhereFilterException( + f"{METRIC_TIME_ELEMENT_NAME} is a time dimension, so it should be referenced using " + f"TimeDimension(...) or Dimension(...).grain(...)" + ) + + if len(group_by_item_name.entity_links) != 1: + raise ParseWhereFilterException(ParameterSetFactory._exception_message_for_incorrect_format(dimension_name)) + + return DimensionCallParameterSet( + dimension_reference=DimensionReference(element_name=group_by_item_name.element_name), + entity_path=( + tuple(EntityReference(element_name=arg) for arg in entity_path) + group_by_item_name.entity_links + ), + ) + + @staticmethod + def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet: + """Gets called by Jinja when rendering {{ Entity(...) }}.""" + group_by_item_name = DunderedNameFormatter.parse_name(entity_name) + if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None: + ParameterSetFactory._exception_message_for_incorrect_format( + f"Name is in an incorrect format: {entity_name} " + f"When referencing entities, the name should not have any dunders (double underscores, or __)." + ) + + return EntityCallParameterSet( + entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path), + entity_reference=EntityReference(element_name=entity_name), + ) diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_dimension.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_dimension.py new file mode 100644 index 00000000..fbd884d4 --- /dev/null +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_dimension.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import List, Optional, Sequence + +from typing_extensions import override + +from dbt_semantic_interfaces.errors import InvalidQuerySyntax +from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint +from dbt_semantic_interfaces.protocols.query_interface import ( + QueryInterfaceDimension, + QueryInterfaceDimensionFactory, +) + + +class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]): + """A dimension that is passed in through the where filter parameter.""" + + @override + def _implements_protocol(self) -> QueryInterfaceDimension: + return self + + def __init__( # noqa + self, + name: str, + entity_path: Sequence[str], + ) -> None: + self.name = name + self.entity_path = entity_path + self.time_granularity_name: Optional[str] = None + + def grain(self, time_granularity: str) -> QueryInterfaceDimension: + """The time granularity.""" + self.time_granularity_name = time_granularity + return self + + def descending(self, _is_descending: bool) -> QueryInterfaceDimension: + """Set the sort order for order-by.""" + raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec") + + def date_part(self, _date_part: str) -> QueryInterfaceDimension: + """Date part to extract from the dimension.""" + raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter and filter spec") + + +class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]): + """Creates a WhereFilterDimension. + + Each call to `create` adds a WhereFilterDimension to `created`. + """ + + @override + def _implements_protocol(self) -> QueryInterfaceDimensionFactory: + return self + + def __init__(self) -> None: # noqa + self.created: List[WhereFilterDimension] = [] + + def create(self, dimension_name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension: + """Gets called by Jinja when rendering {{ Dimension(...) }}.""" + dimension = WhereFilterDimension(dimension_name, entity_path) + self.created.append(dimension) + return dimension diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py new file mode 100644 index 00000000..83e5fe8c --- /dev/null +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import List, Sequence + +from typing_extensions import override + +from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet +from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( + ParameterSetFactory, +) +from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint +from dbt_semantic_interfaces.protocols.query_interface import ( + QueryInterfaceEntity, + QueryInterfaceEntityFactory, +) + + +class EntityStub(ProtocolHint[QueryInterfaceEntity]): + """An Entity implementation that just satisfies the protocol. + + QueryInterfaceEntity currently has no methods and the parameter set is created in the factory. + So, there is nothing to do here. + """ + + @override + def _implements_protocol(self) -> QueryInterfaceEntity: + return self + + +class WhereFilterEntityFactory(ProtocolHint[QueryInterfaceEntityFactory]): + """Executes in the Jinja sandbox to produce parameter sets and append them to a list.""" + + @override + def _implements_protocol(self) -> QueryInterfaceEntityFactory: + return self + + def __init__(self) -> None: # noqa + self.entity_call_parameter_sets: List[EntityCallParameterSet] = [] + + def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> EntityStub: + """Gets called by Jinja when rendering {{ Entity(...) }}.""" + self.entity_call_parameter_sets.append(ParameterSetFactory.create_entity(entity_name, entity_path)) + return EntityStub() diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py new file mode 100644 index 00000000..f8022f45 --- /dev/null +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import List, Optional, Sequence + +from typing_extensions import override + +from dbt_semantic_interfaces.call_parameter_sets import TimeDimensionCallParameterSet +from dbt_semantic_interfaces.errors import InvalidQuerySyntax +from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( + ParameterSetFactory, +) +from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint +from dbt_semantic_interfaces.protocols.query_interface import ( + QueryInterfaceTimeDimension, + QueryInterfaceTimeDimensionFactory, +) + + +class TimeDimensionStub(ProtocolHint[QueryInterfaceTimeDimension]): + """A TimeDimension implementation that just satisfies the protocol. + + QueryInterfaceTimeDimension currently has no methods and the parameter set is created in the factory. + So, there is nothing to do here. + """ + + @override + def _implements_protocol(self) -> QueryInterfaceTimeDimension: + return self + + +class WhereFilterTimeDimensionFactory(ProtocolHint[QueryInterfaceTimeDimensionFactory]): + """Executes in the Jinja sandbox to produce parameter sets and append them to a list.""" + + @override + def _implements_protocol(self) -> QueryInterfaceTimeDimensionFactory: + return self + + def __init__(self) -> None: # noqa + self.time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet] = [] + + def create( + self, + time_dimension_name: str, + time_granularity_name: str, + entity_path: Sequence[str] = (), + descending: Optional[bool] = None, + date_part_name: Optional[str] = None, + ) -> TimeDimensionStub: + """Gets called by Jinja when rendering {{ TimeDimension(...) }}.""" + if descending is not None: + raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec") + if date_part_name is not None: + raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter and filter spec") + self.time_dimension_call_parameter_sets.append( + ParameterSetFactory.create_time_dimension(time_dimension_name, time_granularity_name, entity_path) + ) + return TimeDimensionStub() diff --git a/dbt_semantic_interfaces/parsing/where_filter_parser.py b/dbt_semantic_interfaces/parsing/where_filter_parser.py index 40a367ec..0303572a 100644 --- a/dbt_semantic_interfaces/parsing/where_filter_parser.py +++ b/dbt_semantic_interfaces/parsing/where_filter_parser.py @@ -1,135 +1,68 @@ from __future__ import annotations -from typing import List, Sequence - from jinja2 import StrictUndefined from jinja2.exceptions import SecurityError, TemplateSyntaxError, UndefinedError from jinja2.sandbox import SandboxedEnvironment from dbt_semantic_interfaces.call_parameter_sets import ( - DimensionCallParameterSet, - EntityCallParameterSet, FilterCallParameterSets, ParseWhereFilterException, - TimeDimensionCallParameterSet, ) -from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter -from dbt_semantic_interfaces.naming.keywords import ( - METRIC_TIME_ELEMENT_NAME, - is_metric_time_name, +from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( + ParameterSetFactory, +) +from dbt_semantic_interfaces.parsing.where_filter.where_filter_dimension import ( + WhereFilterDimensionFactory, ) -from dbt_semantic_interfaces.references import ( - DimensionReference, - EntityReference, - TimeDimensionReference, +from dbt_semantic_interfaces.parsing.where_filter.where_filter_entity import ( + WhereFilterEntityFactory, +) +from dbt_semantic_interfaces.parsing.where_filter.where_filter_time_dimension import ( + WhereFilterTimeDimensionFactory, ) -from dbt_semantic_interfaces.type_enums import TimeGranularity class WhereFilterParser: """Parses the template in the WhereFilter into FilterCallParameterSets.""" - @staticmethod - def _exception_message_for_incorrect_format(element_name: str) -> str: - return ( - f"Name is in an incorrect format: '{element_name}'. It should be of the form: " - f"__" - ) - @staticmethod def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSets: """Return the result of extracting the semantic objects referenced in the where SQL template string.""" - # To extract the parameters to the calls, we use a function to record the parameters while rendering the Jinja - # template. The rendered result is not used, but since Jinja has to render something, using this as a - # placeholder. An alternative approach would have been to use the Jinja AST API, but this seemed simpler. - _DUMMY_PLACEHOLDER = "DUMMY_PLACEHOLDER" - - dimension_call_parameter_sets: List[DimensionCallParameterSet] = [] - time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet] = [] - entity_call_parameter_sets: List[EntityCallParameterSet] = [] - - def _dimension_call(dimension_name: str, entity_path: Sequence[str] = ()) -> str: - """Gets called by Jinja when rendering {{ dimension(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(dimension_name) - if len(group_by_item_name.entity_links) != 1: - raise ParseWhereFilterException( - WhereFilterParser._exception_message_for_incorrect_format(dimension_name) - ) - - dimension_call_parameter_sets.append( - DimensionCallParameterSet( - dimension_reference=DimensionReference(element_name=group_by_item_name.element_name), - entity_path=( - tuple(EntityReference(element_name=arg) for arg in entity_path) - + group_by_item_name.entity_links - ), - ) - ) - return _DUMMY_PLACEHOLDER - - def _time_dimension_call( - time_dimension_name: str, time_granularity_name: str, entity_path: Sequence[str] = () - ) -> str: - """Gets called by Jinja when rendering {{ time_dimension(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(time_dimension_name) - - # metric_time is the only time dimension that does not have an associated primary entity, so the - # GroupByItemName would not have any entity links. - if is_metric_time_name(group_by_item_name.element_name): - if len(group_by_item_name.entity_links) != 0 or group_by_item_name.time_granularity is not None: - raise ParseWhereFilterException( - WhereFilterParser._exception_message_for_incorrect_format( - f"Name is in an incorrect format: {time_dimension_name} " - f"When referencing {METRIC_TIME_ELEMENT_NAME}, the name should not have any dunders." - ) - ) - - else: - if len(group_by_item_name.entity_links) != 1 or group_by_item_name.time_granularity is not None: - raise ParseWhereFilterException( - WhereFilterParser._exception_message_for_incorrect_format(time_dimension_name) - ) - - time_dimension_call_parameter_sets.append( - TimeDimensionCallParameterSet( - time_dimension_reference=TimeDimensionReference(element_name=group_by_item_name.element_name), - entity_path=( - tuple(EntityReference(element_name=arg) for arg in entity_path) - + group_by_item_name.entity_links - ), - time_granularity=TimeGranularity(time_granularity_name), - ) - ) - return _DUMMY_PLACEHOLDER - - def _entity_call(entity_name: str, entity_path: Sequence[str] = ()) -> str: - """Gets called by Jinja when rendering {{ entity(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(entity_name) - if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None: - WhereFilterParser._exception_message_for_incorrect_format( - f"Name is in an incorrect format: {entity_name} " - f"When referencing entities, the name should not have any dunders." - ) - - entity_call_parameter_sets.append( - EntityCallParameterSet( - entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path), - entity_reference=EntityReference(element_name=entity_name), - ) - ) - return _DUMMY_PLACEHOLDER + time_dimension_factory = WhereFilterTimeDimensionFactory() + dimension_factory = WhereFilterDimensionFactory() + entity_factory = WhereFilterEntityFactory() try: + # the string that the sandbox renders is unused SandboxedEnvironment(undefined=StrictUndefined).from_string(where_sql_template).render( - Dimension=_dimension_call, - TimeDimension=_time_dimension_call, - Entity=_entity_call, + Dimension=dimension_factory.create, + TimeDimension=time_dimension_factory.create, + Entity=entity_factory.create, ) except (UndefinedError, TemplateSyntaxError, SecurityError) as e: raise ParseWhereFilterException(f"Error while parsing Jinja template:\n{where_sql_template}") from e + """ + Dimensions that are created with a grain parameter, Dimension(...).grain(...), are + added to time_dimension_call_parameter_sets otherwise they are add to dimension_call_parameter_sets + """ + dimension_call_parameter_sets = [] + for dimension in dimension_factory.created: + if dimension.time_granularity_name: + time_dimension_factory.time_dimension_call_parameter_sets.append( + ParameterSetFactory.create_time_dimension( + dimension.name, + dimension.time_granularity_name, + dimension.entity_path, + ) + ) + else: + dimension_call_parameter_sets.append( + ParameterSetFactory.create_dimension(dimension.name, dimension.entity_path) + ) + return FilterCallParameterSets( dimension_call_parameter_sets=tuple(dimension_call_parameter_sets), - time_dimension_call_parameter_sets=tuple(time_dimension_call_parameter_sets), - entity_call_parameter_sets=tuple(entity_call_parameter_sets), + time_dimension_call_parameter_sets=tuple(time_dimension_factory.time_dimension_call_parameter_sets), + entity_call_parameter_sets=tuple(entity_factory.entity_call_parameter_sets), ) diff --git a/dbt_semantic_interfaces/protocols/query_interface.py b/dbt_semantic_interfaces/protocols/query_interface.py new file mode 100644 index 00000000..2831740b --- /dev/null +++ b/dbt_semantic_interfaces/protocols/query_interface.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import Optional, Protocol, Sequence + + +class QueryInterfaceMetric(Protocol): + """Represents the interface for Metric in the query interface.""" + + @abstractmethod + def descending(self, _is_descending: bool) -> QueryInterfaceMetric: + """Set the sort order for order-by.""" + pass + + +class QueryInterfaceDimension(Protocol): + """Represents the interface for Dimension in the query interface.""" + + @abstractmethod + def grain(self, _grain: str) -> QueryInterfaceDimension: + """The time granularity.""" + pass + + @abstractmethod + def descending(self, _is_descending: bool) -> QueryInterfaceDimension: + """Set the sort order for order-by.""" + pass + + @abstractmethod + def date_part(self, _date_part: str) -> QueryInterfaceDimension: + """Date part to extract from the dimension.""" + pass + + +class QueryInterfaceDimensionFactory(Protocol): + """Creates a Dimension for the query interface. + + Represented as the Dimension constructor in the Jinja sandbox. + """ + + @abstractmethod + def create(self, name: str, entity_path: Sequence[str] = ()) -> QueryInterfaceDimension: + """Create a QueryInterfaceDimension.""" + pass + + +class QueryInterfaceTimeDimension(Protocol): + """Represents the interface for TimeDimension in the query interface.""" + + pass + + +class QueryInterfaceTimeDimensionFactory(Protocol): + """Creates a TimeDimension for the query interface. + + Represented as the TimeDimension constructor in the Jinja sandbox. + """ + + @abstractmethod + def create( + self, + time_dimension_name: str, + time_granularity_name: str, + entity_path: Sequence[str] = (), + descending: Optional[bool] = None, + date_part_name: Optional[str] = None, + ) -> QueryInterfaceTimeDimension: + """Create a TimeDimension.""" + pass + + +class QueryInterfaceEntity(Protocol): + """Represents the interface for Entity in the query interface.""" + + pass + + +class QueryInterfaceEntityFactory(Protocol): + """Creates an Entity for the query interface. + + Represented as the Entity constructor in the Jinja sandbox. + """ + + @abstractmethod + def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> QueryInterfaceEntity: + """Create an Entity.""" + pass diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 0fb7d04b..32d6cdbf 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -47,6 +47,28 @@ def test_extract_dimension_call_parameter_sets() -> None: # noqa: D ) +def test_extract_dimension_with_grain_call_parameter_sets() -> None: # noqa: D + parse_result = PydanticWhereFilter( + where_sql_template=( + """ + {{ Dimension('metric_time').grain('WEEK') }} > 2023-09-18 + """ + ) + ).call_parameter_sets + + assert parse_result == FilterCallParameterSets( + dimension_call_parameter_sets=(), + time_dimension_call_parameter_sets=( + TimeDimensionCallParameterSet( + entity_path=(), + time_dimension_reference=TimeDimensionReference(element_name="metric_time"), + time_granularity=TimeGranularity.WEEK, + ), + ), + entity_call_parameter_sets=(), + ) + + def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=(