Skip to content

Commit

Permalink
Add validation rule for SCD query with no time
Browse files Browse the repository at this point in the history
This commit adds a new step to `MetricTimeQueryValidationRule` which
does the following:
1. Selects all the existing SCDs for the queried metric
2. Match them against the spec pattern of the group by input
3. Raise a `SCDRequiresMetricTimeIssue` if no `metric_time` was provided
  and there were matches

To accomplish step 1, I had to create a new `LinkableElementProperty`
called `SCD_HOP`. This new property indicates that the join path to the
linkable element goes through an SCD at some point.

I changed the `ValidLinkableSpecResolver` to add `SCD_HOP` to the
properties of all the elements it finds whenever that element belongs to
an SCD or if the path to it contains an SCD.
  • Loading branch information
serramatutu committed Oct 9, 2024
1 parent a4ea534 commit 3d5170d
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class LinkableElementProperty(Enum):
METRIC = "metric"
# A time dimension with a DatePart.
DATE_PART = "date_part"
# A linkable element that is itself part of an SCD model, or a linkable element that gets joined through another SCD model.
SCD_HOP = "scd_hop"

@staticmethod
def all_properties() -> FrozenSet[LinkableElementProperty]: # noqa: D102
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def __init__(

logger.debug(LazyFormat(lambda: f"Building valid group-by-item indexes took: {time.time() - start_time:.2f}s"))

def _semantic_model_is_scd(self, semantic_model: SemanticModel) -> bool:
"""Whether the semantic model is an SCD."""
return any(dim.validity_params is not None for dim in semantic_model.dimensions)

def _generate_linkable_time_dimensions(
self,
semantic_model_origin: SemanticModelReference,
Expand Down Expand Up @@ -289,6 +293,8 @@ def get_joinable_metrics_for_semantic_model(
necessary.
"""
properties = frozenset({LinkableElementProperty.METRIC, LinkableElementProperty.JOINED})
if self._semantic_model_is_scd(semantic_model):
properties = properties.union({LinkableElementProperty.SCD_HOP})

join_path_has_path_links = len(using_join_path.path_elements) > 0
if join_path_has_path_links:
Expand Down Expand Up @@ -326,8 +332,15 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
Elements related to metric_time are handled separately in _get_metric_time_elements().
Linkable metrics are not considered local to the semantic model since they always require a join.
"""
semantic_model_is_scd = self._semantic_model_is_scd(semantic_model)

linkable_dimensions = []
linkable_entities = []

entity_properties = frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY})
if semantic_model_is_scd:
entity_properties = entity_properties.union({LinkableElementProperty.SCD_HOP})

for entity in semantic_model.entities:
linkable_entities.append(
LinkableEntity.create(
Expand All @@ -337,7 +350,7 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
join_path=SemanticModelJoinPath(
left_semantic_model_reference=semantic_model.reference,
),
properties=frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}),
properties=entity_properties,
)
)
for entity_link in self._semantic_model_lookup.entity_links_for_local_elements(semantic_model):
Expand All @@ -352,12 +365,15 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link
join_path=SemanticModelJoinPath(
left_semantic_model_reference=semantic_model.reference,
),
properties=frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}),
properties=entity_properties,
)
)

dimension_properties = frozenset({LinkableElementProperty.LOCAL})
if semantic_model_is_scd:
dimension_properties = dimension_properties.union({LinkableElementProperty.SCD_HOP})

