From ed608fbd20b48db1aa7a42ff34fb5e6483341f4a Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Fri, 15 Mar 2024 19:18:16 -0700 Subject: [PATCH] Allow metrics in filters --- .../call_parameter_sets.py | 11 +++++ .../implementations/filters/where_filter.py | 8 ++-- .../where_filter/parameter_set_factory.py | 11 +++++ .../where_filter/where_filter_entity.py | 40 ++++++++++++++++++- .../where_filter/where_filter_parser.py | 27 ++++++------- .../protocols/query_interface.py | 12 ++++++ 6 files changed, 89 insertions(+), 20 deletions(-) diff --git a/dbt_semantic_interfaces/call_parameter_sets.py b/dbt_semantic_interfaces/call_parameter_sets.py index 5a3b23ac..49c663c7 100644 --- a/dbt_semantic_interfaces/call_parameter_sets.py +++ b/dbt_semantic_interfaces/call_parameter_sets.py @@ -6,7 +6,9 @@ from dbt_semantic_interfaces.references import ( DimensionReference, EntityReference, + MetricReference, TimeDimensionReference, + LinkableElementReference, ) from dbt_semantic_interfaces.type_enums import TimeGranularity from dbt_semantic_interfaces.type_enums.date_part import DatePart @@ -38,6 +40,14 @@ class EntityCallParameterSet: entity_reference: EntityReference +@dataclass(frozen=True) +class MetricCallParameterSet: + """When 'Metric(...)' is used in the Jinja template of the where filter, the parameters to that call.""" + + metric_reference: MetricReference + group_by: Tuple[LinkableElementReference, ...] + + @dataclass(frozen=True) class FilterCallParameterSets: """The calls for metric items made in the Jinja template of the where filter.""" @@ -45,6 +55,7 @@ class FilterCallParameterSets: dimension_call_parameter_sets: Tuple[DimensionCallParameterSet, ...] = () time_dimension_call_parameter_sets: Tuple[TimeDimensionCallParameterSet, ...] = () entity_call_parameter_sets: Tuple[EntityCallParameterSet, ...] = () + metric_call_parameter_sets: Tuple[MetricCallParameterSet, ...] = () class ParseWhereFilterException(Exception): # noqa: D diff --git a/dbt_semantic_interfaces/implementations/filters/where_filter.py b/dbt_semantic_interfaces/implementations/filters/where_filter.py index 96a74581..a6518d39 100644 --- a/dbt_semantic_interfaces/implementations/filters/where_filter.py +++ b/dbt_semantic_interfaces/implementations/filters/where_filter.py @@ -121,10 +121,10 @@ def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParamete filter_parameter_sets: List[Tuple[str, FilterCallParameterSets]] = [] invalid_filter_expressions: List[Tuple[str, Exception]] = [] for where_filter in self.where_filters: - try: - filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets)) - except Exception as e: - invalid_filter_expressions.append((where_filter.where_sql_template, e)) + # try: + filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets)) + # except Exception as e: + # invalid_filter_expressions.append((where_filter.where_sql_template, e)) if invalid_filter_expressions: lines = ["Encountered error(s) while parsing:\n"] diff --git a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py index 128b0553..1bbdc215 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py +++ b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py @@ -3,6 +3,7 @@ from dbt_semantic_interfaces.call_parameter_sets import ( DimensionCallParameterSet, EntityCallParameterSet, + MetricCallParameterSet, ParseWhereFilterException, TimeDimensionCallParameterSet, ) @@ -14,7 +15,9 @@ from dbt_semantic_interfaces.references import ( DimensionReference, EntityReference, + MetricReference, TimeDimensionReference, + LinkableElementReference, ) from dbt_semantic_interfaces.type_enums import TimeGranularity from dbt_semantic_interfaces.type_enums.date_part import DatePart @@ -101,3 +104,11 @@ def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCa entity_path=additional_entity_path_elements + structured_dundered_name.entity_links, entity_reference=EntityReference(element_name=structured_dundered_name.element_name), ) + + @staticmethod + def create_metric(metric_name: str, group_by: Sequence[str] = ()) -> MetricCallParameterSet: + """Gets called by Jinja when rendering {{ Metric(...) }}.""" + return MetricCallParameterSet( + metric_reference=MetricReference(element_name=metric_name), + group_by=tuple([LinkableElementReference(element_name=group_by_name) for group_by_name in group_by]), + ) diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py index 83e5fe8c..670b0968 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_entity.py @@ -4,7 +4,11 @@ from typing_extensions import override -from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet +from dbt_semantic_interfaces.call_parameter_sets import ( + EntityCallParameterSet, + MetricCallParameterSet, +) +from dbt_semantic_interfaces.errors import InvalidQuerySyntax from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( ParameterSetFactory, ) @@ -12,6 +16,8 @@ from dbt_semantic_interfaces.protocols.query_interface import ( QueryInterfaceEntity, QueryInterfaceEntityFactory, + QueryInterfaceMetric, + QueryInterfaceMetricFactory, ) @@ -27,6 +33,21 @@ def _implements_protocol(self) -> QueryInterfaceEntity: return self +class MetricStub(ProtocolHint[QueryInterfaceMetric]): + """A Metric implementation that just satisfies the protocol. + + QueryInterfaceMetric currently has no methods and the parameter set is created in the factory. + So, there is nothing to do here. + """ + + @override + def _implements_protocol(self) -> QueryInterfaceMetric: + return self + + def descending(self, _is_descending: bool) -> QueryInterfaceMetric: # noqa: D + raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec") + + class WhereFilterEntityFactory(ProtocolHint[QueryInterfaceEntityFactory]): """Executes in the Jinja sandbox to produce parameter sets and append them to a list.""" @@ -41,3 +62,20 @@ def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> EntityStu """Gets called by Jinja when rendering {{ Entity(...) }}.""" self.entity_call_parameter_sets.append(ParameterSetFactory.create_entity(entity_name, entity_path)) return EntityStub() + + +class WhereFilterMetricFactory(ProtocolHint[QueryInterfaceMetricFactory]): + """Executes in the Jinja sandbox to produce parameter sets and append them to a list.""" + + @override + def _implements_protocol(self) -> QueryInterfaceMetricFactory: + return self + + def __init__(self) -> None: # noqa: D + self.metric_call_parameter_sets: List[MetricCallParameterSet] = [] + + def create(self, metric_name: str, group_by: Sequence[str] = ()) -> MetricStub: # noqa: D + self.metric_call_parameter_sets.append( + ParameterSetFactory.create_metric(metric_name=metric_name, group_by=group_by) + ) + return MetricStub() diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py index 083612ce..bce28375 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py @@ -1,13 +1,9 @@ from __future__ import annotations from jinja2 import StrictUndefined -from jinja2.exceptions import SecurityError, TemplateSyntaxError, UndefinedError from jinja2.sandbox import SandboxedEnvironment -from dbt_semantic_interfaces.call_parameter_sets import ( - FilterCallParameterSets, - ParseWhereFilterException, -) +from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( ParameterSetFactory, ) @@ -16,6 +12,7 @@ ) from dbt_semantic_interfaces.parsing.where_filter.where_filter_entity import ( WhereFilterEntityFactory, + WhereFilterMetricFactory, ) from dbt_semantic_interfaces.parsing.where_filter.where_filter_time_dimension import ( WhereFilterTimeDimensionFactory, @@ -31,16 +28,15 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet time_dimension_factory = WhereFilterTimeDimensionFactory() dimension_factory = WhereFilterDimensionFactory() entity_factory = WhereFilterEntityFactory() - - try: - # the string that the sandbox renders is unused - SandboxedEnvironment(undefined=StrictUndefined).from_string(where_sql_template).render( - Dimension=dimension_factory.create, - TimeDimension=time_dimension_factory.create, - Entity=entity_factory.create, - ) - except (UndefinedError, TemplateSyntaxError, SecurityError) as e: - raise ParseWhereFilterException(f"Error while parsing Jinja template:\n{where_sql_template}") from e + metric_factory = WhereFilterMetricFactory() + + # the string that the sandbox renders is unused + SandboxedEnvironment(undefined=StrictUndefined).from_string(where_sql_template).render( + Dimension=dimension_factory.create, + TimeDimension=time_dimension_factory.create, + Entity=entity_factory.create, + Metric=metric_factory.create, + ) """ Dimensions that are created with a grain or date_part parameter, for instance Dimension(...).grain(...), are @@ -63,4 +59,5 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet dimension_call_parameter_sets=tuple(dimension_call_parameter_sets), time_dimension_call_parameter_sets=tuple(time_dimension_factory.time_dimension_call_parameter_sets), entity_call_parameter_sets=tuple(entity_factory.entity_call_parameter_sets), + metric_call_parameter_sets=tuple(metric_factory.metric_call_parameter_sets), ) diff --git a/dbt_semantic_interfaces/protocols/query_interface.py b/dbt_semantic_interfaces/protocols/query_interface.py index 2831740b..70225450 100644 --- a/dbt_semantic_interfaces/protocols/query_interface.py +++ b/dbt_semantic_interfaces/protocols/query_interface.py @@ -85,3 +85,15 @@ class QueryInterfaceEntityFactory(Protocol): def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> QueryInterfaceEntity: """Create an Entity.""" pass + + +class QueryInterfaceMetricFactory(Protocol): + """Creates an Metric for the query interface. + + Represented as the Metric constructor in the Jinja sandbox. + """ + + @abstractmethod + def create(self, metric_name: str, group_by: Sequence[str] = ()) -> QueryInterfaceMetric: + """Create a Metric.""" + pass