Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Mar 13, 2024
1 parent a337407 commit 355f79f
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 42 deletions.
107 changes: 107 additions & 0 deletions metricflow/specs/where_filter_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

from typing import Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
MetricCallParameterSet,
)
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import QueryInterfaceMetric, QueryInterfaceMetricFactory
from dbt_semantic_interfaces.references import MetricReference, LinkableElementReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
ResolvedSpecLookUpKey,
)
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.rendered_spec_tracker import RenderedSpecTracker


class WhereFilterMetric(ProtocolHint[QueryInterfaceMetric]):
"""A metric that is passed in through the where filter parameter."""

@override
def _implements_protocol(self) -> QueryInterfaceMetric:
return self

def __init__( # noqa
self,
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
element_name: str,
group_by: Sequence[LinkableElementReference],
) -> None:
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = resolved_spec_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker
self._element_name = element_name
self._group_by = tuple(group_by)

def descending(self, _is_descending: bool) -> QueryInterfaceMetric:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)

def __str__(self) -> str:
"""Returns the column name.
Important in the Jinja sandbox.
"""
call_parameter_set = MetricCallParameterSet(
group_by=self._group_by,
metric_reference=MetricReference(self._element_name),
)
resolved_spec = self._resolved_spec_lookup.checked_resolved_spec(
ResolvedSpecLookUpKey(
filter_location=self._where_filter_location,
call_parameter_set=call_parameter_set,
)
)
self._rendered_spec_tracker.record_rendered_spec(resolved_spec)
column_association = self._column_association_resolver.resolve_spec(resolved_spec)

return column_association.column_name


class WhereFilterMetricFactory(ProtocolHint[QueryInterfaceMetricFactory]):
"""Creates a WhereFilterMetric.
Each call to `create` adds a MetricSpec to metric_specs.
"""

@override
def _implements_protocol(self) -> QueryInterfaceMetricFactory:
return self

def __init__( # noqa
self,
column_association_resolver: ColumnAssociationResolver,
spec_resolution_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
):
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = spec_resolution_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker

def create(self, metric_name: str, group_by: Sequence[str] = ()) -> WhereFilterMetric:
"""Create a WhereFilterMetric."""
return WhereFilterMetric(
column_association_resolver=self._column_association_resolver,
resolved_spec_lookup=self._resolved_spec_lookup,
where_filter_location=self._where_filter_location,
rendered_spec_tracker=self._rendered_spec_tracker,
element_name=metric_name,
group_by=tuple(LinkableElementReference(group_by_name.lower()) for group_by_name in group_by),
)
36 changes: 22 additions & 14 deletions metricflow/specs/where_filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metricflow.specs.specs import LinkableSpecSet, WhereFilterSpec
from metricflow.specs.where_filter_dimension import WhereFilterDimensionFactory
from metricflow.specs.where_filter_entity import WhereFilterEntityFactory
from metricflow.specs.where_filter_metric import WhereFilterMetricFactory
from metricflow.specs.where_filter_time_dimension import WhereFilterTimeDimensionFactory
from metricflow.sql.sql_bind_parameters import SqlBindParameters

