Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 18, 2024
1 parent d04c4ee commit fc66910
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 44 deletions.
12 changes: 7 additions & 5 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,16 +704,18 @@ def _contains_multihop_linkables(linkable_specs: Sequence[LinkableInstanceSpec])
"""Returns true if any of the linkable specs requires a multi-hop join to realize."""
return any(len(x.entity_links) > 1 for x in linkable_specs)

def _get_semantic_model_names_for_measures(self, measure_names: Sequence[MeasureSpec]) -> Set[str]:
def _get_semantic_model_names_for_measures(self, measures: Sequence[MeasureSpec]) -> Set[str]:
"""Return the names of the semantic models needed to compute the input measures.
This is a temporary method for use in assertion boundaries while we implement support for multiple semantic models
"""
semantic_model_names: Set[str] = set()
for measure_name in measure_names:
semantic_model_names = semantic_model_names.union(
{d.name for d in self._semantic_model_lookup.get_semantic_models_for_measure(measure_name.reference)}
)
for measure in measures:
semantic_model = self._semantic_model_lookup.get_semantic_model_for_measure(measure.reference)
if not semantic_model:
raise ValueError(f"Could not find measure with name {measure.reference} in configured semantic models.")
semantic_model_names.add(semantic_model.name)

return semantic_model_names

def _sort_by_suitability(self, nodes: Sequence[BaseOutput]) -> Sequence[BaseOutput]:
Expand Down
17 changes: 7 additions & 10 deletions metricflow/model/semantics/semantic_model_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def __init__( # noqa: D
model: SemanticManifest,
) -> None:
self._model = model
self._measure_index: Dict[MeasureReference, List[SemanticModel]] = defaultdict(list)
self._measure_index: Dict[MeasureReference, SemanticModel] = {}
self._measure_aggs: Dict[
MeasureReference, AggregationType
] = {} # maps measures to their one consistent aggregation
self._measure_agg_time_dimension: Dict[MeasureReference, TimeDimensionReference] = {}
self._measure_non_additive_dimension_specs: Dict[MeasureReference, NonAdditiveDimensionSpec] = {}
# TODO: remove defaultdicts. Will add fake elements to the dict that might get referenced.
self._dimension_index: Dict[DimensionReference, List[SemanticModel]] = defaultdict(list)
self._linkable_reference_index: Dict[LinkableElementReference, List[SemanticModel]] = defaultdict(list)
self._entity_index: Dict[Optional[str], List[SemanticModel]] = defaultdict(list)
Expand Down Expand Up @@ -141,12 +142,10 @@ def get_measure_from_semantic_model(semantic_model: SemanticModel, measure_refer
)

def get_measure(self, measure_reference: MeasureReference) -> Measure: # noqa: D
if measure_reference not in self._measure_index:
semantic_model = self._measure_index.get(measure_reference)
if not semantic_model:
raise ValueError(f"Could not find measure with name ({measure_reference}) in configured semantic models")

assert len(self._measure_index[measure_reference]) >= 1
# Measures should be consistent across semantic models, so just use the first one.
semantic_model = list(self._measure_index[measure_reference])[0]
return SemanticModelLookup.get_measure_from_semantic_model(
semantic_model=semantic_model, measure_reference=measure_reference
)
Expand All @@ -155,10 +154,8 @@ def get_entity_references(self) -> Sequence[EntityReference]: # noqa: D
return list(self._entity_ref_to_entity.keys())

# DSC interface
def get_semantic_models_for_measure( # noqa: D
self, measure_reference: MeasureReference
) -> Sequence[SemanticModel]:
return self._measure_index[measure_reference]
def get_semantic_model_for_measure(self, measure_reference: MeasureReference) -> Optional[SemanticModel]: # noqa: D
return self._measure_index.get(measure_reference)

