From 3d5170d0195ca19b124450c19e1ad1fd71929cb7 Mon Sep 17 00:00:00 2001 From: serramatutu Date: Wed, 9 Oct 2024 19:33:12 +0200 Subject: [PATCH] Add validation rule for SCD query with no time This commit adds a new step to `MetricTimeQueryValidationRule` which does the following: 1. Selects all the existing SCDs for the queried metric 2. Match them against the spec pattern of the group by input 3. Raise a `SCDRequiresMetricTimeIssue` if no `metric_time` was provided and there were matches To accomplish step 1, I had to create a new `LinkableElementProperty` called `SCD_HOP`. This new property indicates that the join path to the linkable element goes through an SCD at some point. I changed the `ValidLinkableSpecResolver` to add `SCD_HOP` to the properties of all the elements it finds whenever that element belongs to an SCD or if the path to it contains an SCD. --- .../model/linkable_element_property.py | 2 + .../model/semantics/linkable_spec_resolver.py | 39 ++++++++++-- .../model/semantics/metric_lookup.py | 10 ++++ .../parsing/scd_requires_metric_time.py | 51 ++++++++++++++++ .../metric_time_requirements.py | 33 +++++++++- .../query/test_query_parser.py | 60 +++++++++++++++++-- 6 files changed, 182 insertions(+), 13 deletions(-) create mode 100644 metricflow-semantics/metricflow_semantics/query/issues/parsing/scd_requires_metric_time.py diff --git a/metricflow-semantics/metricflow_semantics/model/linkable_element_property.py b/metricflow-semantics/metricflow_semantics/model/linkable_element_property.py index 642a9b3449..ab43a1adff 100644 --- a/metricflow-semantics/metricflow_semantics/model/linkable_element_property.py +++ b/metricflow-semantics/metricflow_semantics/model/linkable_element_property.py @@ -29,6 +29,8 @@ class LinkableElementProperty(Enum): METRIC = "metric" # A time dimension with a DatePart. DATE_PART = "date_part" + # A linkable element that is itself part of an SCD model, or a linkable element that gets joined through another SCD model. + SCD_HOP = "scd_hop" @staticmethod def all_properties() -> FrozenSet[LinkableElementProperty]: # noqa: D102 diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py index b8da4e7ce6..6e13efaf2e 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_spec_resolver.py @@ -188,6 +188,10 @@ def __init__( logger.debug(LazyFormat(lambda: f"Building valid group-by-item indexes took: {time.time() - start_time:.2f}s")) + def _semantic_model_is_scd(self, semantic_model: SemanticModel) -> bool: + """Whether the semantic model is an SCD.""" + return any(dim.validity_params is not None for dim in semantic_model.dimensions) + def _generate_linkable_time_dimensions( self, semantic_model_origin: SemanticModelReference, @@ -289,6 +293,8 @@ def get_joinable_metrics_for_semantic_model( necessary. """ properties = frozenset({LinkableElementProperty.METRIC, LinkableElementProperty.JOINED}) + if self._semantic_model_is_scd(semantic_model): + properties = properties.union({LinkableElementProperty.SCD_HOP}) join_path_has_path_links = len(using_join_path.path_elements) > 0 if join_path_has_path_links: @@ -326,8 +332,15 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link Elements related to metric_time are handled separately in _get_metric_time_elements(). Linkable metrics are not considered local to the semantic model since they always require a join. """ + semantic_model_is_scd = self._semantic_model_is_scd(semantic_model) + linkable_dimensions = [] linkable_entities = [] + + entity_properties = frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}) + if semantic_model_is_scd: + entity_properties = entity_properties.union({LinkableElementProperty.SCD_HOP}) + for entity in semantic_model.entities: linkable_entities.append( LinkableEntity.create( @@ -337,7 +350,7 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link join_path=SemanticModelJoinPath( left_semantic_model_reference=semantic_model.reference, ), - properties=frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}), + properties=entity_properties, ) ) for entity_link in self._semantic_model_lookup.entity_links_for_local_elements(semantic_model): @@ -352,12 +365,15 @@ def _get_elements_in_semantic_model(self, semantic_model: SemanticModel) -> Link join_path=SemanticModelJoinPath( left_semantic_model_reference=semantic_model.reference, ), - properties=frozenset({LinkableElementProperty.LOCAL, LinkableElementProperty.ENTITY}), + properties=entity_properties, ) ) + dimension_properties = frozenset({LinkableElementProperty.LOCAL}) + if semantic_model_is_scd: + dimension_properties = dimension_properties.union({LinkableElementProperty.SCD_HOP}) + for entity_link in self._semantic_model_lookup.entity_links_for_local_elements(semantic_model): - dimension_properties = frozenset({LinkableElementProperty.LOCAL}) for dimension in semantic_model.dimensions: dimension_type = dimension.type if dimension_type is DimensionType.CATEGORICAL: @@ -459,6 +475,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference defined_granularity: Optional[ExpandedTimeGranularity] = None if measure_reference: measure_semantic_model = self._get_semantic_model_for_measure(measure_reference) + semantic_model_is_scd = self._semantic_model_is_scd(measure_semantic_model) measure_agg_time_dimension_reference = measure_semantic_model.checked_agg_time_dimension_for_measure( measure_reference=measure_reference ) @@ -471,6 +488,7 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference # If querying metric_time without metrics, will query from time spines. # Defaults to DAY granularity if available in time spines, else smallest available granularity. min_granularity = min(self._time_spine_sources.keys()) + semantic_model_is_scd = False possible_metric_time_granularities = tuple( ExpandedTimeGranularity.from_time_granularity(time_granularity) for time_granularity in TimeGranularity @@ -501,6 +519,8 @@ def _get_metric_time_elements(self, measure_reference: Optional[MeasureReference properties.add(LinkableElementProperty.DERIVED_TIME_GRANULARITY) if date_part: properties.add(LinkableElementProperty.DATE_PART) + if semantic_model_is_scd: + properties.add(LinkableElementProperty.SCD_HOP) linkable_dimension = LinkableDimension.create( defined_in_semantic_model=measure_semantic_model.reference if measure_semantic_model else None, element_name=MetricFlowReservedKeywords.METRIC_TIME.value, @@ -711,12 +731,21 @@ def create_linkable_element_set_from_join_path( join_path: SemanticModelJoinPath, ) -> LinkableElementSet: """Given the current path, generate the respective linkable elements from the last semantic model in the path.""" + semantic_model = self._semantic_model_lookup.get_by_reference(join_path.last_semantic_model_reference) + assert semantic_model + properties = frozenset({LinkableElementProperty.JOINED}) if len(join_path.path_elements) > 1: properties = properties.union({LinkableElementProperty.MULTI_HOP}) - semantic_model = self._semantic_model_lookup.get_by_reference(join_path.last_semantic_model_reference) - assert semantic_model + # If any of the semantic models in the join path is an SCD, add SCD_HOP + for reference_to_derived_model in join_path.derived_from_semantic_models: + derived_model = self._semantic_model_lookup.get_by_reference(reference_to_derived_model) + assert derived_model + + if self._semantic_model_is_scd(derived_model): + properties = properties.union({LinkableElementProperty.SCD_HOP}) + break linkable_dimensions: List[LinkableDimension] = [] linkable_entities: List[LinkableEntity] = [] diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py b/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py index e39e9063ec..c54e690dce 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py @@ -20,6 +20,7 @@ ) from metricflow_semantics.model.semantics.semantic_model_join_evaluator import MAX_JOIN_HOPS from metricflow_semantics.model.semantics.semantic_model_lookup import SemanticModelLookup +from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec from metricflow_semantics.time.granularity import ExpandedTimeGranularity @@ -221,3 +222,12 @@ def get_min_queryable_time_granularity(self, metric_reference: MetricReference) minimum_queryable_granularity = defined_time_granularity return minimum_queryable_granularity + + def get_joinable_scd_specs_for_metric(self, metric_reference: MetricReference) -> Sequence[LinkableInstanceSpec]: + """Get the SCDs that can be joined to a metric.""" + scd_elems = self.linkable_elements_for_metrics( + metric_references=(metric_reference,), + with_any_property=frozenset([LinkableElementProperty.SCD_HOP]), + ) + + return scd_elems.specs diff --git a/metricflow-semantics/metricflow_semantics/query/issues/parsing/scd_requires_metric_time.py b/metricflow-semantics/metricflow_semantics/query/issues/parsing/scd_requires_metric_time.py new file mode 100644 index 0000000000..1fec2498f8 --- /dev/null +++ b/metricflow-semantics/metricflow_semantics/query/issues/parsing/scd_requires_metric_time.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass + +from typing_extensions import override + +from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath +from metricflow_semantics.query.issues.issues_base import ( + MetricFlowQueryIssueType, + MetricFlowQueryResolutionIssue, +) +from metricflow_semantics.query.resolver_inputs.base_resolver_inputs import MetricFlowQueryResolverInput + + +@dataclass(frozen=True) +class SCDRequiresMetricTimeIssue(MetricFlowQueryResolutionIssue): + """Describes an issue with a query that includes a SCD group by but does not include metric_time.""" + + scd_qualified_names: Sequence[str] + + @override + def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str: + dim_str = ", ".join(self.scd_qualified_names) + return ( + f"Your query contains the Slowly Changing Dimensions (SCDs): [{dim_str}]. " + "A query containing SCDs must also contain the metric_time dimension in order " + "to join the SCD table to the valid time range. Please add metric_time " + "to the query and try again. If you're using agg_time_dimension, use " + "metric_time instead." + ) + + @override + def with_path_prefix(self, path_prefix: MetricFlowQueryResolutionPath) -> SCDRequiresMetricTimeIssue: + return SCDRequiresMetricTimeIssue( + issue_type=self.issue_type, + parent_issues=self.parent_issues, + query_resolution_path=self.query_resolution_path.with_path_prefix(path_prefix), + scd_qualified_names=self.scd_qualified_names, + ) + + @staticmethod + def from_parameters( # noqa: D102 + scd_qualified_names: Sequence[str], query_resolution_path: MetricFlowQueryResolutionPath + ) -> SCDRequiresMetricTimeIssue: + return SCDRequiresMetricTimeIssue( + issue_type=MetricFlowQueryIssueType.ERROR, + parent_issues=(), + query_resolution_path=query_resolution_path, + scd_qualified_names=scd_qualified_names, + ) diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index e61dff1d01..e6ce6b344d 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -2,11 +2,15 @@ from collections.abc import Sequence from dataclasses import dataclass +from typing import List from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME from dbt_semantic_interfaces.protocols import WhereFilterIntersection -from dbt_semantic_interfaces.references import MetricReference, TimeDimensionReference +from dbt_semantic_interfaces.references import ( + MetricReference, + TimeDimensionReference, +) from dbt_semantic_interfaces.type_enums import MetricType from typing_extensions import override @@ -19,7 +23,12 @@ from metricflow_semantics.query.issues.parsing.offset_metric_requires_metric_time import ( OffsetMetricRequiresMetricTimeIssue, ) -from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery +from metricflow_semantics.query.issues.parsing.scd_requires_metric_time import ( + SCDRequiresMetricTimeIssue, +) +from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ( + ResolverInputForQuery, +) from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec @@ -28,6 +37,7 @@ class QueryItemsAnalysis: """Contains data about which items a query contains.""" + scds: Sequence[str] has_metric_time: bool has_agg_time_dimension: bool @@ -39,6 +49,7 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): * Cumulative metrics. * Derived metrics with an offset time.g + * Slowly changing dimensions """ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 @@ -57,10 +68,14 @@ def _get_query_items_analysis( ) -> QueryItemsAnalysis: has_agg_time_dimension = False has_metric_time = False + scds: List[str] = [] valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric( metric_reference ) + + scd_specs = self._manifest_lookup.metric_lookup.get_joinable_scd_specs_for_metric(metric_reference) + for group_by_item_input in query_resolver_input.group_by_item_inputs: if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs): has_metric_time = True @@ -68,7 +83,11 @@ def _get_query_items_analysis( if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs): has_agg_time_dimension = True + matches = group_by_item_input.spec_pattern.match(scd_specs) + scds.extend(match.qualified_name for match in matches) + return QueryItemsAnalysis( + scds=scds, has_metric_time=has_metric_time, has_agg_time_dimension=has_agg_time_dimension, ) @@ -86,6 +105,16 @@ def validate_metric_in_resolution_dag( issues = MetricFlowQueryResolutionIssueSet.empty_instance() + # Queries that join to an SCD don't support direct references to agg_time_dimension, so we + # only check for metric_time. If we decide to support agg_time_dimension, we should add a check + if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time: + issues = issues.add_issue( + SCDRequiresMetricTimeIssue.from_parameters( + scd_qualified_names=query_items_analysis.scds, + query_resolution_path=resolution_path, + ) + ) + if metric.type is MetricType.CUMULATIVE: if ( metric.type_params is not None diff --git a/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py b/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py index 50c9cc6853..d9b7025201 100644 --- a/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py +++ b/metricflow-semantics/tests_metricflow_semantics/query/test_query_parser.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import textwrap from typing import List @@ -29,6 +30,9 @@ EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE, ) from metricflow_semantics.test_helpers.metric_time_dimension import MTD +from metricflow_semantics.test_helpers.semantic_manifest_yamls.scd_manifest import ( + SCD_MANIFEST_ANCHOR, +) from metricflow_semantics.test_helpers.snapshot_helpers import assert_object_snapshot_equal logger = logging.getLogger(__name__) @@ -191,6 +195,23 @@ def revenue_query_parser() -> MetricFlowQueryParser: # noqa return query_parser_from_yaml([EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE, revenue_yaml_file]) +@pytest.fixture +def scd_query_parser() -> MetricFlowQueryParser: # noqa + file_paths = [ + os.path.join(SCD_MANIFEST_ANCHOR.directory, f) + for f in os.listdir(SCD_MANIFEST_ANCHOR.directory) + if f.endswith(".yaml") + ] + + contents: List[str] = [] + for fp in file_paths: + with open(fp, "r") as f: + contents.append(f.read()) + + files = [YamlConfigFile(filepath="inline_for_test_1", contents=c) for c in contents] + return query_parser_from_yaml(files) + + def test_query_parser( # noqa: D103 request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, @@ -459,6 +480,34 @@ def test_cumulative_metric_agg_time_dimension_name_validation( assert_object_snapshot_equal(request=request, mf_test_configuration=mf_test_configuration, obj=result) +def test_join_to_scd_no_time_dimension_validation( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + scd_query_parser: MetricFlowQueryParser, +) -> None: + """Test that queries that join to SCD semantic models fail if no time dimensions are selected.""" + with pytest.raises(InvalidQueryException, match="query containing SCDs must also contain the metric_time"): + scd_query_parser.parse_and_validate_query( + metric_names=["bookings"], + group_by_names=["listing__country"], + ) + + +def test_join_through_scd_no_time_dimension_validation( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + scd_query_parser: MetricFlowQueryParser, +) -> None: + """Test that queries that join through SCDs semantic models fail if no time dimensions are selected.""" + with pytest.raises(InvalidQueryException, match="query containing SCDs must also contain the metric_time"): + # "user__home_state_latest" is not an SCD itself, but since we go through + # "listing" and that is an SCD, we should raise an exception here as well + scd_query_parser.parse_and_validate_query( + metric_names=["bookings"], + group_by_names=["listing__user__home_state_latest"], + ) + + def test_derived_metric_query_parsing( request: FixtureRequest, mf_test_configuration: MetricFlowTestConfiguration, @@ -607,12 +656,11 @@ def test_offset_metric_with_diff_agg_time_dims_error() -> None: # noqa: D103 def query_parser_from_yaml(yaml_contents: List[YamlConfigFile]) -> MetricFlowQueryParser: """Given yaml files, return a query parser using default source nodes, resolvers and time spine source.""" - semantic_manifest_lookup = SemanticManifestLookup( - parse_yaml_files_to_validation_ready_semantic_manifest( - yaml_contents, apply_transformations=True - ).semantic_manifest - ) - SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest_lookup.semantic_manifest) + semantic_manifest = parse_yaml_files_to_validation_ready_semantic_manifest( + yaml_contents, apply_transformations=True + ).semantic_manifest + semantic_manifest_lookup = SemanticManifestLookup(semantic_manifest) + SemanticManifestValidator[SemanticManifest]().checked_validations(semantic_manifest) return MetricFlowQueryParser( semantic_manifest_lookup=semantic_manifest_lookup, )