Skip to content

Commit

Permalink
Add GroupByMetricSpec, GroupByMetricInstance, and LinkableMetric (
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Mar 28, 2024
1 parent 48844ee commit fcb7dfa
Show file tree
Hide file tree
Showing 10 changed files with 596 additions and 88 deletions.
15 changes: 15 additions & 0 deletions metricflow/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metricflow.specs.specs import (
DimensionSpec,
EntitySpec,
GroupByMetricSpec,
InstanceSpec,
InstanceSpecSet,
MeasureSpec,
Expand Down Expand Up @@ -102,6 +103,13 @@ class EntityInstance(MdoInstance[EntitySpec], SemanticModelElementInstance): #
spec: EntitySpec


@dataclass(frozen=True)
class GroupByMetricInstance(MdoInstance[GroupByMetricSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
spec: GroupByMetricSpec
defined_from: MetricModelReference


@dataclass(frozen=True)
class MetricInstance(MdoInstance[MetricSpec], SerializableDataclass): # noqa: D101
associated_columns: Tuple[ColumnAssociation, ...]
Expand Down Expand Up @@ -143,6 +151,7 @@ class InstanceSet(SerializableDataclass):
dimension_instances: Tuple[DimensionInstance, ...] = ()
time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
entity_instances: Tuple[EntityInstance, ...] = ()
group_by_metric_instances: Tuple[GroupByMetricInstance, ...] = ()
metric_instances: Tuple[MetricInstance, ...] = ()
metadata_instances: Tuple[MetadataInstance, ...] = ()

Expand All @@ -159,6 +168,7 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
dimension_instances: List[DimensionInstance] = []
time_dimension_instances: List[TimeDimensionInstance] = []
entity_instances: List[EntityInstance] = []
group_by_metric_instances: List[GroupByMetricInstance] = []
metric_instances: List[MetricInstance] = []
metadata_instances: List[MetadataInstance] = []

Expand All @@ -175,6 +185,9 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
for entity_instance in instance_set.entity_instances:
if entity_instance.spec not in {x.spec for x in entity_instances}:
entity_instances.append(entity_instance)
for group_by_metric_instance in instance_set.group_by_metric_instances:
if group_by_metric_instance.spec not in {x.spec for x in group_by_metric_instances}:
group_by_metric_instances.append(group_by_metric_instance)
for metric_instance in instance_set.metric_instances:
if metric_instance.spec not in {x.spec for x in metric_instances}:
metric_instances.append(metric_instance)
Expand All @@ -187,6 +200,7 @@ def merge(instance_sets: List[InstanceSet]) -> InstanceSet:
dimension_instances=tuple(dimension_instances),
time_dimension_instances=tuple(time_dimension_instances),
entity_instances=tuple(entity_instances),
group_by_metric_instances=tuple(group_by_metric_instances),
metric_instances=tuple(metric_instances),
metadata_instances=tuple(metadata_instances),
)
Expand All @@ -198,6 +212,7 @@ def spec_set(self) -> InstanceSpecSet: # noqa: D102
dimension_specs=tuple(x.spec for x in self.dimension_instances),
time_dimension_specs=tuple(x.spec for x in self.time_dimension_instances),
entity_specs=tuple(x.spec for x in self.entity_instances),
group_by_metric_specs=tuple(x.spec for x in self.group_by_metric_instances),
metric_specs=tuple(x.spec for x in self.metric_instances),
metadata_specs=tuple(x.spec for x in self.metadata_instances),
)
2 changes: 2 additions & 0 deletions metricflow/model/semantics/linkable_element_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class LinkableElementProperties(Enum):
ENTITY = "entity"
# See metric_time in DataSet
METRIC_TIME = "metric_time"
# Refers to a metric, not a dimension.
METRIC = "metric"

@staticmethod
def all_properties() -> FrozenSet[LinkableElementProperties]: # noqa: D102
Expand Down
281 changes: 196 additions & 85 deletions metricflow/model/semantics/linkable_spec_resolver.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions metricflow/plan_conversion/column_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from metricflow.specs.specs import (
DimensionSpec,
EntitySpec,
GroupByMetricSpec,
InstanceSpec,
InstanceSpecVisitor,
MeasureSpec,
Expand Down Expand Up @@ -77,6 +78,15 @@ def visit_entity_spec(self, entity_spec: EntitySpec) -> ColumnAssociation: # no
single_column_correlation_key=SingleColumnCorrelationKey(),
)

def visit_group_by_metric_spec(self, group_by_metric_spec: GroupByMetricSpec) -> ColumnAssociation: # noqa: D102
return ColumnAssociation(
column_name=StructuredLinkableSpecName(
entity_link_names=tuple(x.element_name for x in group_by_metric_spec.entity_links),
element_name=group_by_metric_spec.element_name,
).qualified_name,
single_column_correlation_key=SingleColumnCorrelationKey(),
)

def visit_metadata_spec(self, metadata_spec: MetadataSpec) -> ColumnAssociation: # noqa: D102
return ColumnAssociation(
column_name=metadata_spec.qualified_name,
Expand Down
80 changes: 78 additions & 2 deletions metricflow/specs/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> V
def visit_entity_spec(self, entity_spec: EntitySpec) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_group_by_metric_spec(self, group_by_metric_spec: GroupByMetricSpec) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_metric_spec(self, metric_spec: MetricSpec) -> VisitorOutputT: # noqa: D102
raise NotImplementedError
Expand Down Expand Up @@ -236,6 +240,48 @@ def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT
return visitor.visit_entity_spec(self)


@dataclass(frozen=True)
class GroupByMetricSpec(LinkableInstanceSpec, SerializableDataclass):
"""Metric used in group by or where filter."""

@property
def without_first_entity_link(self) -> GroupByMetricSpec: # noqa: D102
assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}"
return GroupByMetricSpec(element_name=self.element_name, entity_links=self.entity_links[1:])

@property
def without_entity_links(self) -> GroupByMetricSpec: # noqa: D102
return GroupByMetricSpec(element_name=self.element_name, entity_links=())

@staticmethod
def from_name(name: str) -> GroupByMetricSpec: # noqa: D102
structured_name = StructuredLinkableSpecName.from_name(name)
return GroupByMetricSpec(
entity_links=tuple(EntityReference(idl) for idl in structured_name.entity_link_names),
element_name=structured_name.element_name,
)

def __eq__(self, other: Any) -> bool: # type: ignore[misc] # noqa: D105
if not isinstance(other, GroupByMetricSpec):
return False
return self.element_name == other.element_name and self.entity_links == other.entity_links

def __hash__(self) -> int: # noqa: D105
return hash((self.element_name, self.entity_links))

@property
def reference(self) -> MetricReference: # noqa: D102
return MetricReference(element_name=self.element_name)

@property
@override
def as_spec_set(self) -> InstanceSpecSet:
return InstanceSpecSet(group_by_metric_specs=(self,))

def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_group_by_metric_spec(self)


@dataclass(frozen=True)
class LinklessEntitySpec(EntitySpec, SerializableDataclass):
"""Similar to EntitySpec, but requires that it doesn't have entity links."""
Expand Down Expand Up @@ -656,6 +702,7 @@ class LinkableSpecSet(Mergeable, SerializableDataclass):
dimension_specs: Tuple[DimensionSpec, ...] = ()
time_dimension_specs: Tuple[TimeDimensionSpec, ...] = ()
entity_specs: Tuple[EntitySpec, ...] = ()
group_by_metric_specs: Tuple[GroupByMetricSpec, ...] = ()

@property
def contains_metric_time(self) -> bool:
Expand Down Expand Up @@ -701,14 +748,19 @@ def metric_time_specs(self) -> Sequence[TimeDimensionSpec]:

@property
def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102
return tuple(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs))
return tuple(
itertools.chain(
self.dimension_specs, self.time_dimension_specs, self.entity_specs, self.group_by_metric_specs
)
)

@override
def merge(self, other: LinkableSpecSet) -> LinkableSpecSet:
return LinkableSpecSet(
dimension_specs=self.dimension_specs + other.dimension_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs,
)

@override
Expand All @@ -731,10 +783,15 @@ def dedupe(self) -> LinkableSpecSet: # noqa: D102
for entity_spec in self.entity_specs:
entity_spec_dict[entity_spec] = None

group_by_metric_spec_dict: Dict[GroupByMetricSpec, None] = {}
for group_by_metric in self.group_by_metric_specs:
group_by_metric_spec_dict[group_by_metric] = None

return LinkableSpecSet(
dimension_specs=tuple(dimension_spec_dict.keys()),
time_dimension_specs=tuple(time_dimension_spec_dict.keys()),
entity_specs=tuple(entity_spec_dict.keys()),
group_by_metric_specs=tuple(group_by_metric_spec_dict.keys()),
)

def is_subset_of(self, other_set: LinkableSpecSet) -> bool: # noqa: D102
Expand All @@ -746,13 +803,15 @@ def as_spec_set(self) -> InstanceSpecSet: # noqa: D102
dimension_specs=self.dimension_specs,
time_dimension_specs=self.time_dimension_specs,
entity_specs=self.entity_specs,
group_by_metric_specs=self.group_by_metric_specs,
)

