Skip to content

Commit

Permalink
Add MeasureLookup to break-up SemanticModelLookup (#1486)
Browse files Browse the repository at this point in the history
To help break up `SemanticModelLookup`, this PR moves measure-related
lookup methods to the separate class `MeasureLookup`. Note that there
are additional methods to move later.
  • Loading branch information
plypaul authored Nov 1, 2024
1 parent 66f4ea2 commit 38dcd77
Show file tree
Hide file tree
Showing 15 changed files with 1,324 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,14 @@ def __init__(
continue
metric_reference = MetricReference(metric.name)
linkable_element_set_for_metric = self.get_linkable_elements_for_metrics([metric_reference])

defined_from_semantic_models = tuple(
self._semantic_model_lookup.get_semantic_model_for_measure(input_measure.measure_reference).reference
self._semantic_model_lookup.measure_lookup.get_properties(
input_measure.measure_reference
).model_reference
for input_measure in metric.input_measures
)

for linkable_entities in linkable_element_set_for_metric.path_key_to_linkable_entities.values():
for linkable_entity in linkable_entities:
# TODO: some users encounter a situation in which the entity reference is in the entity links. Debug why.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Mapping, Sequence, Tuple

from dbt_semantic_interfaces.protocols import Measure, SemanticModel
from dbt_semantic_interfaces.references import (
EntityReference,
MeasureReference,
SemanticModelReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.semantics.semantic_model_helper import SemanticModelHelper
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.time.granularity import ExpandedTimeGranularity


@dataclass(frozen=True)
class MeasureRelationshipPropertySet:
"""Properties of a measure that include how it relates to other elements in the semantic model."""

model_reference: SemanticModelReference
model_primary_entity: EntityReference

# This is the time dimension along which the measure will be aggregated when a metric built on this measure
# is queried with metric_time.
agg_time_dimension_reference: TimeDimensionReference
# This is the grain of the above dimension in the semantic model.
agg_time_granularity: TimeGranularity
# Specs that can be used to query the aggregation time dimension.
agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...]


class MeasureLookup:
"""Looks up properties related to measures.
The functionality of this method was split off from `SemanticModelLookup`, and there are additional items to
migrate.
"""

def __init__( # noqa: D107
self,
semantic_models: Sequence[SemanticModel],
custom_granularities: Mapping[str, ExpandedTimeGranularity],
) -> None:
self._measure_reference_to_property_set: Dict[MeasureReference, MeasureRelationshipPropertySet] = {}
self._measure_reference_to_measure: Dict[MeasureReference, Measure] = {}
for semantic_model in semantic_models:
semantic_model_reference = semantic_model.reference

primary_entity = SemanticModelHelper.resolved_primary_entity(semantic_model)
time_dimension_reference_to_grain = SemanticModelHelper.get_time_dimension_grains(semantic_model)

for measure in semantic_model.measures:
measure_reference = measure.reference
self._measure_reference_to_measure[measure_reference] = measure

agg_time_dimension_reference = semantic_model.checked_agg_time_dimension_for_measure(measure_reference)
agg_time_granularity = time_dimension_reference_to_grain.get(agg_time_dimension_reference)
if agg_time_granularity is None:
raise ValueError(
f"Could not find the defined grain of the aggregation time dimension for {measure=}"
)
self._measure_reference_to_property_set[measure.reference] = MeasureRelationshipPropertySet(
model_reference=semantic_model_reference,
model_primary_entity=primary_entity,
agg_time_dimension_reference=semantic_model.checked_agg_time_dimension_for_measure(
measure_reference
),
agg_time_granularity=agg_time_granularity,
agg_time_dimension_specs=tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
time_dimension_reference=agg_time_dimension_reference,
entity_links=(primary_entity,),
custom_granularities=custom_granularities,
)
),
)

def get_properties(self, measure_reference: MeasureReference) -> MeasureRelationshipPropertySet:
"""Return properties of the measure as it relates to other elements in the semantic model."""
property_set = self._measure_reference_to_property_set.get(measure_reference)
if property_set is None:
raise ValueError(
str(
LazyFormat(
"Unable to get properties as the given measure reference is unknown",
measure_reference=measure_reference,
known_measures=list(self._measure_reference_to_property_set.keys()),
)
)
)

