From 7992b74e0fb9f46ae811db9f2a38cb9767406c1c Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Fri, 17 Nov 2023 16:12:15 -0800 Subject: [PATCH] Add naming schemes. These naming scheme classes will be used in the query parser to convert string inputs into patterns. The patterns will be used later to resolve ambiguous group-by-items. --- metricflow/naming/dunder_scheme.py | 169 +++++++++++++++ metricflow/naming/metric_scheme.py | 43 ++++ metricflow/naming/naming_scheme.py | 44 ++++ metricflow/naming/object_builder_scheme.py | 203 ++++++++++++++++++ metricflow/specs/patterns/metric_pattern.py | 28 +++ metricflow/test/naming/__init__.py | 0 metricflow/test/naming/conftest.py | 52 +++++ .../test/naming/test_dunder_naming_scheme.py | 112 ++++++++++ .../test/naming/test_metric_name_scheme.py | 34 +++ .../test_object_builder_naming_scheme.py | 109 ++++++++++ 10 files changed, 794 insertions(+) create mode 100644 metricflow/naming/dunder_scheme.py create mode 100644 metricflow/naming/metric_scheme.py create mode 100644 metricflow/naming/naming_scheme.py create mode 100644 metricflow/naming/object_builder_scheme.py create mode 100644 metricflow/specs/patterns/metric_pattern.py create mode 100644 metricflow/test/naming/__init__.py create mode 100644 metricflow/test/naming/conftest.py create mode 100644 metricflow/test/naming/test_dunder_naming_scheme.py create mode 100644 metricflow/test/naming/test_metric_name_scheme.py create mode 100644 metricflow/test/naming/test_object_builder_naming_scheme.py diff --git a/metricflow/naming/dunder_scheme.py b/metricflow/naming/dunder_scheme.py new file mode 100644 index 0000000000..35eb1f12e9 --- /dev/null +++ b/metricflow/naming/dunder_scheme.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import re +from typing import Optional, Sequence, Tuple + +from dbt_semantic_interfaces.naming.keywords import DUNDER +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart +from typing_extensions import override + +from metricflow.naming.naming_scheme import QueryItemNamingScheme +from metricflow.specs.patterns.entity_link_pattern import ( + EntityLinkPattern, + EntityLinkPatternParameterSet, + ParameterSetField, +) +from metricflow.specs.specs import ( + InstanceSpec, + InstanceSpecSet, + InstanceSpecSetTransform, +) + + +class DunderNamingScheme(QueryItemNamingScheme): + """A naming scheme using the dundered name syntax. + + TODO: Consolidate with StructuredLinkableSpecName / DunderedNameFormatter. + """ + + _INPUT_REGEX = re.compile(r"\A[a-z]([a-z0-9_])*[a-z0-9]\Z") + + @staticmethod + def date_part_suffix(date_part: DatePart) -> str: + """Suffix used for names with a date_part.""" + return f"extract_{date_part.value}" + + @override + def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: + spec_set = InstanceSpecSet.from_specs((instance_spec,)) + + for time_dimension_spec in spec_set.time_dimension_specs: + # From existing comment in StructuredLinkableSpecName: + # + # Dunder syntax not supported for querying date_part + # + if time_dimension_spec.date_part is not None: + return None + names = _DunderNameTransform().transform(spec_set) + if len(names) != 1: + raise RuntimeError(f"Did not get 1 name for {instance_spec}. Got {names}") + + return names[0] + + @override + def spec_pattern(self, input_str: str) -> EntityLinkPattern: + if not self.input_str_follows_scheme(input_str): + raise ValueError(f"{repr(input_str)} does not follow this scheme.") + + input_str = input_str.lower() + + input_str_parts = input_str.split(DUNDER) + fields_to_compare: Tuple[ParameterSetField, ...] = ( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ParameterSetField.DATE_PART, + ) + + time_grain = None + + # No dunder, e.g. "ds" + if len(input_str_parts) == 1: + return EntityLinkPattern( + parameter_set=EntityLinkPatternParameterSet.from_parameters( + element_name=input_str_parts[0], + entity_links=(), + time_granularity=time_grain, + date_part=None, + fields_to_compare=tuple(fields_to_compare), + ) + ) + + # At this point, len(input_str_parts) >= 2 + for granularity in TimeGranularity: + if input_str_parts[-1] == granularity.value: + time_grain = granularity + + # Has a time grain specified. + if time_grain is not None: + fields_to_compare = fields_to_compare + (ParameterSetField.TIME_GRANULARITY,) + # e.g. "ds__month" + if len(input_str_parts) == 2: + return EntityLinkPattern( + parameter_set=EntityLinkPatternParameterSet.from_parameters( + element_name=input_str_parts[0], + entity_links=(), + time_granularity=time_grain, + date_part=None, + fields_to_compare=fields_to_compare, + ) + ) + # e.g. "messages__ds__month" + return EntityLinkPattern( + parameter_set=EntityLinkPatternParameterSet.from_parameters( + element_name=input_str_parts[-2], + entity_links=tuple(EntityReference(entity_name) for entity_name in input_str_parts[:-2]), + time_granularity=time_grain, + date_part=None, + fields_to_compare=fields_to_compare, + ) + ) + + # e.g. "messages__ds" + return EntityLinkPattern( + parameter_set=EntityLinkPatternParameterSet.from_parameters( + element_name=input_str_parts[-1], + entity_links=tuple(EntityReference(entity_name) for entity_name in input_str_parts[:-1]), + time_granularity=None, + date_part=None, + fields_to_compare=fields_to_compare, + ) + ) + + @override + def input_str_follows_scheme(self, input_str: str) -> bool: + # This naming scheme is case-insensitive. + input_str = input_str.lower() + if DunderNamingScheme._INPUT_REGEX.match(input_str) is None: + return False + + input_str_parts = input_str.split(DUNDER) + + for date_part in DatePart: + if input_str_parts[-1] == DunderNamingScheme.date_part_suffix(date_part=date_part): + # From existing message in StructuredLinkableSpecName: "Dunder syntax not supported for querying + # date_part". + return False + + return True + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}(id()={hex(id(self))})" + + +class _DunderNameTransform(InstanceSpecSetTransform[Sequence[str]]): + """Transforms group-by-item spec into the dundered name.""" + + @override + def transform(self, spec_set: InstanceSpecSet) -> Sequence[str]: + names_to_return = [] + + for time_dimension_spec in spec_set.time_dimension_specs: + items = list(entity_link.element_name for entity_link in time_dimension_spec.entity_links) + [ + time_dimension_spec.element_name + ] + if time_dimension_spec.date_part is not None: + items.append(DunderNamingScheme.date_part_suffix(date_part=time_dimension_spec.date_part)) + else: + items.append(time_dimension_spec.time_granularity.value) + names_to_return.append(DUNDER.join(items)) + + for other_group_by_item_specs in spec_set.entity_specs + spec_set.dimension_specs: + items = list(entity_link.element_name for entity_link in other_group_by_item_specs.entity_links) + [ + other_group_by_item_specs.element_name + ] + names_to_return.append(DUNDER.join(items)) + + return sorted(names_to_return) diff --git a/metricflow/naming/metric_scheme.py b/metricflow/naming/metric_scheme.py new file mode 100644 index 0000000000..fdc0244df4 --- /dev/null +++ b/metricflow/naming/metric_scheme.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Optional + +from dbt_semantic_interfaces.references import MetricReference +from typing_extensions import override + +from metricflow.naming.naming_scheme import QueryItemNamingScheme +from metricflow.specs.patterns.metric_pattern import MetricSpecPattern +from metricflow.specs.specs import ( + InstanceSpec, + InstanceSpecSet, +) + + +class MetricNamingScheme(QueryItemNamingScheme): + """A naming scheme for metrics.""" + + @override + def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: + spec_set = InstanceSpecSet.from_specs((instance_spec,)) + names = tuple(spec.element_name for spec in spec_set.metric_specs) + + if len(names) != 1: + raise RuntimeError(f"Did not get 1 name for {instance_spec}. Got {names}") + + return names[0] + + @override + def spec_pattern(self, input_str: str) -> MetricSpecPattern: + input_str = input_str.lower() + if not self.input_str_follows_scheme(input_str): + raise RuntimeError(f"{repr(input_str)} does not follow this scheme.") + return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str)) + + @override + def input_str_follows_scheme(self, input_str: str) -> bool: + # TODO: Use regex. + return True + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}(id()={hex(id(self))})" diff --git a/metricflow/naming/naming_scheme.py b/metricflow/naming/naming_scheme.py new file mode 100644 index 0000000000..508e1d0029 --- /dev/null +++ b/metricflow/naming/naming_scheme.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from metricflow.specs.patterns.spec_pattern import SpecPattern +from metricflow.specs.specs import InstanceSpec + + +class QueryItemNamingScheme(ABC): + """Describes how to name items that are involved in a MetricFlow query. + + Most useful for group-by-items as there are different ways to name them like "user__country" + or "TimeDimension('metric_time', 'DAY')". + """ + + @abstractmethod + def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: + """Following this scheme, return the string that can be used as an input that would specify the given spec. + + This is used to generate suggestions from available group-by-items if the user specifies a group-by-item that is + invalid. + + If this scheme cannot accommodate the spec, return None. This is needed to handle unsupported cases in + DunderNamingScheme, such as DatePart, but naming schemes should otherwise be complete. + """ + pass + + @abstractmethod + def spec_pattern(self, input_str: str) -> SpecPattern: + """Given an input that follows this scheme, return a spec pattern that matches the described input. + + If the input_str does not follow this scheme, raise a ValueError. In practice, input_str_follows_scheme() should + be called on the input_str beforehand. + """ + pass + + @abstractmethod + def input_str_follows_scheme(self, input_str: str) -> bool: + """Returns true if the given input string follows this naming scheme. + + Consider adding a structured result that indicates why it does not match the scheme. + """ + pass diff --git a/metricflow/naming/object_builder_scheme.py b/metricflow/naming/object_builder_scheme.py new file mode 100644 index 0000000000..20d68c6f89 --- /dev/null +++ b/metricflow/naming/object_builder_scheme.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import logging +import re +from typing import Optional, Sequence + +from dbt_semantic_interfaces.call_parameter_sets import ParseWhereFilterException +from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter +from dbt_semantic_interfaces.naming.keywords import DUNDER +from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import WhereFilterParser +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart +from typing_extensions import override + +from metricflow.naming.naming_scheme import QueryItemNamingScheme +from metricflow.specs.patterns.entity_link_pattern import ( + EntityLinkPattern, + EntityLinkPatternParameterSet, + ParameterSetField, +) +from metricflow.specs.patterns.spec_pattern import SpecPattern +from metricflow.specs.patterns.typed_patterns import DimensionPattern, TimeDimensionPattern +from metricflow.specs.specs import ( + InstanceSpec, + InstanceSpecSet, + InstanceSpecSetTransform, +) + +logger = logging.getLogger(__name__) + + +class ObjectBuilderNamingScheme(QueryItemNamingScheme): + """A naming scheme using a builder syntax like Dimension('metric_time').grain('day').""" + + _NAME_REGEX = re.compile(r"\A(Dimension|TimeDimension|Entity)\(.*\)\Z") + + @override + def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: + names = _ObjectBuilderNameTransform().transform(InstanceSpecSet.from_specs((instance_spec,))) + + if len(names) != 1: + raise RuntimeError(f"Did not get exactly 1 name from {instance_spec}. Got {names}") + + return names[0] + + @override + def spec_pattern(self, input_str: str) -> SpecPattern: + if not self.input_str_follows_scheme(input_str): + raise ValueError( + f"The specified input {repr(input_str)} does not match the input described by the object builder " + f"pattern." + ) + try: + # TODO: Update when more appropriate parsing libraries are available. + call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets + except ParseWhereFilterException as e: + raise ValueError(f"A spec pattern can't be generated from the input string {repr(input_str)}") from e + + num_parameter_sets = ( + len(call_parameter_sets.dimension_call_parameter_sets) + + len(call_parameter_sets.time_dimension_call_parameter_sets) + + len(call_parameter_sets.entity_call_parameter_sets) + ) + if num_parameter_sets != 1: + raise ValueError(f"Did not find exactly 1 call parameter set. Got: {num_parameter_sets}") + + for dimension_call_parameter_set in call_parameter_sets.dimension_call_parameter_sets: + return DimensionPattern( + EntityLinkPatternParameterSet.from_parameters( + element_name=dimension_call_parameter_set.dimension_reference.element_name, + entity_links=dimension_call_parameter_set.entity_path, + time_granularity=None, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ParameterSetField.DATE_PART, + ), + ) + ) + + for time_dimension_call_parameter_set in call_parameter_sets.time_dimension_call_parameter_sets: + fields_to_compare = [ + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ParameterSetField.DATE_PART, + ] + + if time_dimension_call_parameter_set.time_granularity is not None: + fields_to_compare.append(ParameterSetField.TIME_GRANULARITY) + + return TimeDimensionPattern( + EntityLinkPatternParameterSet.from_parameters( + element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name, + entity_links=time_dimension_call_parameter_set.entity_path, + time_granularity=time_dimension_call_parameter_set.time_granularity, + date_part=time_dimension_call_parameter_set.date_part, + fields_to_compare=tuple(fields_to_compare), + ) + ) + + for entity_call_parameter_set in call_parameter_sets.entity_call_parameter_sets: + return EntityLinkPattern( + EntityLinkPatternParameterSet.from_parameters( + element_name=entity_call_parameter_set.entity_reference.element_name, + entity_links=entity_call_parameter_set.entity_path, + time_granularity=None, + date_part=None, + fields_to_compare=( + ParameterSetField.ELEMENT_NAME, + ParameterSetField.ENTITY_LINKS, + ), + ) + ) + + raise RuntimeError("There should have been a return associated with one of the CallParameterSets.") + + @override + def input_str_follows_scheme(self, input_str: str) -> bool: + if ObjectBuilderNamingScheme._NAME_REGEX.match(input_str) is None: + return False + try: + call_parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{ " + input_str + " }}") + return_value = ( + len(call_parameter_sets.dimension_call_parameter_sets) + + len(call_parameter_sets.time_dimension_call_parameter_sets) + + len(call_parameter_sets.entity_call_parameter_sets) + ) == 1 + return return_value + except ParseWhereFilterException: + return False + + @override + def __repr__(self) -> str: + return f"{self.__class__.__name__}(id()={hex(id(self))})" + + +class _ObjectBuilderNameTransform(InstanceSpecSetTransform[Sequence[str]]): + """Transforms specs into strings following the object builder scheme.""" + + @staticmethod + def _get_initializer_parameter_str( + element_name: str, + entity_links: Sequence[EntityReference], + time_granularity: Optional[TimeGranularity], + date_part: Optional[DatePart], + ) -> str: + """Return the parameters that should go in the initializer. + + e.g. `'user__country', time_granularity_name='month'` + """ + initializer_parameters = [] + entity_link_names = list(entity_link.element_name for entity_link in entity_links) + if len(entity_link_names) > 0: + initializer_parameters.append(repr(entity_link_names[-1] + DUNDER + element_name)) + else: + initializer_parameters.append(repr(element_name)) + if time_granularity is not None: + initializer_parameters.append( + f"'{time_granularity.value}'", + ) + if date_part is not None: + initializer_parameters.append(f"date_part_name={repr(date_part.value)}") + if len(entity_link_names) > 1: + initializer_parameters.append(f"entity_path={repr(entity_link_names[:-1])}") + + return ", ".join(initializer_parameters) + + @override + def transform(self, spec_set: InstanceSpecSet) -> Sequence[str]: + assert len(spec_set.entity_specs) + len(spec_set.dimension_specs) + len(spec_set.time_dimension_specs) == 1 + + names_to_return = [] + + for entity_spec in spec_set.entity_specs: + initializer_parameter_str = _ObjectBuilderNameTransform._get_initializer_parameter_str( + element_name=entity_spec.element_name, + entity_links=entity_spec.entity_links, + time_granularity=None, + date_part=None, + ) + names_to_return.append(f"Entity({initializer_parameter_str})") + + for dimension_spec in spec_set.dimension_specs: + initializer_parameter_str = _ObjectBuilderNameTransform._get_initializer_parameter_str( + element_name=dimension_spec.element_name, + entity_links=dimension_spec.entity_links, + time_granularity=None, + date_part=None, + ) + names_to_return.append(f"Dimension({initializer_parameter_str})") + + for time_dimension_spec in spec_set.time_dimension_specs: + initializer_parameter_str = _ObjectBuilderNameTransform._get_initializer_parameter_str( + element_name=time_dimension_spec.element_name, + entity_links=time_dimension_spec.entity_links, + time_granularity=time_dimension_spec.time_granularity, + date_part=time_dimension_spec.date_part, + ) + names_to_return.append(f"TimeDimension({initializer_parameter_str})") + + return names_to_return diff --git a/metricflow/specs/patterns/metric_pattern.py b/metricflow/specs/patterns/metric_pattern.py new file mode 100644 index 0000000000..b08d73e5d7 --- /dev/null +++ b/metricflow/specs/patterns/metric_pattern.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Sequence + +from dbt_semantic_interfaces.references import MetricReference +from typing_extensions import override + +from metricflow.specs.patterns.spec_pattern import SpecPattern +from metricflow.specs.specs import ( + InstanceSpec, + InstanceSpecSet, + MetricSpec, +) + + +@dataclass(frozen=True) +class MetricSpecPattern(SpecPattern): + """Matches MetricSpecs that have the given metric_reference.""" + + metric_reference: MetricReference + + @override + def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[MetricSpec]: + spec_set = InstanceSpecSet.from_specs(candidate_specs) + return tuple( + metric_name for metric_name in spec_set.metric_specs if metric_name.reference == self.metric_reference + ) diff --git a/metricflow/test/naming/__init__.py b/metricflow/test/naming/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/metricflow/test/naming/conftest.py b/metricflow/test/naming/conftest.py new file mode 100644 index 0000000000..546f25c9a0 --- /dev/null +++ b/metricflow/test/naming/conftest.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart + +from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec +from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR + + +@pytest.fixture(scope="session") +def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D + return ( + # Time dimensions + MTD_SPEC_WEEK, + MTD_SPEC_MONTH, + MTD_SPEC_YEAR, + TimeDimensionSpec( + element_name="creation_time", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.MONTH, + date_part=DatePart.DAY, + ), + # Dimensions + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="booking"), + EntityReference(element_name="listing"), + ), + ), + DimensionSpec(element_name="is_instant", entity_links=(EntityReference(element_name="booking"),)), + # Entities + EntitySpec( + element_name="listing", + entity_links=(EntityReference(element_name="booking"),), + ), + EntitySpec( + element_name="user", + entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), + ), + ) diff --git a/metricflow/test/naming/test_dunder_naming_scheme.py b/metricflow/test/naming/test_dunder_naming_scheme.py new file mode 100644 index 0000000000..d6d716c975 --- /dev/null +++ b/metricflow/test/naming/test_dunder_naming_scheme.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart + +from metricflow.naming.dunder_scheme import DunderNamingScheme +from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec +from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR + + +@pytest.fixture(scope="session") +def dunder_naming_scheme() -> DunderNamingScheme: # noqa: D + return DunderNamingScheme() + + +def test_input_str(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D + assert ( + dunder_naming_scheme.input_str( + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="booking"), + EntityReference(element_name="listing"), + ), + ) + ) + == "booking__listing__country" + ) + + assert ( + dunder_naming_scheme.input_str( + TimeDimensionSpec( + element_name="creation_time", + entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), + time_granularity=TimeGranularity.MONTH, + date_part=DatePart.DAY, + ) + ) + is None + ) + + assert ( + dunder_naming_scheme.input_str( + TimeDimensionSpec( + element_name="creation_time", + entity_links=( + EntityReference(element_name="booking"), + EntityReference(element_name="listing"), + ), + time_granularity=TimeGranularity.MONTH, + ) + ) + == "booking__listing__creation_time__month" + ) + + assert ( + dunder_naming_scheme.input_str( + EntitySpec( + element_name="user", + entity_links=( + EntityReference(element_name="booking"), + EntityReference(element_name="listing"), + ), + ) + ) + == "booking__listing__user" + ) + + +def test_input_follows_scheme(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D + assert dunder_naming_scheme.input_str_follows_scheme("listing__country") + assert dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__month") + assert dunder_naming_scheme.input_str_follows_scheme("booking__listing") + assert not dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month") + assert not dunder_naming_scheme.input_str_follows_scheme("123") + assert not dunder_naming_scheme.input_str_follows_scheme("TimeDimension('metric_time')") + + +def test_spec_pattern( # noqa: D + dunder_naming_scheme: DunderNamingScheme, specs: Sequence[LinkableInstanceSpec] +) -> None: + assert tuple(dunder_naming_scheme.spec_pattern("listing__user__country").match(specs)) == ( + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="listing"), + EntityReference(element_name="user"), + ), + ), + ) + + assert tuple(dunder_naming_scheme.spec_pattern("metric_time").match(specs)) == ( + MTD_SPEC_WEEK, + MTD_SPEC_MONTH, + MTD_SPEC_YEAR, + ) + + assert tuple(dunder_naming_scheme.spec_pattern("booking__listing__user").match(specs)) == ( + EntitySpec( + element_name="user", + entity_links=( + EntityReference(element_name="booking"), + EntityReference(element_name="listing"), + ), + ), + ) + + assert tuple(dunder_naming_scheme.spec_pattern("metric_time__month").match(specs)) == (MTD_SPEC_MONTH,) diff --git a/metricflow/test/naming/test_metric_name_scheme.py b/metricflow/test/naming/test_metric_name_scheme.py new file mode 100644 index 0000000000..7084753408 --- /dev/null +++ b/metricflow/test/naming/test_metric_name_scheme.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest + +from metricflow.naming.metric_scheme import MetricNamingScheme +from metricflow.specs.specs import DimensionSpec, InstanceSpec, MetricSpec + + +@pytest.fixture(scope="session") +def metric_naming_scheme() -> MetricNamingScheme: # noqa: D + return MetricNamingScheme() + + +def test_input_str(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D + assert metric_naming_scheme.input_str(MetricSpec(element_name="example_metric")) == "example_metric" + + +def test_input_follows_scheme(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D + assert metric_naming_scheme.input_str_follows_scheme("some_metric_name") + + +def test_spec_pattern(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D + spec_pattern = metric_naming_scheme.spec_pattern("metric_0") + + specs: Sequence[InstanceSpec] = ( + MetricSpec(element_name="metric_0"), + MetricSpec(element_name="metric_1"), + # Shouldn't happen in practice, but checks to see that only metric specs are matched. + DimensionSpec(element_name="metric_0", entity_links=()), + ) + + assert (MetricSpec(element_name="metric_0"),) == tuple(spec_pattern.match(specs)) diff --git a/metricflow/test/naming/test_object_builder_naming_scheme.py b/metricflow/test/naming/test_object_builder_naming_scheme.py new file mode 100644 index 0000000000..ab52412cd1 --- /dev/null +++ b/metricflow/test/naming/test_object_builder_naming_scheme.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import Sequence + +import pytest +from dbt_semantic_interfaces.references import EntityReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.type_enums.date_part import DatePart + +from metricflow.naming.object_builder_scheme import ObjectBuilderNamingScheme +from metricflow.specs.specs import DimensionSpec, EntitySpec, LinkableInstanceSpec, TimeDimensionSpec +from metricflow.test.time.metric_time_dimension import MTD_SPEC_MONTH, MTD_SPEC_WEEK, MTD_SPEC_YEAR + + +@pytest.fixture(scope="session") +def object_builder_naming_scheme() -> ObjectBuilderNamingScheme: # noqa: D + return ObjectBuilderNamingScheme() + + +def test_input_str(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> None: # noqa: D + assert ( + object_builder_naming_scheme.input_str( + DimensionSpec( + element_name="country", + entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), + ) + ) + == "Dimension('listing__country', entity_path=['booking'])" + ) + + assert object_builder_naming_scheme.input_str( + TimeDimensionSpec( + element_name="creation_time", + entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), + time_granularity=TimeGranularity.MONTH, + date_part=DatePart.DAY, + ) + ) == ("TimeDimension('listing__creation_time', 'month', date_part_name='day', entity_path=['booking'])") + + assert ( + object_builder_naming_scheme.input_str( + EntitySpec( + element_name="user", + entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), + ) + ) + == "Entity('listing__user', entity_path=['booking'])" + ) + + +def test_input_follows_scheme(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> None: # noqa: D + assert object_builder_naming_scheme.input_str_follows_scheme( + "Dimension('listing__country', entity_path=['booking'])" + ) + assert object_builder_naming_scheme.input_str_follows_scheme( + "TimeDimension('listing__creation_time', time_granularity_name='month', date_part_name='day', " + "entity_path=['booking'])" + ) + assert object_builder_naming_scheme.input_str_follows_scheme( + "Entity('user', entity_path=['booking', 'listing'])", + ) + assert not object_builder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month") + assert not object_builder_naming_scheme.input_str_follows_scheme("123") + assert not object_builder_naming_scheme.input_str_follows_scheme("NotADimension('listing__country')") + + +def test_spec_pattern( # noqa: D + object_builder_naming_scheme: ObjectBuilderNamingScheme, specs: Sequence[LinkableInstanceSpec] +) -> None: + assert tuple( + object_builder_naming_scheme.spec_pattern("Dimension('listing__country', entity_path=['booking'])").match(specs) + ) == ( + DimensionSpec( + element_name="country", + entity_links=( + EntityReference(element_name="booking"), + EntityReference(element_name="listing"), + ), + ), + ) + + assert tuple( + object_builder_naming_scheme.spec_pattern( + "TimeDimension('listing__creation_time', time_granularity_name='month', date_part_name='day', " + "entity_path=['booking'])" + ).match(specs) + ) == ( + TimeDimensionSpec( + element_name="creation_time", + entity_links=(EntityReference("booking"), EntityReference("listing")), + time_granularity=TimeGranularity.MONTH, + date_part=DatePart.DAY, + ), + ) + + assert tuple(object_builder_naming_scheme.spec_pattern("TimeDimension('metric_time')").match(specs)) == ( + MTD_SPEC_WEEK, + MTD_SPEC_MONTH, + MTD_SPEC_YEAR, + ) + + assert tuple( + object_builder_naming_scheme.spec_pattern("Entity('user', entity_path=['booking', 'listing'])").match(specs) + ) == ( + EntitySpec( + element_name="user", + entity_links=(EntityReference(element_name="booking"), EntityReference(element_name="listing")), + ), + )