Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: use FULL OUTER JOIN for dimension-only queries #863

Merged
merged 7 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 15 additions & 56 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ConstrainTimeRangeNode,
DataflowPlan,
FilterElementsNode,
JoinDescription,
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
Expand Down Expand Up @@ -77,53 +76,6 @@ class DataflowRecipe:
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]

@property
def join_targets(self) -> List[JoinDescription]:
"""Joins to be made to source node."""
join_targets = []
for join_recipe in self.join_linkable_instances_recipes:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this logic to JoinLinkableInstancesRecipe, which seemed cleaner & more appropriate.

# Figure out what elements to filter from the joined node.

# Sanity check - all linkable specs should have a link, or else why would we be joining them.
assert all([len(x.entity_links) > 0 for x in join_recipe.satisfiable_linkable_specs])

# If we're joining something in, then we need the associated entity, partitions, and time dimension
# specs defining the validity window (if necessary)
include_specs: List[LinkableInstanceSpec] = [
LinklessEntitySpec.from_reference(x.entity_links[0]) for x in join_recipe.satisfiable_linkable_specs
]
include_specs.extend([x.node_to_join_dimension_spec for x in join_recipe.join_on_partition_dimensions])
include_specs.extend(
[x.node_to_join_time_dimension_spec for x in join_recipe.join_on_partition_time_dimensions]
)
if join_recipe.validity_window:
include_specs.extend(
[
join_recipe.validity_window.window_start_dimension,
join_recipe.validity_window.window_end_dimension,
]
)

# satisfiable_linkable_specs describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the measure node.
include_specs.extend([x.without_first_entity_link for x in join_recipe.satisfiable_linkable_specs])
filtered_node_to_join = FilterElementsNode(
parent_node=join_recipe.node_to_join,
include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs),
)
join_targets.append(
JoinDescription(
join_node=filtered_node_to_join,
join_on_entity=join_recipe.join_on_entity,
join_on_partition_dimensions=join_recipe.join_on_partition_dimensions,
join_on_partition_time_dimensions=join_recipe.join_on_partition_time_dimensions,
validity_window=join_recipe.validity_window,
)
)
return join_targets


@dataclass(frozen=True)
class MeasureSpecProperties:
Expand Down Expand Up @@ -306,10 +258,11 @@ def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Dat
raise UnableToSatisfyQueryError(f"Recipe not found for linkable specs: {query_spec.linkable_specs}")

joined_node: Optional[JoinToBaseOutputNode] = None
if dataflow_recipe.join_targets:
joined_node = JoinToBaseOutputNode(
left_node=dataflow_recipe.source_node, join_targets=dataflow_recipe.join_targets
)
if dataflow_recipe.join_linkable_instances_recipes:
join_targets = [
join_recipe.join_description for join_recipe in dataflow_recipe.join_linkable_instances_recipes
]
joined_node = JoinToBaseOutputNode(left_node=dataflow_recipe.source_node, join_targets=join_targets)

