diff --git a/.changes/unreleased/Under the Hood-20231017-155210.yaml b/.changes/unreleased/Under the Hood-20231017-155210.yaml new file mode 100644 index 0000000000..787ea45ab6 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20231017-155210.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: re-categorize `TypeErrors` that arise from `create_from_where_filter` into `InvalidQueryException` +time: 2023-10-17T15:52:10.948956-05:00 +custom: + Author: DevonFulcher + Issue: None diff --git a/metricflow/query/query_parser.py b/metricflow/query/query_parser.py index 406cb92645..665295405c 100644 --- a/metricflow/query/query_parser.py +++ b/metricflow/query/query_parser.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple -from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, PydanticWhereFilterIntersection, @@ -445,14 +444,9 @@ def _parse_and_validate_query( ) where_filter_spec: Optional[WhereFilterSpec] = None if where_filter is not None: - try: - where_filter_spec = WhereSpecFactory( - column_association_resolver=self._column_association_resolver, - ).create_from_where_filter(where_filter) - except ParseWhereFilterException as e: - raise InvalidQueryException( - f"Error parsing the where filter: {where_filter.where_sql_template}. {e}" - ) from e + where_filter_spec = WhereSpecFactory( + column_association_resolver=self._column_association_resolver, + ).create_from_where_filter(where_filter) where_spec_set = QueryTimeLinkableSpecSet.create_from_linkable_spec_set(where_filter_spec.linkable_spec_set) requested_linkable_specs_with_requested_filter_specs = QueryTimeLinkableSpecSet.combine( diff --git a/metricflow/specs/where_filter_transform.py b/metricflow/specs/where_filter_transform.py index 4b0954bc53..c039cee1f8 100644 --- a/metricflow/specs/where_filter_transform.py +++ b/metricflow/specs/where_filter_transform.py @@ -4,10 +4,12 @@ from typing import Optional import jinja2 +from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException from dbt_semantic_interfaces.protocols import WhereFilterIntersection from dbt_semantic_interfaces.protocols.where_filter import WhereFilter from metricflow.filters.merge_where import merge_to_single_where_filter +from metricflow.query.query_exceptions import InvalidQueryException from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import LinkableSpecSet, WhereFilterSpec from metricflow.specs.where_filter_dimension import WhereFilterDimensionFactory @@ -45,17 +47,15 @@ def create_from_where_filter_intersection( # noqa: D return self.create_from_where_filter(where_filter) - def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec: # noqa: D - # Used to check that call parameter sets generated by DSI match those generated below. - call_parameter_sets = where_filter.call_parameter_sets - - dimension_factory = WhereFilterDimensionFactory(call_parameter_sets, self._column_association_resolver) - time_dimension_factory = WhereFilterTimeDimensionFactory(call_parameter_sets, self._column_association_resolver) - entity_factory = WhereFilterEntityFactory(call_parameter_sets, self._column_association_resolver) + def _render_sql_template( + self, + where_filter: WhereFilter, + dimension_factory: WhereFilterDimensionFactory, + time_dimension_factory: WhereFilterTimeDimensionFactory, + entity_factory: WhereFilterEntityFactory, + ) -> str: try: - rendered_sql_template = jinja2.Template( - where_filter.where_sql_template, undefined=jinja2.StrictUndefined - ).render( + return jinja2.Template(where_filter.where_sql_template, undefined=jinja2.StrictUndefined).render( { "Dimension": dimension_factory.create, "TimeDimension": time_dimension_factory.create, @@ -67,23 +67,41 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec f"Error while rendering Jinja template:\n{where_filter.where_sql_template}" ) from e - """ - Dimensions that are created with a grain parameter, Dimension(...).grain(...), are - added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs - """ - dimension_specs = [] - for dimension in dimension_factory.created: - if dimension.time_dimension_spec: - time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec) - else: - dimension_specs.append(dimension.dimension_spec) - - return WhereFilterSpec( - where_sql=rendered_sql_template, - bind_parameters=self._bind_parameters, - linkable_spec_set=LinkableSpecSet( - dimension_specs=tuple(dimension_specs), - time_dimension_specs=tuple(time_dimension_factory.time_dimension_specs), - entity_specs=tuple(entity_factory.entity_specs), - ), - ) + def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec: + """Generates WhereFilterSpec using Jinja.""" + try: + call_parameter_sets = where_filter.call_parameter_sets + + dimension_factory = WhereFilterDimensionFactory(call_parameter_sets, self._column_association_resolver) + time_dimension_factory = WhereFilterTimeDimensionFactory( + call_parameter_sets, self._column_association_resolver + ) + entity_factory = WhereFilterEntityFactory(call_parameter_sets, self._column_association_resolver) + rendered_sql_template = self._render_sql_template( + where_filter, dimension_factory, time_dimension_factory, entity_factory + ) + + """ + Dimensions that are created with a grain parameter, Dimension(...).grain(...), are + added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs + """ + dimension_specs = [] + for dimension in dimension_factory.created: + if dimension.time_dimension_spec: + time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec) + else: + dimension_specs.append(dimension.dimension_spec) + + return WhereFilterSpec( + where_sql=rendered_sql_template, + bind_parameters=self._bind_parameters, + linkable_spec_set=LinkableSpecSet( + dimension_specs=tuple(dimension_specs), + time_dimension_specs=tuple(time_dimension_factory.time_dimension_specs), + entity_specs=tuple(entity_factory.entity_specs), + ), + ) + except (ParseWhereFilterException, TypeError) as e: + raise InvalidQueryException( + f"Error parsing the where filter: {where_filter.where_sql_template}. {e}" + ) from e diff --git a/metricflow/test/model/test_where_filter_spec.py b/metricflow/test/model/test_where_filter_spec.py index 428fa17c21..80dd4cc0fa 100644 --- a/metricflow/test/model/test_where_filter_spec.py +++ b/metricflow/test/model/test_where_filter_spec.py @@ -2,10 +2,12 @@ import logging +import pytest from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter from dbt_semantic_interfaces.references import EntityReference from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from metricflow.query.query_exceptions import InvalidQueryException from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( DimensionSpec, @@ -62,6 +64,15 @@ def test_dimension_in_filter_with_grain( # noqa: D ) +def test_time_dimension_without_grain(column_association_resolver: ColumnAssociationResolver) -> None: # noqa + where_filter = PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time') }} > '2023-10-17'") + + with pytest.raises(InvalidQueryException): + WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter(where_filter) + + def test_time_dimension_in_filter( # noqa: D column_association_resolver: ColumnAssociationResolver, ) -> None: