diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py index 19d80eb049..ca5227ae92 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element.py @@ -12,6 +12,7 @@ from dbt_semantic_interfaces.references import ( DimensionReference, EntityReference, + LinkableElementReference, MetricReference, SemanticModelReference, ) @@ -171,6 +172,17 @@ def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]: return sorted(semantic_model_references, key=lambda reference: reference.semantic_model_name) +# TODO: add to DSI +@dataclass(frozen=True) +class GroupByMetricReference(LinkableElementReference): + """Represents a group by metric. + + Different from MetricReference because it inherits linkable element attributes. + """ + + pass + + @dataclass(frozen=True) class LinkableMetric(LinkableElement, SerializableDataclass): """Describes how a metric can be realized by joining based on entity links.""" @@ -199,9 +211,13 @@ def path_key(self) -> ElementPathKey: # noqa: D102 ) @property - def reference(self) -> MetricReference: # noqa: D102 + def metric_reference(self) -> MetricReference: # noqa: D102 return self.join_path.metric_subquery_join_path_element.metric_reference + @property + def reference(self) -> GroupByMetricReference: # noqa: D102 + return GroupByMetricReference(self.metric_reference.element_name) + @property def join_by_semantic_model(self) -> Optional[SemanticModelReference]: # noqa: D102 return self.join_path.last_semantic_model_reference @@ -209,11 +225,16 @@ def join_by_semantic_model(self) -> Optional[SemanticModelReference]: # noqa: D @property @override def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]: - semantic_model_references = set() - for join_path_item in ( - self.join_path.semantic_model_join_path.path_elements if self.join_path.semantic_model_join_path else () - ): - semantic_model_references.add(join_path_item.semantic_model_reference) + """Semantic models needed to build and join to this LinkableMetric. + + Includes semantic models used in the join paths for both the inner and outer queries (if applicable), + plus the semantic models the metric's measure(s) originated from. + """ + semantic_model_references = set(self.join_path.metric_subquery_join_path_element.derived_from_semantic_models) + if self.join_path.semantic_model_join_path: + semantic_model_references.update(self.join_path.semantic_model_join_path.derived_from_semantic_models) + if self.metric_to_entity_join_path: + semantic_model_references.update(self.metric_to_entity_join_path.derived_from_semantic_models) return sorted(semantic_model_references, key=lambda reference: reference.semantic_model_name) @@ -274,6 +295,14 @@ def from_single_element( ) ) + @property + def derived_from_semantic_models(self) -> Sequence[SemanticModelReference]: + """Unique semantic models used in this join path.""" + return sorted( + [path_element.semantic_model_reference for path_element in self.path_elements], + key=lambda reference: reference.semantic_model_name, + ) + @dataclass(frozen=True) class MetricSubqueryJoinPathElement: @@ -281,17 +310,25 @@ class MetricSubqueryJoinPathElement: Args: metric_reference: The metric that's aggregated in the subquery. + derived_from_semantic_models: The semantic models that the measure's input metrics are defined in. join_on_entity: The entity that the metric is grouped by in the subquery. This will be updated in V2 to allow a list of entitites & dimensions. + entity_links: Sequence of entities joined to get from a metric source to the `join_on_entity`. metric_to_entity_join_path: Describes the join path used in the subquery to join the metric to the `join_on_entity`. Can be none if all required elements are defined in the same semantic model. """ metric_reference: MetricReference + derived_from_semantic_models: Tuple[SemanticModelReference, ...] join_on_entity: EntityReference entity_links: Tuple[EntityReference, ...] metric_to_entity_join_path: Optional[SemanticModelJoinPath] = None + def __post_init__(self) -> None: # noqa: D105 + assert ( + self.derived_from_semantic_models + ), "There must be at least one semantic model from which the metric is derived." + @dataclass(frozen=True) class SemanticModelToMetricSubqueryJoinPath: diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py index 525ab681de..b39a591843 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py @@ -192,10 +192,15 @@ def __init__( continue metric_reference = MetricReference(metric.name) linkable_element_set_for_metric = self.get_linkable_elements_for_metrics([metric_reference]) + defined_from_semantic_models = tuple( + self._semantic_model_lookup.get_semantic_model_for_measure(input_measure.measure_reference).reference + for input_measure in metric.input_measures + ) for linkable_entities in linkable_element_set_for_metric.path_key_to_linkable_entities.values(): for linkable_entity in linkable_entities: metric_subquery_join_path_element = MetricSubqueryJoinPathElement( metric_reference=metric_reference, + derived_from_semantic_models=defined_from_semantic_models, join_on_entity=linkable_entity.reference, entity_links=linkable_entity.entity_links, metric_to_entity_join_path=( diff --git a/metricflow-semantics/metricflow_semantics/specs/spec_classes.py b/metricflow-semantics/metricflow_semantics/specs/spec_classes.py index de56d7065f..546f07f58d 100644 --- a/metricflow-semantics/metricflow_semantics/specs/spec_classes.py +++ b/metricflow-semantics/metricflow_semantics/specs/spec_classes.py @@ -38,7 +38,12 @@ from metricflow_semantics.aggregation_properties import AggregationState from metricflow_semantics.collection_helpers.dedupe import ordered_dedupe from metricflow_semantics.collection_helpers.merger import Mergeable -from metricflow_semantics.model.semantics.linkable_element import ElementPathKey, LinkableElement, LinkableElementType +from metricflow_semantics.model.semantics.linkable_element import ( + ElementPathKey, + GroupByMetricReference, + LinkableElement, + LinkableElementType, +) from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from metricflow_semantics.sql.sql_column_type import SqlColumnType @@ -728,17 +733,6 @@ class JoinToTimeSpineDescription: offset_to_grain: Optional[TimeGranularity] -# TODO: add to DSI -@dataclass(frozen=True) -class GroupByMetricReference(LinkableElementReference): - """Represents a group by metric. - - Different from MetricReference because it inherits linkable element attributes. - """ - - pass - - @dataclass(frozen=True) class GroupByMetricSpec(LinkableInstanceSpec, SerializableDataclass): """Metric used in group by or where filter. diff --git a/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_linkable_element_set.py b/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_linkable_element_set.py index fecf36d400..624f3c5962 100644 --- a/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_linkable_element_set.py +++ b/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_linkable_element_set.py @@ -49,6 +49,7 @@ _base_entity_reference = EntityReference(element_name="base_entity") _base_dimension_reference = DimensionReference(element_name="base_dimension") _time_dimension_reference = TimeDimensionReference(element_name="time_dimension") +_metric_semantic_model = SemanticModelReference(semantic_model_name="metric_semantic_model") _base_metric_reference = MetricReference(element_name="base_metric") @@ -136,6 +137,7 @@ join_path=SemanticModelToMetricSubqueryJoinPath( metric_subquery_join_path_element=MetricSubqueryJoinPathElement( metric_reference=_base_metric_reference, + derived_from_semantic_models=(_metric_semantic_model,), join_on_entity=_base_entity_reference, entity_links=(_base_entity_reference,), ), @@ -146,6 +148,7 @@ join_path=SemanticModelToMetricSubqueryJoinPath( metric_subquery_join_path_element=MetricSubqueryJoinPathElement( metric_reference=MetricReference(AMBIGUOUS_NAME), + derived_from_semantic_models=(_metric_semantic_model,), join_on_entity=_base_entity_reference, entity_links=(_base_entity_reference,), ), @@ -157,6 +160,7 @@ join_path=SemanticModelToMetricSubqueryJoinPath( metric_subquery_join_path_element=MetricSubqueryJoinPathElement( metric_reference=MetricReference(AMBIGUOUS_NAME), + derived_from_semantic_models=(_metric_semantic_model,), join_on_entity=_base_entity_reference, entity_links=(_base_entity_reference,), ), @@ -547,6 +551,8 @@ def linkable_set() -> LinkableElementSet: # noqa: D103 entity_2_source = SemanticModelReference("entity_2_source") entity_3 = EntityReference("entity_3") entity_3_source = SemanticModelReference("entity_3_source") + entity_4 = EntityReference("entity_4") + entity_4_source = SemanticModelReference("entity_4_source") return LinkableElementSet( path_key_to_linkable_dimensions={ @@ -626,8 +632,19 @@ def linkable_set() -> LinkableElementSet: # noqa: D103 join_path=SemanticModelToMetricSubqueryJoinPath( metric_subquery_join_path_element=MetricSubqueryJoinPathElement( metric_reference=MetricReference("metric_element"), + derived_from_semantic_models=(_metric_semantic_model,), join_on_entity=entity_3, - entity_links=(entity_3,), + entity_links=(entity_3, entity_4), + metric_to_entity_join_path=SemanticModelJoinPath( + path_elements=( + SemanticModelJoinPathElement( + semantic_model_reference=entity_4_source, join_on_entity=entity_4 + ), + SemanticModelJoinPathElement( + semantic_model_reference=entity_3_source, join_on_entity=entity_3 + ), + ) + ), ), semantic_model_join_path=SemanticModelJoinPath.from_single_element( semantic_model_reference=entity_3_source, join_on_entity=entity_3 @@ -641,14 +658,15 @@ def linkable_set() -> LinkableElementSet: # noqa: D103 def test_derived_semantic_models(linkable_set: LinkableElementSet) -> None: """Tests that the semantic models in the element set are returned via `derived_from_semantic_models`.""" - # TODO: add metric source for linkable metrics assert tuple(linkable_set.derived_from_semantic_models) == ( SemanticModelReference(semantic_model_name="dimension_source"), SemanticModelReference(semantic_model_name="entity_0_source"), SemanticModelReference(semantic_model_name="entity_1_source"), SemanticModelReference(semantic_model_name="entity_2_source"), SemanticModelReference(semantic_model_name="entity_3_source"), + SemanticModelReference(semantic_model_name="entity_4_source"), SemanticModelReference(semantic_model_name="entity_source"), + SemanticModelReference(semantic_model_name="metric_semantic_model"), SemanticModelReference(semantic_model_name="time_dimension_source"), ) diff --git a/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py b/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py index e643bd26c1..4d2557c8ab 100644 --- a/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py +++ b/metricflow-semantics/tests_metricflow_semantics/model/test_where_filter_spec.py @@ -536,6 +536,7 @@ def test_metric_in_filter( # noqa: D103 join_path=SemanticModelToMetricSubqueryJoinPath( metric_subquery_join_path_element=MetricSubqueryJoinPathElement( metric_reference=MetricReference("bookings"), + derived_from_semantic_models=(SemanticModelReference("bookings"),), join_on_entity=EntityReference("listing"), entity_links=(EntityReference("listing"),), )