Skip to content

Commit

Permalink
implemented date_part in where filter
Browse files Browse the repository at this point in the history
  • Loading branch information
DevonFulcher committed Nov 8, 2023
1 parent 043bbbb commit acae8fa
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 20 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231107-180843.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Implemented date_part in where filter.
time: 2023-11-07T18:08:43.67846-06:00
custom:
Author: DevonFulcher
Issue: None
13 changes: 10 additions & 3 deletions metricflow/specs/dimension_spec_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
Expand All @@ -10,6 +10,7 @@
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart

from metricflow.specs.specs import DEFAULT_TIME_GRANULARITY, DimensionSpec, TimeDimensionSpec

Expand All @@ -35,16 +36,21 @@ def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> Dimen
)

def resolve_time_dimension_spec(
self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str]
self,
name: str,
time_granularity: Optional[TimeGranularity],
entity_path: Sequence[str],
date_part: Optional[DatePart],
) -> TimeDimensionSpec:
"""Resolve TimeDimension spec with the call_parameter_sets."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name),
time_granularity=time_granularity_name,
time_granularity=time_granularity,
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
date_part=date_part,
)
assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets
return TimeDimensionSpec(
Expand All @@ -56,4 +62,5 @@ def resolve_time_dimension_spec(
if call_parameter_set.time_granularity is not None
else DEFAULT_TIME_GRANULARITY
),
date_part=call_parameter_set.date_part,
)
39 changes: 28 additions & 11 deletions metricflow/specs/where_filter_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
QueryInterfaceDimensionFactory,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver
from metricflow.specs.specs import TimeDimensionSpec
from metricflow.specs.specs import DimensionSpec, InstanceSpec, TimeDimensionSpec


class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
Expand All @@ -37,32 +38,48 @@ def __init__( # noqa
self._column_association_resolver = column_association_resolver
self._name = name
self._entity_path = entity_path
self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path)
self.time_dimension_spec: Optional[TimeDimensionSpec] = None
self.dimension_spec: DimensionSpec = self._dimension_spec_resolver.resolve_dimension_spec(
self._name, self._entity_path
)
self.date_part_name: Optional[str] = None
self.time_granularity_name: Optional[str] = None

@property
def time_dimension_spec(self) -> TimeDimensionSpec:
"""TimeDimensionSpec that results from the builder-pattern configuration."""
return self._dimension_spec_resolver.resolve_time_dimension_spec(
self._name,
TimeGranularity(self.time_granularity_name) if self.time_granularity_name else None,
self._entity_path,
DatePart(self.date_part_name) if self.date_part_name else None,
)

def grain(self, time_granularity_name: str) -> QueryInterfaceDimension:
"""The time granularity."""
self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
self._name, TimeGranularity(time_granularity_name), self._entity_path
)
self.time_granularity_name = time_granularity_name
return self

def date_part(self, _date_part: str) -> QueryInterfaceDimension:
def date_part(self, date_part_name: str) -> QueryInterfaceDimension:
"""The date_part requested to extract."""
raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter")
self.date_part_name = date_part_name
return self

def descending(self, _is_descending: bool) -> QueryInterfaceDimension:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax("descending is invalid in the where parameter")

def _get_spec(self) -> InstanceSpec:
"""Get either the TimeDimensionSpec or DimensionSpec."""
if self.time_granularity_name or self.date_part_name:
return self.time_dimension_spec
return self.dimension_spec

def __str__(self) -> str:
"""Returns the column name.
Important in the Jinja sandbox.
"""
return self._column_association_resolver.resolve_spec(
self.time_dimension_spec or self.dimension_spec
).column_name
return self._column_association_resolver.resolve_spec(self._get_spec()).column_name


class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]):
Expand Down
8 changes: 5 additions & 3 deletions metricflow/specs/where_filter_time_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QueryInterfaceTimeDimensionFactory,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
Expand Down Expand Up @@ -68,10 +69,11 @@ def create(
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)
if date_part_name:
raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter")
time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
time_dimension_name, TimeGranularity(time_granularity_name), entity_path
time_dimension_name,
TimeGranularity(time_granularity_name) if time_dimension_name else None,
entity_path,
DatePart(date_part_name) if date_part_name else None,
)
self.time_dimension_specs.append(time_dimension_spec)
column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name
Expand Down
6 changes: 3 additions & 3 deletions metricflow/specs/where_filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec
)

"""
Dimensions that are created with a grain parameter, Dimension(...).grain(...), are
added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs
Dimensions that are created with a grain or date_part parameter, Dimension(...).grain(...), are
added to time_dimension_factory.time_dimension_specs otherwise they are add to dimension_specs
"""
dimension_specs = []
for dimension in dimension_factory.created:
if dimension.time_dimension_spec:
if dimension.time_granularity_name or dimension.date_part_name:
time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec)
else:
dimension_specs.append(dimension.dimension_spec)
Expand Down
89 changes: 89 additions & 0 deletions metricflow/test/model/test_where_filter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.query.query_exceptions import InvalidQueryException
Expand Down Expand Up @@ -98,6 +99,94 @@ def test_time_dimension_in_filter( # noqa: D
)


def test_date_part_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template="{{ Dimension('metric_time').date_part('year') }} = '2020'")

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="metric_time",
entity_links=(),
time_granularity=TimeGranularity.DAY,
date_part=DatePart.YEAR,
),
),
entity_specs=(),
)


@pytest.mark.parametrize(
"where_sql",
(
("{{ TimeDimension('metric_time', 'WEEK', date_part_name='year') }} = '2020'"),
("{{ Dimension('metric_time').date_part('year').grain('WEEK') }} = '2020'"),
("{{ Dimension('metric_time').grain('WEEK').date_part('year') }} = '2020'"),
),
)
def test_date_part_and_grain_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver, where_sql: str
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="metric_time",
entity_links=(),
time_granularity=TimeGranularity.WEEK,
date_part=DatePart.YEAR,
),
),
entity_specs=(),
)


@pytest.mark.parametrize(
"where_sql",
(
("{{ TimeDimension('metric_time', 'WEEK', date_part_name='day') }} = '2020'"),
("{{ Dimension('metric_time').date_part('day').grain('WEEK') }} = '2020'"),
("{{ Dimension('metric_time').grain('WEEK').date_part('day') }} = '2020'"),
),
)
def test_date_part_less_than_grain_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver, where_sql: str
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="metric_time",
entity_links=(),
time_granularity=TimeGranularity.WEEK,
date_part=DatePart.DAY,
),
),
entity_specs=(),
)


def test_entity_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
Expand Down

0 comments on commit acae8fa

Please sign in to comment.