Skip to content

Commit

Permalink
/* PR_START p--short-term-perf 28 */ Use a measure-based validation f…
Browse files Browse the repository at this point in the history
…or SCDs.
  • Loading branch information
plypaul committed Oct 28, 2024
1 parent 5a5c7fa commit 01b4b22
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,39 @@ def filter_by_spec_patterns(self, spec_patterns: Sequence[SpecPattern]) -> Linka
)
logger.debug(LazyFormat(lambda: f"Filtering valid linkable elements took: {time.time() - start_time:.2f}s"))
return filtered_elements

def filter_by_left_semantic_model(
self, left_semantic_model_reference: SemanticModelReference
) -> LinkableElementSet:
"""Return a `LinkableElementSet` with only elements that have the given left semantic model in the join path."""
path_key_to_linkable_dimensions: Dict[ElementPathKey, Tuple[LinkableDimension, ...]] = {}
path_key_to_linkable_entities: Dict[ElementPathKey, Tuple[LinkableEntity, ...]] = {}
path_key_to_linkable_metrics: Dict[ElementPathKey, Tuple[LinkableMetric, ...]] = {}

for path_key, linkable_dimensions in self.path_key_to_linkable_dimensions.items():
path_key_to_linkable_dimensions[path_key] = tuple(
linkable_dimension
for linkable_dimension in linkable_dimensions
if linkable_dimension.join_path.left_semantic_model_reference == left_semantic_model_reference
)

for path_key, linkable_entities in self.path_key_to_linkable_entities.items():
path_key_to_linkable_entities[path_key] = tuple(
linkable_entity
for linkable_entity in linkable_entities
if linkable_entity.join_path.left_semantic_model_reference == left_semantic_model_reference
)

for path_key, linkable_metrics in self.path_key_to_linkable_metrics.items():
path_key_to_linkable_metrics[path_key] = tuple(
linkable_metric
for linkable_metric in linkable_metrics
if linkable_metric.join_path.semantic_model_join_path.left_semantic_model_reference
== left_semantic_model_reference
)

