Skip to content

Commit

Permalink
Add SpecPattern classes.
Browse files Browse the repository at this point in the history
Please see the following feature request:

#887
  • Loading branch information
plypaul committed Nov 21, 2023
1 parent eb4c1d0 commit f654067
Show file tree
Hide file tree
Showing 6 changed files with 518 additions and 0 deletions.
174 changes: 174 additions & 0 deletions metricflow/specs/patterns/entity_link_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import annotations

import logging
from dataclasses import asdict, dataclass
from enum import Enum
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
EntityCallParameterSet,
TimeDimensionCallParameterSet,
)
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override

from metricflow.specs.patterns.spec_pattern import SpecPattern
from metricflow.specs.specs import LinkableInstanceSpec, LinkableSpecSet

logger = logging.getLogger(__name__)


class ParameterSetField(Enum):
"""The fields of the EntityLinkPatternParameterSet class used for matching in the EntityLinkPattern.
Considering moving this to be a part of the specs module / classes.
"""

ELEMENT_NAME = "element_name"
ENTITY_LINKS = "entity_links"
TIME_GRANULARITY = "time_granularity"
DATE_PART = "date_part"


@dataclass(frozen=True)
class EntityLinkPatternParameterSet:
"""See EntityPathPattern for more details."""

# Specify the field values to compare. None can't be used to signal "don't compare" because sometimes a pattern
# needs to match a spec where the field is None.
fields_to_compare: Tuple[ParameterSetField, ...]

# The name of the element in the semantic model
element_name: Optional[str] = None
# The entities used for joining semantic models.
entity_links: Optional[Tuple[EntityReference, ...]] = None
# Properties of time dimensions to match.
time_granularity: Optional[TimeGranularity] = None
date_part: Optional[DatePart] = None


@dataclass(frozen=True)
class EntityLinkPattern(SpecPattern):
"""A pattern that matches group-by-items using the entity-link-path specification.
The entity link path specifies how a group-by-item for a metric query should be constructed. The group-by-item
is obtained by joining the semantic model containing the measure to a semantic model containing the group-by-
item using a specified entity. Additional semantic models can be joined using additional entities to obtain the
group-by-item. The series of entities that are used form the entity path. Since the entity path does not specify
which semantic models need to be used, additional resolution is done in later stages to generate the necessary SQL.
"""

parameter_set: EntityLinkPatternParameterSet

@override
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]:
matching_specs: List[LinkableInstanceSpec] = []

# Using some Python introspection magic to figure out specs that match the listed fields.
keys_to_check = set(field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare)
asdict(self.parameter_set)
# Checks that EntityLinkPatternParameterSetField is valid wrt to the parameter set.
parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in keys_to_check)

for spec in candidate_specs:
spec_values = tuple(
(getattr(spec, key_to_check) if hasattr(spec, key_to_check) else None) for key_to_check in keys_to_check
)
if spec_values == parameter_set_values:
matching_specs.append(spec)

return matching_specs


@dataclass(frozen=True)
class DimensionPattern(EntityLinkPattern):
"""Similar to EntityPathPattern but only matches dimensions / time dimensions.
Analogous pattern for Dimension() in the object builder naming scheme.
"""

@override
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]:
spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs))
filtered_specs: Sequence[LinkableInstanceSpec] = spec_set.dimension_specs + spec_set.time_dimension_specs
return super().match(filtered_specs)

@staticmethod
def from_call_parameter_set( # noqa: D
dimension_call_parameter_set: DimensionCallParameterSet,
) -> DimensionPattern:
return DimensionPattern(
parameter_set=EntityLinkPatternParameterSet(
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
),
element_name=dimension_call_parameter_set.dimension_reference.element_name,
entity_links=dimension_call_parameter_set.entity_path,
)
)


@dataclass(frozen=True)
class TimeDimensionPattern(EntityLinkPattern):
"""Similar to EntityPathPattern but only matches time dimensions.
Analogous pattern for TimeDimension() in the object builder naming scheme.
"""

@override
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]:
spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs))
return super().match(spec_set.time_dimension_specs)

@staticmethod
def from_call_parameter_set( # noqa: D
time_dimension_call_parameter_set: TimeDimensionCallParameterSet,
) -> TimeDimensionPattern:
fields_to_compare: List[ParameterSetField] = [
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
ParameterSetField.DATE_PART,
]

if time_dimension_call_parameter_set.time_granularity is not None:
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)

return TimeDimensionPattern(
parameter_set=EntityLinkPatternParameterSet(
fields_to_compare=tuple(fields_to_compare),
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name,
entity_links=time_dimension_call_parameter_set.entity_path,
time_granularity=time_dimension_call_parameter_set.time_granularity,
date_part=time_dimension_call_parameter_set.date_part,
)
)


@dataclass(frozen=True)
class EntityPattern(EntityLinkPattern):
"""Similar to EntityPathPattern but only matches entities.
Analogous pattern for Entity() in the object builder naming scheme.
"""

@override
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]:
spec_set = LinkableSpecSet.from_specs(tuple(candidate_specs))
return super().match(spec_set.entity_specs)

@staticmethod
def from_call_parameter_set(entity_call_parameter_set: EntityCallParameterSet) -> EntityPattern: # noqa: D
return EntityPattern(
parameter_set=EntityLinkPatternParameterSet(
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
),
element_name=entity_call_parameter_set.entity_reference.element_name,
entity_links=entity_call_parameter_set.entity_path,
)
)
22 changes: 22 additions & 0 deletions metricflow/specs/patterns/spec_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Sequence

from metricflow.specs.specs import LinkableInstanceSpec


class SpecPattern(ABC):
"""A pattern is used to select specs from a group of candidate specs based on class-defined criteria.
This could be named SpecFilter as well, but a filter is often used in the context of the WhereFilter.
"""

@abstractmethod
def match(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]:
"""Given candidate specs, return the ones that match this pattern."""
raise NotImplementedError

def matches_any(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> bool:
"""Returns true if this spec matches any of the given specs."""
return len(self.match(candidate_specs)) > 0
Empty file.
Empty file.
189 changes: 189 additions & 0 deletions metricflow/test/specs/patterns/test_entity_link_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from __future__ import annotations

import logging
from dataclasses import asdict
from typing import Sequence

import pytest
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart

from metricflow.specs.patterns.entity_link_pattern import (
EntityLinkPattern,
EntityLinkPatternParameterSet,
ParameterSetField,
)
from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec
from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR

logger = logging.getLogger(__name__)


@pytest.fixture(scope="session")
def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D
return (
# Time dimensions
MTD_SPEC_WEEK,
MTD_SPEC_MONTH,
MTD_SPEC_YEAR,
TimeDimensionSpec(
element_name="creation_time",
entity_links=(EntityReference("booking"), EntityReference("listing")),
time_granularity=TimeGranularity.MONTH,
date_part=DatePart.YEAR,
),
# Dimensions
DimensionSpec(
element_name="country",
entity_links=(
EntityReference(element_name="listing"),
EntityReference(element_name="user"),
),
),
DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)),
# Entities
EntitySpec(
element_name="listing",
entity_links=(EntityReference(element_name="booking"),),
),
EntitySpec(
element_name="host",
entity_links=(EntityReference(element_name="booking"),),
),
)


def test_valid_parameter_fields() -> None:
"""Tests that ParameterSetField.value maps to a valid field in EntityLinkPatternParameterSet."""
parameter_set = EntityLinkPatternParameterSet(
fields_to_compare=(),
element_name=None,
entity_links=None,
time_granularity=None,
date_part=None,
)
parameter_set_dict = set(asdict(parameter_set).keys())
for spec_field in ParameterSetField:
assert spec_field.value in parameter_set_dict, f"{spec_field} is not a valid field for {parameter_set}"


def test_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D
pattern = EntityLinkPattern(
EntityLinkPatternParameterSet(
element_name="is_instant",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=None,
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
),
)
)

assert tuple(pattern.match(specs)) == (
DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)),
)


def test_entity_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D
pattern = EntityLinkPattern(
EntityLinkPatternParameterSet(
element_name="listing",
entity_links=(EntityReference(element_name="booking"),),
time_granularity=None,
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
),
)
)

assert tuple(pattern.match(specs)) == (
EntitySpec(element_name="listing", entity_links=(EntityReference(element_name="booking"),)),
)


def test_time_dimension_match(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D
pattern = EntityLinkPattern(
EntityLinkPatternParameterSet(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
time_granularity=TimeGranularity.WEEK,
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
ParameterSetField.TIME_GRANULARITY,
),
)
)

assert tuple(pattern.match(specs)) == (MTD_SPEC_WEEK,)


def test_time_dimension_match_without_grain_specified(specs: Sequence[LinkableInstanceSpec]) -> None: # noqa: D
pattern = EntityLinkPattern(
EntityLinkPatternParameterSet(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
time_granularity=None,
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.ENTITY_LINKS,
),
)
)

assert tuple(pattern.match(specs)) == (
MTD_SPEC_WEEK,
MTD_SPEC_MONTH,
MTD_SPEC_YEAR,
)


def test_time_dimension_date_part_mismatch(specs: Sequence[LinkableInstanceSpec]) -> None:
"""Checks that a None for the date_part field does not match to a non-None value."""
pattern = EntityLinkPattern(
EntityLinkPatternParameterSet(
element_name="creation_time",
entity_links=None,
time_granularity=None,
date_part=None,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.DATE_PART,
),
)
)

assert tuple(pattern.match(specs)) == ()


def test_time_dimension_date_part_match(specs: Sequence[LinkableInstanceSpec]) -> None:
"""Checks that a correct date_part field produces a match."""
pattern = EntityLinkPattern(
EntityLinkPatternParameterSet(
element_name="creation_time",
entity_links=None,
time_granularity=None,
date_part=DatePart.YEAR,
fields_to_compare=(
ParameterSetField.ELEMENT_NAME,
ParameterSetField.DATE_PART,
),
)
)

assert tuple(pattern.match(specs)) == (
TimeDimensionSpec(
element_name="creation_time",
entity_links=(EntityReference("booking"), EntityReference("listing")),
time_granularity=TimeGranularity.MONTH,
date_part=DatePart.YEAR,
),
)
Loading

0 comments on commit f654067

Please sign in to comment.