Skip to content

Commit

Permalink
Bug fix: Generate new select columns for instances with new entity li…
Browse files Browse the repository at this point in the history
…nk - commit needs cleanup
  • Loading branch information
courtneyholcomb committed Nov 2, 2024
1 parent 194c3c4 commit 3031733
Show file tree
Hide file tree
Showing 61 changed files with 955 additions and 1,043 deletions.
135 changes: 77 additions & 58 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@
MetadataInstance,
MetricInstance,
TimeDimensionInstance,
group_instances_by_type,
)
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.specs.column_assoc import (
ColumnAssociationResolver,
)
from metricflow_semantics.specs.column_assoc import ColumnAssociation, ColumnAssociationResolver
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.measure_spec import MeasureSpec
from metricflow_semantics.specs.metadata_spec import MetadataSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
Expand Down Expand Up @@ -72,7 +71,6 @@
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.instance_converters import (
AddGroupByMetric,
AddLinkToLinkableElements,
AddMetadata,
AddMetrics,
AliasAggregatedMeasures,
Expand Down Expand Up @@ -229,7 +227,7 @@ def convert_to_sql_query_plan(
sql_node = optimizer.optimize(sql_node)
logger.debug(
LazyFormat(
lambda: f"After applying {optimizer.__class__.__name__}, the SQL query plan is:\n"
lambda: f"After applying optimizer {optimizer.__class__.__name__}, the SQL query plan is:\n"
f"{indent(sql_node.structure_text())}"
)
)
Expand Down Expand Up @@ -456,31 +454,58 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
)

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
"""Generates the query that realizes the behavior of the JoinToStandardOutputNode."""
# Keep a mapping between the table aliases that would be used in the query and the MDO instances in that source.
# e.g. when building "FROM from_table a JOIN right_table b", the value for key "a" would be the instances in
# "from_table"
table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict()

# Convert the dataflow from the left node to a DataSet and add context for it to table_alias_to_instance_set
# A DataSet is a bundle of the SQL query (in object form) and the MDO instances that the SQL query contains.
"""Generates the query that realizes the behavior of the JoinOnEntitiesNode."""
from_data_set = node.left_node.accept(self)
from_data_set_alias = self._next_unique_table_alias()
table_alias_to_instance_set[from_data_set_alias] = from_data_set.instance_set

# Build the join descriptions for the SqlQueryPlan - different from node.join_descriptions which are the join
# descriptions from the dataflow plan.
sql_join_descs: List[SqlJoinDescription] = []
# TODO: make prettier
def build_columns(spec: LinkableInstanceSpec) -> Tuple[ColumnAssociation]:
return (self._column_association_resolver.resolve_spec(spec),)

def build_select_column(
table_alias: str, original_instance: MdoInstance, new_instance: MdoInstance
) -> SqlSelectColumn:
"""Build new select column using the old column name as the expr and the new column name as the alias.
Example: "country AS user_id__country"
"""
return SqlSelectColumn(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=table_alias, column_name=original_instance.associated_column.column_name
),
column_alias=new_instance.associated_column.column_name,
)

# Change the aggregation state for the measures to be partially aggregated if it was previously aggregated
# since we removed the entities and added the dimensions. The dimensions could have the same value for
# multiple rows, so we'll need to re-aggregate.
from_data_set_output_instance_set = from_data_set.instance_set.transform(
# TODO: is this filter doing anything? seems like no?
FilterElements(include_specs=from_data_set.instance_set.spec_set)
).transform(
ChangeMeasureAggregationState(
{
AggregationState.NON_AGGREGATED: AggregationState.NON_AGGREGATED,
AggregationState.COMPLETE: AggregationState.PARTIAL,
AggregationState.PARTIAL: AggregationState.PARTIAL,
}
)
)
instances_to_build_simple_select_columns_for = OrderedDict(
{from_data_set_alias: from_data_set_output_instance_set}
)

# The dataflow plan describes how the data sets coming from the parent nodes should be joined together. Use
# those descriptions to convert them to join descriptions for the SQL query plan.
# Build SQL join description, instance set, and select columns for each join target.
output_instance_set = from_data_set_output_instance_set
select_columns: Tuple[SqlSelectColumn, ...] = ()
sql_join_descs: List[SqlJoinDescription] = []
for join_description in node.join_targets:
join_on_entity = join_description.join_on_entity

