Skip to content

Commit

Permalink
Bug fix: Generate correct instances for JoinOnEntitiesNode (#1499)
Browse files Browse the repository at this point in the history
There is an internal bug in the dataflow to SQL logic for the
`JoinOnEntitiesNode` that does not appear to impact production, but does
impact work I'm doing up the stack. The instances generated in the node
contained column associations that were not correct. They referenced the
old column associations (from the parent dataset), which did not include
the entity links that were added in this node. These incorrect column
associations did not appear to be used for anything, and the select
columns were generated from the specs instead (which is correct), so
that's why this bug flew under the radar.

This PR removes the logic that was causing the bug (the
`AddLinkToLinkableElements` instance converter). Instead, it adds logic
to loop through the parent instances and build new instances & select
columns simultaneously.

This change left a lot of boilerplate & duplicate code in the function,
so I moved a lot of that out to helpers for the instances and specs.

This also adds a lot of simplification for the `JoinOnEntitiesNode`.
It's one of our oldest nodes and was quite verbose.

I recommend reviewing by commit. The snapshot changes are primarily
changes to the column order and some node identifiers. There should be
no actual behavior changes here.
  • Loading branch information
courtneyholcomb authored Nov 5, 2024
1 parent 272f738 commit bfcc13b
Show file tree
Hide file tree
Showing 358 changed files with 3,065 additions and 2,950 deletions.
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ populate-persistent-source-schemas:
.PHONY: test-snap
test-snap:
make test ADDITIONAL_PYTEST_OPTIONS=--overwrite-snapshots

.PHONY: testx
testx:
make test ADDITIONAL_PYTEST_OPTIONS=-x

.PHONY: testx-snap
testx-snap:
make test ADDITIONAL_PYTEST_OPTIONS='-x --overwrite-snapshots'
202 changes: 194 additions & 8 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, List, Tuple, TypeVar
from dataclasses import dataclass, field
from typing import Generic, List, Sequence, Tuple, TypeVar

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference
from typing_extensions import override

from metricflow_semantics.aggregation_properties import AggregationState
from metricflow_semantics.specs.column_assoc import ColumnAssociation
from metricflow_semantics.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
Expand All @@ -20,6 +21,7 @@
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

# Type for the specification used in the instance.
SpecT = TypeVar("SpecT", bound=InstanceSpec)
Expand All @@ -38,6 +40,8 @@ class MdoInstance(ABC, Generic[SpecT]):
"""

# The columns associated with this instance.
# TODO: if poss, remove this and instead add a method that resolves this from the spec + column association resolver
# (ensure we're using consistent logic everywhere so this bug doesn't happen again)
associated_columns: Tuple[ColumnAssociation, ...]
# The spec that describes this instance.
spec: SpecT
Expand All @@ -48,6 +52,20 @@ def associated_column(self) -> ColumnAssociation:
assert len(self.associated_columns) == 1
return self.associated_columns[0]

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT:
"""See Visitable."""
raise NotImplementedError()


class LinkableInstance(MdoInstance, Generic[SpecT]):
"""An MdoInstance whose spec is linkable (i.e., it can have entity links)."""

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> MdoInstance:
"""Add entity link to the underlying spec and associated column."""
raise NotImplementedError()


# Instances for the major metric object types

Expand Down Expand Up @@ -82,44 +100,109 @@ class MeasureInstance(MdoInstance[MeasureSpec], SemanticModelElementInstance):
spec: MeasureSpec
aggregation_state: AggregationState

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_measure_instance(self)


@dataclass(frozen=True)
class DimensionInstance(MdoInstance[DimensionSpec], SemanticModelElementInstance): # noqa: D101
class DimensionInstance(LinkableInstance[DimensionSpec], SemanticModelElementInstance): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: DimensionSpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_dimension_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> DimensionInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return DimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class TimeDimensionInstance(MdoInstance[TimeDimensionSpec], SemanticModelElementInstance): # noqa: D101
class TimeDimensionInstance(LinkableInstance[TimeDimensionSpec], SemanticModelElementInstance): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: TimeDimensionSpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_time_dimension_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> TimeDimensionInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return TimeDimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class EntityInstance(MdoInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
class EntityInstance(LinkableInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: EntitySpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_entity_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> EntityInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return EntityInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class GroupByMetricInstance(MdoInstance[GroupByMetricSpec], SerializableDataclass): # noqa: D101
class GroupByMetricInstance(LinkableInstance[GroupByMetricSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: GroupByMetricSpec
defined_from: MetricModelReference

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_group_by_metric_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> GroupByMetricInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return GroupByMetricInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetricSpec
defined_from: MetricModelReference

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_metric_instance(self)


@dataclass(frozen=True)
class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetadataSpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_metadata_instance(self)


# Output type of transform function
TransformOutputT = TypeVar("TransformOutputT")
Expand Down Expand Up @@ -226,3 +309,106 @@ def as_tuple(self) -> Tuple[MdoInstance, ...]: # noqa: D102
+ self.metric_instances
+ self.metadata_instances
)

@property
def linkable_instances(self) -> Tuple[LinkableInstance, ...]: # noqa: D102
return (
self.dimension_instances
+ self.time_dimension_instances
+ self.entity_instances
+ self.group_by_metric_instances
)


class InstanceVisitor(Generic[VisitorOutputT], ABC):
"""Visitor for the Instance classes."""

@abstractmethod
def visit_measure_instance(self, measure_instance: MeasureInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_dimension_instance(self, dimension_instance: DimensionInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_time_dimension_instance( # noqa: D102
self, time_dimension_instance: TimeDimensionInstance
) -> VisitorOutputT:
raise NotImplementedError

@abstractmethod
def visit_entity_instance(self, entity_instance: EntityInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_group_by_metric_instance( # noqa: D102
self, group_by_metric_instance: GroupByMetricInstance
) -> VisitorOutputT:
raise NotImplementedError

@abstractmethod
def visit_metric_instance(self, metric_instance: MetricInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_metadata_instance(self, metadata_instance: MetadataInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


@dataclass
class _GroupInstanceByTypeVisitor(InstanceVisitor[None]):
"""Group instances by type into an `InstanceSet`."""

metric_instances: List[MetricInstance] = field(default_factory=list)
measure_instances: List[MeasureInstance] = field(default_factory=list)
dimension_instances: List[DimensionInstance] = field(default_factory=list)
entity_instances: List[EntityInstance] = field(default_factory=list)
time_dimension_instances: List[TimeDimensionInstance] = field(default_factory=list)
group_by_metric_instances: List[GroupByMetricInstance] = field(default_factory=list)
metadata_instances: List[MetadataInstance] = field(default_factory=list)

@override
def visit_measure_instance(self, measure_instance: MeasureInstance) -> None:
self.measure_instances.append(measure_instance)

@override
def visit_dimension_instance(self, dimension_instance: DimensionInstance) -> None:
self.dimension_instances.append(dimension_instance)

@override
def visit_time_dimension_instance(self, time_dimension_instance: TimeDimensionInstance) -> None:
self.time_dimension_instances.append(time_dimension_instance)

@override
def visit_entity_instance(self, entity_instance: EntityInstance) -> None:
self.entity_instances.append(entity_instance)

@override
def visit_group_by_metric_instance(self, group_by_metric_instance: GroupByMetricInstance) -> None:
self.group_by_metric_instances.append(group_by_metric_instance)

@override
def visit_metric_instance(self, metric_instance: MetricInstance) -> None:
self.metric_instances.append(metric_instance)

@override
def visit_metadata_instance(self, metadata_instance: MetadataInstance) -> None:
self.metadata_instances.append(metadata_instance)


def group_instances_by_type(instances: Sequence[MdoInstance]) -> InstanceSet:
"""Groups a sequence of instances by type."""
grouper = _GroupInstanceByTypeVisitor()
for instance in instances:
instance.accept(grouper)

return InstanceSet(
metric_instances=tuple(grouper.metric_instances),
measure_instances=tuple(grouper.measure_instances),
dimension_instances=tuple(grouper.dimension_instances),
entity_instances=tuple(grouper.entity_instances),
time_dimension_instances=tuple(grouper.time_dimension_instances),
group_by_metric_instances=tuple(grouper.group_by_metric_instances),
metadata_instances=tuple(grouper.metadata_instances),
)
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ def element_path_key(self) -> ElementPathKey:
return ElementPathKey(
element_name=self.element_name, element_type=LinkableElementType.DIMENSION, entity_links=self.entity_links
)

def with_entity_prefix(self, entity_prefix: EntityReference) -> DimensionSpec: # noqa: D102
return DimensionSpec(element_name=self.element_name, entity_links=(entity_prefix,) + self.entity_links)
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def element_path_key(self) -> ElementPathKey:
element_name=self.element_name, element_type=LinkableElementType.ENTITY, entity_links=self.entity_links
)

def with_entity_prefix(self, entity_prefix: EntityReference) -> EntitySpec: # noqa: D102
return EntitySpec(element_name=self.element_name, entity_links=(entity_prefix,) + self.entity_links)


@dataclass(frozen=True)
class LinklessEntitySpec(EntitySpec, SerializableDataclass):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,10 @@ def element_path_key(self) -> ElementPathKey:
return ElementPathKey(
element_name=self.element_name, element_type=LinkableElementType.METRIC, entity_links=self.entity_links
)

def with_entity_prefix(self, entity_prefix: EntityReference) -> GroupByMetricSpec: # noqa: D102
return GroupByMetricSpec(
element_name=self.element_name,
entity_links=(entity_prefix,) + self.entity_links,
metric_subquery_entity_links=self.metric_subquery_entity_links,
)
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,10 @@ def element_path_key(self) -> ElementPathKey:
"""Return the ElementPathKey representation of the LinkableInstanceSpec subtype."""
raise NotImplementedError()

@abstractmethod
def with_entity_prefix(self, entity_prefix: EntityReference) -> LinkableInstanceSpec:
"""Add the selected entity prefix to the start of the entity links."""
raise NotImplementedError()


SelfTypeT = TypeVar("SelfTypeT", bound="LinkableInstanceSpec")
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,12 @@ def generate_possible_specs_for_time_dimension(
@property
def is_metric_time(self) -> bool: # noqa: D102
return self.element_name == METRIC_TIME_ELEMENT_NAME

def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=(entity_prefix,) + self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
)
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def _build_plan(
def _optimize_plan(self, plan: DataflowPlan, optimizations: FrozenSet[DataflowPlanOptimization]) -> DataflowPlan:
optimizer_factory = DataflowPlanOptimizerFactory(self._node_data_set_resolver)
for optimizer in optimizer_factory.get_optimizers(optimizations):
logger.debug(LazyFormat(lambda: f"Applying {optimizer.__class__.__name__}"))
logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}"))
try:
plan = optimizer.optimize(plan)
logger.debug(
LazyFormat(
lambda: f"After applying {optimizer.__class__.__name__}, the dataflow plan is:\n"
lambda: f"After applying optimizer {optimizer.__class__.__name__}, the dataflow plan is:\n"
f"{indent(plan.structure_text())}"
)
)
Expand Down
Loading

0 comments on commit bfcc13b

Please sign in to comment.