Skip to content

Commit

Permalink
Update call sites to use MeasureLookup.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 1, 2024
1 parent 1c869dc commit 6099e3b
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,14 @@ def __init__(
continue
metric_reference = MetricReference(metric.name)
linkable_element_set_for_metric = self.get_linkable_elements_for_metrics([metric_reference])

defined_from_semantic_models = tuple(
self._semantic_model_lookup.get_semantic_model_for_measure(input_measure.measure_reference).reference
self._semantic_model_lookup.measure_lookup.get_properties(
input_measure.measure_reference
).model_reference
for input_measure in metric.input_measures
)

for linkable_entities in linkable_element_set_for_metric.path_key_to_linkable_entities.values():
for linkable_entity in linkable_entities:
# TODO: some users encounter a situation in which the entity reference is in the entity links. Debug why.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def _get_agg_time_dimension_specs_for_metric(
metric = self.get_metric(metric_reference)
specs: Set[TimeDimensionSpec] = set()
for input_measure in metric.input_measures:
time_dimension_specs = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure(
measure_properties = self._semantic_model_lookup.measure_lookup.get_properties(
measure_reference=input_measure.measure_reference
)
specs.update(time_dimension_specs)
specs.update(measure_properties.agg_time_dimension_specs)
return list(specs)

def get_valid_agg_time_dimensions_for_metric(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Sequence
from typing import Dict, Mapping, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols import Dimension
from dbt_semantic_interfaces.protocols.entity import Entity
from dbt_semantic_interfaces.protocols.measure import Measure
Expand All @@ -10,8 +11,9 @@
EntityReference,
LinkableElementReference,
MeasureReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import EntityType
from dbt_semantic_interfaces.type_enums import DimensionType, EntityType, TimeGranularity


class SemanticModelHelper:
Expand Down Expand Up @@ -94,3 +96,25 @@ def get_dimension_from_semantic_model(
raise ValueError(
f"No dimension with name ({dimension_reference}) in semantic_model with name ({semantic_model.name})"
)

@staticmethod
def get_time_dimension_grains(semantic_model: SemanticModel) -> Mapping[TimeDimensionReference, TimeGranularity]:
"""Return a mapping of the defined time granularity of the time dimensions in the semantic mode."""
time_dimension_reference_to_grain: Dict[TimeDimensionReference, TimeGranularity] = {}

for dimension in semantic_model.dimensions:
if dimension.type is DimensionType.TIME:
if dimension.type_params is None:
raise ValueError(
f"A dimension is specified as a time dimension but does not specify a gain. This should have "
f"been caught in semantic-manifest validation {dimension=} {semantic_model=}"
)
time_dimension_reference_to_grain[
dimension.reference.time_dimension_reference
] = dimension.type_params.time_granularity
elif dimension.type is DimensionType.CATEGORICAL:
pass
else:
assert_values_exhausted(dimension.type)

return time_dimension_reference_to_grain
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow_semantics.errors.error_classes import InvalidSemanticModelError
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.semantics.element_group import ElementGrouper
from metricflow_semantics.model.semantics.measure_lookup import MeasureLookup
from metricflow_semantics.model.semantics.semantic_model_helper import SemanticModelHelper
from metricflow_semantics.model.spec_converters import MeasureConverter
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
Expand Down Expand Up @@ -60,7 +61,8 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
] = {}

self._semantic_model_reference_to_semantic_model: Dict[SemanticModelReference, SemanticModel] = {}
for semantic_model in sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name):
sorted_semantic_models = sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name)
for semantic_model in sorted_semantic_models:
self._add_semantic_model(semantic_model)

# Cache for defined time granularity.
Expand All @@ -69,6 +71,8 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa
# Cache for agg. time dimension for measure.
self._measure_reference_to_agg_time_dimension_specs: Dict[MeasureReference, Sequence[TimeDimensionSpec]] = {}

self._measure_lookup = MeasureLookup(sorted_semantic_models, custom_granularities)

def get_dimension_references(self) -> Sequence[DimensionReference]:
"""Retrieve all dimension references from the collection of semantic models."""
return tuple(self._dimension_index.keys())
Expand Down Expand Up @@ -315,3 +319,7 @@ def _get_defined_time_granularity(self, time_dimension_reference: TimeDimensionR
defined_time_granularity = time_dimension.type_params.time_granularity

return defined_time_granularity

@property
def measure_lookup(self) -> MeasureLookup: # noqa: D102
return self._measure_lookup
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,9 @@ def _get_models_for_measures(self, resolution_dag: GroupByItemResolutionDag) ->

model_references: Set[SemanticModelReference] = set()
for measure_reference in measure_references:
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.measure_lookup.get_properties(
measure_reference
)
model_references.add(measure_semantic_model.reference)
).model_reference
model_references.add(measure_semantic_model)

return model_references
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ def _validate_derived_metric(

def _scd_linkable_element_set_for_measure(self, measure_reference: MeasureReference) -> LinkableElementSet:
"""Returns subset of the query's `LinkableElements` that are SCDs and associated with the measure."""
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.measure_lookup.get_properties(
measure_reference
)
).model_reference

return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model.reference)
return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model)