return property_set

def get_measure(self, measure_reference: MeasureReference) -> Measure:
"""Return the measure object with the given reference."""
measure = self._measure_reference_to_measure.get(measure_reference)
if measure is None:
raise ValueError(
str(
LazyFormat(
"Unable to get the measure as the given reference is unknown",
measure_reference=measure_reference,
known_measures=self._measure_reference_to_property_set.keys(),
)
)
)

return measure
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def _get_agg_time_dimension_specs_for_metric(
metric = self.get_metric(metric_reference)
specs: Set[TimeDimensionSpec] = set()
for input_measure in metric.input_measures:
time_dimension_specs = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure(
measure_properties = self._semantic_model_lookup.measure_lookup.get_properties(
measure_reference=input_measure.measure_reference
)
specs.update(time_dimension_specs)
specs.update(measure_properties.agg_time_dimension_specs)
return list(specs)

def get_valid_agg_time_dimensions_for_metric(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Sequence
from typing import Dict, Mapping, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols import Dimension
from dbt_semantic_interfaces.protocols.entity import Entity
from dbt_semantic_interfaces.protocols.measure import Measure
Expand All @@ -10,8 +11,9 @@
EntityReference,
LinkableElementReference,
MeasureReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import EntityType
from dbt_semantic_interfaces.type_enums import DimensionType, EntityType, TimeGranularity


class SemanticModelHelper:
Expand Down Expand Up @@ -94,3 +96,25 @@ def get_dimension_from_semantic_model(
raise ValueError(
f"No dimension with name ({dimension_reference}) in semantic_model with name ({semantic_model.name})"
)

@staticmethod
def get_time_dimension_grains(semantic_model: SemanticModel) -> Mapping[TimeDimensionReference, TimeGranularity]:
"""Return a mapping of the defined time granularity of the time dimensions in the semantic mode."""
time_dimension_reference_to_grain: Dict[TimeDimensionReference, TimeGranularity] = {}

for dimension in semantic_model.dimensions:
if dimension.type is DimensionType.TIME:
if dimension.type_params is None:
raise ValueError(
f"A dimension is specified as a time dimension but does not specify a gain. This should have "
f"been caught in semantic-manifest validation {dimension=} {semantic_model=}"
)
time_dimension_reference_to_grain[
dimension.reference.time_dimension_reference
] = dimension.type_params.time_granularity
elif dimension.type is DimensionType.CATEGORICAL:
pass
else:
assert_values_exhausted(dimension.type)

return time_dimension_reference_to_grain
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow_semantics.errors.error_classes import InvalidSemanticModelError
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.semantics.element_group import ElementGrouper
from metricflow_semantics.model.semantics.measure_lookup import MeasureLookup
from metricflow_semantics.model.semantics.semantic_model_helper import SemanticModelHelper
from metricflow_semantics.model.spec_converters import MeasureConverter
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
Expand Down Expand Up @@ -60,7 +61,8 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
] = {}

self._semantic_model_reference_to_semantic_model: Dict[SemanticModelReference, SemanticModel] = {}
for semantic_model in sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name):
sorted_semantic_models = sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name)
for semantic_model in sorted_semantic_models:
self._add_semantic_model(semantic_model)

# Cache for defined time granularity.
Expand All @@ -69,6 +71,8 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
# Cache for agg. time dimension for measure.
self._measure_reference_to_agg_time_dimension_specs: Dict[MeasureReference, Sequence[TimeDimensionSpec]] = {}

self._measure_lookup = MeasureLookup(sorted_semantic_models, custom_granularities)

def get_dimension_references(self) -> Sequence[DimensionReference]:
"""Retrieve all dimension references from the collection of semantic models."""
return tuple(self._dimension_index.keys())
Expand Down Expand Up @@ -315,3 +319,7 @@ def _get_defined_time_granularity(self, time_dimension_reference: TimeDimensionR
defined_time_granularity = time_dimension.type_params.time_granularity

return defined_time_granularity

@property
def measure_lookup(self) -> MeasureLookup: # noqa: D102
return self._measure_lookup
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,9 @@ def _get_models_for_measures(self, resolution_dag: GroupByItemResolutionDag) ->

model_references: Set[SemanticModelReference] = set()
for measure_reference in measure_references:
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.measure_lookup.get_properties(
measure_reference
)
model_references.add(measure_semantic_model.reference)
).model_reference
model_references.add(measure_semantic_model)

return model_references
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ def _validate_derived_metric(

def _scd_linkable_element_set_for_measure(self, measure_reference: MeasureReference) -> LinkableElementSet:
"""Returns subset of the query's `LinkableElements` that are SCDs and associated with the measure."""
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.measure_lookup.get_properties(
measure_reference
)
).model_reference

return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model.reference)
return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model)