where_constraint_node: Optional[WhereConstraintNode] = None
if query_spec.where_constraint:
Expand Down Expand Up @@ -485,13 +438,15 @@ def _find_dataflow_recipe(
potential_source_nodes: Sequence[BaseOutput] = self._select_source_nodes_with_measures(
measure_specs=set(measure_spec_properties.measure_specs), source_nodes=source_nodes
)
default_join_type = SqlJoinType.LEFT_OUTER
else:
# Only read nodes can be source nodes for queries without measures
source_nodes = self._read_nodes
source_nodes_to_linkable_specs = self._select_read_nodes_with_linkable_specs(
linkable_specs=linkable_spec_set, read_nodes=source_nodes
)
potential_source_nodes = list(source_nodes_to_linkable_specs.keys())
default_join_type = SqlJoinType.FULL_OUTER

logger.info(f"There are {len(potential_source_nodes)} potential source nodes")

Expand All @@ -518,7 +473,9 @@ def _find_dataflow_recipe(
f"After removing unnecessary nodes, there are {len(nodes_available_for_joins)} nodes available for joins"
)
if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs):
nodes_available_for_joins = node_processor.add_multi_hop_joins(linkable_specs, source_nodes)
nodes_available_for_joins = node_processor.add_multi_hop_joins(
desired_linkable_specs=linkable_specs, nodes=source_nodes, join_type=default_join_type
)
logger.info(
f"After adding multi-hop nodes, there are {len(nodes_available_for_joins)} nodes available for joins:\n"
f"{pformat_big_objects(nodes_available_for_joins)}"
Expand Down Expand Up @@ -552,7 +509,9 @@ def _find_dataflow_recipe(
logger.debug(f"Evaluating source node:\n{pformat_big_objects(source_node=dataflow_dag_as_text(node))}")

start_time = time.time()
evaluation = node_evaluator.evaluate_node(start_node=node, required_linkable_specs=list(linkable_specs))
evaluation = node_evaluator.evaluate_node(
start_node=node, required_linkable_specs=list(linkable_specs), default_join_type=default_join_type
)
logger.info(f"Evaluation of {node} took {time.time() - start_time:.2f}s")

logger.debug(
Expand Down Expand Up @@ -597,7 +556,7 @@ def _find_dataflow_recipe(

# Nodes containing the linkable instances will be joined to the source node, so these
# entities will need to be present in the source node.
required_local_entity_specs = tuple(x.join_on_entity for x in evaluation.join_recipes)
required_local_entity_specs = tuple(x.join_on_entity for x in evaluation.join_recipes if x.join_on_entity)
# Same thing with partitions.
required_local_dimension_specs = tuple(
y.start_node_dimension_spec for x in evaluation.join_recipes for y in x.join_on_partition_dimensions
Expand Down Expand Up @@ -780,7 +739,7 @@ def _build_aggregated_measure_from_measure_source_node(
),
)

join_targets = measure_recipe.join_targets
join_targets = [join_recipe.join_description for join_recipe in measure_recipe.join_linkable_instances_recipes]
unaggregated_measure_node: BaseOutput
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinToBaseOutputNode(
Expand Down
50 changes: 42 additions & 8 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metricflow.dataflow.builder.partitions import PartitionJoinResolver
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
FilterElementsNode,
JoinDescription,
PartitionDimensionJoinDescription,
PartitionTimeDimensionJoinDescription,
Expand All @@ -37,10 +38,8 @@
from metricflow.model.semantics.semantic_model_join_evaluator import SemanticModelJoinEvaluator
from metricflow.plan_conversion.instance_converters import CreateValidityWindowJoinDescription
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.specs import (
LinkableInstanceSpec,
LinklessEntitySpec,
)
from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinklessEntitySpec
from metricflow.sql.sql_plan import SqlJoinType

logger = logging.getLogger(__name__)

Expand All @@ -55,12 +54,14 @@ class JoinLinkableInstancesRecipe:
"""

node_to_join: BaseOutput
# The entity to join "node_to_join" on.
join_on_entity: LinklessEntitySpec
# The entity to join "node_to_join" on. Not needed for cross-joins.
join_on_entity: Optional[LinklessEntitySpec]
# The linkable instances from the query that can be satisfied if we join this node. Note that this is different from
# the linkable specs in the node that can help to satisfy the query. e.g. "user_id__country" might be one of the
# "satisfiable_linkable_specs", but "country" is the linkable spec in the node.
satisfiable_linkable_specs: List[LinkableInstanceSpec]
# Join type to use when joining node
join_type: SqlJoinType

# The partitions to join on, if there are matching partitions between the start_node and node_to_join.
join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
Expand All @@ -71,12 +72,37 @@ class JoinLinkableInstancesRecipe:
@property
def join_description(self) -> JoinDescription:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prior to this PR, this attribute was not used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lolwut

"""The recipe as a join description to use in the dataflow plan node."""
# Figure out what elements to filter from the joined node.
include_specs: List[LinkableInstanceSpec] = []
if not self.join_type == SqlJoinType.CROSS_JOIN:
assert all([len(spec.entity_links) > 0 for spec in self.satisfiable_linkable_specs])
include_specs.extend(
[LinklessEntitySpec.from_reference(spec.entity_links[0]) for spec in self.satisfiable_linkable_specs]
)

include_specs.extend([join.node_to_join_dimension_spec for join in self.join_on_partition_dimensions])
include_specs.extend([join.node_to_join_time_dimension_spec for join in self.join_on_partition_time_dimensions])
if self.validity_window:
include_specs.extend(
[self.validity_window.window_start_dimension, self.validity_window.window_end_dimension]
)

# `satisfiable_linkable_specs` describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the source node.
include_specs.extend([spec.without_first_entity_link for spec in self.satisfiable_linkable_specs])
filtered_node_to_join = FilterElementsNode(
parent_node=self.node_to_join, include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs)
)

return JoinDescription(
join_node=self.node_to_join,
join_node=filtered_node_to_join,
join_on_entity=self.join_on_entity,
join_on_partition_dimensions=self.join_on_partition_dimensions,
join_on_partition_time_dimensions=self.join_on_partition_time_dimensions,
validity_window=self.validity_window,
join_type=self.join_type,
)


Expand Down Expand Up @@ -133,6 +159,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
self,
start_node_instance_set: InstanceSet,
needed_linkable_specs: List[LinkableInstanceSpec],
join_type: SqlJoinType,
) -> List[JoinLinkableInstancesRecipe]:
"""Get nodes that can be joined to get 1 or more of the "needed_linkable_specs".

Expand Down Expand Up @@ -257,6 +284,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
join_on_partition_dimensions=join_on_partition_dimensions,
join_on_partition_time_dimensions=join_on_partition_time_dimensions,
validity_window=validity_window_join_description,
join_type=join_type,
)
)

Expand All @@ -271,6 +299,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
def _update_candidates_that_can_satisfy_linkable_specs(
candidates_for_join: List[JoinLinkableInstancesRecipe],
already_satisfisfied_linkable_specs: List[LinkableInstanceSpec],
join_type: SqlJoinType,
) -> List[JoinLinkableInstancesRecipe]:
"""Update / filter candidates_for_join based on linkable instance specs that we have already satisfied.

Expand All @@ -294,6 +323,7 @@ def _update_candidates_that_can_satisfy_linkable_specs(
join_on_partition_dimensions=candidate_for_join.join_on_partition_dimensions,
join_on_partition_time_dimensions=candidate_for_join.join_on_partition_time_dimensions,
validity_window=candidate_for_join.validity_window,
join_type=join_type,
)
)
return sorted(
Expand All @@ -306,6 +336,7 @@ def evaluate_node(
self,
start_node: BaseOutput,
required_linkable_specs: Sequence[LinkableInstanceSpec],
default_join_type: SqlJoinType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this default_join_type? Does it get overridden somehow? Or should it be called join_type here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was in preparation for my next PR - we'll need to use CROSS JOIN when querying metric_time with other dimensions (but no metrics).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... which will override the default join type for dimension-only queries.

) -> LinkableInstanceSatisfiabilityEvaluation:
"""Evaluates if the "required_linkable_specs" can be realized by joining the "start_node" with other nodes.

Expand Down Expand Up @@ -345,7 +376,9 @@ def evaluate_node(
possibly_joinable_linkable_specs.append(required_linkable_spec)

candidates_for_join = self._find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
start_node_instance_set=candidate_instance_set, needed_linkable_specs=possibly_joinable_linkable_specs
start_node_instance_set=candidate_instance_set,
needed_linkable_specs=possibly_joinable_linkable_specs,
join_type=default_join_type,
)
join_candidates: List[JoinLinkableInstancesRecipe] = []

Expand Down Expand Up @@ -378,6 +411,7 @@ def evaluate_node(
candidates_for_join = self._update_candidates_that_can_satisfy_linkable_specs(
candidates_for_join=candidates_for_join,
already_satisfisfied_linkable_specs=next_candidate.satisfiable_linkable_specs,
join_type=default_join_type,
)

# The once possibly joinable specs are definitely joinable and no longer need to be searched for.
Expand Down
4 changes: 3 additions & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class JoinDescription:
"""Describes how data from a node should be joined to data from another node."""

join_node: BaseOutput
join_on_entity: LinklessEntitySpec
join_on_entity: Optional[LinklessEntitySpec]
join_type: SqlJoinType

join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
join_on_partition_time_dimensions: Tuple[PartitionTimeDimensionJoinDescription, ...]
Expand Down Expand Up @@ -339,6 +340,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinToBase
join_on_partition_dimensions=old_join_target.join_on_partition_dimensions,
join_on_partition_time_dimensions=old_join_target.join_on_partition_time_dimensions,
validity_window=old_join_target.validity_window,
join_type=old_join_target.join_type,
)
for i, old_join_target in enumerate(self._join_targets)
],
Expand Down
45 changes: 23 additions & 22 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,28 +396,29 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> SqlDataS
)
)

# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.

# logger.error(f"before filter is:\n{pformat_big_objects(right_data_set.instance_set.spec_set)}")
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity,
).transform(right_data_set.instance_set)
# logger.error(f"after filter is:\n{pformat_big_objects(right_data_set_instance_set_filtered.spec_set)}")

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
)
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join
if join_description.join_type == SqlJoinType.CROSS_JOIN:
table_alias_to_instance_set[right_data_set_alias] = right_data_set.instance_set
else:
# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.
assert join_on_entity
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity
).transform(right_data_set.instance_set)

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
)
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join

from_data_set_output_instance_set = from_data_set.instance_set.transform(
FilterElements(include_specs=from_data_set.instance_set.spec_set)
Expand Down
Loading
Loading