From f654067b9db6bc70debcc91d17b109581e323550 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Fri, 17 Nov 2023 16:11:50 -0800 Subject: [PATCH] Add SpecPattern classes. Please see the following feature request: https://github.com/dbt-labs/metricflow/issues/887 --- .../specs/patterns/entity_link_pattern.py | 174 ++++++++++++++++ metricflow/specs/patterns/spec_pattern.py | 22 ++ metricflow/test/specs/__init__.py | 0 metricflow/test/specs/patterns/__init__.py | 0 .../patterns/test_entity_link_pattern.py | 189 ++++++++++++++++++ .../specs/patterns/test_typed_patterns.py | 133 ++++++++++++ 6 files changed, 518 insertions(+) create mode 100644 metricflow/specs/patterns/entity_link_pattern.py create mode 100644 metricflow/specs/patterns/spec_pattern.py create mode 100644 metricflow/test/specs/__init__.py create mode 100644 metricflow/test/specs/patterns/__init__.py create mode 100644 metricflow/test/specs/patterns/test_entity_link_pattern.py create mode 100644 metricflow/test/specs/patterns/test_typed_patterns.py diff --git a/metricflow/specs/patterns/entity_link_pattern.py b/metricflow/specs/patterns/entity_link_pattern.py new file mode 100644 index 0000000000..8ee55cdcd1 --- /dev/null +++ b/metricflow/specs/patterns/entity_link_pattern.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +from dataclasses import asdict, dataclass +from enum import Enum +from typing import List, Optional, Sequence, Tuple + +from dbt_semantic_interfaces.call_parameter_sets import ( + DimensionCallParameterSet, + EntityCallParameterSet, + TimeDimensionCallParameterSet, +) +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart +from typing_extensions import override + +from metricflow.specs.patterns.spec_pattern import SpecPattern +from metricflow.specs.specs import LinkableInstanceSpec, LinkableSpecSet + +logger = logging.getLogger(__name__) + + +class ParameterSetField(Enum): + """The fields of the EntityLinkPatternParameterSet class used for matching in the EntityLinkPattern. + + Considering moving this to be a part of the specs module / classes. + """ + + ELEMENT_NAME = "element_name" + ENTITY_LINKS = "entity_links" + TIME_GRANULARITY = "time_granularity" + DATE_PART = "date_part" + + +@dataclass(frozen=True) +class EntityLinkPatternParameterSet: + """See EntityPathPattern for more details.""" + + # Specify the field values to compare. None can't be used to signal "don't compare" because sometimes a pattern + # needs to match a spec where the field is None. + fields_to_compare: Tuple[ParameterSetField, ...] + + # The name of the element in the semantic model + element_name: Optional[str] = None + # The entities used for joining semantic models. + entity_links: Optional[Tuple[EntityReference, ...]] = None + # Properties of time dimensions to match. + time_granularity: Optional[TimeGranularity] = None + date_part: Optional[DatePart] = None + + +@dataclass(frozen=True) +class EntityLinkPattern(SpecPattern): + """A pattern that matches group-by-items using the entity-link-path specification. + + The entity link path specifies how a group-by-item for a metric query should be constructed. The group-by-item + is obtained by joining the semantic model containing the measure to a semantic model containing the group-by- + item using a specified entity. Additional semantic models can be joined using additional entities to obtain the + group-by-item. The series of entities that are used form the entity path. Since the entity path does not specify + which semantic models need to be used, additional resolution is done in later stages to generate the necessary SQL. + """ + + parameter_set: EntityLinkPatternParameterSet + + @override + def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: + matching_specs: List[LinkableInstanceSpec] = [] + + # Using some Python introspection magic to figure out specs that match the listed fields. + keys_to_check = set(field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare) + asdict(self.parameter_set) + # Checks that EntityLinkPatternParameterSetField is valid wrt to the parameter set. + parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in keys_to_check) + + for spec in candidate_specs: + spec_values = tuple( + (getattr(spec, key_to_check) if hasattr(spec, key_to_check) else None) for key_to_check in keys_to_check + ) + if spec_values == parameter_set_values: + matching_specs.append(spec) + + return matching_specs + + +@dataclass(frozen=True) +class DimensionPattern(EntityLinkPattern): + """Similar to EntityPathPattern but only matches dimensions / time dimensions. + + Analogous pattern for Dimension() in the object builder naming scheme. + """ + + @override + def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: + spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs)) + filtered_specs: Sequence[LinkableInstanceSpec] = spec_set.dimension_specs + spec_set.time_dimension_specs + return super().match(filtered_specs) + + @staticmethod + def from_call_parameter_set( # noqa: D + dimension_call_parameter_set: DimensionCallParameterSet, + ) -> DimensionPattern: + return DimensionPattern( + parameter_set=EntityLinkPatternParameterSet( + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ), + element_name=dimension_call_parameter_set.dimension_reference.element_name, + entity_links=dimension_call_parameter_set.entity_path, + ) + ) + + +@dataclass(frozen=True) +class TimeDimensionPattern(EntityLinkPattern): + """Similar to EntityPathPattern but only matches time dimensions. + + Analogous pattern for TimeDimension() in the object builder naming scheme. + """ + + @override + def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: + spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs)) + return super().match(spec_set.time_dimension_specs) + + @staticmethod + def from_call_parameter_set( # noqa: D + time_dimension_call_parameter_set: TimeDimensionCallParameterSet, + ) -> TimeDimensionPattern: + fields_to_compare: List[ParameterSetField] = [ + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ParameterSetField.DATE_PART, + ] + + if time_dimension_call_parameter_set.time_granularity is not None: + fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) + + return TimeDimensionPattern( + parameter_set=EntityLinkPatternParameterSet( + fields_to_compare=tuple(fields_to_compare), + element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name, + entity_links=time_dimension_call_parameter_set.entity_path, + time_granularity=time_dimension_call_parameter_set.time_granularity, + date_part=time_dimension_call_parameter_set.date_part, + ) + ) + + +@dataclass(frozen=True) +class EntityPattern(EntityLinkPattern): + """Similar to EntityPathPattern but only matches entities. + + Analogous pattern for Entity() in the object builder naming scheme. + """ + + @override + def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: + spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs)) + return super().match(spec_set.entity_specs) + + @staticmethod + def from_call_parameter_set(entity_call_parameter_set: EntityCallParameterSet) -> EntityPattern: # noqa: D + return EntityPattern( + parameter_set=EntityLinkPatternParameterSet( + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ), + element_name=entity_call_parameter_set.entity_reference.element_name, + entity_links=entity_call_parameter_set.entity_path, + ) + ) diff --git a/metricflow/specs/patterns/spec_pattern.py b/metricflow/specs/patterns/spec_pattern.py new file mode 100644 index 0000000000..9fb9cf5a24 --- /dev/null +++ b/metricflow/specs/patterns/spec_pattern.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Sequence + +from metricflow.specs.specs import LinkableInstanceSpec + + +class SpecPattern(ABC): + """A pattern is used to select specs from a group of candidate specs based on class-defined criteria. + + This could be named SpecFilter as well, but a filter is often used in the context of the WhereFilter. + """ + + @abstractmethod + def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]: + """Given candidate specs, return the ones that match this pattern.""" + raise NotImplementedError + + def matches_any(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> bool: + """Returns true if this spec matches any of the given specs.""" + return len(self.match(candidate_specs)) > 0 diff --git a/metricflow/test/specs/__init__.py b/metricflow/test/specs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/test/specs/patterns/__init__.py b/metricflow/test/specs/patterns/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/test/specs/patterns/test_entity_link_pattern.py b/metricflow/test/specs/patterns/test_entity_link_pattern.py new file mode 100644 index 0000000000..474c5c77d9 --- /dev/null +++ b/metricflow/test/specs/patterns/test_entity_link_pattern.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import logging +from dataclasses import asdict +from typing import Sequence + +import pytest +from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart + +from metricflow.specs.patterns.entity_link_pattern import ( + EntityLinkPattern, + EntityLinkPatternParameterSet, + ParameterSetField, +) +from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec +from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D + return ( + # Time dimensions + MTD_SPEC_WEEK, + MTD_SPEC_MONTH, + MTD_SPEC_YEAR, + TimeDimensionSpec( + element_name="creation_time", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.MONTH, + date_part=DatePart.YEAR, + ), + # Dimensions + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)), + # Entities + EntitySpec( + element_name="listing", + entity_links=(EntityReference(element_name="booking"),), + ), + EntitySpec( + element_name="host", + entity_links=(EntityReference(element_name="booking"),), + ), + ) + + +def test_valid_parameter_fields() -> None: + """Tests that ParameterSetField.value maps to a valid field in EntityLinkPatternParameterSet.""" + parameter_set = EntityLinkPatternParameterSet( + fields_to_compare=(), + element_name=None, + entity_links=None, + time_granularity=None, + date_part=None, + ) + parameter_set_dict = set(asdict(parameter_set).keys()) + for spec_field in ParameterSetField: + assert spec_field.value in parameter_set_dict, f"{spec_field} is not a valid field for {parameter_set}" + + +def test_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = EntityLinkPattern( + EntityLinkPatternParameterSet( + element_name="is_instant", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=None, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ), + ) + ) + + assert tuple(pattern.match(specs)) == ( + DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)), + ) + + +def test_entity_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = EntityLinkPattern( + EntityLinkPatternParameterSet( + element_name="listing", + entity_links=(EntityReference(element_name="booking"),), + time_granularity=None, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ), + ) + ) + + assert tuple(pattern.match(specs)) == ( + EntitySpec(element_name="listing", entity_links=(EntityReference(element_name="booking"),)), + ) + + +def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = EntityLinkPattern( + EntityLinkPatternParameterSet( + element_name=METRIC_TIME_ELEMENT_NAME, + entity_links=(), + time_granularity=TimeGranularity.WEEK, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ParameterSetField.TIME_GRANULARITY, + ), + ) + ) + + assert tuple(pattern.match(specs)) == (MTD_SPEC_WEEK,) + + +def test_time_dimension_match_without_grain_specified(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = EntityLinkPattern( + EntityLinkPatternParameterSet( + element_name=METRIC_TIME_ELEMENT_NAME, + entity_links=(), + time_granularity=None, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ), + ) + ) + + assert tuple(pattern.match(specs)) == ( + MTD_SPEC_WEEK, + MTD_SPEC_MONTH, + MTD_SPEC_YEAR, + ) + + +def test_time_dimension_date_part_mismatch(specs: Sequence[LinkableInstanceSpec]) -> None: + """Checks that a None for the date_part field does not match to a non-None value.""" + pattern = EntityLinkPattern( + EntityLinkPatternParameterSet( + element_name="creation_time", + entity_links=None, + time_granularity=None, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.DATE_PART, + ), + ) + ) + + assert tuple(pattern.match(specs)) == () + + +def test_time_dimension_date_part_match(specs: Sequence[LinkableInstanceSpec]) -> None: + """Checks that a correct date_part field produces a match.""" + pattern = EntityLinkPattern( + EntityLinkPatternParameterSet( + element_name="creation_time", + entity_links=None, + time_granularity=None, + date_part=DatePart.YEAR, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.DATE_PART, + ), + ) + ) + + assert tuple(pattern.match(specs)) == ( + TimeDimensionSpec( + element_name="creation_time", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.MONTH, + date_part=DatePart.YEAR, + ), + ) diff --git a/metricflow/test/specs/patterns/test_typed_patterns.py b/metricflow/test/specs/patterns/test_typed_patterns.py new file mode 100644 index 0000000000..7368026561 --- /dev/null +++ b/metricflow/test/specs/patterns/test_typed_patterns.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import logging +from typing import Sequence + +import pytest +from dbt_semantic_interfaces.call_parameter_sets import ( + DimensionCallParameterSet, + EntityCallParameterSet, + TimeDimensionCallParameterSet, +) +from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart + +from metricflow.specs.patterns.entity_link_pattern import ( + DimensionPattern, + EntityPattern, + TimeDimensionPattern, +) +from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="session") +def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D + return ( + # Time dimensions + TimeDimensionSpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.DAY, + date_part=None, + ), + TimeDimensionSpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.DAY, + date_part=DatePart.MONTH, + ), + # Dimensions + DimensionSpec( + element_name="common_name", + entity_links=((EntityReference("booking"), EntityReference("listing"))), + ), + # Entities + EntitySpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + ), + ) + + +def test_dimension_pattern(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = DimensionPattern.from_call_parameter_set( + DimensionCallParameterSet( + entity_path=(EntityReference("booking"), EntityReference("listing")), + dimension_reference=DimensionReference(element_name="common_name"), + ) + ) + + assert tuple(pattern.match(specs)) == ( + DimensionSpec( + element_name="common_name", + entity_links=((EntityReference("booking"), EntityReference("listing"))), + ), + TimeDimensionSpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.DAY, + date_part=None, + ), + TimeDimensionSpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.DAY, + date_part=DatePart.MONTH, + ), + ) + + +def test_time_dimension_pattern(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = TimeDimensionPattern.from_call_parameter_set( + TimeDimensionCallParameterSet( + entity_path=(EntityReference("booking"), EntityReference("listing")), + time_dimension_reference=TimeDimensionReference(element_name="common_name"), + ) + ) + + assert tuple(pattern.match(specs)) == ( + TimeDimensionSpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.DAY, + date_part=None, + ), + ) + + +def test_time_dimension_pattern_with_date_part(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = TimeDimensionPattern.from_call_parameter_set( + TimeDimensionCallParameterSet( + entity_path=(EntityReference("booking"), EntityReference("listing")), + time_dimension_reference=TimeDimensionReference(element_name="common_name"), + date_part=DatePart.MONTH, + ) + ) + + assert tuple(pattern.match(specs)) == ( + TimeDimensionSpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.DAY, + date_part=DatePart.MONTH, + ), + ) + + +def test_entity_pattern(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D + pattern = EntityPattern.from_call_parameter_set( + EntityCallParameterSet( + entity_path=(EntityReference("booking"), EntityReference("listing")), + entity_reference=EntityReference(element_name="common_name"), + ) + ) + + assert tuple(pattern.match(specs)) == ( + EntitySpec( + element_name="common_name", + entity_links=(EntityReference("booking"), EntityReference("listing")), + ), + )