def difference(self, other: LinkableSpecSet) -> LinkableSpecSet: # noqa: D102
return LinkableSpecSet(
dimension_specs=tuple(set(self.dimension_specs) - set(other.dimension_specs)),
time_dimension_specs=tuple(set(self.time_dimension_specs) - set(other.time_dimension_specs)),
entity_specs=tuple(set(self.entity_specs) - set(other.entity_specs)),
group_by_metric_specs=tuple(set(self.group_by_metric_specs) - set(other.group_by_metric_specs)),
)

def __len__(self) -> int: # noqa: D105
Expand All @@ -765,6 +824,7 @@ def from_specs(specs: Sequence[LinkableInstanceSpec]) -> LinkableSpecSet: # noq
dimension_specs=instance_spec_set.dimension_specs,
time_dimension_specs=instance_spec_set.time_dimension_specs,
entity_specs=instance_spec_set.entity_specs,
group_by_metric_specs=instance_spec_set.group_by_metric_specs,
)


Expand All @@ -776,6 +836,7 @@ class MetricFlowQuerySpec(SerializableDataclass):
dimension_specs: Tuple[DimensionSpec, ...] = ()
entity_specs: Tuple[EntitySpec, ...] = ()
time_dimension_specs: Tuple[TimeDimensionSpec, ...] = ()
group_by_metric_specs: Tuple[GroupByMetricSpec, ...] = ()
order_by_specs: Tuple[OrderBySpec, ...] = ()
time_range_constraint: Optional[TimeRangeConstraint] = None
limit: Optional[int] = None
Expand All @@ -789,6 +850,7 @@ def linkable_specs(self) -> LinkableSpecSet: # noqa: D102
dimension_specs=self.dimension_specs,
time_dimension_specs=self.time_dimension_specs,
entity_specs=self.entity_specs,
group_by_metric_specs=self.group_by_metric_specs,
)