def get_agg_time_dimension_for_measure( # noqa: D
self, measure_reference: MeasureReference
Expand Down Expand Up @@ -202,7 +199,7 @@ def _add_semantic_model(self, semantic_model: SemanticModel) -> None:

for measure in semantic_model.measures:
self._measure_aggs[measure.reference] = measure.agg
self._measure_index[measure.reference].append(semantic_model)
self._measure_index[measure.reference] = semantic_model
agg_time_dimension_reference = semantic_model.checked_agg_time_dimension_for_measure(measure.reference)

matching_dimensions = tuple(
Expand Down
2 changes: 1 addition & 1 deletion metricflow/protocols/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_entity_references(self) -> Sequence[EntityReference]:
raise NotImplementedError

@abstractmethod
def get_semantic_models_for_measure(self, measure_reference: MeasureReference) -> Sequence[SemanticModel]:
def get_semantic_model_for_measure(self, measure_reference: MeasureReference) -> Optional[SemanticModel]:
"""Retrieve a list of all semantic model model objects associated with the measure reference."""
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ class CumulativeMetricRequiresMetricTimeIssue(MetricFlowQueryResolutionIssue):
def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str:
return (
f"The query includes a cumulative metric {repr(self.metric_reference.element_name)} that does not "
f"accumulate over all-time, but the group-by items do not include {repr(METRIC_TIME_ELEMENT_NAME)}"
f"accumulate over all-time, but the group-by items do not include {repr(METRIC_TIME_ELEMENT_NAME)} "
"or the metric's agg_time_dimension."
# TODO: add name of agg_time_dim?
)

@override
Expand Down
71 changes: 57 additions & 14 deletions metricflow/query/validation_rules/metric_time_requirements.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from typing import List, Sequence
from typing import List, Sequence, Tuple

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.references import MetricReference
from dbt_semantic_interfaces.references import EntityReference, MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override
Expand Down Expand Up @@ -34,29 +34,36 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D
super().__init__(manifest_lookup=manifest_lookup)

metric_time_specs: List[TimeDimensionSpec] = []
self._metric_time_specs = tuple(
self._generate_valid_specs_for_time_dimension(
time_dimension_reference=TimeDimensionReference(element_name=METRIC_TIME_ELEMENT_NAME), entity_links=()
)
)

def _generate_valid_specs_for_time_dimension(
self, time_dimension_reference: TimeDimensionReference, entity_links: Tuple[EntityReference, ...]
) -> List[TimeDimensionSpec]:
time_dimension_specs: List[TimeDimensionSpec] = []
for time_granularity in TimeGranularity:
metric_time_specs.append(
time_dimension_specs.append(
TimeDimensionSpec(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
element_name=time_dimension_reference.element_name,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=None,
)
)
for date_part in DatePart:
for time_granularity in date_part.compatible_granularities:
metric_time_specs.append(
time_dimension_specs.append(
TimeDimensionSpec(
element_name=METRIC_TIME_ELEMENT_NAME,
entity_links=(),
element_name=time_dimension_reference.element_name,
entity_links=entity_links,
time_granularity=time_granularity,
date_part=date_part,
)
)

self._metric_time_specs = tuple(metric_time_specs)
return time_dimension_specs

def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool:
for group_by_item_input in query_resolver_input.group_by_item_inputs:
Expand All @@ -65,6 +72,38 @@ def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInpu

return False

def _group_by_items_include_agg_time_dimension(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> bool:
metric = self._manifest_lookup.metric_lookup.get_metric(metric_reference=metric_reference)
semantic_model_lookup = self._manifest_lookup.semantic_model_lookup

valid_agg_time_dimension_specs: List[TimeDimensionSpec] = []
for measure_reference in metric.measure_references:
agg_time_dimension_reference = semantic_model_lookup.get_agg_time_dimension_for_measure(measure_reference)
semantic_model = semantic_model_lookup.get_semantic_model_for_measure(measure_reference)
assert semantic_model, f"No semantic model found for measure {measure_reference}."

# is this too broad? need to narrow entity links?
possible_entity_links = semantic_model_lookup.entity_links_for_local_elements(semantic_model)
for entity_link in possible_entity_links:
valid_agg_time_dimension_specs.extend(
self._generate_valid_specs_for_time_dimension(
time_dimension_reference=agg_time_dimension_reference, entity_links=(entity_link,)
)
)
print("valid:::")
for x in valid_agg_time_dimension_specs:
print(f"\n{x}")

print("requested:::")
for group_by_item_input in query_resolver_input.group_by_item_inputs:
print(print(f"\n{group_by_item_input}"))
if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs):
return True

return False

@override
def validate_metric_in_resolution_dag(
self,
Expand All @@ -73,15 +112,19 @@ def validate_metric_in_resolution_dag(
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
metric = self._get_metric(metric_reference)
query_includes_metric_time = self._group_by_items_include_metric_time(resolver_input_for_query)
query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time(
resolver_input_for_query
) or self._group_by_items_include_agg_time_dimension(
query_resolver_input=resolver_input_for_query, metric_reference=metric_reference
)

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
return MetricFlowQueryResolutionIssueSet.empty_instance()
elif metric.type is MetricType.CUMULATIVE:
if (
metric.type_params is not None
and (metric.type_params.window is not None or metric.type_params.grain_to_date is not None)
and not query_includes_metric_time
and not query_includes_metric_time_or_agg_time_dimension
):
return MetricFlowQueryResolutionIssueSet.from_issue(
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
Expand All @@ -97,7 +140,7 @@ def validate_metric_in_resolution_dag(
for input_metric in metric.input_metrics
)

if has_time_offset and not query_includes_metric_time:
if has_time_offset and not query_includes_metric_time_or_agg_time_dimension:
return MetricFlowQueryResolutionIssueSet.from_issue(
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,31 @@ integration_test:
OR {{ render_time_constraint("a.ds", "2020-01-04", "2020-01-04") }}
GROUP BY a.ds
ORDER BY a.ds
---
integration_test:
name: cumulative_metric_with_agg_time_dimension
description: Query a cumulative metric with its agg_time_dimension.
model: SIMPLE_MODEL
metrics: ["trailing_2_months_revenue"]
group_bys: ["company__ds__day"]
order_bys: ["company__ds__day"]
time_constraint: ["2020-03-05", "2021-01-04"]
check_query: |
SELECT
SUM(b.txn_revenue) as trailing_2_months_revenue
, a.ds AS company__ds__day
FROM (
SELECT ds
FROM {{ mf_time_spine_source }}
WHERE {{ render_time_constraint("ds", "2020-01-05", "2021-01-04") }}
) a
INNER JOIN (
SELECT
revenue as txn_revenue
, created_at AS ds
FROM {{ source_schema }}.fct_revenue
) b
ON b.ds <= a.ds AND b.ds > {{ render_date_sub("a", "ds", 2, TimeGranularity.MONTH) }}
WHERE {{ render_time_constraint("a.ds", "2020-03-05", "2021-01-04") }}
GROUP BY a.ds
ORDER BY a.ds
3 changes: 2 additions & 1 deletion metricflow/test/integration/test_configured_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def filter_not_supported_features(

@pytest.mark.parametrize(
"name",
CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names,
# CONFIGURED_INTEGRATION_TESTS_REPOSITORY.all_test_case_names,
["itest_cumulative_metric.yaml/cumulative_metric_with_agg_time_dimension"],
ids=lambda name: f"name={name}",
)
def test_case(
Expand Down
24 changes: 12 additions & 12 deletions metricflow/test/model/test_semantic_model_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ def test_get_elements(semantic_model_lookup: SemanticModelLookup) -> None: # no
assert semantic_model_lookup.get_measure(measure_reference=measure_reference).reference == measure_reference


def test_get_semantic_models_for_measure(semantic_model_lookup: SemanticModelLookup) -> None: # noqa: D
bookings_sources = semantic_model_lookup.get_semantic_models_for_measure(MeasureReference(element_name="bookings"))
assert len(bookings_sources) == 1
assert bookings_sources[0].name == "bookings_source"

views_sources = semantic_model_lookup.get_semantic_models_for_measure(MeasureReference(element_name="views"))
assert len(views_sources) == 1
assert views_sources[0].name == "views_source"

listings_sources = semantic_model_lookup.get_semantic_models_for_measure(MeasureReference(element_name="listings"))
assert len(listings_sources) == 1
assert listings_sources[0].name == "listings_latest"
def test_get_semantic_model_for_measure(semantic_model_lookup: SemanticModelLookup) -> None: # noqa: D
bookings_source = semantic_model_lookup.get_semantic_model_for_measure(MeasureReference(element_name="bookings"))
assert bookings_source
assert bookings_source.name == "bookings_source"

views_source = semantic_model_lookup.get_semantic_model_for_measure(MeasureReference(element_name="views"))
assert views_source
assert views_source.name == "views_source"

listings_source = semantic_model_lookup.get_semantic_model_for_measure(MeasureReference(element_name="listings"))
assert listings_source
assert listings_source.name == "listings_latest"


def test_elements_for_metric( # noqa: D
Expand Down

0 comments on commit fc66910

Please sign in to comment.