diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index d602c38bb9..8d1e1fbc33 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -704,16 +704,18 @@ def _contains_multihop_linkables(linkable_specs: Sequence[LinkableInstanceSpec]) """Returns true if any of the linkable specs requires a multi-hop join to realize.""" return any(len(x.entity_links) > 1 for x in linkable_specs) - def _get_semantic_model_names_for_measures(self, measure_names: Sequence[MeasureSpec]) -> Set[str]: + def _get_semantic_model_names_for_measures(self, measures: Sequence[MeasureSpec]) -> Set[str]: """Return the names of the semantic models needed to compute the input measures. This is a temporary method for use in assertion boundaries while we implement support for multiple semantic models """ semantic_model_names: Set[str] = set() - for measure_name in measure_names: - semantic_model_names = semantic_model_names.union( - {d.name for d in self._semantic_model_lookup.get_semantic_models_for_measure(measure_name.reference)} - ) + for measure in measures: + semantic_model = self._semantic_model_lookup.get_semantic_model_for_measure(measure.reference) + if not semantic_model: + raise ValueError(f"Could not find measure with name {measure.reference} in configured semantic models.") + semantic_model_names.add(semantic_model.name) + return semantic_model_names def _sort_by_suitability(self, nodes: Sequence[BaseOutput]) -> Sequence[BaseOutput]: diff --git a/metricflow/model/semantics/semantic_model_lookup.py b/metricflow/model/semantics/semantic_model_lookup.py index 9aceec4582..3d68d460a8 100644 --- a/metricflow/model/semantics/semantic_model_lookup.py +++ b/metricflow/model/semantics/semantic_model_lookup.py @@ -52,12 +52,13 @@ def __init__( # noqa: D model: SemanticManifest, ) -> None: self._model = model - self._measure_index: Dict[MeasureReference, List[SemanticModel]] = defaultdict(list) + self._measure_index: Dict[MeasureReference, SemanticModel] = {} self._measure_aggs: Dict[ MeasureReference, AggregationType ] = {} # maps measures to their one consistent aggregation self._measure_agg_time_dimension: Dict[MeasureReference, TimeDimensionReference] = {} self._measure_non_additive_dimension_specs: Dict[MeasureReference, NonAdditiveDimensionSpec] = {} + # TODO: remove defaultdicts. Will add fake elements to the dict that might get referenced. self._dimension_index: Dict[DimensionReference, List[SemanticModel]] = defaultdict(list) self._linkable_reference_index: Dict[LinkableElementReference, List[SemanticModel]] = defaultdict(list) self._entity_index: Dict[Optional[str], List[SemanticModel]] = defaultdict(list) @@ -141,12 +142,10 @@ def get_measure_from_semantic_model(semantic_model: SemanticModel, measure_refer ) def get_measure(self, measure_reference: MeasureReference) -> Measure: # noqa: D - if measure_reference not in self._measure_index: + semantic_model = self._measure_index.get(measure_reference) + if not semantic_model: raise ValueError(f"Could not find measure with name ({measure_reference}) in configured semantic models") - assert len(self._measure_index[measure_reference]) >= 1 - # Measures should be consistent across semantic models, so just use the first one. - semantic_model = list(self._measure_index[measure_reference])[0] return SemanticModelLookup.get_measure_from_semantic_model( semantic_model=semantic_model, measure_reference=measure_reference ) @@ -155,10 +154,8 @@ def get_entity_references(self) -> Sequence[EntityReference]: # noqa: D return list(self._entity_ref_to_entity.keys()) # DSC interface - def get_semantic_models_for_measure( # noqa: D - self, measure_reference: MeasureReference - ) -> Sequence[SemanticModel]: - return self._measure_index[measure_reference] + def get_semantic_model_for_measure(self, measure_reference: MeasureReference) -> Optional[SemanticModel]: # noqa: D + return self._measure_index.get(measure_reference) def get_agg_time_dimension_for_measure( # noqa: D self, measure_reference: MeasureReference @@ -202,7 +199,7 @@ def _add_semantic_model(self, semantic_model: SemanticModel) -> None: for measure in semantic_model.measures: self._measure_aggs[measure.reference] = measure.agg - self._measure_index[measure.reference].append(semantic_model) + self._measure_index[measure.reference] = semantic_model agg_time_dimension_reference = semantic_model.checked_agg_time_dimension_for_measure(measure.reference) matching_dimensions = tuple( diff --git a/metricflow/protocols/semantics.py b/metricflow/protocols/semantics.py index cdaeeaa9a6..291eb1d25a 100644 --- a/metricflow/protocols/semantics.py +++ b/metricflow/protocols/semantics.py @@ -85,7 +85,7 @@ def get_entity_references(self) -> Sequence[EntityReference]: raise NotImplementedError @abstractmethod - def get_semantic_models_for_measure(self, measure_reference: MeasureReference) -> Sequence[SemanticModel]: + def get_semantic_model_for_measure(self, measure_reference: MeasureReference) -> Optional[SemanticModel]: """Retrieve a list of all semantic model model objects associated with the measure reference.""" raise NotImplementedError diff --git a/metricflow/query/issues/parsing/cumulative_metric_requires_metric_time.py b/metricflow/query/issues/parsing/cumulative_metric_requires_metric_time.py index 484cd5131e..60893c6b98 100644 --- a/metricflow/query/issues/parsing/cumulative_metric_requires_metric_time.py +++ b/metricflow/query/issues/parsing/cumulative_metric_requires_metric_time.py @@ -24,7 +24,9 @@ class CumulativeMetricRequiresMetricTimeIssue(MetricFlowQueryResolutionIssue): def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str: return ( f"The query includes a cumulative metric {repr(self.metric_reference.element_name)} that does not " - f"accumulate over all-time, but the group-by items do not include {repr(METRIC_TIME_ELEMENT_NAME)}" + f"accumulate over all-time, but the group-by items do not include {repr(METRIC_TIME_ELEMENT_NAME)} " + "or the metric's agg_time_dimension." + # TODO: add name of agg_time_dim? ) @override diff --git a/metricflow/query/validation_rules/metric_time_requirements.py b/metricflow/query/validation_rules/metric_time_requirements.py index fe75f6854c..61c8e506aa 100644 --- a/metricflow/query/validation_rules/metric_time_requirements.py +++ b/metricflow/query/validation_rules/metric_time_requirements.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import List, Sequence +from typing import List, Sequence, Tuple from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME from dbt_semantic_interfaces.protocols import WhereFilterIntersection -from dbt_semantic_interfaces.references import MetricReference +from dbt_semantic_interfaces.references import EntityReference, MetricReference, TimeDimensionReference from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity from dbt_semantic_interfaces.type_enums.date_part import DatePart from typing_extensions import override @@ -34,29 +34,36 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D super().__init__(manifest_lookup=manifest_lookup) - metric_time_specs: List[TimeDimensionSpec] = [] + self._metric_time_specs = tuple( + self._generate_valid_specs_for_time_dimension( + time_dimension_reference=TimeDimensionReference(element_name=METRIC_TIME_ELEMENT_NAME), entity_links=() + ) + ) + def _generate_valid_specs_for_time_dimension( + self, time_dimension_reference: TimeDimensionReference, entity_links: Tuple[EntityReference, ...] + ) -> List[TimeDimensionSpec]: + time_dimension_specs: List[TimeDimensionSpec] = [] for time_granularity in TimeGranularity: - metric_time_specs.append( + time_dimension_specs.append( TimeDimensionSpec( - element_name=METRIC_TIME_ELEMENT_NAME, - entity_links=(), + element_name=time_dimension_reference.element_name, + entity_links=entity_links, time_granularity=time_granularity, date_part=None, ) ) for date_part in DatePart: for time_granularity in date_part.compatible_granularities: - metric_time_specs.append( + time_dimension_specs.append( TimeDimensionSpec( - element_name=METRIC_TIME_ELEMENT_NAME, - entity_links=(), + element_name=time_dimension_reference.element_name, + entity_links=entity_links, time_granularity=time_granularity, date_part=date_part, ) ) - - self._metric_time_specs = tuple(metric_time_specs) + return time_dimension_specs def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool: for group_by_item_input in query_resolver_input.group_by_item_inputs: @@ -65,6 +72,38 @@ def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInpu return False + def _group_by_items_include_agg_time_dimension( + self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference + ) -> bool: + metric = self._manifest_lookup.metric_lookup.get_metric(metric_reference=metric_reference) + semantic_model_lookup = self._manifest_lookup.semantic_model_lookup + + valid_agg_time_dimension_specs: List[TimeDimensionSpec] = [] + for measure_reference in metric.measure_references: + agg_time_dimension_reference = semantic_model_lookup.get_agg_time_dimension_for_measure(measure_reference) + semantic_model = semantic_model_lookup.get_semantic_model_for_measure(measure_reference) + assert semantic_model, f"No semantic model found for measure {measure_reference}." + + # is this too broad? need to narrow entity links? + possible_entity_links = semantic_model_lookup.entity_links_for_local_elements(semantic_model) + for entity_link in possible_entity_links: + valid_agg_time_dimension_specs.extend( + self._generate_valid_specs_for_time_dimension( + time_dimension_reference=agg_time_dimension_reference, entity_links=(entity_link,) + ) + ) + print("valid:::") + for x in valid_agg_time_dimension_specs: + print(f"\n{x}") + + print("requested:::") + for group_by_item_input in query_resolver_input.group_by_item_inputs: + print(print(f"\n{group_by_item_input}")) + if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs): + return True + + return False + @override def validate_metric_in_resolution_dag( self, @@ -73,7 +112,11 @@ def validate_metric_in_resolution_dag( resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: metric = self._get_metric(metric_reference) - query_includes_metric_time = self._group_by_items_include_metric_time(resolver_input_for_query) + query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time( + resolver_input_for_query + ) or self._group_by_items_include_agg_time_dimension( + query_resolver_input=resolver_input_for_query, metric_reference=metric_reference + ) if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -81,7 +124,7 @@ def validate_metric_in_resolution_dag( if ( metric.type_params is not None and (metric.type_params.window is not None or metric.type_params.grain_to_date is not None) - and not query_includes_metric_time + and not query_includes_metric_time_or_agg_time_dimension ): return MetricFlowQueryResolutionIssueSet.from_issue( CumulativeMetricRequiresMetricTimeIssue.from_parameters( @@ -97,7 +140,7 @@ def validate_metric_in_resolution_dag( for input_metric in metric.input_metrics ) - if has_time_offset and not query_includes_metric_time: + if has_time_offset and not query_includes_metric_time_or_agg_time_dimension: return MetricFlowQueryResolutionIssueSet.from_issue( OffsetMetricRequiresMetricTimeIssue.from_parameters( metric_reference=metric_reference, diff --git a/metricflow/test/integration/test_cases/itest_cumulative_metric.yaml b/metricflow/test/integration/test_cases/itest_cumulative_metric.yaml index a86a9ea838..1f5d8b0eb9 100644 --- a/metricflow/test/integration/test_cases/itest_cumulative_metric.yaml +++ b/metricflow/test/integration/test_cases/itest_cumulative_metric.yaml @@ -347,3 +347,31 @@ integration_test: OR {{ render_time_constraint("a.ds", "2020-01-04", "2020-01-04") }} GROUP BY a.ds ORDER BY a.ds +--- +integration_test: + name: cumulative_metric_with_agg_time_dimension + description: Query a cumulative metric with its agg_time_dimension. + model: SIMPLE_MODEL + metrics: ["trailing_2_months_revenue"] + group_bys: ["company__ds__day"] + order_bys: ["company__ds__day"] + time_constraint: ["2020-03-05", "2021-01-04"] + check_query: | + SELECT + SUM(b.txn_revenue) as trailing_2_months_revenue + , a.ds AS company__ds__day + FROM ( + SELECT ds + FROM {{ mf_time_spine_source }} + WHERE {{ render_time_constraint("ds", "2020-01-05", "2021-01-04") }} + ) a + INNER JOIN ( + SELECT + revenue as txn_revenue + , created_at AS ds + FROM {{ source_schema }}.fct_revenue + ) b + ON b.ds <= a.ds AND b.ds > {{ render_date_sub("a", "ds", 2, TimeGranularity.MONTH) }} + WHERE {{ render_time_constraint("a.ds", "2020-03-05", "2021-01-04") }} + GROUP BY a.ds + ORDER BY a.ds diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index 92462b515f..6b6e080470 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -230,7 +230,8 @@ def filter_not_supported_features( @pytest.mark.parametrize( "name", - CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names, + # CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names, + ["itest_cumulative_metric.yaml/cumulative_metric_with_agg_time_dimension"], ids=lambda name: f"name={name}", ) def test_case( diff --git a/metricflow/test/model/test_semantic_model_container.py b/metricflow/test/model/test_semantic_model_container.py index 80ec27d5c7..1623eae3cb 100644 --- a/metricflow/test/model/test_semantic_model_container.py +++ b/metricflow/test/model/test_semantic_model_container.py @@ -58,18 +58,18 @@ def test_get_elements(semantic_model_lookup: SemanticModelLookup) -> None: # no assert semantic_model_lookup.get_measure(measure_reference=measure_reference).reference == measure_reference -def test_get_semantic_models_for_measure(semantic_model_lookup: SemanticModelLookup) -> None: # noqa: D - bookings_sources = semantic_model_lookup.get_semantic_models_for_measure(MeasureReference(element_name="bookings")) - assert len(bookings_sources) == 1 - assert bookings_sources[0].name == "bookings_source" - - views_sources = semantic_model_lookup.get_semantic_models_for_measure(MeasureReference(element_name="views")) - assert len(views_sources) == 1 - assert views_sources[0].name == "views_source" - - listings_sources = semantic_model_lookup.get_semantic_models_for_measure(MeasureReference(element_name="listings")) - assert len(listings_sources) == 1 - assert listings_sources[0].name == "listings_latest" +def test_get_semantic_model_for_measure(semantic_model_lookup: SemanticModelLookup) -> None: # noqa: D + bookings_source = semantic_model_lookup.get_semantic_model_for_measure(MeasureReference(element_name="bookings")) + assert bookings_source + assert bookings_source.name == "bookings_source" + + views_source = semantic_model_lookup.get_semantic_model_for_measure(MeasureReference(element_name="views")) + assert views_source + assert views_source.name == "views_source" + + listings_source = semantic_model_lookup.get_semantic_model_for_measure(MeasureReference(element_name="listings")) + assert listings_source + assert listings_source.name == "listings_latest" def test_elements_for_metric( # noqa: D