diff --git a/metricflow/model/semantics/linkable_spec_resolver.py b/metricflow/model/semantics/linkable_spec_resolver.py index 777efab43d..f1c19b2cb4 100644 --- a/metricflow/model/semantics/linkable_spec_resolver.py +++ b/metricflow/model/semantics/linkable_spec_resolver.py @@ -507,6 +507,7 @@ def __init__( linkable_sets_for_measure.append( self._get_linkable_element_set_for_measure(measure).filter( with_any_of=LinkableElementProperties.all_properties(), + # Use filter() here becasue `without_all_of` param is only available on that method. without_all_of=frozenset( { LinkableElementProperties.METRIC_TIME, @@ -576,6 +577,27 @@ def _get_semantic_model_for_measure(self, measure_reference: MeasureReference) - ) return semantic_models_where_measure_was_found[0] + def _get_joinable_metrics_for_semantic_model(self, semantic_model: SemanticModel) -> LinkableElementSet: + linkable_metrics = [] + for metric_ref in self._joinable_metrics_for_semantic_models.get(semantic_model.reference, set()): + for entity_link in [entity.reference for entity in semantic_model.entities]: + linkable_metrics.append( + LinkableMetric( + element_name=metric_ref.element_name, + join_by_semantic_model=semantic_model.reference, + entity_links=(entity_link,), + properties=frozenset({LinkableElementProperties.METRIC}), + join_path=(), + ) + ) + return LinkableElementSet( + path_key_to_linkable_dimensions={}, + path_key_to_linkable_entities={}, + path_key_to_linkable_metrics={ + linkable_metric.path_key: (linkable_metric,) for linkable_metric in linkable_metrics + }, + ) + def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> LinkableElementSet: """Gets the elements in the semantic model, without requiring any joins. @@ -699,7 +721,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference measure_agg_time_dimension_reference = measure_semantic_model.checked_agg_time_dimension_for_measure( measure_reference=measure_reference ) - defined_granularity = ValidLinkableSpecResolver._get_time_granularity_for_dimension( + defined_granularity = self._get_time_granularity_for_dimension( semantic_model=measure_semantic_model, time_dimension_reference=measure_agg_time_dimension_reference, ) @@ -838,12 +860,14 @@ def _get_linkable_element_set_for_measure( measure_semantic_model = self._get_semantic_model_for_measure(measure_reference) elements_in_semantic_model = self._get_elements_in_semantic_model(measure_semantic_model) + metrics_linked_to_semantic_model = self._get_joinable_metrics_for_semantic_model(measure_semantic_model) metric_time_elements = self._get_metric_time_elements(measure_reference) joined_elements = self._get_joined_elements(measure_semantic_model) return LinkableElementSet.merge_by_path_key( ( elements_in_semantic_model, + metrics_linked_to_semantic_model, metric_time_elements, joined_elements, )