From acae8fabe17326b5085af70014414fa6fa0b5dcb Mon Sep 17 00:00:00 2001 From: Devon Fulcher Date: Tue, 7 Nov 2023 18:37:06 -0600 Subject: [PATCH] implemented date_part in where filter --- .../unreleased/Features-20231107-180843.yaml | 6 ++ metricflow/specs/dimension_spec_resolver.py | 13 ++- metricflow/specs/where_filter_dimension.py | 39 +++++--- .../specs/where_filter_time_dimension.py | 8 +- metricflow/specs/where_filter_transform.py | 6 +- .../test/model/test_where_filter_spec.py | 89 +++++++++++++++++++ 6 files changed, 141 insertions(+), 20 deletions(-) create mode 100644 .changes/unreleased/Features-20231107-180843.yaml diff --git a/.changes/unreleased/Features-20231107-180843.yaml b/.changes/unreleased/Features-20231107-180843.yaml new file mode 100644 index 0000000000..59b3d218ea --- /dev/null +++ b/.changes/unreleased/Features-20231107-180843.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Implemented date_part in where filter. +time: 2023-11-07T18:08:43.67846-06:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/metricflow/specs/dimension_spec_resolver.py b/metricflow/specs/dimension_spec_resolver.py index febbefa478..25caa0ac6c 100644 --- a/metricflow/specs/dimension_spec_resolver.py +++ b/metricflow/specs/dimension_spec_resolver.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from typing import Optional, Sequence from dbt_semantic_interfaces.call_parameter_sets import ( DimensionCallParameterSet, @@ -10,6 +10,7 @@ from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart from metricflow.specs.specs import DEFAULT_TIME_GRANULARITY, DimensionSpec, TimeDimensionSpec @@ -35,16 +36,21 @@ def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> Dimen ) def resolve_time_dimension_spec( - self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str] + self, + name: str, + time_granularity: Optional[TimeGranularity], + entity_path: Sequence[str], + date_part: Optional[DatePart], ) -> TimeDimensionSpec: """Resolve TimeDimension spec with the call_parameter_sets.""" structured_name = DunderedNameFormatter.parse_name(name) call_parameter_set = TimeDimensionCallParameterSet( time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name), - time_granularity=time_granularity_name, + time_granularity=time_granularity, entity_path=( tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links ), + date_part=date_part, ) assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets return TimeDimensionSpec( @@ -56,4 +62,5 @@ def resolve_time_dimension_spec( if call_parameter_set.time_granularity is not None else DEFAULT_TIME_GRANULARITY ), + date_part=call_parameter_set.date_part, ) diff --git a/metricflow/specs/where_filter_dimension.py b/metricflow/specs/where_filter_dimension.py index 9af2858d8f..f98df6d249 100644 --- a/metricflow/specs/where_filter_dimension.py +++ b/metricflow/specs/where_filter_dimension.py @@ -11,12 +11,13 @@ QueryInterfaceDimensionFactory, ) from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import InvalidQuerySyntax from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver -from metricflow.specs.specs import TimeDimensionSpec +from metricflow.specs.specs import DimensionSpec, InstanceSpec, TimeDimensionSpec class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]): @@ -37,32 +38,48 @@ def __init__( # noqa self._column_association_resolver = column_association_resolver self._name = name self._entity_path = entity_path - self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path) - self.time_dimension_spec: Optional[TimeDimensionSpec] = None + self.dimension_spec: DimensionSpec = self._dimension_spec_resolver.resolve_dimension_spec( + self._name, self._entity_path + ) + self.date_part_name: Optional[str] = None + self.time_granularity_name: Optional[str] = None + + @property + def time_dimension_spec(self) -> TimeDimensionSpec: + """TimeDimensionSpec that results from the builder-pattern configuration.""" + return self._dimension_spec_resolver.resolve_time_dimension_spec( + self._name, + TimeGranularity(self.time_granularity_name) if self.time_granularity_name else None, + self._entity_path, + DatePart(self.date_part_name) if self.date_part_name else None, + ) def grain(self, time_granularity_name: str) -> QueryInterfaceDimension: """The time granularity.""" - self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( - self._name, TimeGranularity(time_granularity_name), self._entity_path - ) + self.time_granularity_name = time_granularity_name return self - def date_part(self, _date_part: str) -> QueryInterfaceDimension: + def date_part(self, date_part_name: str) -> QueryInterfaceDimension: """The date_part requested to extract.""" - raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter") + self.date_part_name = date_part_name + return self def descending(self, _is_descending: bool) -> QueryInterfaceDimension: """Set the sort order for order-by.""" raise InvalidQuerySyntax("descending is invalid in the where parameter") + def _get_spec(self) -> InstanceSpec: + """Get either the TimeDimensionSpec or DimensionSpec.""" + if self.time_granularity_name or self.date_part_name: + return self.time_dimension_spec + return self.dimension_spec + def __str__(self) -> str: """Returns the column name. Important in the Jinja sandbox. """ - return self._column_association_resolver.resolve_spec( - self.time_dimension_spec or self.dimension_spec - ).column_name + return self._column_association_resolver.resolve_spec(self._get_spec()).column_name class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]): diff --git a/metricflow/specs/where_filter_time_dimension.py b/metricflow/specs/where_filter_time_dimension.py index 75d12d5fbe..a6f840dc92 100644 --- a/metricflow/specs/where_filter_time_dimension.py +++ b/metricflow/specs/where_filter_time_dimension.py @@ -9,6 +9,7 @@ QueryInterfaceTimeDimensionFactory, ) from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override from metricflow.errors.errors import InvalidQuerySyntax @@ -68,10 +69,11 @@ def create( raise InvalidQuerySyntax( "Can't set descending in the where clause. Try setting descending in the order_by clause instead" ) - if date_part_name: - raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter") time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( - time_dimension_name, TimeGranularity(time_granularity_name), entity_path + time_dimension_name, + TimeGranularity(time_granularity_name) if time_dimension_name else None, + entity_path, + DatePart(date_part_name) if date_part_name else None, ) self.time_dimension_specs.append(time_dimension_spec) column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name diff --git a/metricflow/specs/where_filter_transform.py b/metricflow/specs/where_filter_transform.py index c039cee1f8..418737cead 100644 --- a/metricflow/specs/where_filter_transform.py +++ b/metricflow/specs/where_filter_transform.py @@ -82,12 +82,12 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec ) """ - Dimensions that are created with a grain parameter, Dimension(...).grain(...), are - added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs + Dimensions that are created with a grain or date_part parameter, Dimension(...).grain(...), are + added to time_dimension_factory.time_dimension_specs otherwise they are add to dimension_specs """ dimension_specs = [] for dimension in dimension_factory.created: - if dimension.time_dimension_spec: + if dimension.time_granularity_name or dimension.date_part_name: time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec) else: dimension_specs.append(dimension.dimension_spec) diff --git a/metricflow/test/model/test_where_filter_spec.py b/metricflow/test/model/test_where_filter_spec.py index 80dd4cc0fa..3d6e905259 100644 --- a/metricflow/test/model/test_where_filter_spec.py +++ b/metricflow/test/model/test_where_filter_spec.py @@ -5,6 +5,7 @@ import pytest from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow.query.query_exceptions import InvalidQueryException @@ -98,6 +99,94 @@ def test_time_dimension_in_filter( # noqa: D ) +def test_date_part_in_filter( # noqa: D + column_association_resolver: ColumnAssociationResolver, +) -> None: + where_filter = PydanticWhereFilter(where_sql_template="{{ Dimension('metric_time').date_part('year') }} = '2020'") + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.DAY, + date_part=DatePart.YEAR, + ), + ), + entity_specs=(), + ) + + +@pytest.mark.parametrize( + "where_sql", + ( + ("{{ TimeDimension('metric_time', 'WEEK', date_part_name='year') }} = '2020'"), + ("{{ Dimension('metric_time').date_part('year').grain('WEEK') }} = '2020'"), + ("{{ Dimension('metric_time').grain('WEEK').date_part('year') }} = '2020'"), + ), +) +def test_date_part_and_grain_in_filter( # noqa: D + column_association_resolver: ColumnAssociationResolver, where_sql: str +) -> None: + where_filter = PydanticWhereFilter(where_sql_template=where_sql) + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=DatePart.YEAR, + ), + ), + entity_specs=(), + ) + + +@pytest.mark.parametrize( + "where_sql", + ( + ("{{ TimeDimension('metric_time', 'WEEK', date_part_name='day') }} = '2020'"), + ("{{ Dimension('metric_time').date_part('day').grain('WEEK') }} = '2020'"), + ("{{ Dimension('metric_time').grain('WEEK').date_part('day') }} = '2020'"), + ), +) +def test_date_part_less_than_grain_in_filter( # noqa: D + column_association_resolver: ColumnAssociationResolver, where_sql: str +) -> None: + where_filter = PydanticWhereFilter(where_sql_template=where_sql) + + where_filter_spec = WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'" + assert where_filter_spec.linkable_spec_set == LinkableSpecSet( + dimension_specs=(), + time_dimension_specs=( + TimeDimensionSpec( + element_name="metric_time", + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=DatePart.DAY, + ), + ), + entity_specs=(), + ) + + def test_entity_in_filter( # noqa: D column_association_resolver: ColumnAssociationResolver, ) -> None: