-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Please see the following feature request: #887
- Loading branch information
Showing
6 changed files
with
518 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from dataclasses import asdict, dataclass | ||
from enum import Enum | ||
from typing import List, Optional, Sequence, Tuple | ||
|
||
from dbt_semantic_interfaces.call_parameter_sets import ( | ||
DimensionCallParameterSet, | ||
EntityCallParameterSet, | ||
TimeDimensionCallParameterSet, | ||
) | ||
from dbt_semantic_interfaces.references import EntityReference | ||
from dbt_semantic_interfaces.type_enums import TimeGranularity | ||
from dbt_semantic_interfaces.type_enums.date_part import DatePart | ||
from typing_extensions import override | ||
|
||
from metricflow.specs.patterns.spec_pattern import SpecPattern | ||
from metricflow.specs.specs import LinkableInstanceSpec, LinkableSpecSet | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ParameterSetField(Enum): | ||
"""The fields of the EntityLinkPatternParameterSet class used for matching in the EntityLinkPattern. | ||
Considering moving this to be a part of the specs module / classes. | ||
""" | ||
|
||
ELEMENT_NAME = "element_name" | ||
ENTITY_LINKS = "entity_links" | ||
TIME_GRANULARITY = "time_granularity" | ||
DATE_PART = "date_part" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EntityLinkPatternParameterSet: | ||
"""See EntityPathPattern for more details.""" | ||
|
||
# Specify the field values to compare. None can't be used to signal "don't compare" because sometimes a pattern | ||
# needs to match a spec where the field is None. | ||
fields_to_compare: Tuple[ParameterSetField, ...] | ||
|
||
# The name of the element in the semantic model | ||
element_name: Optional[str] = None | ||
# The entities used for joining semantic models. | ||
entity_links: Optional[Tuple[EntityReference, ...]] = None | ||
# Properties of time dimensions to match. | ||
time_granularity: Optional[TimeGranularity] = None | ||
date_part: Optional[DatePart] = None | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EntityLinkPattern(SpecPattern): | ||
"""A pattern that matches group-by-items using the entity-link-path specification. | ||
The entity link path specifies how a group-by-item for a metric query should be constructed. The group-by-item | ||
is obtained by joining the semantic model containing the measure to a semantic model containing the group-by- | ||
item using a specified entity. Additional semantic models can be joined using additional entities to obtain the | ||
group-by-item. The series of entities that are used form the entity path. Since the entity path does not specify | ||
which semantic models need to be used, additional resolution is done in later stages to generate the necessary SQL. | ||
""" | ||
|
||
parameter_set: EntityLinkPatternParameterSet | ||
|
||
@override | ||
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: | ||
matching_specs: List[LinkableInstanceSpec] = [] | ||
|
||
# Using some Python introspection magic to figure out specs that match the listed fields. | ||
keys_to_check = set(field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare) | ||
asdict(self.parameter_set) | ||
# Checks that EntityLinkPatternParameterSetField is valid wrt to the parameter set. | ||
parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in keys_to_check) | ||
|
||
for spec in candidate_specs: | ||
spec_values = tuple( | ||
(getattr(spec, key_to_check) if hasattr(spec, key_to_check) else None) for key_to_check in keys_to_check | ||
) | ||
if spec_values == parameter_set_values: | ||
matching_specs.append(spec) | ||
|
||
return matching_specs | ||
|
||
|
||
@dataclass(frozen=True) | ||
class DimensionPattern(EntityLinkPattern): | ||
"""Similar to EntityPathPattern but only matches dimensions / time dimensions. | ||
Analogous pattern for Dimension() in the object builder naming scheme. | ||
""" | ||
|
||
@override | ||
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: | ||
spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs)) | ||
filtered_specs: Sequence[LinkableInstanceSpec] = spec_set.dimension_specs + spec_set.time_dimension_specs | ||
return super().match(filtered_specs) | ||
|
||
@staticmethod | ||
def from_call_parameter_set( # noqa: D | ||
dimension_call_parameter_set: DimensionCallParameterSet, | ||
) -> DimensionPattern: | ||
return DimensionPattern( | ||
parameter_set=EntityLinkPatternParameterSet( | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
), | ||
element_name=dimension_call_parameter_set.dimension_reference.element_name, | ||
entity_links=dimension_call_parameter_set.entity_path, | ||
) | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class TimeDimensionPattern(EntityLinkPattern): | ||
"""Similar to EntityPathPattern but only matches time dimensions. | ||
Analogous pattern for TimeDimension() in the object builder naming scheme. | ||
""" | ||
|
||
@override | ||
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: | ||
spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs)) | ||
return super().match(spec_set.time_dimension_specs) | ||
|
||
@staticmethod | ||
def from_call_parameter_set( # noqa: D | ||
time_dimension_call_parameter_set: TimeDimensionCallParameterSet, | ||
) -> TimeDimensionPattern: | ||
fields_to_compare: List[ParameterSetField] = [ | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
ParameterSetField.DATE_PART, | ||
] | ||
|
||
if time_dimension_call_parameter_set.time_granularity is not None: | ||
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) | ||
|
||
return TimeDimensionPattern( | ||
parameter_set=EntityLinkPatternParameterSet( | ||
fields_to_compare=tuple(fields_to_compare), | ||
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name, | ||
entity_links=time_dimension_call_parameter_set.entity_path, | ||
time_granularity=time_dimension_call_parameter_set.time_granularity, | ||
date_part=time_dimension_call_parameter_set.date_part, | ||
) | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EntityPattern(EntityLinkPattern): | ||
"""Similar to EntityPathPattern but only matches entities. | ||
Analogous pattern for Entity() in the object builder naming scheme. | ||
""" | ||
|
||
@override | ||
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: | ||
spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs)) | ||
return super().match(spec_set.entity_specs) | ||
|
||
@staticmethod | ||
def from_call_parameter_set(entity_call_parameter_set: EntityCallParameterSet) -> EntityPattern: # noqa: D | ||
return EntityPattern( | ||
parameter_set=EntityLinkPatternParameterSet( | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
), | ||
element_name=entity_call_parameter_set.entity_reference.element_name, | ||
entity_links=entity_call_parameter_set.entity_path, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Sequence | ||
|
||
from metricflow.specs.specs import LinkableInstanceSpec | ||
|
||
|
||
class SpecPattern(ABC): | ||
"""A pattern is used to select specs from a group of candidate specs based on class-defined criteria. | ||
This could be named SpecFilter as well, but a filter is often used in the context of the WhereFilter. | ||
""" | ||
|
||
@abstractmethod | ||
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: | ||
"""Given candidate specs, return the ones that match this pattern.""" | ||
raise NotImplementedError | ||
|
||
def matches_any(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> bool: | ||
"""Returns true if this spec matches any of the given specs.""" | ||
return len(self.match(candidate_specs)) > 0 |
Empty file.
Empty file.
189 changes: 189 additions & 0 deletions
189
metricflow/test/specs/patterns/test_entity_link_pattern.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from dataclasses import asdict | ||
from typing import Sequence | ||
|
||
import pytest | ||
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME | ||
from dbt_semantic_interfaces.references import EntityReference | ||
from dbt_semantic_interfaces.type_enums import TimeGranularity | ||
from dbt_semantic_interfaces.type_enums.date_part import DatePart | ||
|
||
from metricflow.specs.patterns.entity_link_pattern import ( | ||
EntityLinkPattern, | ||
EntityLinkPatternParameterSet, | ||
ParameterSetField, | ||
) | ||
from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec | ||
from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D | ||
return ( | ||
# Time dimensions | ||
MTD_SPEC_WEEK, | ||
MTD_SPEC_MONTH, | ||
MTD_SPEC_YEAR, | ||
TimeDimensionSpec( | ||
element_name="creation_time", | ||
entity_links=(EntityReference("booking"), EntityReference("listing")), | ||
time_granularity=TimeGranularity.MONTH, | ||
date_part=DatePart.YEAR, | ||
), | ||
# Dimensions | ||
DimensionSpec( | ||
element_name="country", | ||
entity_links=( | ||
EntityReference(element_name="listing"), | ||
EntityReference(element_name="user"), | ||
), | ||
), | ||
DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)), | ||
# Entities | ||
EntitySpec( | ||
element_name="listing", | ||
entity_links=(EntityReference(element_name="booking"),), | ||
), | ||
EntitySpec( | ||
element_name="host", | ||
entity_links=(EntityReference(element_name="booking"),), | ||
), | ||
) | ||
|
||
|
||
def test_valid_parameter_fields() -> None: | ||
"""Tests that ParameterSetField.value maps to a valid field in EntityLinkPatternParameterSet.""" | ||
parameter_set = EntityLinkPatternParameterSet( | ||
fields_to_compare=(), | ||
element_name=None, | ||
entity_links=None, | ||
time_granularity=None, | ||
date_part=None, | ||
) | ||
parameter_set_dict = set(asdict(parameter_set).keys()) | ||
for spec_field in ParameterSetField: | ||
assert spec_field.value in parameter_set_dict, f"{spec_field} is not a valid field for {parameter_set}" | ||
|
||
|
||
def test_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D | ||
pattern = EntityLinkPattern( | ||
EntityLinkPatternParameterSet( | ||
element_name="is_instant", | ||
entity_links=(EntityReference(element_name="booking"),), | ||
time_granularity=None, | ||
date_part=None, | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
), | ||
) | ||
) | ||
|
||
assert tuple(pattern.match(specs)) == ( | ||
DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)), | ||
) | ||
|
||
|
||
def test_entity_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D | ||
pattern = EntityLinkPattern( | ||
EntityLinkPatternParameterSet( | ||
element_name="listing", | ||
entity_links=(EntityReference(element_name="booking"),), | ||
time_granularity=None, | ||
date_part=None, | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
), | ||
) | ||
) | ||
|
||
assert tuple(pattern.match(specs)) == ( | ||
EntitySpec(element_name="listing", entity_links=(EntityReference(element_name="booking"),)), | ||
) | ||
|
||
|
||
def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D | ||
pattern = EntityLinkPattern( | ||
EntityLinkPatternParameterSet( | ||
element_name=METRIC_TIME_ELEMENT_NAME, | ||
entity_links=(), | ||
time_granularity=TimeGranularity.WEEK, | ||
date_part=None, | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
ParameterSetField.TIME_GRANULARITY, | ||
), | ||
) | ||
) | ||
|
||
assert tuple(pattern.match(specs)) == (MTD_SPEC_WEEK,) | ||
|
||
|
||
def test_time_dimension_match_without_grain_specified(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D | ||
pattern = EntityLinkPattern( | ||
EntityLinkPatternParameterSet( | ||
element_name=METRIC_TIME_ELEMENT_NAME, | ||
entity_links=(), | ||
time_granularity=None, | ||
date_part=None, | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.ENTITY_LINKS, | ||
), | ||
) | ||
) | ||
|
||
assert tuple(pattern.match(specs)) == ( | ||
MTD_SPEC_WEEK, | ||
MTD_SPEC_MONTH, | ||
MTD_SPEC_YEAR, | ||
) | ||
|
||
|
||
def test_time_dimension_date_part_mismatch(specs: Sequence[LinkableInstanceSpec]) -> None: | ||
"""Checks that a None for the date_part field does not match to a non-None value.""" | ||
pattern = EntityLinkPattern( | ||
EntityLinkPatternParameterSet( | ||
element_name="creation_time", | ||
entity_links=None, | ||
time_granularity=None, | ||
date_part=None, | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.DATE_PART, | ||
), | ||
) | ||
) | ||
|
||
assert tuple(pattern.match(specs)) == () | ||
|
||
|
||
def test_time_dimension_date_part_match(specs: Sequence[LinkableInstanceSpec]) -> None: | ||
"""Checks that a correct date_part field produces a match.""" | ||
pattern = EntityLinkPattern( | ||
EntityLinkPatternParameterSet( | ||
element_name="creation_time", | ||
entity_links=None, | ||
time_granularity=None, | ||
date_part=DatePart.YEAR, | ||
fields_to_compare=( | ||
ParameterSetField.ELEMENT_NAME, | ||
ParameterSetField.DATE_PART, | ||
), | ||
) | ||
) | ||
|
||
assert tuple(pattern.match(specs)) == ( | ||
TimeDimensionSpec( | ||
element_name="creation_time", | ||
entity_links=(EntityReference("booking"), EntityReference("listing")), | ||
time_granularity=TimeGranularity.MONTH, | ||
date_part=DatePart.YEAR, | ||
), | ||
) |
Oops, something went wrong.