Skip to content

Commit

Permalink
Include last entity link in metric subquery entity path
Browse files Browse the repository at this point in the history
Originally I left out this last entity link because it was duplicated in the outer query entity links. Turns out this is needed when there are no entity links left in the outer query, because otherwise we won't know what entity the metric is being grouped by.
  • Loading branch information
courtneyholcomb committed Apr 26, 2024
1 parent 0a400a5 commit d5fd394
Show file tree
Hide file tree
Showing 29 changed files with 162 additions and 145 deletions.
7 changes: 6 additions & 1 deletion metricflow/model/semantics/linkable_spec_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def metric_to_entity_join_path(self) -> Optional[SemanticModelJoinPath]:
@property
def metric_subquery_entity_links(self) -> Tuple[EntityReference, ...]:
"""Entity links used to join the metric to the entity it's grouped by in the metric subquery."""
return self.metric_to_entity_join_path.entity_links if self.metric_to_entity_join_path else ()
return self.join_path.metric_subquery_join_path_element.entity_links


@dataclass(frozen=True)
Expand Down Expand Up @@ -410,6 +410,11 @@ class MetricSubqueryJoinPathElement:
join_on_entity: EntityReference
metric_to_entity_join_path: Optional[SemanticModelJoinPath] = None

@property
def entity_links(self) -> Tuple[EntityReference, ...]: # noqa: D102
join_links = self.metric_to_entity_join_path.entity_links if self.metric_to_entity_join_path else ()
return join_links + (self.join_on_entity,)


