diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index ac2cb77c7..67d77d2ab 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -498,17 +498,15 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet: if join_on_entity: # Remove any instances that already have the join_on_entity as the leading link. This will prevent a duplicate # entity link when we add it in the next step. - right_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(join_on_entity).transform( - right_data_set.instance_set - ) + right_instance_set_filtered = FilterLinkableInstancesWithLeadingLink( + join_on_entity.reference + ).transform(right_data_set.instance_set) # After the right data set is joined, update the entity links to indicate that joining on the entity was # required to reach the spec. If the "country" dimension was joined and "user_id" is the join_on_entity, # then the joined data set should have the "user__country" dimension. new_instances: Tuple[MdoInstance, ...] = () for original_instance in right_instance_set_filtered.linkable_instances: - if original_instance.spec == join_on_entity: - continue new_instance = original_instance.with_entity_prefix( join_on_entity.reference, column_association_resolver=self._column_association_resolver ) diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index 0d230f3bc..c6e25835b 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -8,7 +8,7 @@ from itertools import chain from typing import Dict, List, Optional, Sequence, Tuple -from dbt_semantic_interfaces.references import MetricReference, SemanticModelReference +from dbt_semantic_interfaces.references import EntityReference, MetricReference, SemanticModelReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity @@ -29,7 +29,6 @@ from metricflow_semantics.model.semantics.metric_lookup import MetricLookup from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver -from metricflow_semantics.specs.entity_spec import LinklessEntitySpec from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec from metricflow_semantics.specs.spec_set import InstanceSpecSet @@ -390,15 +389,14 @@ class FilterLinkableInstancesWithLeadingLink(InstanceSetTransform[InstanceSet]): e.g. Remove "listing__country" if the specified link is "listing". """ - def __init__(self, entity_link: LinklessEntitySpec) -> None: + def __init__(self, entity_link: EntityReference) -> None: """Remove elements with this link as the first element in "entity_links".""" self._entity_link = entity_link def _should_pass(self, linkable_spec: LinkableInstanceSpec) -> bool: - return ( - len(linkable_spec.entity_links) == 0 - or LinklessEntitySpec.from_reference(linkable_spec.entity_links[0]) != self._entity_link - ) + if len(linkable_spec.entity_links) == 0: + return not linkable_spec.reference == self._entity_link + return linkable_spec.entity_links[0] != self._entity_link def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102 # Normal to not filter anything if the instance set has no instances with links.