return LinkableElementSet(
path_key_to_linkable_dimensions=path_key_to_linkable_dimensions,
path_key_to_linkable_entities=path_key_to_linkable_entities,
path_key_to_linkable_metrics=path_key_to_linkable_metrics,
)
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from typing import List, Sequence, Tuple
from typing import List, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import Metric, WhereFilterIntersection
from dbt_semantic_interfaces.references import (
MeasureReference,
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import MetricType
from typing_extensions import override

from metricflow_semantics.collection_helpers.lru_cache import LruCache
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.model.semantics.element_filter import LinkableElementFilter
from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryResolutionIssue,
Expand All @@ -33,22 +32,11 @@
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

if typing.TYPE_CHECKING:
from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult


@dataclass(frozen=True)
class QueryItemsAnalysis:
"""Contains data about which items a query contains."""

scds: Sequence[InstanceSpec]
has_metric_time: bool
has_agg_time_dimension: bool


class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
"""Validates cases where a query requires metric_time to be specified as a group-by-item.
Expand All @@ -71,73 +59,44 @@ def __init__( # noqa: D107
resolve_group_by_item_result=resolve_group_by_item_result,
)

self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
time_dimension_reference=TimeDimensionReference(element_name=METRIC_TIME_ELEMENT_NAME),
entity_links=(),
custom_granularities=self._manifest_lookup.custom_granularities,
)
)
self._query_items_analysis_cache: LruCache[
Tuple[ResolverInputForQuery, MetricReference], QueryItemsAnalysis
] = LruCache(128)

def _get_query_items_analysis(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> QueryItemsAnalysis:
cache_key = (query_resolver_input, metric_reference)
result = self._query_items_analysis_cache.get(cache_key)
if result is not None:
return result
result = self._uncached_query_items_analysis(query_resolver_input, metric_reference)
self._query_items_analysis_cache.set(cache_key, result)
return result

def _uncached_query_items_analysis(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> QueryItemsAnalysis:
has_agg_time_dimension = False
has_metric_time = False
scds: List[InstanceSpec] = []

valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
self._query_includes_metric_time = (
self._resolve_group_by_item_result.linkable_element_set.filter(
LinkableElementFilter(with_any_of=frozenset({LinkableElementProperty.METRIC_TIME}))
).spec_count
> 0
)

scd_specs = self._manifest_lookup.metric_lookup.get_joinable_scd_specs_for_metric(metric_reference)

for group_by_item_input in query_resolver_input.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
has_metric_time = True

if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs):
has_agg_time_dimension = True

scd_matches = group_by_item_input.spec_pattern.match(scd_specs)
scds.extend(scd_matches)
self._scd_linkable_element_set = self._resolve_group_by_item_result.linkable_element_set.filter(
LinkableElementFilter(with_any_of=frozenset({LinkableElementProperty.SCD_HOP}))
)

return QueryItemsAnalysis(
scds=scds,
has_metric_time=has_metric_time,
has_agg_time_dimension=has_agg_time_dimension,
def _query_includes_agg_time_dimension_of_metric(self, metric_reference: MetricReference) -> bool:
valid_agg_time_dimensions = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
)
return (
len(set(valid_agg_time_dimensions).intersection(self._resolve_group_by_item_result.group_by_item_specs)) > 0
)

def _validate_cumulative_metric(
self,
metric_reference: MetricReference,
metric: Metric,
query_items_analysis: QueryItemsAnalysis,
resolution_path: MetricFlowQueryResolutionPath,
) -> Sequence[MetricFlowQueryResolutionIssue]:
# A cumulative metric with a window or grain-to-date specified requires a `metric-time` or the aggregation time
# dimension for the metric.
if (
metric.type_params is not None
and metric.type_params.cumulative_type_params is not None
and (
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not (query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension)
):
if self._query_includes_metric_time or self._query_includes_agg_time_dimension_of_metric(metric_reference):
return ()

return (
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
Expand All @@ -151,24 +110,37 @@ def _validate_derived_metric(
metric_reference: MetricReference,
metric: Metric,
resolution_path: MetricFlowQueryResolutionPath,
query_items_analysis: QueryItemsAnalysis,
) -> Sequence[MetricFlowQueryResolutionIssue]:
has_time_offset = any(
input_metric.offset_window is not None or input_metric.offset_to_grain is not None
for input_metric in metric.input_metrics
)

if has_time_offset and not (
query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension
):
return (
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
input_metrics=metric.input_metrics,
query_resolution_path=resolution_path,
),
)
return ()
# If a derived metric does not define a time offset, then there are no requirements on what's in the group-by
# items.
if not has_time_offset:
return ()

# If a derived metric has a time offset, then the query needs to include `metric_time` or the aggregation time
# dimension of a metric.
if self._query_includes_metric_time or self._query_includes_agg_time_dimension_of_metric(metric_reference):
return ()

return (
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
input_metrics=metric.input_metrics,
query_resolution_path=resolution_path,
),
)

def _scd_linkable_element_set_for_measure(self, measure_reference: MeasureReference) -> LinkableElementSet:
"""Returns subset of the query's `LinkableElements` that are SCDs and associated with the measure."""
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_reference
)

return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model.reference)

@override
def validate_metric_in_resolution_dag(
Expand All @@ -177,27 +149,13 @@ def validate_metric_in_resolution_dag(
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
metric = self._manifest_lookup.metric_lookup.get_metric(metric_reference)

query_items_analysis = self._get_query_items_analysis(self._resolver_input_for_query, metric_reference)

issues: List[MetricFlowQueryResolutionIssue] = []

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check
if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time:
issues.append(
ScdRequiresMetricTimeIssue.from_parameters(
scds_in_query=query_items_analysis.scds,
query_resolution_path=resolution_path,
)
)

if metric.type is MetricType.CUMULATIVE:
issues.extend(
self._validate_cumulative_metric(
metric_reference=metric_reference,
metric=metric,
query_items_analysis=query_items_analysis,
resolution_path=resolution_path,
)
)
Expand All @@ -207,7 +165,6 @@ def validate_metric_in_resolution_dag(
self._validate_derived_metric(
metric_reference=metric_reference,
metric=metric,
query_items_analysis=query_items_analysis,
resolution_path=resolution_path,
)
)
Expand Down Expand Up @@ -235,4 +192,19 @@ def validate_measure_in_resolution_dag(
measure_reference: MeasureReference,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
scd_linkable_elemenent_set_for_measure = self._scd_linkable_element_set_for_measure(measure_reference)

if scd_linkable_elemenent_set_for_measure.spec_count == 0:
return MetricFlowQueryResolutionIssueSet.empty_instance()

if self._query_includes_metric_time:
return MetricFlowQueryResolutionIssueSet.empty_instance()

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check

return MetricFlowQueryResolutionIssueSet.from_issue(
ScdRequiresMetricTimeIssue.from_parameters(
scds_in_query=scd_linkable_elemenent_set_for_measure.specs, query_resolution_path=resolution_path
)
)

0 comments on commit 01b4b22

Please sign in to comment.