def with_time_range_constraint(self, time_range_constraint: Optional[TimeRangeConstraint]) -> MetricFlowQuerySpec:
Expand All @@ -798,6 +860,7 @@ def with_time_range_constraint(self, time_range_constraint: Optional[TimeRangeCo
dimension_specs=self.dimension_specs,
entity_specs=self.entity_specs,
time_dimension_specs=self.time_dimension_specs,
group_by_metric_specs=self.group_by_metric_specs,
order_by_specs=self.order_by_specs,
time_range_constraint=time_range_constraint,
limit=self.limit,
Expand Down Expand Up @@ -826,6 +889,7 @@ class InstanceSpecSet(Mergeable, SerializableDataclass):
dimension_specs: Tuple[DimensionSpec, ...] = ()
entity_specs: Tuple[EntitySpec, ...] = ()
time_dimension_specs: Tuple[TimeDimensionSpec, ...] = ()
group_by_metric_specs: Tuple[GroupByMetricSpec, ...] = ()
metadata_specs: Tuple[MetadataSpec, ...] = ()

@override
Expand All @@ -835,6 +899,7 @@ def merge(self, other: InstanceSpecSet) -> InstanceSpecSet:
measure_specs=self.measure_specs + other.measure_specs,
dimension_specs=self.dimension_specs + other.dimension_specs,
entity_specs=self.entity_specs + other.entity_specs,
group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs,
time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs,
metadata_specs=self.metadata_specs + other.metadata_specs,
)
Expand Down Expand Up @@ -874,18 +939,28 @@ def dedupe(self) -> InstanceSpecSet:
if entity_spec not in entity_specs_deduped:
entity_specs_deduped.append(entity_spec)

group_by_metric_specs_deduped = []
for group_by_metric_spec in self.group_by_metric_specs:
if group_by_metric_spec not in group_by_metric_specs_deduped:
group_by_metric_specs_deduped.append(group_by_metric_spec)

return InstanceSpecSet(
metric_specs=tuple(metric_specs_deduped),
measure_specs=tuple(measure_specs_deduped),
dimension_specs=tuple(dimension_specs_deduped),
time_dimension_specs=tuple(time_dimension_specs_deduped),
entity_specs=tuple(entity_specs_deduped),
group_by_metric_specs=tuple(group_by_metric_specs_deduped),
)

@property
def linkable_specs(self) -> Sequence[LinkableInstanceSpec]:
"""All linkable specs in this set."""
return list(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs))
return list(
itertools.chain(
self.dimension_specs, self.time_dimension_specs, self.entity_specs, self.group_by_metric_specs
)
)

