Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP #1079

Closed
wants to merge 5 commits into from
Closed

WIP #1079

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def build_sink_node(
@staticmethod
def _contains_multihop_linkables(linkable_specs: Sequence[LinkableInstanceSpec]) -> bool:
"""Returns true if any of the linkable specs requires a multi-hop join to realize."""
return any(len(x.entity_links) > 1 for x in linkable_specs)
return any(len(x.group_by_links) > 1 for x in linkable_specs)

def _get_semantic_model_names_for_measures(self, measures: Sequence[MeasureSpec]) -> Set[str]:
"""Return the names of the semantic models needed to compute the input measures.
Expand Down Expand Up @@ -815,6 +815,23 @@ def _find_dataflow_recipe(
measure_specs=set(measure_spec_properties.measure_specs),
source_nodes=self._source_node_set.source_nodes_for_metric_queries,
)
# If there are MetricGroupBys in the requested linkable specs, build source nodes to satisfy them.
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
# We do this at query time instead of during usual source node generation because the number of potential
# MetricGroupBy source nodes would be extremely large (and potentially slow).
for group_by_metric_spec in linkable_spec_set.group_by_metric_specs:
# TODO: handle dimensions
group_by_metric_source_node = self._build_query_output_node(
# TODO: move this logic into MetricGroupBySpec
MetricFlowQuerySpec(
metric_specs=(MetricSpec(element_name=group_by_metric_spec.element_name),),
entity_specs=tuple(
EntitySpec.from_name(group_by_link.element_name)
for group_by_link in group_by_metric_spec.group_by_links
),
)
)
candidate_nodes_for_right_side_of_join += (group_by_metric_source_node,)

default_join_type = SqlJoinType.LEFT_OUTER
else:
candidate_nodes_for_right_side_of_join = list(self._source_node_set.source_nodes_for_group_by_item_queries)
Expand Down Expand Up @@ -1329,6 +1346,8 @@ def _build_aggregated_measure_from_measure_source_node(
required_linkable_specs.as_spec_set,
)

# somehow ensure that group by metrics are in data set before it gets here. so after join to base output.
#
after_join_filtered_node = FilterElementsNode(
parent_node=filtered_measures_with_joined_elements, include_specs=specs_to_keep_after_join
)
Expand Down
34 changes: 21 additions & 13 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PartitionTimeDimensionJoinDescription,
)
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, ValidityWindowJoinDescription
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
Expand Down Expand Up @@ -82,17 +83,17 @@ def join_description(self) -> JoinDescription:
include_specs: List[LinkableInstanceSpec] = []
assert all(
[
len(spec.entity_links) > 0
len(spec.group_by_links) > 0
for spec in self.satisfiable_linkable_specs
if not spec.element_name == METRIC_TIME_ELEMENT_NAME
]
)

include_specs.extend(
[
LinklessEntitySpec.from_reference(spec.entity_links[0])
LinklessEntitySpec.from_reference(spec.group_by_links[0])
for spec in self.satisfiable_linkable_specs
if len(spec.entity_links) > 0
if len(spec.group_by_links) > 0
]
)

Expand All @@ -109,7 +110,7 @@ def join_description(self) -> JoinDescription:
# "user_id" and the "country" dimension so that it can be joined to the source node.
include_specs.extend(
[
spec.without_first_entity_link if len(spec.entity_links) > 0 else spec
spec.without_first_group_by_link if len(spec.group_by_links) > 0 else spec
for spec in self.satisfiable_linkable_specs
]
)
Expand Down Expand Up @@ -220,7 +221,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
for entity_spec_in_right_node in entity_specs_in_right_node:
# If an entity has links, what that means and whether it can be used is unclear at the moment,
# so skip it.
if len(entity_spec_in_right_node.entity_links) > 0:
if len(entity_spec_in_right_node.group_by_links) > 0:
continue

entity_instance_in_right_node = None
Expand Down Expand Up @@ -264,15 +265,17 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
].semantic_model_reference,
on_entity_reference=entity_spec_in_right_node.reference,
):
continue
# Check if it's joining to something pre-aggregated. If so, we can allow the supposed fan-out join.
if not isinstance(right_node, ComputeMetricsNode):
courtneyholcomb marked this conversation as resolved.
Show resolved Hide resolved
continue

linkless_entity_spec_in_node = LinklessEntitySpec.from_element_name(
entity_spec_in_right_node.element_name
)

satisfiable_linkable_specs = []
for needed_linkable_spec in needed_linkable_specs:
if len(needed_linkable_spec.entity_links) == 0:
if len(needed_linkable_spec.group_by_links) == 0:
assert (
needed_linkable_spec.element_name == METRIC_TIME_ELEMENT_NAME
), "Only metric_time should have 0 entity links."
Expand All @@ -288,21 +291,26 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
#
# Then the data set must contain "device_id__platform", which is realized with
#
# required_linkable_spec.remove_first_entity_link()
# required_linkable_spec.remove_first_group_by_link()
#
# We might also need to check the entity type and see if it's the type of join we're allowing,
# but since we're doing all left joins now, it's been left out.

required_entity_matches_data_set_entity = (
LinklessEntitySpec.from_reference(needed_linkable_spec.entity_links[0])
LinklessEntitySpec.from_reference(needed_linkable_spec.group_by_links[0])
== linkless_entity_spec_in_node
)
needed_linkable_spec_in_node = (
needed_linkable_spec.without_first_entity_link in linkable_specs_in_right_node
needed_linkable_spec.without_first_group_by_link in linkable_specs_in_right_node
)
if required_entity_matches_data_set_entity and needed_linkable_spec_in_node:
satisfiable_linkable_specs.append(needed_linkable_spec)

if isinstance(right_node, ComputeMetricsNode):
print(
"made it here!3",
needed_linkable_spec.without_first_group_by_link,
linkable_specs_in_right_node,
)
# If this node can satisfy some linkable specs, it could be useful to join on, so add it to the
# candidate list.
if len(satisfiable_linkable_specs) > 0:
Expand Down Expand Up @@ -406,8 +414,8 @@ def evaluate_node(
is_metric_time = required_linkable_spec.element_name == DataSet.metric_time_dimension_name()
is_local = required_linkable_spec in data_set_linkable_specs
is_unjoinable = not is_metric_time and (
len(required_linkable_spec.entity_links) == 0
or LinklessEntitySpec.from_reference(required_linkable_spec.entity_links[0])
len(required_linkable_spec.group_by_links) == 0
or LinklessEntitySpec.from_reference(required_linkable_spec.group_by_links[0])
not in data_set_linkable_specs
)
if is_local:
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_partitions(self, spec_set: InstanceSpecSet) -> PartitionSpecSet:
def _get_simplest_dimension_spec(dimension_specs: Sequence[DimensionSpec]) -> DimensionSpec:
"""Return the time dimension spec with the fewest entity links."""
assert len(dimension_specs) > 0
sorted_dimension_specs = sorted(dimension_specs, key=lambda x: len(x.entity_links))
sorted_dimension_specs = sorted(dimension_specs, key=lambda x: len(x.group_by_links))
return sorted_dimension_specs[0]

def resolve_partition_dimension_joins(
Expand Down Expand Up @@ -99,7 +99,7 @@ def resolve_partition_dimension_joins(
def _get_simplest_time_dimension_spec(time_dimension_specs: Sequence[TimeDimensionSpec]) -> TimeDimensionSpec:
"""Return the time dimension spec with the smallest granularity, then fewest entity links."""
assert len(time_dimension_specs) > 0
sorted_specs = sorted(time_dimension_specs, key=lambda x: (x.time_granularity, len(x.entity_links)))
sorted_specs = sorted(time_dimension_specs, key=lambda x: (x.time_granularity, len(x.group_by_links)))
return sorted_specs[0]

def resolve_partition_time_dimension_joins(
Expand Down
Loading
Loading