Skip to content

Commit

Permalink
Add join_type to JoinDescription
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 10, 2023
1 parent 1410c46 commit b8b62aa
Show file tree
Hide file tree
Showing 23 changed files with 166 additions and 109 deletions.
71 changes: 15 additions & 56 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ConstrainTimeRangeNode,
DataflowPlan,
FilterElementsNode,
JoinDescription,
JoinOverTimeRangeNode,
JoinToBaseOutputNode,
JoinToTimeSpineNode,
Expand Down Expand Up @@ -77,53 +76,6 @@ class DataflowRecipe:
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]

@property
def join_targets(self) -> List[JoinDescription]:
"""Joins to be made to source node."""
join_targets = []
for join_recipe in self.join_linkable_instances_recipes:
# Figure out what elements to filter from the joined node.

# Sanity check - all linkable specs should have a link, or else why would we be joining them.
assert all([len(x.entity_links) > 0 for x in join_recipe.satisfiable_linkable_specs])

# If we're joining something in, then we need the associated entity, partitions, and time dimension
# specs defining the validity window (if necessary)
include_specs: List[LinkableInstanceSpec] = [
LinklessEntitySpec.from_reference(x.entity_links[0]) for x in join_recipe.satisfiable_linkable_specs
]
include_specs.extend([x.node_to_join_dimension_spec for x in join_recipe.join_on_partition_dimensions])
include_specs.extend(
[x.node_to_join_time_dimension_spec for x in join_recipe.join_on_partition_time_dimensions]
)
if join_recipe.validity_window:
include_specs.extend(
[
join_recipe.validity_window.window_start_dimension,
join_recipe.validity_window.window_end_dimension,
]
)

# satisfiable_linkable_specs describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the measure node.
include_specs.extend([x.without_first_entity_link for x in join_recipe.satisfiable_linkable_specs])
filtered_node_to_join = FilterElementsNode(
parent_node=join_recipe.node_to_join,
include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs),
)
join_targets.append(
JoinDescription(
join_node=filtered_node_to_join,
join_on_entity=join_recipe.join_on_entity,
join_on_partition_dimensions=join_recipe.join_on_partition_dimensions,
join_on_partition_time_dimensions=join_recipe.join_on_partition_time_dimensions,
validity_window=join_recipe.validity_window,
)
)
return join_targets


@dataclass(frozen=True)
class MeasureSpecProperties:
Expand Down Expand Up @@ -306,10 +258,11 @@ def build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Dat
raise UnableToSatisfyQueryError(f"Recipe not found for linkable specs: {query_spec.linkable_specs}")

joined_node: Optional[JoinToBaseOutputNode] = None
if dataflow_recipe.join_targets:
joined_node = JoinToBaseOutputNode(
left_node=dataflow_recipe.source_node, join_targets=dataflow_recipe.join_targets
)
if dataflow_recipe.join_linkable_instances_recipes:
join_targets = [
join_recipe.join_description for join_recipe in dataflow_recipe.join_linkable_instances_recipes
]
joined_node = JoinToBaseOutputNode(left_node=dataflow_recipe.source_node, join_targets=join_targets)