for entity_link in self._semantic_model_lookup.entity_links_for_local_elements(semantic_model):
dimension_properties = frozenset({LinkableElementProperty.LOCAL})
for dimension in semantic_model.dimensions:
dimension_type = dimension.type
if dimension_type is DimensionType.CATEGORICAL:
Expand Down Expand Up @@ -459,6 +475,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
defined_granularity: Optional[ExpandedTimeGranularity] = None
if measure_reference:
measure_semantic_model = self._get_semantic_model_for_measure(measure_reference)
semantic_model_is_scd = self._semantic_model_is_scd(measure_semantic_model)
measure_agg_time_dimension_reference = measure_semantic_model.checked_agg_time_dimension_for_measure(
measure_reference=measure_reference
)
Expand All @@ -471,6 +488,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
# If querying metric_time without metrics, will query from time spines.
# Defaults to DAY granularity if available in time spines, else smallest available granularity.
min_granularity = min(self._time_spine_sources.keys())
semantic_model_is_scd = False
possible_metric_time_granularities = tuple(
ExpandedTimeGranularity.from_time_granularity(time_granularity)
for time_granularity in TimeGranularity
Expand Down Expand Up @@ -501,6 +519,8 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference
properties.add(LinkableElementProperty.DERIVED_TIME_GRANULARITY)
if date_part:
properties.add(LinkableElementProperty.DATE_PART)
if semantic_model_is_scd:
properties.add(LinkableElementProperty.SCD_HOP)
linkable_dimension = LinkableDimension.create(
defined_in_semantic_model=measure_semantic_model.reference if measure_semantic_model else None,
element_name=MetricFlowReservedKeywords.METRIC_TIME.value,
Expand Down Expand Up @@ -711,12 +731,21 @@ def create_linkable_element_set_from_join_path(
join_path: SemanticModelJoinPath,
) -> LinkableElementSet:
"""Given the current path, generate the respective linkable elements from the last semantic model in the path."""
semantic_model = self._semantic_model_lookup.get_by_reference(join_path.last_semantic_model_reference)
assert semantic_model

properties = frozenset({LinkableElementProperty.JOINED})
if len(join_path.path_elements) > 1:
properties = properties.union({LinkableElementProperty.MULTI_HOP})

semantic_model = self._semantic_model_lookup.get_by_reference(join_path.last_semantic_model_reference)
assert semantic_model
# If any of the semantic models in the join path is an SCD, add SCD_HOP
for reference_to_derived_model in join_path.derived_from_semantic_models:
derived_model = self._semantic_model_lookup.get_by_reference(reference_to_derived_model)
assert derived_model

if self._semantic_model_is_scd(derived_model):
properties = properties.union({LinkableElementProperty.SCD_HOP})
break

linkable_dimensions: List[LinkableDimension] = []
linkable_entities: List[LinkableEntity] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from metricflow_semantics.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS
from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.time.granularity import ExpandedTimeGranularity

Expand Down Expand Up @@ -221,3 +222,12 @@ def get_min_queryable_time_granularity(self, metric_reference: MetricReference)
minimum_queryable_granularity = defined_time_granularity

return minimum_queryable_granularity

def get_joinable_scd_specs_for_metric(self, metric_reference: MetricReference) -> Sequence[LinkableInstanceSpec]:
"""Get the SCDs that can be joined to a metric."""
scd_elems = self.linkable_elements_for_metrics(
metric_references=(metric_reference,),
with_any_property=frozenset([LinkableElementProperty.SCD_HOP]),
)

return scd_elems.specs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass

from typing_extensions import override

from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryIssueType,
MetricFlowQueryResolutionIssue,
)
from metricflow_semantics.query.resolver_inputs.base_resolver_inputs import MetricFlowQueryResolverInput


@dataclass(frozen=True)
class SCDRequiresMetricTimeIssue(MetricFlowQueryResolutionIssue):
"""Describes an issue with a query that includes a SCD group by but does not include metric_time."""

scd_qualified_names: Sequence[str]

@override
def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str:
dim_str = ", ".join(self.scd_qualified_names)
return (
f"Your query contains the Slowly Changing Dimensions (SCDs): [{dim_str}]. "
"A query containing SCDs must also contain the metric_time dimension in order "
"to join the SCD table to the valid time range. Please add metric_time "
"to the query and try again. If you're using agg_time_dimension, use "
"metric_time instead."
)

@override
def with_path_prefix(self, path_prefix: MetricFlowQueryResolutionPath) -> SCDRequiresMetricTimeIssue:
return SCDRequiresMetricTimeIssue(
issue_type=self.issue_type,
parent_issues=self.parent_issues,
query_resolution_path=self.query_resolution_path.with_path_prefix(path_prefix),
scd_qualified_names=self.scd_qualified_names,
)

@staticmethod
def from_parameters( # noqa: D102
scd_qualified_names: Sequence[str], query_resolution_path: MetricFlowQueryResolutionPath
) -> SCDRequiresMetricTimeIssue:
return SCDRequiresMetricTimeIssue(
issue_type=MetricFlowQueryIssueType.ERROR,
parent_issues=(),
query_resolution_path=query_resolution_path,
scd_qualified_names=scd_qualified_names,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from collections.abc import Sequence
from dataclasses import dataclass
from typing import List

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.references import MetricReference, TimeDimensionReference
from dbt_semantic_interfaces.references import (
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums import MetricType
from typing_extensions import override

Expand All @@ -19,7 +23,12 @@
from metricflow_semantics.query.issues.parsing.offset_metric_requires_metric_time import (
OffsetMetricRequiresMetricTimeIssue,
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.issues.parsing.scd_requires_metric_time import (
SCDRequiresMetricTimeIssue,
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import (
ResolverInputForQuery,
)
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

Expand All @@ -28,6 +37,7 @@
class QueryItemsAnalysis:
"""Contains data about which items a query contains."""

scds: Sequence[str]
has_metric_time: bool
has_agg_time_dimension: bool

Expand All @@ -39,6 +49,7 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
* Cumulative metrics.
* Derived metrics with an offset time.g
* Slowly changing dimensions
"""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
Expand All @@ -57,18 +68,26 @@ def _get_query_items_analysis(
) -> QueryItemsAnalysis:
has_agg_time_dimension = False
has_metric_time = False
scds: List[str] = []

valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric(
metric_reference
)

scd_specs = self._manifest_lookup.metric_lookup.get_joinable_scd_specs_for_metric(metric_reference)

for group_by_item_input in query_resolver_input.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
has_metric_time = True

if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs):
has_agg_time_dimension = True

matches = group_by_item_input.spec_pattern.match(scd_specs)
scds.extend(match.qualified_name for match in matches)

return QueryItemsAnalysis(
scds=scds,
has_metric_time=has_metric_time,
has_agg_time_dimension=has_agg_time_dimension,
)
Expand All @@ -86,6 +105,16 @@ def validate_metric_in_resolution_dag(

issues = MetricFlowQueryResolutionIssueSet.empty_instance()

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check
if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time:
issues = issues.add_issue(
SCDRequiresMetricTimeIssue.from_parameters(
scd_qualified_names=query_items_analysis.scds,
query_resolution_path=resolution_path,
)
)

if metric.type is MetricType.CUMULATIVE:
if (
metric.type_params is not None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
import textwrap
from typing import List

Expand Down Expand Up @@ -29,6 +30,9 @@
EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE,
)
from metricflow_semantics.test_helpers.metric_time_dimension import MTD
from metricflow_semantics.test_helpers.semantic_manifest_yamls.scd_manifest import (
SCD_MANIFEST_ANCHOR,
)
from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -191,6 +195,23 @@ def revenue_query_parser() -> MetricFlowQueryParser: # noqa
return query_parser_from_yaml([EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE, revenue_yaml_file])


@pytest.fixture
def scd_query_parser() -> MetricFlowQueryParser: # noqa
file_paths = [
os.path.join(SCD_MANIFEST_ANCHOR.directory, f)
for f in os.listdir(SCD_MANIFEST_ANCHOR.directory)
if f.endswith(".yaml")
]

contents: List[str] = []
for fp in file_paths:
with open(fp, "r") as f:
contents.append(f.read())

files = [YamlConfigFile(filepath="inline_for_test_1", contents=c) for c in contents]
return query_parser_from_yaml(files)


def test_query_parser( # noqa: D103
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
Expand Down Expand Up @@ -459,6 +480,34 @@ def test_cumulative_metric_agg_time_dimension_name_validation(
assert_object_snapshot_equal(request=request, mf_test_configuration=mf_test_configuration, obj=result)


def test_join_to_scd_no_time_dimension_validation(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
scd_query_parser: MetricFlowQueryParser,
) -> None:
"""Test that queries that join to SCD semantic models fail if no time dimensions are selected."""
with pytest.raises(InvalidQueryException, match="query containing SCDs must also contain the metric_time"):
scd_query_parser.parse_and_validate_query(
metric_names=["bookings"],
group_by_names=["listing__country"],
)


def test_join_through_scd_no_time_dimension_validation(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
scd_query_parser: MetricFlowQueryParser,
) -> None:
"""Test that queries that join through SCDs semantic models fail if no time dimensions are selected."""
with pytest.raises(InvalidQueryException, match="query containing SCDs must also contain the metric_time"):
# "user__home_state_latest" is not an SCD itself, but since we go through
# "listing" and that is an SCD, we should raise an exception here as well
scd_query_parser.parse_and_validate_query(
metric_names=["bookings"],
group_by_names=["listing__user__home_state_latest"],
)


def test_derived_metric_query_parsing(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
Expand Down Expand Up @@ -607,12 +656,11 @@ def test_offset_metric_with_diff_agg_time_dims_error() -> None: # noqa: D103

def query_parser_from_yaml(yaml_contents: List[YamlConfigFile]) -> MetricFlowQueryParser:
"""Given yaml files, return a query parser using default source nodes, resolvers and time spine source."""
semantic_manifest_lookup = SemanticManifestLookup(
parse_yaml_files_to_validation_ready_semantic_manifest(
yaml_contents, apply_transformations=True
).semantic_manifest
)
SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest_lookup.semantic_manifest)
semantic_manifest = parse_yaml_files_to_validation_ready_semantic_manifest(
yaml_contents, apply_transformations=True
).semantic_manifest
semantic_manifest_lookup = SemanticManifestLookup(semantic_manifest)
SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest)
return MetricFlowQueryParser(
semantic_manifest_lookup=semantic_manifest_lookup,
)
Expand Down

0 comments on commit 3d5170d

Please sign in to comment.