Skip to content

Commit

Permalink
Add WhereFilterMetric
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Mar 27, 2024
1 parent 501cbfa commit c3611f9
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
104 changes: 104 additions & 0 deletions metricflow/specs/where_filter_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

from typing import Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
MetricCallParameterSet,
)
from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint
from dbt_semantic_interfaces.protocols.query_interface import QueryInterfaceMetric, QueryInterfaceMetricFactory
from dbt_semantic_interfaces.references import EntityReference, LinkableElementReference, MetricReference
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
from metricflow.query.group_by_item.filter_spec_resolution.filter_location import WhereFilterLocation
from metricflow.query.group_by_item.filter_spec_resolution.filter_spec_lookup import (
FilterSpecResolutionLookUp,
ResolvedSpecLookUpKey,
)
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.rendered_spec_tracker import RenderedSpecTracker


class WhereFilterMetric(ProtocolHint[QueryInterfaceMetric]):
"""A metric that is passed in through the where filter parameter."""

@override
def _implements_protocol(self) -> QueryInterfaceMetric:
return self

def __init__( # noqa
self,
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
element_name: str,
group_by: Sequence[LinkableElementReference],
) -> None:
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = resolved_spec_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker
self._element_name = element_name
self._group_by = tuple(group_by)

def descending(self, _is_descending: bool) -> QueryInterfaceMetric:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)

def __str__(self) -> str:
"""Returns the column name.
Important in the Jinja sandbox.
"""
call_parameter_set = MetricCallParameterSet(
group_by=tuple(EntityReference(element_name=group_by_ref.element_name) for group_by_ref in self._group_by),
metric_reference=MetricReference(self._element_name),
)
resolved_spec = self._resolved_spec_lookup.checked_resolved_spec(
ResolvedSpecLookUpKey(
filter_location=self._where_filter_location,
call_parameter_set=call_parameter_set,
)
)
self._rendered_spec_tracker.record_rendered_spec(resolved_spec)
column_association = self._column_association_resolver.resolve_spec(resolved_spec)

return column_association.column_name


class WhereFilterMetricFactory(ProtocolHint[QueryInterfaceMetricFactory]):
"""Creates a WhereFilterMetric.
Each call to `create` adds a MetricSpec to metric_specs.
"""

@override
def _implements_protocol(self) -> QueryInterfaceMetricFactory:
return self

def __init__( # noqa
self,
column_association_resolver: ColumnAssociationResolver,
spec_resolution_lookup: FilterSpecResolutionLookUp,
where_filter_location: WhereFilterLocation,
rendered_spec_tracker: RenderedSpecTracker,
):
self._column_association_resolver = column_association_resolver
self._resolved_spec_lookup = spec_resolution_lookup
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker

def create(self, metric_name: str, group_by: Sequence[str] = ()) -> WhereFilterMetric:
"""Create a WhereFilterMetric."""
return WhereFilterMetric(
column_association_resolver=self._column_association_resolver,
resolved_spec_lookup=self._resolved_spec_lookup,
where_filter_location=self._where_filter_location,
rendered_spec_tracker=self._rendered_spec_tracker,
element_name=metric_name,
group_by=tuple(LinkableElementReference(group_by_name.lower()) for group_by_name in group_by),
)
8 changes: 8 additions & 0 deletions metricflow/specs/where_filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metricflow.specs.specs import LinkableSpecSet, WhereFilterSpec
from metricflow.specs.where_filter_dimension import WhereFilterDimensionFactory
from metricflow.specs.where_filter_entity import WhereFilterEntityFactory
from metricflow.specs.where_filter_metric import WhereFilterMetricFactory
from metricflow.specs.where_filter_time_dimension import WhereFilterTimeDimensionFactory
from metricflow.sql.sql_bind_parameters import SqlBindParameters

Expand Down Expand Up @@ -75,6 +76,12 @@ def create_from_where_filter_intersection( # noqa: D102
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
)
metric_factory = WhereFilterMetricFactory(
column_association_resolver=self._column_association_resolver,
spec_resolution_lookup=self._spec_resolution_lookup,
where_filter_location=filter_location,
rendered_spec_tracker=rendered_spec_tracker,
)
try:
# If there was an error with the template, it should have been caught while resolving the specs for
# the filters during query resolution.
Expand All @@ -83,6 +90,7 @@ def create_from_where_filter_intersection( # noqa: D102
"Dimension": dimension_factory.create,
"TimeDimension": time_dimension_factory.create,
"Entity": entity_factory.create,
"Metric": metric_factory.create,
}
)
except (jinja2.exceptions.UndefinedError, jinja2.exceptions.TemplateSyntaxError) as e:
Expand Down
36 changes: 36 additions & 0 deletions tests/model/test_where_filter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
EntityCallParameterSet,
MetricCallParameterSet,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.implementations.filters.where_filter import (
Expand Down Expand Up @@ -37,6 +38,7 @@
from metricflow.specs.specs import (
DimensionSpec,
EntitySpec,
GroupByMetricSpec,
LinkableInstanceSpec,
LinkableSpecSet,
TimeDimensionSpec,
Expand Down Expand Up @@ -101,6 +103,7 @@ def test_dimension_in_filter( # noqa: D103
),
time_dimension_specs=(),
entity_specs=(),
group_by_metric_specs=(),
)


Expand Down Expand Up @@ -140,6 +143,7 @@ def test_dimension_in_filter_with_grain( # noqa: D103
),
),
entity_specs=(),
group_by_metric_specs=(),
)


Expand Down Expand Up @@ -179,6 +183,7 @@ def test_time_dimension_in_filter( # noqa: D103
),
),
entity_specs=(),
group_by_metric_specs=(),
)


Expand Down Expand Up @@ -220,6 +225,7 @@ def test_date_part_in_filter( # noqa: D103
),
),
entity_specs=(),
group_by_metric_specs=(),
)


Expand Down Expand Up @@ -289,6 +295,7 @@ def test_date_part_and_grain_in_filter( # noqa: D103
),
),
entity_specs=(),
group_by_metric_specs=(),
)


Expand Down Expand Up @@ -325,6 +332,7 @@ def test_date_part_less_than_grain_in_filter( # noqa: D103
),
),
entity_specs=(),
group_by_metric_specs=(),
)


Expand Down Expand Up @@ -352,6 +360,34 @@ def test_entity_in_filter( # noqa: D103
dimension_specs=(),
time_dimension_specs=(),
entity_specs=(EntitySpec(element_name="user", entity_links=(EntityReference(element_name="listing"),)),),
group_by_metric_specs=(),
)


def test_metric_in_filter( # noqa: D103
column_association_resolver: ColumnAssociationResolver,
resolved_spec_lookup: FilterSpecResolutionLookUp,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template="{{ Metric('bookings', group_by=['listing']) }} > 2")

group_by_metric_spec = GroupByMetricSpec(element_name="bookings", entity_links=(EntityReference("listing"),))
where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
spec_resolution_lookup=create_spec_lookup(
call_parameter_set=MetricCallParameterSet(
group_by=(EntityReference("listing"),),
metric_reference=MetricReference("bookings"),
),
resolved_spec=group_by_metric_spec,
),
).create_from_where_filter(filter_location=EXAMPLE_FILTER_LOCATION, where_filter=where_filter)

assert where_filter_spec.where_sql == "listing__bookings > 2"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(),
entity_specs=(),
group_by_metric_specs=(group_by_metric_spec,),
)


Expand Down

0 comments on commit c3611f9

Please sign in to comment.