Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of MetricTimeQueryValidationRule #1481

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,39 @@ def filter_by_spec_patterns(self, spec_patterns: Sequence[SpecPattern]) -> Linka
)
logger.debug(LazyFormat(lambda: f"Filtering valid linkable elements took: {time.time() - start_time:.2f}s"))
return filtered_elements

def filter_by_left_semantic_model(
self, left_semantic_model_reference: SemanticModelReference
) -> LinkableElementSet:
"""Return a `LinkableElementSet` with only elements that have the given left semantic model in the join path."""
path_key_to_linkable_dimensions: Dict[ElementPathKey, Tuple[LinkableDimension, ...]] = {}
path_key_to_linkable_entities: Dict[ElementPathKey, Tuple[LinkableEntity, ...]] = {}
path_key_to_linkable_metrics: Dict[ElementPathKey, Tuple[LinkableMetric, ...]] = {}

for path_key, linkable_dimensions in self.path_key_to_linkable_dimensions.items():
path_key_to_linkable_dimensions[path_key] = tuple(
linkable_dimension
for linkable_dimension in linkable_dimensions
if linkable_dimension.join_path.left_semantic_model_reference == left_semantic_model_reference
)

for path_key, linkable_entities in self.path_key_to_linkable_entities.items():
path_key_to_linkable_entities[path_key] = tuple(
linkable_entity
for linkable_entity in linkable_entities
if linkable_entity.join_path.left_semantic_model_reference == left_semantic_model_reference
)

for path_key, linkable_metrics in self.path_key_to_linkable_metrics.items():
path_key_to_linkable_metrics[path_key] = tuple(
linkable_metric
for linkable_metric in linkable_metrics
if linkable_metric.join_path.semantic_model_join_path.left_semantic_model_reference
== left_semantic_model_reference
)

return LinkableElementSet(
path_key_to_linkable_dimensions=path_key_to_linkable_dimensions,
path_key_to_linkable_entities=path_key_to_linkable_entities,
path_key_to_linkable_metrics=path_key_to_linkable_metrics,
)
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from typing import List, Sequence, Tuple
from typing import List, Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import Metric, WhereFilterIntersection
from dbt_semantic_interfaces.references import (
MeasureReference,
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import MetricType
from typing_extensions import override

from metricflow_semantics.collection_helpers.lru_cache import LruCache
from metricflow_semantics.model.linkable_element_property import LinkableElementProperty
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.model.semantics.element_filter import LinkableElementFilter
from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryResolutionIssue,
Expand All @@ -33,22 +32,11 @@
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

if typing.TYPE_CHECKING:
from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult


@dataclass(frozen=True)
class QueryItemsAnalysis:
"""Contains data about which items a query contains."""

scds: Sequence[InstanceSpec]
has_metric_time: bool
has_agg_time_dimension: bool


class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
"""Validates cases where a query requires metric_time to be specified as a group-by-item.

Expand All @@ -71,73 +59,44 @@ def __init__( # noqa: D107
resolve_group_by_item_result=resolve_group_by_item_result,
)

self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
time_dimension_reference=TimeDimensionReference(element_name=METRIC_TIME_ELEMENT_NAME),
entity_links=(),
custom_granularities=self._manifest_lookup.custom_granularities,
)
)
self._query_items_analysis_cache: LruCache[
Tuple[ResolverInputForQuery, MetricReference], QueryItemsAnalysis
] = LruCache(128)

def _get_query_items_analysis(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> QueryItemsAnalysis:
cache_key = (query_resolver_input, metric_reference)
result = self._query_items_analysis_cache.get(cache_key)
if result is not None:
return result
result = self._uncached_query_items_analysis(query_resolver_input, metric_reference)
self._query_items_analysis_cache.set(cache_key, result)
return result

def _uncached_query_items_analysis(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> QueryItemsAnalysis:
has_agg_time_dimension = False
has_metric_time = False
scds: List[InstanceSpec] = []

valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
self._query_includes_metric_time = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol this is so much better I didn't know you could do it like this. Next time I'll do it this way!

self._resolve_group_by_item_result.linkable_element_set.filter(
LinkableElementFilter(with_any_of=frozenset({LinkableElementProperty.METRIC_TIME}))
).spec_count
> 0
)

scd_specs = self._manifest_lookup.metric_lookup.get_joinable_scd_specs_for_metric(metric_reference)

for group_by_item_input in query_resolver_input.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
has_metric_time = True

if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs):
has_agg_time_dimension = True

scd_matches = group_by_item_input.spec_pattern.match(scd_specs)
scds.extend(scd_matches)
self._scd_linkable_element_set = self._resolve_group_by_item_result.linkable_element_set.filter(
LinkableElementFilter(with_any_of=frozenset({LinkableElementProperty.SCD_HOP}))
)

return QueryItemsAnalysis(
scds=scds,
has_metric_time=has_metric_time,
has_agg_time_dimension=has_agg_time_dimension,
def _query_includes_agg_time_dimension_of_metric(self, metric_reference: MetricReference) -> bool:
valid_agg_time_dimensions = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
)
return (
len(set(valid_agg_time_dimensions).intersection(self._resolve_group_by_item_result.group_by_item_specs)) > 0
)

def _validate_cumulative_metric(
self,
metric_reference: MetricReference,
metric: Metric,
query_items_analysis: QueryItemsAnalysis,
resolution_path: MetricFlowQueryResolutionPath,
) -> Sequence[MetricFlowQueryResolutionIssue]:
# A cumulative metric with a window or grain-to-date specified requires a `metric-time` or the aggregation time
# dimension for the metric.
if (
metric.type_params is not None
and metric.type_params.cumulative_type_params is not None
and (
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not (query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension)
):
if self._query_includes_metric_time or self._query_includes_agg_time_dimension_of_metric(metric_reference):
return ()

return (
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
Expand All @@ -151,24 +110,37 @@ def _validate_derived_metric(
metric_reference: MetricReference,
metric: Metric,
resolution_path: MetricFlowQueryResolutionPath,
query_items_analysis: QueryItemsAnalysis,
) -> Sequence[MetricFlowQueryResolutionIssue]:
has_time_offset = any(
input_metric.offset_window is not None or input_metric.offset_to_grain is not None
for input_metric in metric.input_metrics
)

if has_time_offset and not (
query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension
):
return (
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
input_metrics=metric.input_metrics,
query_resolution_path=resolution_path,
),
)
return ()
# If a derived metric does not define a time offset, then there are no requirements on what's in the group-by
# items.
if not has_time_offset:
return ()

# If a derived metric has a time offset, then the query needs to include `metric_time` or the aggregation time
# dimension of a metric.
if self._query_includes_metric_time or self._query_includes_agg_time_dimension_of_metric(metric_reference):
return ()

return (
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
input_metrics=metric.input_metrics,
query_resolution_path=resolution_path,
),
)

def _scd_linkable_element_set_for_measure(self, measure_reference: MeasureReference) -> LinkableElementSet:
"""Returns subset of the query's `LinkableElements` that are SCDs and associated with the measure."""
measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure(
measure_reference
)

return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model.reference)