@dataclass(frozen=True)
class SemanticModelToMetricSubqueryJoinPath:
Expand Down
4 changes: 2 additions & 2 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,8 +736,8 @@ def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
defined_from=metric_instance.defined_from,
spec=GroupByMetricSpec(
element_name=metric_spec.element_name,
entity_links=(), # check this
metric_subquery_entity_links=entity_instance.spec.entity_links,
entity_links=(),
metric_subquery_entity_links=entity_instance.spec.entity_links + (entity_instance.spec.reference,),
),
)
transform_func = AddGroupByMetric(group_by_metric_instance)
Expand Down
26 changes: 13 additions & 13 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,18 @@ class GroupByMetricSpec(LinkableInstanceSpec, SerializableDataclass):
entity_links: Sequence of entities joined to join the metric subquery to the outer query. Last entity is the one
joining the subquery to the outer query.
metric_subquery_entity_links: Sequence of entities used in the metric subquery to join the metric to the entity.
Does not include the top-level entity (to avoid duplicating the last element of `entity_links`).
"""

metric_subquery_entity_links: Tuple[EntityReference, ...]

def __post_init__(self) -> None:
"""The inner query and outer query entity paths must end with the same entity (that's what they join on).
If no entity links, it's because we're already in the final joined node (no links left).
"""
if self.entity_links:
assert self.metric_subquery_entity_links[-1] == self.entity_links[-1]

@property
def without_first_entity_link(self) -> GroupByMetricSpec: # noqa: D102
assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}"
Expand All @@ -277,21 +283,15 @@ def last_entity_link(self) -> EntityReference: # noqa: D102
assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}"
return self.entity_links[-1]

@staticmethod
def from_name(name: str) -> GroupByMetricSpec: # noqa: D102
structured_name = StructuredLinkableSpecName.from_name(name)
return GroupByMetricSpec(
entity_links=tuple(EntityReference(idl) for idl in structured_name.entity_link_names),
element_name=structured_name.element_name,
metric_subquery_entity_links=(),
)

@property
def metric_subquery_entity_spec(self) -> EntitySpec:
"""Spec for the entity that the metric will be grouped by it the metric subquery."""
assert (
len(self.metric_subquery_entity_links) > 0
), "GroupByMetricSpec must have at least one metric_subquery_entity_link."
return EntitySpec(
element_name=self.last_entity_link.element_name,
entity_links=self.metric_subquery_entity_links,
element_name=self.metric_subquery_entity_links[-1].element_name,
entity_links=self.metric_subquery_entity_links[:-1],
)

def __eq__(self, other: Any) -> bool: # type: ignore[misc] # noqa: D105
Expand Down
4 changes: 3 additions & 1 deletion tests/model/test_where_filter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ def test_metric_in_filter( # noqa: D103
where_filter = PydanticWhereFilter(where_sql_template="{{ Metric('bookings', group_by=['listing']) }} > 2")

group_by_metric_spec = GroupByMetricSpec(
element_name="bookings", entity_links=(EntityReference("listing"),), metric_subquery_entity_links=()
element_name="bookings",
entity_links=(EntityReference("listing"),),
metric_subquery_entity_links=(EntityReference(element_name="listing"),),
)
where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
Expand Down
2 changes: 1 addition & 1 deletion tests/naming/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ def specs() -> Sequence[LinkableInstanceSpec]: # noqa: D103
GroupByMetricSpec(
element_name="bookings",
entity_links=(EntityReference(element_name="listing"),),
metric_subquery_entity_links=(),
metric_subquery_entity_links=(EntityReference(element_name="listing"),),
),
)
4 changes: 2 additions & 2 deletions tests/naming/test_object_builder_naming_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_input_str(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> N
GroupByMetricSpec(
element_name="bookings",
entity_links=(EntityReference(element_name="listing"),),
metric_subquery_entity_links=(),
metric_subquery_entity_links=(EntityReference(element_name="listing"),),
)
)
== "Metric('bookings', group_by=['listing'])"
Expand Down Expand Up @@ -125,6 +125,6 @@ def test_spec_pattern( # noqa: D103
GroupByMetricSpec(
element_name="bookings",
entity_links=(EntityReference(element_name="listing"),),
metric_subquery_entity_links=(),
metric_subquery_entity_links=(EntityReference(element_name="listing"),),
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
<!-- description = "Pass Only Elements: ['listings', 'listing__bookings']" -->
<!-- node_id = NodeId(id_str='pfe_3') -->
<!-- include_spec = MeasureSpec(element_name='listings') -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- ) -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- metric_subquery_entity_links=(EntityReference(element_name='listing'),), -->
<!-- ) -->
<!-- distinct = False -->
<JoinToBaseOutputNode>
<!-- description = 'Join Standard Outputs' -->
Expand Down Expand Up @@ -56,10 +57,14 @@
</MetricTimeDimensionTransformNode>
</FilterElementsNode>
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['listing', 'bookings']" -->
<!-- description = "Pass Only Elements: ['listing', 'listing__bookings']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- include_spec = LinklessEntitySpec(element_name='listing') -->
<!-- include_spec = GroupByMetricSpec(element_name='bookings') -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
<!-- metric_subquery_entity_links=(EntityReference(element_name='listing'),), -->
<!-- ) -->
<!-- distinct = False -->
<ComputeMetricsNode>
<!-- description = 'Compute Metrics via Expressions' -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
<!-- description = "Pass Only Elements: ['listings', 'listing__bookings']" -->
<!-- node_id = NodeId(id_str='pfe_3') -->
<!-- include_spec = MeasureSpec(element_name='listings') -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- ) -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
<!-- entity_links=(EntityReference(element_name='listing'),), -->
<!-- metric_subquery_entity_links=(EntityReference(element_name='listing'),), -->
<!-- ) -->
<!-- distinct = False -->
<JoinToBaseOutputNode>
<!-- description = 'Join Standard Outputs' -->
Expand Down Expand Up @@ -65,10 +66,14 @@
</MetricTimeDimensionTransformNode>
</FilterElementsNode>
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['listing', 'bookings']" -->
<!-- description = "Pass Only Elements: ['listing', 'listing__bookings']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- include_spec = LinklessEntitySpec(element_name='listing') -->
<!-- include_spec = GroupByMetricSpec(element_name='bookings') -->
<!-- include_spec = -->
<!-- GroupByMetricSpec( -->
<!-- element_name='bookings', -->
<!-- metric_subquery_entity_links=(EntityReference(element_name='listing'),), -->
<!-- ) -->
<!-- distinct = False -->
<ComputeMetricsNode>
<!-- description = 'Compute Metrics via Expressions' -->
Expand Down
Loading

0 comments on commit d5fd394

Please sign in to comment.