@property
def all_specs(self) -> Sequence[InstanceSpec]: # noqa: D102
Expand All @@ -895,6 +970,7 @@ def all_specs(self) -> Sequence[InstanceSpec]: # noqa: D102
self.dimension_specs,
self.time_dimension_specs,
self.entity_specs,
self.group_by_metric_specs,
self.metric_specs,
self.metadata_specs,
)
Expand Down
50 changes: 49 additions & 1 deletion tests/model/semantics/test_linkable_spec_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import pytest
from _pytest.fixtures import FixtureRequest
from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.references import EntityReference, MetricReference, SemanticModelReference

from metricflow.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow.model.semantics.linkable_element_properties import LinkableElementProperties
from metricflow.model.semantics.linkable_spec_resolver import (
SemanticModelJoinPath,
SemanticModelJoinPathElement,
ValidLinkableSpecResolver,
)
from metricflow.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS
Expand Down Expand Up @@ -123,3 +125,49 @@ def test_cyclic_join_manifest( # noqa: D103
without_any_of=frozenset(),
),
)


def test_create_linkable_element_set_from_join_path( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
simple_model_spec_resolver: ValidLinkableSpecResolver,
) -> None:
assert_linkable_element_set_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
set_id="result0",
linkable_element_set=simple_model_spec_resolver.create_linkable_element_set_from_join_path(
join_path=SemanticModelJoinPath(
path_elements=(
SemanticModelJoinPathElement(
semantic_model_reference=SemanticModelReference("listings_latest"),
join_on_entity=EntityReference("listing"),
),
)
),
with_properties=frozenset({LinkableElementProperties.JOINED}),
),
)


def test_create_linkable_element_set_from_join_path_multi_hop( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
simple_model_spec_resolver: ValidLinkableSpecResolver,
) -> None:
assert_linkable_element_set_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
set_id="result0",
linkable_element_set=simple_model_spec_resolver.create_linkable_element_set_from_join_path(
join_path=SemanticModelJoinPath(
path_elements=(
SemanticModelJoinPathElement(
semantic_model_reference=SemanticModelReference("listings_latest"),
join_on_entity=EntityReference("listing"),
),
)
),
with_properties=frozenset({LinkableElementProperties.JOINED, LinkableElementProperties.MULTI_HOP}),
),
)
15 changes: 15 additions & 0 deletions tests/snapshot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,21 @@ def assert_linkable_element_set_snapshot_equal( # noqa: D103
sorted(linkable_element_property.name for linkable_element_property in linkable_entity.properties),
)
)

for linkable_metric_iterable in linkable_element_set.path_key_to_linkable_metrics.values():
for linkable_metric in linkable_metric_iterable:
rows.append(
(
# Checking a limited set of fields as the result is large due to the paths in the object.
linkable_metric.join_by_semantic_model.semantic_model_name,
tuple(entity_link.element_name for entity_link in linkable_entity.entity_links),
linkable_metric.element_name,
"",
"",
sorted(linkable_element_property.name for linkable_element_property in linkable_metric.properties),
)
)

assert_str_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
Expand Down
Loading

0 comments on commit fcb7dfa

Please sign in to comment.