right_node_to_join: DataflowPlanNode = join_description.join_node
right_node_to_join = join_description.join_node
right_data_set: SqlDataSet = right_node_to_join.accept(self)
right_data_set_alias = self._next_unique_table_alias()

# Build join description.
sql_join_desc = SqlQueryPlanJoinBuilder.make_base_output_join_description(
left_data_set=AnnotatedSqlDataSet(data_set=from_data_set, alias=from_data_set_alias),
right_data_set=AnnotatedSqlDataSet(data_set=right_data_set, alias=right_data_set_alias),
Expand All @@ -495,50 +520,44 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity,
).transform(right_data_set.instance_set)

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
# TODO: test if this transformation is necessary and remove it if not. This adds a lot of clutter to the function.
right_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(join_on_entity).transform(
right_data_set.instance_set
)

# After the right data set is joined, we need to change the links to indicate that they a join was used to
# satisfy them. For example, if the right dataset contains the "country" dimension, and "user_id" is the
# join_on_entity, then the joined data set should have the "user__country" dimension.
new_instances: Tuple[MdoInstance, ...] = ()
for original_instance in right_instance_set_filtered.linkable_instances:
# Is this necessary? Does it even work? i.e. diff types here
if original_instance.spec == join_on_entity:
continue
new_instance = original_instance.with_entity_prefix(
join_on_entity.reference, column_association_resolver=self._column_association_resolver
)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
new_instances += (new_instance,)
select_columns += (select_column,)
right_instance_set_after_join = group_instances_by_type(new_instances)
else:
right_data_set_instance_set_after_join = right_data_set.instance_set
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join
right_instance_set_after_join = right_data_set.instance_set
instances_to_build_simple_select_columns_for[right_data_set_alias] = right_instance_set_after_join

from_data_set_output_instance_set = from_data_set.instance_set.transform(
FilterElements(include_specs=from_data_set.instance_set.spec_set)
)
output_instance_set = InstanceSet.merge([output_instance_set, right_instance_set_after_join])

