Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Apr 23, 2024
1 parent 0e6830a commit a599a60
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 39 deletions.
8 changes: 6 additions & 2 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def join_description(self) -> JoinDescription:
[self.validity_window.window_start_dimension, self.validity_window.window_end_dimension]
)

# `satisfiable_linkable_specs` describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# `satisfiable_linkable_specs` describes what can be satisfied after the join, so remove the first
# entity link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the source node.
include_specs.extend(
Expand All @@ -114,6 +114,7 @@ def join_description(self) -> JoinDescription:
for spec in self.satisfiable_linkable_specs
]
)
# What does this look like for multi-hop?
filtered_node_to_join = FilterElementsNode(
parent_node=self.node_to_join, include_specs=InstanceSpecSet.from_specs(include_specs)
)
Expand Down Expand Up @@ -211,6 +212,8 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(

data_set_in_right_node: SqlDataSet = self._node_data_set_resolver.get_output_data_set(right_node)
linkable_specs_in_right_node = data_set_in_right_node.instance_set.spec_set.linkable_specs
if isinstance(right_node, ComputeMetricsNode):
print("linkable_specs_in_right_node:", linkable_specs_in_right_node)
entity_specs_in_right_node = data_set_in_right_node.instance_set.spec_set.entity_specs

# For each unlinked entity in the data set, create a candidate for joining.
Expand Down Expand Up @@ -426,6 +429,7 @@ def evaluate_node(
needed_linkable_specs=possibly_joinable_linkable_specs,
default_join_type=default_join_type,
)
print("candidates_for_join:", candidates_for_join)
join_candidates: List[JoinLinkableInstancesRecipe] = []

logger.info("Looping over nodes that can be joined to get the required linkable specs")
Expand Down
1 change: 1 addition & 0 deletions metricflow/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def associated_column(self) -> ColumnAssociation:
@dataclass(frozen=True)
class SemanticModelElementInstance(SerializableDataclass): # noqa: D101
# This instance is derived from something defined in a semantic model.
# TODO in separate PR: make defined_from not a tuple...
defined_from: Tuple[SemanticModelElementReference, ...]

@property
Expand Down
1 change: 1 addition & 0 deletions metricflow/model/semantics/linkable_spec_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def as_spec_set(self) -> LinkableSpecSet: # noqa: D102
group_by_metric_specs=tuple(
GroupByMetricSpec(
element_name=path_key.element_name,
# TODO: update here??
entity_links=path_key.entity_links,
)
for path_key in self.path_key_to_linkable_metrics
Expand Down
6 changes: 4 additions & 2 deletions metricflow/model/semantics/semantic_model_join_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ def _semantic_model_of_entity_in_instance_set(
"""Return the semantic model where the entity was defined in the instance set."""
matching_instances: List[EntityInstance] = []
for entity_instance in instance_set.entity_instances:
assert len(entity_instance.defined_from) == 1
if len(entity_instance.spec.entity_links) == 0 and entity_instance.spec.reference == entity_reference:
# assert len(entity_instance.defined_from) == 1
# why do we need this??
# if len(entity_instance.spec.entity_links) == 0 and entity_instance.spec.reference == entity_reference:
if entity_instance.spec.reference == entity_reference:
matching_instances.append(entity_instance)

assert len(matching_instances) == 1, (
Expand Down
20 changes: 11 additions & 9 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,14 +723,16 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
spec=metric_spec,
)
)
# for entity_instance in output_instance_set.entity_instances:
# add entities here? or remove from check?
group_by_metric_instances.append(
GroupByMetricInstance(
associated_columns=(output_column_association,),
defined_from=MetricModelReference(metric_name=metric_spec.element_name),
spec=GroupByMetricSpec(element_name=metric_spec.element_name, entity_links=()),
)
)

print("adding group by metrics:", group_by_metric_instances)
transform_func = (
AddGroupByMetrics(group_by_metric_instances)
if node.for_group_by_source_node
Expand Down Expand Up @@ -1119,9 +1121,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down Expand Up @@ -1364,11 +1366,11 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
and len(time_spine_dataset.checked_sql_select_node.select_columns) == 1
), "Time spine dataset not configured properly. Expected exactly one column."
original_time_spine_dim_instance = time_spine_dataset.instance_set.time_dimension_instances[0]
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
time_spine_column_select_expr: Union[SqlColumnReferenceExpression, SqlDateTruncExpression] = (
SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
)
)

Expand Down
52 changes: 26 additions & 26 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
)
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinToBaseOutputNode
Expand Down Expand Up @@ -116,49 +117,51 @@ def add_time_range_constraint(
processed_nodes.append(source_node)
return processed_nodes

def _node_contains_entity(
def _node_contains_entities(
self,
node: BaseOutput,
entity_reference: EntityReference,
entity_references: Set[EntityReference],
) -> bool:
"""Returns true if the output of the node contains an entity of the given types."""
data_set = self._node_data_set_resolver.get_output_data_set(node)

for entity_instance_in_first_node in data_set.instance_set.entity_instances:
entity_spec_in_first_node = entity_instance_in_first_node.spec
found_entity_references = set()
if isinstance(node, ComputeMetricsNode):
print("node has entities:", data_set.instance_set.entity_instances)
for entity_instance in data_set.instance_set.entity_instances:
entity_spec = entity_instance.spec

if entity_spec_in_first_node.reference != entity_reference:
if entity_spec.reference not in entity_references:
continue

if len(entity_spec_in_first_node.entity_links) > 0:
continue

assert (
len(entity_instance_in_first_node.defined_from) == 1
), "Multiple items in defined_from not yet supported"
# if len(entity_spec.entity_links) > 0: # why is this needed?
# continue

entity = self._semantic_model_lookup.get_entity_in_semantic_model(
entity_instance_in_first_node.defined_from[0]
)
assert len(entity_instance.defined_from) == 1, "Multiple items in defined_from not yet supported"
semantic_model = entity_instance.defined_from[0]
entity = self._semantic_model_lookup.get_entity_in_semantic_model(semantic_model)
if entity is None:
raise RuntimeError(
f"Invalid SemanticModelElementReference {entity_instance_in_first_node.defined_from[0]}"
f"Invalid SemanticModelElementReference {semantic_model} for entity {entity_spec.reference}"
)

return True
found_entity_references.add(entity_spec.reference)

return False
return found_entity_references == entity_references

def _get_candidates_nodes_for_multi_hop(
self, desired_linkable_spec: LinkableInstanceSpec, nodes: Sequence[BaseOutput], join_type: SqlJoinType
) -> Sequence[MultiHopJoinCandidate]:
"""Assemble nodes representing all possible one-hop joins."""
if len(desired_linkable_spec.entity_links) > MAX_JOIN_HOPS:
# TODO: update this error to have one more for group by metrics
raise NotImplementedError(
f"Multi-hop joins with more than {MAX_JOIN_HOPS} entity links not yet supported. "
f"Got: {desired_linkable_spec}"
)
if len(desired_linkable_spec.entity_links) != 2:

# If this linkable spec doesn't require multiple hops, skip it.
if len(desired_linkable_spec.entity_links) < 2:
return ()

multi_hop_join_candidates: List[MultiHopJoinCandidate] = []
Expand All @@ -171,22 +174,19 @@ def _get_candidates_nodes_for_multi_hop(

# When joining on the entity, the first node needs the first and second entity links.
if not (
self._node_contains_entity(
node=first_node_that_could_be_joined,
entity_reference=desired_linkable_spec.entity_links[0],
)
and self._node_contains_entity(
self._node_contains_entities(
node=first_node_that_could_be_joined,
entity_reference=desired_linkable_spec.entity_links[1],
entity_references=set(desired_linkable_spec.entity_links[:2]),
)
):
continue

for second_node_that_could_be_joined in nodes:
# print("names:", element_names_in_data_set, desired_linkable_spec.element_name)
if not (
self._node_contains_entity(
self._node_contains_entities(
node=second_node_that_could_be_joined,
entity_reference=desired_linkable_spec.entity_links[1],
entity_references={desired_linkable_spec.entity_links[1]},
)
):
continue
Expand Down

0 comments on commit a599a60

Please sign in to comment.