diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py index 3dd0468a54..499ca16cd4 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py @@ -142,10 +142,14 @@ def __init__( continue metric_reference = MetricReference(metric.name) linkable_element_set_for_metric = self.get_linkable_elements_for_metrics([metric_reference]) + defined_from_semantic_models = tuple( - self._semantic_model_lookup.get_semantic_model_for_measure(input_measure.measure_reference).reference + self._semantic_model_lookup.measure_lookup.get_properties( + input_measure.measure_reference + ).model_reference for input_measure in metric.input_measures ) + for linkable_entities in linkable_element_set_for_metric.path_key_to_linkable_entities.values(): for linkable_entity in linkable_entities: # TODO: some users encounter a situation in which the entity reference is in the entity links. Debug why. diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py b/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py index 408fb3c11b..0a7869e25f 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py @@ -218,10 +218,10 @@ def _get_agg_time_dimension_specs_for_metric( metric = self.get_metric(metric_reference) specs: Set[TimeDimensionSpec] = set() for input_measure in metric.input_measures: - time_dimension_specs = self._semantic_model_lookup.get_agg_time_dimension_specs_for_measure( + measure_properties = self._semantic_model_lookup.measure_lookup.get_properties( measure_reference=input_measure.measure_reference ) - specs.update(time_dimension_specs) + specs.update(measure_properties.agg_time_dimension_specs) return list(specs) def get_valid_agg_time_dimensions_for_metric( diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_helper.py b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_helper.py index bafc655316..61c0edb031 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_helper.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_helper.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import Sequence +from typing import Dict, Mapping, Sequence +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.protocols import Dimension from dbt_semantic_interfaces.protocols.entity import Entity from dbt_semantic_interfaces.protocols.measure import Measure @@ -10,8 +11,9 @@ EntityReference, LinkableElementReference, MeasureReference, + TimeDimensionReference, ) -from dbt_semantic_interfaces.type_enums import EntityType +from dbt_semantic_interfaces.type_enums import DimensionType, EntityType, TimeGranularity class SemanticModelHelper: @@ -94,3 +96,25 @@ def get_dimension_from_semantic_model( raise ValueError( f"No dimension with name ({dimension_reference}) in semantic_model with name ({semantic_model.name})" ) + + @staticmethod + def get_time_dimension_grains(semantic_model: SemanticModel) -> Mapping[TimeDimensionReference, TimeGranularity]: + """Return a mapping of the defined time granularity of the time dimensions in the semantic mode.""" + time_dimension_reference_to_grain: Dict[TimeDimensionReference, TimeGranularity] = {} + + for dimension in semantic_model.dimensions: + if dimension.type is DimensionType.TIME: + if dimension.type_params is None: + raise ValueError( + f"A dimension is specified as a time dimension but does not specify a gain. This should have " + f"been caught in semantic-manifest validation {dimension=} {semantic_model=}" + ) + time_dimension_reference_to_grain[ + dimension.reference.time_dimension_reference + ] = dimension.type_params.time_granularity + elif dimension.type is DimensionType.CATEGORICAL: + pass + else: + assert_values_exhausted(dimension.type) + + return time_dimension_reference_to_grain diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py index 70d34f4332..0dd2091633 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py @@ -21,6 +21,7 @@ from metricflow_semantics.errors.error_classes import InvalidSemanticModelError from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.model.semantics.element_group import ElementGrouper +from metricflow_semantics.model.semantics.measure_lookup import MeasureLookup from metricflow_semantics.model.semantics.semantic_model_helper import SemanticModelHelper from metricflow_semantics.model.spec_converters import MeasureConverter from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName @@ -60,7 +61,8 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa ] = {} self._semantic_model_reference_to_semantic_model: Dict[SemanticModelReference, SemanticModel] = {} - for semantic_model in sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name): + sorted_semantic_models = sorted(model.semantic_models, key=lambda semantic_model: semantic_model.name) + for semantic_model in sorted_semantic_models: self._add_semantic_model(semantic_model) # Cache for defined time granularity. @@ -69,6 +71,8 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa # Cache for agg. time dimension for measure. self._measure_reference_to_agg_time_dimension_specs: Dict[MeasureReference, Sequence[TimeDimensionSpec]] = {} + self._measure_lookup = MeasureLookup(sorted_semantic_models, custom_granularities) + def get_dimension_references(self) -> Sequence[DimensionReference]: """Retrieve all dimension references from the collection of semantic models.""" return tuple(self._dimension_index.keys()) @@ -315,3 +319,7 @@ def _get_defined_time_granularity(self, time_dimension_reference: TimeDimensionR defined_time_granularity = time_dimension.type_params.time_granularity return defined_time_granularity + + @property + def measure_lookup(self) -> MeasureLookup: # noqa: D102 + return self._measure_lookup diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index 87e69a24d2..0469bf633e 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -607,9 +607,9 @@ def _get_models_for_measures(self, resolution_dag: GroupByItemResolutionDag) -> model_references: Set[SemanticModelReference] = set() for measure_reference in measure_references: - measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure( + measure_semantic_model = self._manifest_lookup.semantic_model_lookup.measure_lookup.get_properties( measure_reference - ) - model_references.add(measure_semantic_model.reference) + ).model_reference + model_references.add(measure_semantic_model) return model_references diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index 7d09efc1e1..3bc7d55e70 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -136,11 +136,11 @@ def _validate_derived_metric( def _scd_linkable_element_set_for_measure(self, measure_reference: MeasureReference) -> LinkableElementSet: """Returns subset of the query's `LinkableElements` that are SCDs and associated with the measure.""" - measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure( + measure_semantic_model = self._manifest_lookup.semantic_model_lookup.measure_lookup.get_properties( measure_reference - ) + ).model_reference - return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model.reference) + return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model) @override def validate_metric_in_resolution_dag( diff --git a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py index 4f57921d94..40cc5fe81a 100644 --- a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py +++ b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py @@ -75,7 +75,9 @@ def included_agg_time_dimension_specs_for_measure( """Get the time dims included that are valid agg time dimensions for the specified measure.""" queried_metric_time_specs = list(self.metric_time_specs) - valid_agg_time_dimensions = semantic_model_lookup.get_agg_time_dimension_specs_for_measure(measure_reference) + valid_agg_time_dimensions = semantic_model_lookup.measure_lookup.get_properties( + measure_reference + ).agg_time_dimension_specs queried_agg_time_dimension_specs = ( list(set(self.time_dimension_specs).intersection(set(valid_agg_time_dimensions))) + queried_metric_time_specs diff --git a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py index 9b42dcf2ed..48f8b3f7eb 100644 --- a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from functools import lru_cache -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference @@ -205,7 +205,7 @@ def generate_possible_specs_for_time_dimension( cls, time_dimension_reference: TimeDimensionReference, entity_links: Tuple[EntityReference, ...], - custom_granularities: Dict[str, ExpandedTimeGranularity], + custom_granularities: Mapping[str, ExpandedTimeGranularity], ) -> List[TimeDimensionSpec]: """Generate a list of time dimension specs with all combinations of granularity & date part.""" time_dimension_specs: List[TimeDimensionSpec] = [] diff --git a/metricflow-semantics/tests_metricflow_semantics/model/test_semantic_model_container.py b/metricflow-semantics/tests_metricflow_semantics/model/test_semantic_model_container.py index b4b3f68873..907c5d3bb9 100644 --- a/metricflow-semantics/tests_metricflow_semantics/model/test_semantic_model_container.py +++ b/metricflow-semantics/tests_metricflow_semantics/model/test_semantic_model_container.py @@ -228,7 +228,9 @@ def test_get_valid_agg_time_dimensions_for_metric( # noqa: D103 metric_agg_time_dims = metric_lookup.get_valid_agg_time_dimensions_for_metric(metric_reference) measure_agg_time_dims = list( { - semantic_model_lookup.get_agg_time_dimension_for_measure(measure.measure_reference) + semantic_model_lookup.measure_lookup.get_properties( + measure.measure_reference + ).agg_time_dimension_reference for measure in metric.input_measures } ) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 7f9d154880..78e74ad8bb 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -303,13 +303,19 @@ def _build_aggregated_conversion_node( ) # Get the agg time dimension for each measure used for matching conversion time windows + base_measure_properties = self._semantic_model_lookup.measure_lookup.get_properties( + base_measure_spec.measure_spec.reference + ) base_time_dimension_spec = TimeDimensionSpec.from_reference( - self._semantic_model_lookup.get_agg_time_dimension_for_measure(base_measure_spec.measure_spec.reference) + base_measure_properties.agg_time_dimension_reference + ) + + conversion_measure_properties = self._semantic_model_lookup.measure_lookup.get_properties( + conversion_measure_spec.measure_spec.reference ) + conversion_time_dimension_spec = TimeDimensionSpec.from_reference( - self._semantic_model_lookup.get_agg_time_dimension_for_measure( - conversion_measure_spec.measure_spec.reference - ) + conversion_measure_properties.agg_time_dimension_reference ) # Filter the source nodes with only the required specs needed for the calculation @@ -968,7 +974,9 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - if len(measure_specs) == 0: raise ValueError("Cannot build MeasureParametersForRecipe when given an empty sequence of measure_specs.") semantic_model_names = { - self._semantic_model_lookup.get_semantic_model_for_measure(measure.reference).name + self._semantic_model_lookup.measure_lookup.get_properties( + measure.reference + ).model_reference.semantic_model_name for measure in measure_specs } if len(semantic_model_names) > 1: @@ -978,14 +986,16 @@ def _build_measure_spec_properties(self, measure_specs: Sequence[MeasureSpec]) - ) semantic_model_name = semantic_model_names.pop() - agg_time_dimension = self._semantic_model_lookup.get_agg_time_dimension_for_measure(measure_specs[0].reference) + agg_time_dimension = self._semantic_model_lookup.measure_lookup.get_properties( + measure_specs[0].reference + ).agg_time_dimension_reference non_additive_dimension_spec = measure_specs[0].non_additive_dimension_spec for measure_spec in measure_specs: if non_additive_dimension_spec != measure_spec.non_additive_dimension_spec: raise ValueError(f"measure_specs {measure_specs} do not have the same non_additive_dimension_spec.") - measure_agg_time_dimension = self._semantic_model_lookup.get_agg_time_dimension_for_measure( + measure_agg_time_dimension = self._semantic_model_lookup.measure_lookup.get_properties( measure_spec.reference - ) + ).agg_time_dimension_reference if measure_agg_time_dimension != agg_time_dimension: raise ValueError(f"measure_specs {measure_specs} do not have the same agg_time_dimension.") return MeasureSpecProperties( diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index 52bc82a6a5..2c4f33ce08 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -562,14 +562,14 @@ def get_measures_for_metrics(self, metric_names: List[str]) -> List[Measure]: # for input_measure in metric.input_measures: measure_reference = MeasureReference(element_name=input_measure.name) # populate new obj - measure = semantic_model_lookup.get_measure(measure_reference=measure_reference) + measure = semantic_model_lookup.measure_lookup.get_measure(measure_reference=measure_reference) measures.add( Measure( name=measure.name, agg=measure.agg, - agg_time_dimension=semantic_model_lookup.get_agg_time_dimension_for_measure( + agg_time_dimension=semantic_model_lookup.measure_lookup.get_properties( measure_reference=measure_reference - ).element_name, + ).agg_time_dimension_reference.element_name, description=measure.description, expr=measure.expr, agg_params=measure.agg_params, diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 2447bcc56f..d49457a637 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -214,7 +214,7 @@ def _make_sql_column_expression_to_aggregate_measure( # Create an expression that will aggregate the given measure. # Figure out the aggregation function for the measure. - measure = self._semantic_model_lookup.get_measure(measure_instance.spec.reference) + measure = self._semantic_model_lookup.measure_lookup.get_measure(measure_instance.spec.reference) aggregation_type = measure.agg expression_to_get_measure = SqlColumnReferenceExpression.create(