# Change the aggregation state for the measures to be partially aggregated if it was previously aggregated
# since we removed the entities and added the dimensions. The dimensions could have the same value for
# multiple rows, so we'll need to re-aggregate.
from_data_set_output_instance_set = from_data_set_output_instance_set.transform(
ChangeMeasureAggregationState(
{
AggregationState.NON_AGGREGATED: AggregationState.NON_AGGREGATED,
AggregationState.COMPLETE: AggregationState.PARTIAL,
AggregationState.PARTIAL: AggregationState.PARTIAL,
}
)
select_columns += create_select_columns_for_instance_sets(
column_resolver=self._column_association_resolver,
table_alias_to_instance_set=instances_to_build_simple_select_columns_for,
)

table_alias_to_instance_set[from_data_set_alias] = from_data_set_output_instance_set

# Construct the data set that contains the updated instances and the SQL nodes that should go in the various
# clauses.
return SqlDataSet(
instance_set=InstanceSet.merge(list(table_alias_to_instance_set.values())),
# TODO: Should SqlDataSet have a map like {instance: column}? Trying to match them is a pain in the butt.
instance_set=output_instance_set,
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=create_select_columns_for_instance_sets(
self._column_association_resolver, table_alias_to_instance_set
),
select_columns=select_columns,
from_source=from_data_set.checked_sql_select_node,
from_source_alias=from_data_set_alias,
join_descs=tuple(sql_join_descs),
Expand Down
115 changes: 4 additions & 111 deletions metricflow/plan_conversion/instance_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain
from typing import Dict, List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import EntityReference, MetricReference, SemanticModelReference
from dbt_semantic_interfaces.references import MetricReference, SemanticModelReference
from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
Expand All @@ -29,13 +29,10 @@
from metricflow_semantics.model.semantics.metric_lookup import MetricLookup
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.entity_spec import EntitySpec, LinklessEntitySpec
from metricflow_semantics.specs.group_by_metric_spec import GroupByMetricSpec
from metricflow_semantics.specs.entity_spec import LinklessEntitySpec
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.measure_spec import MeasureSpec, MetricInputMeasureSpec
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from more_itertools import bucket

from metricflow.dataflow.nodes.join_to_base import ValidityWindowJoinDescription
Expand Down Expand Up @@ -387,118 +384,14 @@ def transform(self, instance_set: InstanceSet) -> Optional[ValidityWindowJoinDes
return None


class AddLinkToLinkableElements(InstanceSetTransform[InstanceSet]):
"""Return a new instance set where the all linkable elements in the set have a new link added.
e.g. "country" -> "user_id__country" after a data set has been joined by entity.
"""

def __init__(self, join_on_entity: LinklessEntitySpec) -> None: # noqa: D107
self._join_on_entity = join_on_entity

def transform(self, instance_set: InstanceSet) -> InstanceSet: # noqa: D102
assert len(instance_set.metric_instances) == 0, "Can't add links to instance sets with metrics"
assert len(instance_set.measure_instances) == 0, "Can't add links to instance sets with measures"

# Handle dimension instances
dimension_instances_with_additional_link = []
for dimension_instance in instance_set.dimension_instances:
# The new dimension spec should include the join on entity.
transformed_dimension_spec_from_right = DimensionSpec(
element_name=dimension_instance.spec.element_name,
entity_links=self._join_on_entity.as_linkless_prefix + dimension_instance.spec.entity_links,
)
dimension_instances_with_additional_link.append(
DimensionInstance(
associated_columns=dimension_instance.associated_columns,
defined_from=dimension_instance.defined_from,
spec=transformed_dimension_spec_from_right,
)
)

# Handle time dimension instances
time_dimension_instances_with_additional_link = []
for time_dimension_instance in instance_set.time_dimension_instances:
# The new dimension spec should include the join on entity.
transformed_time_dimension_spec_from_right = TimeDimensionSpec(
element_name=time_dimension_instance.spec.element_name,
entity_links=(
(EntityReference(element_name=self._join_on_entity.element_name),)
+ time_dimension_instance.spec.entity_links
),
time_granularity=time_dimension_instance.spec.time_granularity,
date_part=time_dimension_instance.spec.date_part,
)
time_dimension_instances_with_additional_link.append(
TimeDimensionInstance(
associated_columns=time_dimension_instance.associated_columns,
defined_from=time_dimension_instance.defined_from,
spec=transformed_time_dimension_spec_from_right,
)
)

# Handle entity instances
entity_instances_with_additional_link = []
for entity_instance in instance_set.entity_instances:
# Don't include adding the entity link to the same entity.
# Otherwise, you would create "user_id__user_id", which is confusing.
if entity_instance.spec == self._join_on_entity:
continue
# The new entity spec should include the join on entity.
transformed_entity_spec_from_right = EntitySpec(
element_name=entity_instance.spec.element_name,
entity_links=self._join_on_entity.as_linkless_prefix + entity_instance.spec.entity_links,
)
entity_instances_with_additional_link.append(
EntityInstance(
associated_columns=entity_instance.associated_columns,
defined_from=entity_instance.defined_from,
spec=transformed_entity_spec_from_right,
)
)

# Handle group by metric instances
group_by_metric_instances_with_additional_link = []
for group_by_metric_instance in instance_set.group_by_metric_instances:
transformed_group_by_metric_spec_from_right = GroupByMetricSpec(
element_name=group_by_metric_instance.spec.element_name,
entity_links=self._join_on_entity.as_linkless_prefix + group_by_metric_instance.spec.entity_links,
metric_subquery_entity_links=group_by_metric_instance.spec.metric_subquery_entity_links,
)
group_by_metric_instances_with_additional_link.append(
GroupByMetricInstance(
associated_columns=group_by_metric_instance.associated_columns,
defined_from=group_by_metric_instance.defined_from,
spec=transformed_group_by_metric_spec_from_right,
)
)

return InstanceSet(
measure_instances=(),
dimension_instances=tuple(dimension_instances_with_additional_link),
time_dimension_instances=tuple(time_dimension_instances_with_additional_link),
entity_instances=tuple(entity_instances_with_additional_link),
group_by_metric_instances=tuple(group_by_metric_instances_with_additional_link),
metric_instances=(),
metadata_instances=(),
)


class FilterLinkableInstancesWithLeadingLink(InstanceSetTransform[InstanceSet]):
"""Return an instance set with the elements that have a specified leading link removed.
e.g. Remove "listing__country" if the specified link is "listing".
"""

def __init__(
self,
entity_link: LinklessEntitySpec,
) -> None:
"""Constructor.
Args:
entity_link: Remove elements with this link as the first element in "entity_links"
"""
def __init__(self, entity_link: LinklessEntitySpec) -> None:
"""Remove elements with this link as the first element in "entity_links"."""
self._entity_link = entity_link

def _should_pass(self, linkable_spec: LinkableInstanceSpec) -> bool:
Expand Down
Loading

0 comments on commit 3031733

Please sign in to comment.