@override
def validate_metric_in_resolution_dag(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def included_agg_time_dimension_specs_for_measure(
"""Get the time dims included that are valid agg time dimensions for the specified measure."""
queried_metric_time_specs = list(self.metric_time_specs)

valid_agg_time_dimensions = semantic_model_lookup.get_agg_time_dimension_specs_for_measure(measure_reference)
valid_agg_time_dimensions = semantic_model_lookup.measure_lookup.get_properties(
measure_reference
).agg_time_dimension_specs
queried_agg_time_dimension_specs = (
list(set(self.time_dimension_specs).intersection(set(valid_agg_time_dimensions)))
+ queried_metric_time_specs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference
Expand Down Expand Up @@ -205,7 +205,7 @@ def generate_possible_specs_for_time_dimension(
cls,
time_dimension_reference: TimeDimensionReference,
entity_links: Tuple[EntityReference, ...],
custom_granularities: Dict[str, ExpandedTimeGranularity],
custom_granularities: Mapping[str, ExpandedTimeGranularity],
) -> List[TimeDimensionSpec]:
"""Generate a list of time dimension specs with all combinations of granularity & date part."""
time_dimension_specs: List[TimeDimensionSpec] = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.implementations.semantic_manifest import PydanticSemanticManifest
from dbt_semantic_interfaces.references import MeasureReference
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.model.semantics.measure_lookup import MeasureLookup
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration
from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal


@pytest.fixture(scope="module")
def measure_lookup(extended_date_semantic_manifest_lookup: SemanticManifestLookup) -> MeasureLookup: # noqa: D103
return extended_date_semantic_manifest_lookup.semantic_model_lookup.measure_lookup


def test_get_measure(extended_date_manifest: PydanticSemanticManifest, measure_lookup: MeasureLookup) -> None:
"""Test that all measures in the manifest can be retrieved."""
for semantic_model in extended_date_manifest.semantic_models:
for measure in semantic_model.measures:
assert measure == measure_lookup.get_measure(measure.reference)


def test_measure_properties(
request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, measure_lookup: MeasureLookup
) -> None:
"""Test a couple of measures for correct properties."""
# Check `bookings` and `booking_payments` together as they have different aggregation time dimensions.
measure_names = ["bookings", "bookings_monthly"]
result = {
measure_name: measure_lookup.get_properties(MeasureReference(measure_name)) for measure_name in measure_names
}
assert_object_snapshot_equal(
request=request, mf_test_configuration=mf_test_configuration, obj_id="obj_0", obj=result
)
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def test_get_valid_agg_time_dimensions_for_metric( # noqa: D103
metric_agg_time_dims = metric_lookup.get_valid_agg_time_dimensions_for_metric(metric_reference)
measure_agg_time_dims = list(
{
semantic_model_lookup.get_agg_time_dimension_for_measure(measure.measure_reference)
semantic_model_lookup.measure_lookup.get_properties(
measure.measure_reference
).agg_time_dimension_reference
for measure in metric.input_measures
}
)
Expand Down
Loading

0 comments on commit 38dcd77

Please sign in to comment.