Skip to content

Commit

Permalink
Use helper function to find matching instances in dataset (#1538)
Browse files Browse the repository at this point in the history
Clean up. Just reducing some repeated code in dataflow to SQL logic.
  • Loading branch information
courtneyholcomb authored Dec 9, 2024
1 parent aa6ec15 commit dcf17cf
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 37 deletions.
4 changes: 3 additions & 1 deletion metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class MdoInstance(ABC, Generic[SpecT]):
@property
def associated_column(self) -> ColumnAssociation:
"""Helper for getting the associated column until support for multiple associated columns is added."""
assert len(self.associated_columns) == 1
assert (
len(self.associated_columns) == 1
), f"Expected exactly one column for {self.__class__.__name__}, but got {self.associated_columns}"
return self.associated_columns[0]

def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT:
Expand Down
41 changes: 25 additions & 16 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.instances import EntityInstance, InstanceSet
from metricflow_semantics.instances import EntityInstance, InstanceSet, TimeDimensionInstance
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.specs.column_assoc import ColumnAssociation
from metricflow_semantics.specs.dimension_spec import DimensionSpec
Expand Down Expand Up @@ -122,30 +122,39 @@ def column_association_for_dimension(

return column_associations_to_return[0]

def column_association_for_time_dimension(
self,
time_dimension_spec: TimeDimensionSpec,
) -> ColumnAssociation:
"""Given the name of the time dimension, return the set of columns associated with it in the data set."""
def instances_for_time_dimensions(
self, time_dimension_specs: Sequence[TimeDimensionSpec]
) -> Tuple[TimeDimensionInstance, ...]:
"""Return the instances associated with these specs in the data set."""
time_dimension_specs_set = set(time_dimension_specs)
matching_instances = 0
column_associations_to_return = None
instances_to_return: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in self.instance_set.time_dimension_instances:
if time_dimension_instance.spec == time_dimension_spec:
column_associations_to_return = time_dimension_instance.associated_columns
if time_dimension_instance.spec in time_dimension_specs_set:
instances_to_return += (time_dimension_instance,)
matching_instances += 1

if matching_instances > 1:
if matching_instances != len(time_dimension_specs_set):
raise RuntimeError(
f"More than one time dimension instance with spec {time_dimension_spec} in "
f"instance set: {self.instance_set}"
f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs_set}\n"
f"Instances: {instances_to_return}"
)

if not column_associations_to_return:
return instances_to_return

def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> TimeDimensionInstance:
"""Given the name of the time dimension, return the instance associated with it in the data set."""
instances = self.instances_for_time_dimensions((time_dimension_spec,))
if not len(instances) == 1:
raise RuntimeError(
f"No time dimension instances with spec {time_dimension_spec} in instance set: {self.instance_set}"
f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_spec}\n"
f"Instances: {instances}"
)
return instances[0]

return column_associations_to_return[0]
def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation:
"""Given the name of the time dimension, return the set of columns associated with it in the data set."""
return self.instance_for_time_dimension(time_dimension_spec).associated_column

@property
@override
Expand Down
25 changes: 5 additions & 20 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,16 +1472,8 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set})
)

# Select matching instance from time spine data set (using base grain - custom grain will be joined in a later node).
original_time_spine_dim_instance: Optional[TimeDimensionInstance] = None
for time_dimension_instance in time_spine_dataset.instance_set.time_dimension_instances:
if time_dimension_instance.spec == agg_time_dimension_instance_for_join.spec:
original_time_spine_dim_instance = time_dimension_instance
break
assert original_time_spine_dim_instance, (
"Couldn't find requested agg_time_dimension_instance_for_join in time spine data set, which "
f"indicates it may have been configured incorrectly. Expected: {agg_time_dimension_instance_for_join.spec};"
f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}"
original_time_spine_dim_instance = time_spine_dataset.instance_for_time_dimension(
agg_time_dimension_instance_for_join.spec
)
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
Expand Down Expand Up @@ -1592,17 +1584,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod

# New dataset will be joined to parent dataset without a subquery, so use the same FROM alias as the parent node.
parent_alias = parent_data_set.checked_sql_select_node.from_source_alias
parent_time_dimension_instance: Optional[TimeDimensionInstance] = None
for instance in parent_data_set.instance_set.time_dimension_instances:
if instance.spec == node.time_dimension_spec.with_base_grain():
parent_time_dimension_instance = instance
break
parent_column: Optional[SqlSelectColumn] = None
assert parent_time_dimension_instance, (
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. "
f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain()}; "
f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}"
parent_time_dimension_instance = parent_data_set.instance_for_time_dimension(
node.time_dimension_spec.with_base_grain()
)
parent_column: Optional[SqlSelectColumn] = None
for select_column in parent_data_set.checked_sql_select_node.select_columns:
if select_column.column_alias == parent_time_dimension_instance.associated_column.column_name:
parent_column = select_column
Expand Down

0 comments on commit dcf17cf

Please sign in to comment.