@override
def validate_metric_in_resolution_dag(
Expand All @@ -177,27 +149,13 @@ def validate_metric_in_resolution_dag(
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
metric = self._manifest_lookup.metric_lookup.get_metric(metric_reference)

query_items_analysis = self._get_query_items_analysis(self._resolver_input_for_query, metric_reference)

issues: List[MetricFlowQueryResolutionIssue] = []

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check
if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time:
issues.append(
ScdRequiresMetricTimeIssue.from_parameters(
scds_in_query=query_items_analysis.scds,
query_resolution_path=resolution_path,
)
)

if metric.type is MetricType.CUMULATIVE:
issues.extend(
self._validate_cumulative_metric(
metric_reference=metric_reference,
metric=metric,
query_items_analysis=query_items_analysis,
resolution_path=resolution_path,
)
)
Expand All @@ -207,7 +165,6 @@ def validate_metric_in_resolution_dag(
self._validate_derived_metric(
metric_reference=metric_reference,
metric=metric,
query_items_analysis=query_items_analysis,
resolution_path=resolution_path,
)
)
Expand Down Expand Up @@ -235,4 +192,19 @@ def validate_measure_in_resolution_dag(
measure_reference: MeasureReference,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
scd_linkable_elemenent_set_for_measure = self._scd_linkable_element_set_for_measure(measure_reference)

if scd_linkable_elemenent_set_for_measure.spec_count == 0:
return MetricFlowQueryResolutionIssueSet.empty_instance()

if self._query_includes_metric_time:
return MetricFlowQueryResolutionIssueSet.empty_instance()

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check

return MetricFlowQueryResolutionIssueSet.from_issue(
ScdRequiresMetricTimeIssue.from_parameters(
scds_in_query=scd_linkable_elemenent_set_for_measure.specs, query_resolution_path=resolution_path
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ Error #1:

[Resolve Query(['bookings'])]
-> [Resolve Metric('bookings')]
-> [Resolve Measure('bookings')]
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ Error #1:

[Resolve Query(['bookings'])]
-> [Resolve Metric('bookings')]
-> [Resolve Measure('bookings')]
Loading