Expand Down Expand Up @@ -75,20 +76,27 @@ def create_from_where_filter_intersection( # noqa: D
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
)
try:
# If there was an error with the template, it should have been caught while resolving the specs for
# the filters during query resolution.
where_sql = jinja2.Template(where_filter.where_sql_template, undefined=jinja2.StrictUndefined).render(
{
"Dimension": dimension_factory.create,
"TimeDimension": time_dimension_factory.create,
"Entity": entity_factory.create,
}
)
except (jinja2.exceptions.UndefinedError, jinja2.exceptions.TemplateSyntaxError) as e:
raise RenderSqlTemplateException(
f"Error while rendering Jinja template:\n{where_filter.where_sql_template}"
) from e
metric_factory = WhereFilterMetricFactory(
column_association_resolver=self._column_association_resolver,
spec_resolution_lookup=self._spec_resolution_lookup,
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
)
# try:
# If there was an error with the template, it should have been caught while resolving the specs for
# the filters during query resolution.
where_sql = jinja2.Template(where_filter.where_sql_template, undefined=jinja2.StrictUndefined).render(
{
"Dimension": dimension_factory.create,
"TimeDimension": time_dimension_factory.create,
"Entity": entity_factory.create,
"Metric": metric_factory.create,
}
)
# except (jinja2.exceptions.UndefinedError, jinja2.exceptions.TemplateSyntaxError) as e:
# raise RenderSqlTemplateException(
# f"Error while rendering Jinja template:\n{where_filter.where_sql_template}"
# ) from e
filter_specs.append(
WhereFilterSpec(
where_sql=where_sql,
Expand Down
8 changes: 4 additions & 4 deletions metricflow/test/fixtures/manifest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def mf_engine_test_fixture_mapping(
fixture_mapping: Dict[SemanticManifestSetup, MetricFlowEngineTestFixture] = {}
for semantic_manifest_setup in SemanticManifestSetup:
with patch_id_generators_helper(semantic_manifest_setup.id_number_space.start_value):
try:
build_result = load_semantic_manifest(semantic_manifest_setup.semantic_manifest_name, template_mapping)
except Exception as e:
raise RuntimeError(f"Error while loading semantic manifest: {semantic_manifest_setup}") from e
# try:
build_result = load_semantic_manifest(semantic_manifest_setup.semantic_manifest_name, template_mapping)
# except Exception as e:
# raise RuntimeError(f"Error while loading semantic manifest: {semantic_manifest_setup}") from e

fixture_mapping[semantic_manifest_setup] = MetricFlowEngineTestFixture.from_parameters(
sql_client, build_result.semantic_manifest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,11 @@ metric:
window: 7 days
entity: user
calculation: conversion_rate
---
metric:
name: active_listings
description: Listings with at least 2 bookings
type: simple
type_params:
measure: listings
filter: "{{ Metric('bookings', ['listing']) }} > 2"
8 changes: 8 additions & 0 deletions metricflow/test/integration/test_cases/itest_metrics.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1889,3 +1889,11 @@ integration_test:
ds
) b
ON {{ render_date_sub("a", "ds", 5, TimeGranularity.DAY) }} = b.ds
---
integration_test:
name: active_listings
description: Query a metric that has a filter containing a metric
model: SIMPLE_MODEL
metrics: ["active_listings"]
check_query: |
SELECT 1
51 changes: 27 additions & 24 deletions metricflow/test/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def filter_not_supported_features(

@pytest.mark.parametrize(
"name",
CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names,
# CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names,
["itest_metrics.yaml/active_listings"],
ids=lambda name: f"name={name}",
)
def test_case(
Expand Down Expand Up @@ -301,29 +302,31 @@ def test_case(
limit=case.limit,
time_constraint_start=parser.parse(case.time_constraint[0]) if case.time_constraint else None,
time_constraint_end=parser.parse(case.time_constraint[1]) if case.time_constraint else None,
where_constraint=jinja2.Template(
case.where_filter,
undefined=jinja2.StrictUndefined,
).render(
source_schema=mf_test_session_state.mf_source_schema,
render_time_constraint=check_query_helpers.render_time_constraint,
render_between_time_constraint=check_query_helpers.render_between_time_constraint,
TimeGranularity=TimeGranularity,
DatePart=DatePart,
render_date_sub=check_query_helpers.render_date_sub,
render_date_trunc=check_query_helpers.render_date_trunc,
render_extract=check_query_helpers.render_extract,
render_percentile_expr=check_query_helpers.render_percentile_expr,
mf_time_spine_source=semantic_manifest_lookup.time_spine_source.spine_table.sql,
double_data_type_name=check_query_helpers.double_data_type_name,
render_dimension_template=check_query_helpers.render_dimension_template,
render_entity_template=check_query_helpers.render_entity_template,
render_time_dimension_template=check_query_helpers.render_time_dimension_template,
generate_random_uuid=check_query_helpers.generate_random_uuid,
cast_to_ts=check_query_helpers.cast_to_ts,
)
if case.where_filter
else None,
where_constraint=(
jinja2.Template(
case.where_filter,
undefined=jinja2.StrictUndefined,
).render(
source_schema=mf_test_session_state.mf_source_schema,
render_time_constraint=check_query_helpers.render_time_constraint,
render_between_time_constraint=check_query_helpers.render_between_time_constraint,
TimeGranularity=TimeGranularity,
DatePart=DatePart,
render_date_sub=check_query_helpers.render_date_sub,
render_date_trunc=check_query_helpers.render_date_trunc,
render_extract=check_query_helpers.render_extract,
render_percentile_expr=check_query_helpers.render_percentile_expr,
mf_time_spine_source=semantic_manifest_lookup.time_spine_source.spine_table.sql,
double_data_type_name=check_query_helpers.double_data_type_name,
render_dimension_template=check_query_helpers.render_dimension_template,
render_entity_template=check_query_helpers.render_entity_template,
render_time_dimension_template=check_query_helpers.render_time_dimension_template,
generate_random_uuid=check_query_helpers.generate_random_uuid,
cast_to_ts=check_query_helpers.cast_to_ts,
)
if case.where_filter
else None
),
order_by_names=case.order_bys,
min_max_only=case.min_max_only,
)
Expand Down

0 comments on commit 355f79f

Please sign in to comment.