Skip to content

Commit

Permalink
Remove changes needed for cross-join
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 10, 2023
1 parent b8b62aa commit 44fc845
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 46 deletions.
2 changes: 1 addition & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _find_dataflow_recipe(

# Nodes containing the linkable instances will be joined to the source node, so these
# entities will need to be present in the source node.
required_local_entity_specs = tuple(x.join_on_entity for x in evaluation.join_recipes if x.join_on_entity)
required_local_entity_specs = tuple(x.join_on_entity for x in evaluation.join_recipes)
# Same thing with partitions.
required_local_dimension_specs = tuple(
y.start_node_dimension_spec for x in evaluation.join_recipes for y in x.join_on_partition_dimensions
Expand Down
13 changes: 6 additions & 7 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class JoinLinkableInstancesRecipe:
"""

node_to_join: BaseOutput
# The entity to join "node_to_join" on. Not needed for cross-joins.
join_on_entity: Optional[LinklessEntitySpec]
# The entity to join "node_to_join" on.
join_on_entity: LinklessEntitySpec
# The linkable instances from the query that can be satisfied if we join this node. Note that this is different from
# the linkable specs in the node that can help to satisfy the query. e.g. "user_id__country" might be one of the
# "satisfiable_linkable_specs", but "country" is the linkable spec in the node.
Expand All @@ -74,11 +74,10 @@ def join_description(self) -> JoinDescription:
"""The recipe as a join description to use in the dataflow plan node."""
# Figure out what elements to filter from the joined node.
include_specs: List[LinkableInstanceSpec] = []
if not self.join_type == SqlJoinType.CROSS_JOIN:
assert all([len(spec.entity_links) > 0 for spec in self.satisfiable_linkable_specs])
include_specs.extend(
[LinklessEntitySpec.from_reference(spec.entity_links[0]) for spec in self.satisfiable_linkable_specs]
)
assert all([len(spec.entity_links) > 0 for spec in self.satisfiable_linkable_specs])
include_specs.extend(
[LinklessEntitySpec.from_reference(spec.entity_links[0]) for spec in self.satisfiable_linkable_specs]
)

include_specs.extend([join.node_to_join_dimension_spec for join in self.join_on_partition_dimensions])
include_specs.extend([join.node_to_join_time_dimension_spec for join in self.join_on_partition_time_dimensions])
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class JoinDescription:
"""Describes how data from a node should be joined to data from another node."""

join_node: BaseOutput
join_on_entity: Optional[LinklessEntitySpec]
join_on_entity: LinklessEntitySpec
join_type: SqlJoinType

join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
Expand Down
42 changes: 19 additions & 23 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,29 +396,25 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> SqlDataS
)
)

if join_description.join_type == SqlJoinType.CROSS_JOIN:
table_alias_to_instance_set[right_data_set_alias] = right_data_set.instance_set
else:
# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.
assert join_on_entity
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity
).transform(right_data_set.instance_set)

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
)
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join
# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity
).transform(right_data_set.instance_set)

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
)
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join

from_data_set_output_instance_set = from_data_set.instance_set.transform(
FilterElements(include_specs=from_data_set.instance_set.spec_set)
Expand Down
18 changes: 4 additions & 14 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,7 @@ def make_base_output_join_description(
In addition to the entity equality condition, this will ensure datasets are joined on all partition
columns and account for validity windows, if those are defined in one of the datasets.
"""
validity_conditions = SqlQueryPlanJoinBuilder._make_validity_window_on_conditions(
left_data_set=left_data_set, right_data_set=right_data_set, join_description=join_description
)
if join_description.join_type == SqlJoinType.CROSS_JOIN:
return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=right_data_set.data_set.sql_select_node,
left_source_alias=left_data_set.alias,
right_source_alias=right_data_set.alias,
column_equality_descriptions=[],
join_type=join_description.join_type,
additional_on_conditions=validity_conditions,
)

join_on_entity = join_description.join_on_entity
assert join_on_entity, "Join on entity required unless using cross join."
# Figure out which columns in the "left" data set correspond to the entity that we want to join on.
# The column associations tell us which columns correspond to which instances in the data set.
left_data_set_entity_column_associations = left_data_set.data_set.column_associations_for_entity(join_on_entity)
Expand Down Expand Up @@ -211,6 +197,10 @@ def make_base_output_join_description(
)
)

validity_conditions = SqlQueryPlanJoinBuilder._make_validity_window_on_conditions(
left_data_set=left_data_set, right_data_set=right_data_set, join_description=join_description
)

return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=right_data_set.data_set.sql_select_node,
left_source_alias=left_data_set.alias,
Expand Down

0 comments on commit 44fc845

Please sign in to comment.