where_constraint_node: Optional[WhereConstraintNode] = None
if query_spec.where_constraint:
Expand Down Expand Up @@ -485,13 +438,15 @@ def _find_dataflow_recipe(
potential_source_nodes: Sequence[BaseOutput] = self._select_source_nodes_with_measures(
measure_specs=set(measure_spec_properties.measure_specs), source_nodes=source_nodes
)
default_join_type = SqlJoinType.LEFT_OUTER
else:
# Only read nodes can be source nodes for queries without measures
source_nodes = self._read_nodes
source_nodes_to_linkable_specs = self._select_read_nodes_with_linkable_specs(
linkable_specs=linkable_spec_set, read_nodes=source_nodes
)
potential_source_nodes = list(source_nodes_to_linkable_specs.keys())
default_join_type = SqlJoinType.FULL_OUTER

logger.info(f"There are {len(potential_source_nodes)} potential source nodes")

Expand All @@ -518,7 +473,9 @@ def _find_dataflow_recipe(
f"After removing unnecessary nodes, there are {len(nodes_available_for_joins)} nodes available for joins"
)
if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs):
nodes_available_for_joins = node_processor.add_multi_hop_joins(linkable_specs, source_nodes)
nodes_available_for_joins = node_processor.add_multi_hop_joins(
desired_linkable_specs=linkable_specs, nodes=source_nodes, join_type=default_join_type
)
logger.info(
f"After adding multi-hop nodes, there are {len(nodes_available_for_joins)} nodes available for joins:\n"
f"{pformat_big_objects(nodes_available_for_joins)}"
Expand Down Expand Up @@ -552,7 +509,9 @@ def _find_dataflow_recipe(
logger.debug(f"Evaluating source node:\n{pformat_big_objects(source_node=dataflow_dag_as_text(node))}")

start_time = time.time()
evaluation = node_evaluator.evaluate_node(start_node=node, required_linkable_specs=list(linkable_specs))
evaluation = node_evaluator.evaluate_node(
start_node=node, required_linkable_specs=list(linkable_specs), default_join_type=default_join_type
)
logger.info(f"Evaluation of {node} took {time.time() - start_time:.2f}s")

logger.debug(
Expand Down Expand Up @@ -597,7 +556,7 @@ def _find_dataflow_recipe(

# Nodes containing the linkable instances will be joined to the source node, so these
# entities will need to be present in the source node.
required_local_entity_specs = tuple(x.join_on_entity for x in evaluation.join_recipes)
required_local_entity_specs = tuple(x.join_on_entity for x in evaluation.join_recipes if x.join_on_entity)
# Same thing with partitions.
required_local_dimension_specs = tuple(
y.start_node_dimension_spec for x in evaluation.join_recipes for y in x.join_on_partition_dimensions
Expand Down Expand Up @@ -780,7 +739,7 @@ def _build_aggregated_measure_from_measure_source_node(
),
)

join_targets = measure_recipe.join_targets
join_targets = [join_recipe.join_description for join_recipe in measure_recipe.join_linkable_instances_recipes]
unaggregated_measure_node: BaseOutput
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinToBaseOutputNode(
Expand Down
50 changes: 42 additions & 8 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metricflow.dataflow.builder.partitions import PartitionJoinResolver
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
FilterElementsNode,
JoinDescription,
PartitionDimensionJoinDescription,
PartitionTimeDimensionJoinDescription,
Expand All @@ -37,10 +38,8 @@
from metricflow.model.semantics.semantic_model_join_evaluator import SemanticModelJoinEvaluator
from metricflow.plan_conversion.instance_converters import CreateValidityWindowJoinDescription
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.specs import (
LinkableInstanceSpec,
LinklessEntitySpec,
)
from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinklessEntitySpec
from metricflow.sql.sql_plan import SqlJoinType

logger = logging.getLogger(__name__)

Expand All @@ -55,12 +54,14 @@ class JoinLinkableInstancesRecipe:
"""

node_to_join: BaseOutput
# The entity to join "node_to_join" on.
join_on_entity: LinklessEntitySpec
# The entity to join "node_to_join" on. Not needed for cross-joins.
join_on_entity: Optional[LinklessEntitySpec]
# The linkable instances from the query that can be satisfied if we join this node. Note that this is different from
# the linkable specs in the node that can help to satisfy the query. e.g. "user_id__country" might be one of the
# "satisfiable_linkable_specs", but "country" is the linkable spec in the node.
satisfiable_linkable_specs: List[LinkableInstanceSpec]
# Join type to use when joining node
join_type: SqlJoinType

# The partitions to join on, if there are matching partitions between the start_node and node_to_join.
join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
Expand All @@ -71,12 +72,37 @@ class JoinLinkableInstancesRecipe:
@property
def join_description(self) -> JoinDescription:
"""The recipe as a join description to use in the dataflow plan node."""
# Figure out what elements to filter from the joined node.
include_specs: List[LinkableInstanceSpec] = []
if not self.join_type == SqlJoinType.CROSS_JOIN:
assert all([len(spec.entity_links) > 0 for spec in self.satisfiable_linkable_specs])
include_specs.extend(
[LinklessEntitySpec.from_reference(spec.entity_links[0]) for spec in self.satisfiable_linkable_specs]
)

include_specs.extend([join.node_to_join_dimension_spec for join in self.join_on_partition_dimensions])
include_specs.extend([join.node_to_join_time_dimension_spec for join in self.join_on_partition_time_dimensions])
if self.validity_window:
include_specs.extend(
[self.validity_window.window_start_dimension, self.validity_window.window_end_dimension]
)

# `satisfiable_linkable_specs` describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the source node.
include_specs.extend([spec.without_first_entity_link for spec in self.satisfiable_linkable_specs])
filtered_node_to_join = FilterElementsNode(
parent_node=self.node_to_join, include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs)
)

return JoinDescription(
join_node=self.node_to_join,
join_node=filtered_node_to_join,
join_on_entity=self.join_on_entity,
join_on_partition_dimensions=self.join_on_partition_dimensions,
join_on_partition_time_dimensions=self.join_on_partition_time_dimensions,
validity_window=self.validity_window,
join_type=self.join_type,
)


Expand Down Expand Up @@ -133,6 +159,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
self,
start_node_instance_set: InstanceSet,
needed_linkable_specs: List[LinkableInstanceSpec],
join_type: SqlJoinType,
) -> List[JoinLinkableInstancesRecipe]:
"""Get nodes that can be joined to get 1 or more of the "needed_linkable_specs".
Expand Down Expand Up @@ -257,6 +284,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
join_on_partition_dimensions=join_on_partition_dimensions,
join_on_partition_time_dimensions=join_on_partition_time_dimensions,
validity_window=validity_window_join_description,
join_type=join_type,
)
)

Expand All @@ -271,6 +299,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
def _update_candidates_that_can_satisfy_linkable_specs(
candidates_for_join: List[JoinLinkableInstancesRecipe],
already_satisfisfied_linkable_specs: List[LinkableInstanceSpec],
join_type: SqlJoinType,
) -> List[JoinLinkableInstancesRecipe]:
"""Update / filter candidates_for_join based on linkable instance specs that we have already satisfied.
Expand All @@ -294,6 +323,7 @@ def _update_candidates_that_can_satisfy_linkable_specs(
join_on_partition_dimensions=candidate_for_join.join_on_partition_dimensions,
join_on_partition_time_dimensions=candidate_for_join.join_on_partition_time_dimensions,
validity_window=candidate_for_join.validity_window,
join_type=join_type,
)
)
return sorted(
Expand All @@ -306,6 +336,7 @@ def evaluate_node(
self,
start_node: BaseOutput,
required_linkable_specs: Sequence[LinkableInstanceSpec],
default_join_type: SqlJoinType,
) -> LinkableInstanceSatisfiabilityEvaluation:
"""Evaluates if the "required_linkable_specs" can be realized by joining the "start_node" with other nodes.
Expand Down Expand Up @@ -345,7 +376,9 @@ def evaluate_node(
possibly_joinable_linkable_specs.append(required_linkable_spec)

candidates_for_join = self._find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
start_node_instance_set=candidate_instance_set, needed_linkable_specs=possibly_joinable_linkable_specs
start_node_instance_set=candidate_instance_set,
needed_linkable_specs=possibly_joinable_linkable_specs,
join_type=default_join_type,
)
join_candidates: List[JoinLinkableInstancesRecipe] = []

Expand Down Expand Up @@ -378,6 +411,7 @@ def evaluate_node(
candidates_for_join = self._update_candidates_that_can_satisfy_linkable_specs(
candidates_for_join=candidates_for_join,
already_satisfisfied_linkable_specs=next_candidate.satisfiable_linkable_specs,
join_type=default_join_type,
)

# The once possibly joinable specs are definitely joinable and no longer need to be searched for.
Expand Down
4 changes: 3 additions & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class JoinDescription:
"""Describes how data from a node should be joined to data from another node."""

join_node: BaseOutput
join_on_entity: LinklessEntitySpec
join_on_entity: Optional[LinklessEntitySpec]
join_type: SqlJoinType

join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
join_on_partition_time_dimensions: Tuple[PartitionTimeDimensionJoinDescription, ...]
Expand Down Expand Up @@ -339,6 +340,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinToBase
join_on_partition_dimensions=old_join_target.join_on_partition_dimensions,
join_on_partition_time_dimensions=old_join_target.join_on_partition_time_dimensions,
validity_window=old_join_target.validity_window,
join_type=old_join_target.join_type,
)
for i, old_join_target in enumerate(self._join_targets)
],
Expand Down
45 changes: 23 additions & 22 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,28 +396,29 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> SqlDataS
)
)

# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.

# logger.error(f"before filter is:\n{pformat_big_objects(right_data_set.instance_set.spec_set)}")
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity,
).transform(right_data_set.instance_set)
# logger.error(f"after filter is:\n{pformat_big_objects(right_data_set_instance_set_filtered.spec_set)}")

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
)
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join
if join_description.join_type == SqlJoinType.CROSS_JOIN:
table_alias_to_instance_set[right_data_set_alias] = right_data_set.instance_set
else:
# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.
assert join_on_entity
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity
).transform(right_data_set.instance_set)

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
# measure instance and the right dataset contains the "country" dimension instance, then after the join,
# the output data set should have the "country" dimension instance with the "user_id" entity link
# (if "user_id" equality was the join condition). "country" -> "user_id__country"
right_data_set_instance_set_after_join = right_data_set_instance_set_filtered.transform(
AddLinkToLinkableElements(join_on_entity=join_on_entity)
)
table_alias_to_instance_set[right_data_set_alias] = right_data_set_instance_set_after_join

from_data_set_output_instance_set = from_data_set.instance_set.transform(
FilterElements(include_specs=from_data_set.instance_set.spec_set)
Expand Down
Loading

0 comments on commit b8b62aa

Please sign in to comment.