From 355f79f310a7cee09bdd99f64cf03ef03b809b61 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 13 Mar 2024 11:11:51 -0700 Subject: [PATCH] WIP --- metricflow/specs/where_filter_metric.py | 107 ++++++++++++++++++ metricflow/specs/where_filter_transform.py | 36 +++--- metricflow/test/fixtures/manifest_fixtures.py | 8 +- .../simple_manifest/metrics.yaml | 8 ++ .../integration/test_cases/itest_metrics.yaml | 8 ++ .../test/integration/test_configured_cases.py | 51 +++++---- 6 files changed, 176 insertions(+), 42 deletions(-) create mode 100644 metricflow/specs/where_filter_metric.py diff --git a/metricflow/specs/where_filter_metric.py b/metricflow/specs/where_filter_metric.py new file mode 100644 index 0000000000..0309cffdc6 --- /dev/null +++ b/metricflow/specs/where_filter_metric.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from typing import Optional, Sequence + +from dbt_semantic_interfaces.call_parameter_sets import ( + MetricCallParameterSet, +) +from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter +from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint +from dbt_semantic_interfaces.protocols.query_interface import QueryInterfaceMetric, QueryInterfaceMetricFactory +from dbt_semantic_interfaces.references import MetricReference, LinkableElementReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart +from typing_extensions import override + +from metricflow.errors.errors import InvalidQuerySyntax +from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation +from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import ( + FilterSpecResolutionLookUp, + ResolvedSpecLookUpKey, +) +from metricflow.specs.column_assoc import ColumnAssociationResolver +from metricflow.specs.rendered_spec_tracker import RenderedSpecTracker + + +class WhereFilterMetric(ProtocolHint[QueryInterfaceMetric]): + """A metric that is passed in through the where filter parameter.""" + + @override + def _implements_protocol(self) -> QueryInterfaceMetric: + return self + + def __init__( # noqa + self, + column_association_resolver: ColumnAssociationResolver, + resolved_spec_lookup: FilterSpecResolutionLookUp, + where_filter_location: WhereFilterLocation, + rendered_spec_tracker: RenderedSpecTracker, + element_name: str, + group_by: Sequence[LinkableElementReference], + ) -> None: + self._column_association_resolver = column_association_resolver + self._resolved_spec_lookup = resolved_spec_lookup + self._where_filter_location = where_filter_location + self._rendered_spec_tracker = rendered_spec_tracker + self._element_name = element_name + self._group_by = tuple(group_by) + + def descending(self, _is_descending: bool) -> QueryInterfaceMetric: + """Set the sort order for order-by.""" + raise InvalidQuerySyntax( + "Can't set descending in the where clause. Try setting descending in the order_by clause instead" + ) + + def __str__(self) -> str: + """Returns the column name. + + Important in the Jinja sandbox. + """ + call_parameter_set = MetricCallParameterSet( + group_by=self._group_by, + metric_reference=MetricReference(self._element_name), + ) + resolved_spec = self._resolved_spec_lookup.checked_resolved_spec( + ResolvedSpecLookUpKey( + filter_location=self._where_filter_location, + call_parameter_set=call_parameter_set, + ) + ) + self._rendered_spec_tracker.record_rendered_spec(resolved_spec) + column_association = self._column_association_resolver.resolve_spec(resolved_spec) + + return column_association.column_name + + +class WhereFilterMetricFactory(ProtocolHint[QueryInterfaceMetricFactory]): + """Creates a WhereFilterMetric. + + Each call to `create` adds a MetricSpec to metric_specs. + """ + + @override + def _implements_protocol(self) -> QueryInterfaceMetricFactory: + return self + + def __init__( # noqa + self, + column_association_resolver: ColumnAssociationResolver, + spec_resolution_lookup: FilterSpecResolutionLookUp, + where_filter_location: WhereFilterLocation, + rendered_spec_tracker: RenderedSpecTracker, + ): + self._column_association_resolver = column_association_resolver + self._resolved_spec_lookup = spec_resolution_lookup + self._where_filter_location = where_filter_location + self._rendered_spec_tracker = rendered_spec_tracker + + def create(self, metric_name: str, group_by: Sequence[str] = ()) -> WhereFilterMetric: + """Create a WhereFilterMetric.""" + return WhereFilterMetric( + column_association_resolver=self._column_association_resolver, + resolved_spec_lookup=self._resolved_spec_lookup, + where_filter_location=self._where_filter_location, + rendered_spec_tracker=self._rendered_spec_tracker, + element_name=metric_name, + group_by=tuple(LinkableElementReference(group_by_name.lower()) for group_by_name in group_by), + ) diff --git a/metricflow/specs/where_filter_transform.py b/metricflow/specs/where_filter_transform.py index cb38f50614..2420e53838 100644 --- a/metricflow/specs/where_filter_transform.py +++ b/metricflow/specs/where_filter_transform.py @@ -14,6 +14,7 @@ from metricflow.specs.specs import LinkableSpecSet, WhereFilterSpec from metricflow.specs.where_filter_dimension import WhereFilterDimensionFactory from metricflow.specs.where_filter_entity import WhereFilterEntityFactory +from metricflow.specs.where_filter_metric import WhereFilterMetricFactory from metricflow.specs.where_filter_time_dimension import WhereFilterTimeDimensionFactory from metricflow.sql.sql_bind_parameters import SqlBindParameters @@ -75,20 +76,27 @@ def create_from_where_filter_intersection( # noqa: D where_filter_location=filter_location, rendered_spec_tracker=rendered_spec_tracker, ) - try: - # If there was an error with the template, it should have been caught while resolving the specs for - # the filters during query resolution. - where_sql = jinja2.Template(where_filter.where_sql_template, undefined=jinja2.StrictUndefined).render( - { - "Dimension": dimension_factory.create, - "TimeDimension": time_dimension_factory.create, - "Entity": entity_factory.create, - } - ) - except (jinja2.exceptions.UndefinedError, jinja2.exceptions.TemplateSyntaxError) as e: - raise RenderSqlTemplateException( - f"Error while rendering Jinja template:\n{where_filter.where_sql_template}" - ) from e + metric_factory = WhereFilterMetricFactory( + column_association_resolver=self._column_association_resolver, + spec_resolution_lookup=self._spec_resolution_lookup, + where_filter_location=filter_location, + rendered_spec_tracker=rendered_spec_tracker, + ) + # try: + # If there was an error with the template, it should have been caught while resolving the specs for + # the filters during query resolution. + where_sql = jinja2.Template(where_filter.where_sql_template, undefined=jinja2.StrictUndefined).render( + { + "Dimension": dimension_factory.create, + "TimeDimension": time_dimension_factory.create, + "Entity": entity_factory.create, + "Metric": metric_factory.create, + } + ) + # except (jinja2.exceptions.UndefinedError, jinja2.exceptions.TemplateSyntaxError) as e: + # raise RenderSqlTemplateException( + # f"Error while rendering Jinja template:\n{where_filter.where_sql_template}" + # ) from e filter_specs.append( WhereFilterSpec( where_sql=where_sql, diff --git a/metricflow/test/fixtures/manifest_fixtures.py b/metricflow/test/fixtures/manifest_fixtures.py index c49e1ac81e..b5c186a7be 100644 --- a/metricflow/test/fixtures/manifest_fixtures.py +++ b/metricflow/test/fixtures/manifest_fixtures.py @@ -224,10 +224,10 @@ def mf_engine_test_fixture_mapping( fixture_mapping: Dict[SemanticManifestSetup, MetricFlowEngineTestFixture] = {} for semantic_manifest_setup in SemanticManifestSetup: with patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value): - try: - build_result = load_semantic_manifest(semantic_manifest_setup.semantic_manifest_name, template_mapping) - except Exception as e: - raise RuntimeError(f"Error while loading semantic manifest: {semantic_manifest_setup}") from e + # try: + build_result = load_semantic_manifest(semantic_manifest_setup.semantic_manifest_name, template_mapping) + # except Exception as e: + # raise RuntimeError(f"Error while loading semantic manifest: {semantic_manifest_setup}") from e fixture_mapping[semantic_manifest_setup] = MetricFlowEngineTestFixture.from_parameters( sql_client, build_result.semantic_manifest diff --git a/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/metrics.yaml b/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/metrics.yaml index 3648c71ec6..b63b75a6c3 100644 --- a/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/metrics.yaml +++ b/metricflow/test/fixtures/semantic_manifest_yamls/simple_manifest/metrics.yaml @@ -753,3 +753,11 @@ metric: window: 7 days entity: user calculation: conversion_rate +--- +metric: + name: active_listings + description: Listings with at least 2 bookings + type: simple + type_params: + measure: listings + filter: "{{ Metric('bookings', ['listing']) }} > 2" diff --git a/metricflow/test/integration/test_cases/itest_metrics.yaml b/metricflow/test/integration/test_cases/itest_metrics.yaml index 266148b72b..0e4920d26b 100644 --- a/metricflow/test/integration/test_cases/itest_metrics.yaml +++ b/metricflow/test/integration/test_cases/itest_metrics.yaml @@ -1889,3 +1889,11 @@ integration_test: ds ) b ON {{ render_date_sub("a", "ds", 5, TimeGranularity.DAY) }} = b.ds +--- +integration_test: + name: active_listings + description: Query a metric that has a filter containing a metric + model: SIMPLE_MODEL + metrics: ["active_listings"] + check_query: | + SELECT 1 diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index 1eae8106c7..ba27533e13 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -230,7 +230,8 @@ def filter_not_supported_features( @pytest.mark.parametrize( "name", - CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names, + # CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names, + ["itest_metrics.yaml/active_listings"], ids=lambda name: f"name={name}", ) def test_case( @@ -301,29 +302,31 @@ def test_case( limit=case.limit, time_constraint_start=parser.parse(case.time_constraint[0]) if case.time_constraint else None, time_constraint_end=parser.parse(case.time_constraint[1]) if case.time_constraint else None, - where_constraint=jinja2.Template( - case.where_filter, - undefined=jinja2.StrictUndefined, - ).render( - source_schema=mf_test_session_state.mf_source_schema, - render_time_constraint=check_query_helpers.render_time_constraint, - render_between_time_constraint=check_query_helpers.render_between_time_constraint, - TimeGranularity=TimeGranularity, - DatePart=DatePart, - render_date_sub=check_query_helpers.render_date_sub, - render_date_trunc=check_query_helpers.render_date_trunc, - render_extract=check_query_helpers.render_extract, - render_percentile_expr=check_query_helpers.render_percentile_expr, - mf_time_spine_source=semantic_manifest_lookup.time_spine_source.spine_table.sql, - double_data_type_name=check_query_helpers.double_data_type_name, - render_dimension_template=check_query_helpers.render_dimension_template, - render_entity_template=check_query_helpers.render_entity_template, - render_time_dimension_template=check_query_helpers.render_time_dimension_template, - generate_random_uuid=check_query_helpers.generate_random_uuid, - cast_to_ts=check_query_helpers.cast_to_ts, - ) - if case.where_filter - else None, + where_constraint=( + jinja2.Template( + case.where_filter, + undefined=jinja2.StrictUndefined, + ).render( + source_schema=mf_test_session_state.mf_source_schema, + render_time_constraint=check_query_helpers.render_time_constraint, + render_between_time_constraint=check_query_helpers.render_between_time_constraint, + TimeGranularity=TimeGranularity, + DatePart=DatePart, + render_date_sub=check_query_helpers.render_date_sub, + render_date_trunc=check_query_helpers.render_date_trunc, + render_extract=check_query_helpers.render_extract, + render_percentile_expr=check_query_helpers.render_percentile_expr, + mf_time_spine_source=semantic_manifest_lookup.time_spine_source.spine_table.sql, + double_data_type_name=check_query_helpers.double_data_type_name, + render_dimension_template=check_query_helpers.render_dimension_template, + render_entity_template=check_query_helpers.render_entity_template, + render_time_dimension_template=check_query_helpers.render_time_dimension_template, + generate_random_uuid=check_query_helpers.generate_random_uuid, + cast_to_ts=check_query_helpers.cast_to_ts, + ) + if case.where_filter + else None + ), order_by_names=case.order_bys, min_max_only=case.min_max_only, )