From d8640fc9bae20d20a6e054b1ff5883b5a998da7a Mon Sep 17 00:00:00 2001 From: Devon Fulcher Date: Thu, 21 Sep 2023 17:01:12 -0500 Subject: [PATCH] not mutating time_dimension_factory.time_dimension_specs in WhereFilterDimension --- metricflow/protocols/query_interface.py | 4 --- metricflow/specs/where_filter_dimension.py | 34 +++++++------------ metricflow/specs/where_filter_transform.py | 12 +++---- .../test/specs/test_where_filter_dimension.py | 11 ------ 4 files changed, 18 insertions(+), 43 deletions(-) delete mode 100644 metricflow/test/specs/test_where_filter_dimension.py diff --git a/metricflow/protocols/query_interface.py b/metricflow/protocols/query_interface.py index ec120699c9..5d5797cf8f 100644 --- a/metricflow/protocols/query_interface.py +++ b/metricflow/protocols/query_interface.py @@ -18,10 +18,6 @@ def grain(self, _grain: str) -> QueryInterfaceDimension: """The time granularity.""" raise NotImplementedError - def alias(self, _alias: str) -> QueryInterfaceDimension: - """Renaming the column.""" - raise NotImplementedError - def descending(self, _is_descending: bool) -> QueryInterfaceDimension: """Set the sort order for order-by.""" diff --git a/metricflow/specs/where_filter_dimension.py b/metricflow/specs/where_filter_dimension.py index 5dac469586..8682a1d532 100644 --- a/metricflow/specs/where_filter_dimension.py +++ b/metricflow/specs/where_filter_dimension.py @@ -32,45 +32,37 @@ def __init__( # noqa entity_path: Sequence[str], call_parameter_sets: FilterCallParameterSets, column_association_resolver: ColumnAssociationResolver, - time_dimension_specs: List[TimeDimensionSpec], ) -> None: self._dimension_spec_resolver = DimensionSpecResolver(call_parameter_sets) - self.name = name - self.spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path) self._column_association_resolver = column_association_resolver - self.entity_path = entity_path - self.time_granularity: Optional[TimeGranularity] = None - self._time_dimension_specs = time_dimension_specs + self._name = name + self._entity_path = entity_path + self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path) + self.time_dimension_spec: Optional[TimeDimensionSpec] = None def grain(self, time_granularity_name: str) -> QueryInterfaceDimension: """The time granularity.""" - self.time_granularity = TimeGranularity(time_granularity_name) - self.spec = self._dimension_spec_resolver.resolve_time_dimension_spec( - self.name, self.time_granularity, self.entity_path + self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec( + self._name, TimeGranularity(time_granularity_name), self._entity_path ) - self._time_dimension_specs.append(self.spec) return self def date_part(self, _date_part: str) -> QueryInterfaceDimension: """The date_part requested to extract.""" - raise NotImplementedError - - def alias(self, _alias: str) -> QueryInterfaceDimension: - """Renaming the column.""" - raise NotImplementedError + raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter and filter spec") def descending(self, _is_descending: bool) -> QueryInterfaceDimension: """Set the sort order for order-by.""" - raise InvalidQuerySyntax( - "Can't set descending in the where clause. Try setting descending in the order_by clause instead" - ) + raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec") def __str__(self) -> str: """Returns the column name. Important in the Jinja sandbox. """ - return self._column_association_resolver.resolve_spec(self.spec).column_name + return self._column_association_resolver.resolve_spec( + self.time_dimension_spec or self.dimension_spec + ).column_name class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]): @@ -87,17 +79,15 @@ def __init__( # noqa self, call_parameter_sets: FilterCallParameterSets, column_association_resolver: ColumnAssociationResolver, - time_dimension_specs: List[TimeDimensionSpec], ): self._call_parameter_sets = call_parameter_sets self._column_association_resolver = column_association_resolver - self._time_dimension_specs = time_dimension_specs self.created: List[WhereFilterDimension] = [] def create(self, name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension: """Create a WhereFilterDimension.""" dimension = WhereFilterDimension( - name, entity_path, self._call_parameter_sets, self._column_association_resolver, self._time_dimension_specs + name, entity_path, self._call_parameter_sets, self._column_association_resolver ) self.created.append(dimension) return dimension diff --git a/metricflow/specs/where_filter_transform.py b/metricflow/specs/where_filter_transform.py index 741216dd68..13d6ad492b 100644 --- a/metricflow/specs/where_filter_transform.py +++ b/metricflow/specs/where_filter_transform.py @@ -6,7 +6,7 @@ from dbt_semantic_interfaces.protocols.where_filter import WhereFilter from metricflow.specs.column_assoc import ColumnAssociationResolver -from metricflow.specs.specs import LinkableSpecSet, TimeDimensionSpec, WhereFilterSpec +from metricflow.specs.specs import LinkableSpecSet, WhereFilterSpec from metricflow.specs.where_filter_dimension import WhereFilterDimensionFactory from metricflow.specs.where_filter_entity import WhereFilterEntityFactory from metricflow.specs.where_filter_time_dimension import WhereFilterTimeDimensionFactory @@ -35,9 +35,7 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec call_parameter_sets = where_filter.call_parameter_sets time_dimension_factory = WhereFilterTimeDimensionFactory(call_parameter_sets, self._column_association_resolver) - dimension_factory = WhereFilterDimensionFactory( - call_parameter_sets, self._column_association_resolver, time_dimension_factory.time_dimension_specs - ) + dimension_factory = WhereFilterDimensionFactory(call_parameter_sets, self._column_association_resolver) entity_factory = WhereFilterEntityFactory(call_parameter_sets, self._column_association_resolver) try: rendered_sql_template = jinja2.Template( @@ -56,8 +54,10 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec dimension_specs = [] for dimension in dimension_factory.created: - if not dimension.time_granularity: - dimension_specs.append(dimension.spec) + if dimension.time_dimension_spec: + time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec) + else: + dimension_specs.append(dimension.dimension_spec) return WhereFilterSpec( where_sql=rendered_sql_template, diff --git a/metricflow/test/specs/test_where_filter_dimension.py b/metricflow/test/specs/test_where_filter_dimension.py deleted file mode 100644 index a3c673daae..0000000000 --- a/metricflow/test/specs/test_where_filter_dimension.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -import pytest - -from metricflow.errors.errors import InvalidQuerySyntax -from metricflow.specs.where_filter_dimension import WhereFilterDimension - - -def test_descending_cannot_be_set() -> None: # noqa - with pytest.raises(InvalidQuerySyntax): - WhereFilterDimension("bookings").descending(True)