@override
def validate_metric_in_resolution_dag(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def included_agg_time_dimension_specs_for_measure(
"""Get the time dims included that are valid agg time dimensions for the specified measure."""
queried_metric_time_specs = list(self.metric_time_specs)

valid_agg_time_dimensions = semantic_model_lookup.get_agg_time_dimension_specs_for_measure(measure_reference)
valid_agg_time_dimensions = semantic_model_lookup.measure_lookup.get_properties(
measure_reference
).agg_time_dimension_specs
queried_agg_time_dimension_specs = (
list(set(self.time_dimension_specs).intersection(set(valid_agg_time_dimensions)))
+ queried_metric_time_specs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference
Expand Down Expand Up @@ -205,7 +205,7 @@ def generate_possible_specs_for_time_dimension(
cls,
time_dimension_reference: TimeDimensionReference,
entity_links: Tuple[EntityReference, ...],
custom_granularities: Dict[str, ExpandedTimeGranularity],
custom_granularities: Mapping[str, ExpandedTimeGranularity],
) -> List[TimeDimensionSpec]:
"""Generate a list of time dimension specs with all combinations of granularity & date part."""
time_dimension_specs: List[TimeDimensionSpec] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def test_get_valid_agg_time_dimensions_for_metric( # noqa: D103
metric_agg_time_dims = metric_lookup.get_valid_agg_time_dimensions_for_metric(metric_reference)
measure_agg_time_dims = list(
{
semantic_model_lookup.get_agg_time_dimension_for_measure(measure.measure_reference)
semantic_model_lookup.measure_lookup.get_properties(
measure.measure_reference
).agg_time_dimension_reference
for measure in metric.input_measures
}
)
Expand Down
12 changes: 8 additions & 4 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,9 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
if len(measure_specs) == 0:
raise ValueError("Cannot build MeasureParametersForRecipe when given an empty sequence of measure_specs.")
semantic_model_names = {
self._semantic_model_lookup.get_semantic_model_for_measure(measure.reference).name
self._semantic_model_lookup.measure_lookup.get_properties(
measure.reference
).model_reference.semantic_model_name
for measure in measure_specs
}
if len(semantic_model_names) > 1:
Expand All @@ -978,14 +980,16 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) -
)
semantic_model_name = semantic_model_names.pop()

agg_time_dimension = self._semantic_model_lookup.get_agg_time_dimension_for_measure(measure_specs[0].reference)
agg_time_dimension = self._semantic_model_lookup.measure_lookup.get_properties(
measure_specs[0].reference
).agg_time_dimension_reference
non_additive_dimension_spec = measure_specs[0].non_additive_dimension_spec
for measure_spec in measure_specs:
if non_additive_dimension_spec != measure_spec.non_additive_dimension_spec:
raise ValueError(f"measure_specs {measure_specs} do not have the same non_additive_dimension_spec.")
measure_agg_time_dimension = self._semantic_model_lookup.get_agg_time_dimension_for_measure(
measure_agg_time_dimension = self._semantic_model_lookup.measure_lookup.get_properties(
measure_spec.reference
)
).agg_time_dimension_reference
if measure_agg_time_dimension != agg_time_dimension:
raise ValueError(f"measure_specs {measure_specs} do not have the same agg_time_dimension.")
return MeasureSpecProperties(
Expand Down
6 changes: 3 additions & 3 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,14 @@ def get_measures_for_metrics(self, metric_names: List[str]) -> List[Measure]: #
for input_measure in metric.input_measures:
measure_reference = MeasureReference(element_name=input_measure.name)
# populate new obj
measure = semantic_model_lookup.get_measure(measure_reference=measure_reference)
measure = semantic_model_lookup.measure_lookup.get_measure(measure_reference=measure_reference)
measures.add(
Measure(
name=measure.name,
agg=measure.agg,
agg_time_dimension=semantic_model_lookup.get_agg_time_dimension_for_measure(
agg_time_dimension=semantic_model_lookup.measure_lookup.get_properties(
measure_reference=measure_reference
).element_name,
).agg_time_dimension_reference.element_name,
description=measure.description,
expr=measure.expr,
agg_params=measure.agg_params,
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _make_sql_column_expression_to_aggregate_measure(

# Create an expression that will aggregate the given measure.
# Figure out the aggregation function for the measure.
measure = self._semantic_model_lookup.get_measure(measure_instance.spec.reference)
measure = self._semantic_model_lookup.measure_lookup.get_measure(measure_instance.spec.reference)
aggregation_type = measure.agg

expression_to_get_measure = SqlColumnReferenceExpression.create(
Expand Down

0 comments on commit 6099e3b

Please sign in to comment.