Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: Generate correct instances for JoinOnEntitiesNode #1499

Merged
merged 10 commits into from
Nov 5, 2024
  •  
  •  
  •  
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,11 @@ populate-persistent-source-schemas:
.PHONY: test-snap
test-snap:
make test ADDITIONAL_PYTEST_OPTIONS=--overwrite-snapshots

.PHONY: testx
testx:
make test ADDITIONAL_PYTEST_OPTIONS=-x

.PHONY: testx-snap
testx-snap:
make test ADDITIONAL_PYTEST_OPTIONS='-x --overwrite-snapshots'
202 changes: 194 additions & 8 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, List, Tuple, TypeVar
from dataclasses import dataclass, field
from typing import Generic, List, Sequence, Tuple, TypeVar

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference
from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference
from typing_extensions import override

from metricflow_semantics.aggregation_properties import AggregationState
from metricflow_semantics.specs.column_assoc import ColumnAssociation
from metricflow_semantics.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
Expand All @@ -20,6 +21,7 @@
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

# Type for the specification used in the instance.
SpecT = TypeVar("SpecT", bound=InstanceSpec)
Expand All @@ -38,6 +40,8 @@ class MdoInstance(ABC, Generic[SpecT]):
"""

# The columns associated with this instance.
# TODO: if poss, remove this and instead add a method that resolves this from the spec + column association resolver
# (ensure we're using consistent logic everywhere so this bug doesn't happen again)
associated_columns: Tuple[ColumnAssociation, ...]
# The spec that describes this instance.
spec: SpecT
Expand All @@ -48,6 +52,20 @@ def associated_column(self) -> ColumnAssociation:
assert len(self.associated_columns) == 1
return self.associated_columns[0]

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT:
"""See Visitable."""
raise NotImplementedError()


class LinkableInstance(MdoInstance, Generic[SpecT]):
"""An MdoInstance whose spec is linkable (i.e., it can have entity links)."""

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> MdoInstance:
"""Add entity link to the underlying spec and associated column."""
raise NotImplementedError()


# Instances for the major metric object types

Expand Down Expand Up @@ -82,44 +100,109 @@ class MeasureInstance(MdoInstance[MeasureSpec], SemanticModelElementInstance):
spec: MeasureSpec
aggregation_state: AggregationState

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_measure_instance(self)


@dataclass(frozen=True)
class DimensionInstance(MdoInstance[DimensionSpec], SemanticModelElementInstance): # noqa: D101
class DimensionInstance(LinkableInstance[DimensionSpec], SemanticModelElementInstance): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: DimensionSpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_dimension_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> DimensionInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return DimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class TimeDimensionInstance(MdoInstance[TimeDimensionSpec], SemanticModelElementInstance): # noqa: D101
class TimeDimensionInstance(LinkableInstance[TimeDimensionSpec], SemanticModelElementInstance): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: TimeDimensionSpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_time_dimension_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> TimeDimensionInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return TimeDimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class EntityInstance(MdoInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
class EntityInstance(LinkableInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: EntitySpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_entity_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> EntityInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return EntityInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class GroupByMetricInstance(MdoInstance[GroupByMetricSpec], SerializableDataclass): # noqa: D101
class GroupByMetricInstance(LinkableInstance[GroupByMetricSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: GroupByMetricSpec
defined_from: MetricModelReference

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_group_by_metric_instance(self)

def with_entity_prefix(
self, entity_prefix: EntityReference, column_association_resolver: ColumnAssociationResolver
) -> GroupByMetricInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return GroupByMetricInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)


@dataclass(frozen=True)
class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetricSpec
defined_from: MetricModelReference

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_metric_instance(self)


@dataclass(frozen=True)
class MetadataInstance(MdoInstance[MetadataSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: MetadataSpec

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_metadata_instance(self)


# Output type of transform function
TransformOutputT = TypeVar("TransformOutputT")
Expand Down Expand Up @@ -226,3 +309,106 @@ def as_tuple(self) -> Tuple[MdoInstance, ...]: # noqa: D102
+ self.metric_instances
+ self.metadata_instances
)

@property
def linkable_instances(self) -> Tuple[LinkableInstance, ...]: # noqa: D102
return (
self.dimension_instances
+ self.time_dimension_instances
+ self.entity_instances
+ self.group_by_metric_instances
)


class InstanceVisitor(Generic[VisitorOutputT], ABC):
"""Visitor for the Instance classes."""

@abstractmethod
def visit_measure_instance(self, measure_instance: MeasureInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_dimension_instance(self, dimension_instance: DimensionInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_time_dimension_instance( # noqa: D102
self, time_dimension_instance: TimeDimensionInstance
) -> VisitorOutputT:
raise NotImplementedError

@abstractmethod
def visit_entity_instance(self, entity_instance: EntityInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_group_by_metric_instance( # noqa: D102
self, group_by_metric_instance: GroupByMetricInstance
) -> VisitorOutputT:
raise NotImplementedError

@abstractmethod
def visit_metric_instance(self, metric_instance: MetricInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_metadata_instance(self, metadata_instance: MetadataInstance) -> VisitorOutputT: # noqa: D102
raise NotImplementedError


@dataclass
class _GroupInstanceByTypeVisitor(InstanceVisitor[None]):
"""Group instances by type into an `InstanceSet`."""

metric_instances: List[MetricInstance] = field(default_factory=list)
measure_instances: List[MeasureInstance] = field(default_factory=list)
dimension_instances: List[DimensionInstance] = field(default_factory=list)
entity_instances: List[EntityInstance] = field(default_factory=list)
time_dimension_instances: List[TimeDimensionInstance] = field(default_factory=list)
group_by_metric_instances: List[GroupByMetricInstance] = field(default_factory=list)
metadata_instances: List[MetadataInstance] = field(default_factory=list)

@override
def visit_measure_instance(self, measure_instance: MeasureInstance) -> None:
self.measure_instances.append(measure_instance)

@override
def visit_dimension_instance(self, dimension_instance: DimensionInstance) -> None:
self.dimension_instances.append(dimension_instance)

@override
def visit_time_dimension_instance(self, time_dimension_instance: TimeDimensionInstance) -> None:
self.time_dimension_instances.append(time_dimension_instance)

@override
def visit_entity_instance(self, entity_instance: EntityInstance) -> None:
self.entity_instances.append(entity_instance)

@override
def visit_group_by_metric_instance(self, group_by_metric_instance: GroupByMetricInstance) -> None:
self.group_by_metric_instances.append(group_by_metric_instance)

@override
def visit_metric_instance(self, metric_instance: MetricInstance) -> None:
self.metric_instances.append(metric_instance)

@override
def visit_metadata_instance(self, metadata_instance: MetadataInstance) -> None:
self.metadata_instances.append(metadata_instance)


def group_instances_by_type(instances: Sequence[MdoInstance]) -> InstanceSet:
"""Groups a sequence of instances by type."""
grouper = _GroupInstanceByTypeVisitor()
for instance in instances:
instance.accept(grouper)

return InstanceSet(
metric_instances=tuple(grouper.metric_instances),
measure_instances=tuple(grouper.measure_instances),
dimension_instances=tuple(grouper.dimension_instances),
entity_instances=tuple(grouper.entity_instances),
time_dimension_instances=tuple(grouper.time_dimension_instances),
group_by_metric_instances=tuple(grouper.group_by_metric_instances),
metadata_instances=tuple(grouper.metadata_instances),
)
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ def element_path_key(self) -> ElementPathKey:
return ElementPathKey(
element_name=self.element_name, element_type=LinkableElementType.DIMENSION, entity_links=self.entity_links
)

def with_entity_prefix(self, entity_prefix: EntityReference) -> DimensionSpec: # noqa: D102
return DimensionSpec(element_name=self.element_name, entity_links=(entity_prefix,) + self.entity_links)
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def element_path_key(self) -> ElementPathKey:
element_name=self.element_name, element_type=LinkableElementType.ENTITY, entity_links=self.entity_links
)

def with_entity_prefix(self, entity_prefix: EntityReference) -> EntitySpec: # noqa: D102
return EntitySpec(element_name=self.element_name, entity_links=(entity_prefix,) + self.entity_links)


@dataclass(frozen=True)
class LinklessEntitySpec(EntitySpec, SerializableDataclass):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,10 @@ def element_path_key(self) -> ElementPathKey:
return ElementPathKey(
element_name=self.element_name, element_type=LinkableElementType.METRIC, entity_links=self.entity_links
)

def with_entity_prefix(self, entity_prefix: EntityReference) -> GroupByMetricSpec: # noqa: D102
return GroupByMetricSpec(
element_name=self.element_name,
entity_links=(entity_prefix,) + self.entity_links,
metric_subquery_entity_links=self.metric_subquery_entity_links,
)
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,10 @@ def element_path_key(self) -> ElementPathKey:
"""Return the ElementPathKey representation of the LinkableInstanceSpec subtype."""
raise NotImplementedError()

@abstractmethod
def with_entity_prefix(self, entity_prefix: EntityReference) -> LinkableInstanceSpec:
"""Add the selected entity prefix to the start of the entity links."""
raise NotImplementedError()


SelfTypeT = TypeVar("SelfTypeT", bound="LinkableInstanceSpec")
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,12 @@ def generate_possible_specs_for_time_dimension(
@property
def is_metric_time(self) -> bool: # noqa: D102
return self.element_name == METRIC_TIME_ELEMENT_NAME

def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=(entity_prefix,) + self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
)
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def _build_plan(
def _optimize_plan(self, plan: DataflowPlan, optimizations: FrozenSet[DataflowPlanOptimization]) -> DataflowPlan:
optimizer_factory = DataflowPlanOptimizerFactory(self._node_data_set_resolver)
for optimizer in optimizer_factory.get_optimizers(optimizations):
logger.debug(LazyFormat(lambda: f"Applying {optimizer.__class__.__name__}"))
logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}"))
try:
plan = optimizer.optimize(plan)
logger.debug(
LazyFormat(
lambda: f"After applying {optimizer.__class__.__name__}, the dataflow plan is:\n"
lambda: f"After applying optimizer {optimizer.__class__.__name__}, the dataflow plan is:\n"
f"{indent(plan.structure_text())}"
)
)
Expand Down
Loading
Loading