From 3a409262be0926a5ad9c19acc5683cdf2dedd61e Mon Sep 17 00:00:00 2001 From: Will Deng Date: Mon, 11 Nov 2024 16:50:48 -0500 Subject: [PATCH] Updated filter call set parameters to fix breaking change --- .../naming/dunder_scheme.py | 4 +-- .../naming/metric_scheme.py | 4 +-- .../naming/naming_scheme.py | 2 +- .../naming/object_builder_scheme.py | 13 ++++++--- .../filter_spec_resolver.py | 4 ++- .../query/query_parser.py | 16 +++++++--- .../naming/test_dunder_naming_scheme.py | 29 ++++++++++++++----- .../naming/test_metric_name_scheme.py | 8 +++-- .../test_object_builder_naming_scheme.py | 24 ++++++++++----- 9 files changed, 74 insertions(+), 30 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py b/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py index 926adbccc1..ebef953fdd 100644 --- a/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py +++ b/metricflow-semantics/metricflow_semantics/naming/dunder_scheme.py @@ -52,7 +52,7 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: @override def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> EntityLinkPattern: - if not self.input_str_follows_scheme(input_str): + if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup): raise ValueError(f"{repr(input_str)} does not follow this scheme.") input_str = input_str.lower() @@ -119,7 +119,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes ) @override - def input_str_follows_scheme(self, input_str: str) -> bool: + def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool: # This naming scheme is case-insensitive. input_str = input_str.lower() if DunderNamingScheme._INPUT_REGEX.match(input_str) is None: diff --git a/metricflow-semantics/metricflow_semantics/naming/metric_scheme.py b/metricflow-semantics/metricflow_semantics/naming/metric_scheme.py index fcd9fe5ac9..a19407f5d6 100644 --- a/metricflow-semantics/metricflow_semantics/naming/metric_scheme.py +++ b/metricflow-semantics/metricflow_semantics/naming/metric_scheme.py @@ -28,12 +28,12 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: @override def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> MetricSpecPattern: input_str = input_str.lower() - if not self.input_str_follows_scheme(input_str): + if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup): raise RuntimeError(f"{repr(input_str)} does not follow this scheme.") return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str)) @override - def input_str_follows_scheme(self, input_str: str) -> bool: + def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool: # TODO: Use regex. return True diff --git a/metricflow-semantics/metricflow_semantics/naming/naming_scheme.py b/metricflow-semantics/metricflow_semantics/naming/naming_scheme.py index f9110c01b8..f78cf72ae9 100644 --- a/metricflow-semantics/metricflow_semantics/naming/naming_scheme.py +++ b/metricflow-semantics/metricflow_semantics/naming/naming_scheme.py @@ -42,7 +42,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes pass @abstractmethod - def input_str_follows_scheme(self, input_str: str) -> bool: + def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool: """Returns true if the given input string follows this naming scheme. Consider adding a structured result that indicates why it does not match the scheme. diff --git a/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py b/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py index b6f11d1ec5..55a3dd51ad 100644 --- a/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py +++ b/metricflow-semantics/metricflow_semantics/naming/object_builder_scheme.py @@ -38,14 +38,16 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]: @override def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> SpecPattern: - if not self.input_str_follows_scheme(input_str): + if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup): raise ValueError( f"The specified input {repr(input_str)} does not match the input described by the object builder " f"pattern." ) try: # TODO: Update when more appropriate parsing libraries are available. - call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets + call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets( + custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names + ) except ParseWhereFilterException as e: raise ValueError(f"A spec pattern can't be generated from the input string {repr(input_str)}") from e @@ -121,11 +123,14 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes raise RuntimeError("There should have been a return associated with one of the CallParameterSets.") @override - def input_str_follows_scheme(self, input_str: str) -> bool: + def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool: if ObjectBuilderNamingScheme._NAME_REGEX.match(input_str) is None: return False try: - call_parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{ " + input_str + " }}") + call_parameter_sets = WhereFilterParser.parse_call_parameter_sets( + where_sql_template="{{ " + input_str + " }}", + custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names, + ) return_value = ( len(call_parameter_sets.dimension_call_parameter_sets) + len(call_parameter_sets.time_dimension_call_parameter_sets) diff --git a/metricflow-semantics/metricflow_semantics/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py b/metricflow-semantics/metricflow_semantics/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py index 091fddc84a..30b64b0811 100644 --- a/metricflow-semantics/metricflow_semantics/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/group_by_item/filter_spec_resolution/filter_spec_resolver.py @@ -331,7 +331,9 @@ def _resolve_specs_for_where_filters( for location, where_filters in where_filters_and_locations.items(): for where_filter in where_filters: try: - filter_call_parameter_sets = where_filter.call_parameter_sets + filter_call_parameter_sets = where_filter.call_parameter_sets( + custom_granularity_names=self._manifest_lookup.semantic_model_lookup.custom_granularity_names + ) except Exception as e: non_parsable_resolutions.append( NonParsableFilterResolution( diff --git a/metricflow-semantics/metricflow_semantics/query/query_parser.py b/metricflow-semantics/metricflow_semantics/query/query_parser.py index cf0c1f948c..555216e1ab 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_parser.py +++ b/metricflow-semantics/metricflow_semantics/query/query_parser.py @@ -210,7 +210,9 @@ def _parse_order_by_names( order_by_name_without_prefix = order_by_name for group_by_item_naming_scheme in self._group_by_item_naming_schemes: - if group_by_item_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix): + if group_by_item_naming_scheme.input_str_follows_scheme( + order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup + ): possible_inputs.append( ResolverInputForGroupByItem( input_obj=order_by_name, @@ -223,7 +225,9 @@ def _parse_order_by_names( break for metric_naming_scheme in self._metric_naming_schemes: - if metric_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix): + if metric_naming_scheme.input_str_follows_scheme( + order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup + ): possible_inputs.append( ResolverInputForMetric( input_obj=order_by_name, @@ -373,7 +377,9 @@ def _parse_and_validate_query( for metric_name in metric_names: resolver_input_for_metric: Optional[MetricFlowQueryResolverInput] = None for metric_naming_scheme in self._metric_naming_schemes: - if metric_naming_scheme.input_str_follows_scheme(metric_name): + if metric_naming_scheme.input_str_follows_scheme( + metric_name, semantic_manifest_lookup=self._manifest_lookup + ): resolver_input_for_metric = ResolverInputForMetric( input_obj=metric_name, naming_scheme=metric_naming_scheme, @@ -405,7 +411,9 @@ def _parse_and_validate_query( for group_by_name in group_by_names: resolver_input_for_group_by_item: Optional[MetricFlowQueryResolverInput] = None for group_by_item_naming_scheme in self._group_by_item_naming_schemes: - if group_by_item_naming_scheme.input_str_follows_scheme(group_by_name): + if group_by_item_naming_scheme.input_str_follows_scheme( + group_by_name, semantic_manifest_lookup=self._manifest_lookup + ): spec_pattern = group_by_item_naming_scheme.spec_pattern( group_by_name, semantic_manifest_lookup=self._manifest_lookup ) diff --git a/metricflow-semantics/tests_metricflow_semantics/naming/test_dunder_naming_scheme.py b/metricflow-semantics/tests_metricflow_semantics/naming/test_dunder_naming_scheme.py index 2724ead45f..142213125b 100644 --- a/metricflow-semantics/tests_metricflow_semantics/naming/test_dunder_naming_scheme.py +++ b/metricflow-semantics/tests_metricflow_semantics/naming/test_dunder_naming_scheme.py @@ -75,13 +75,28 @@ def test_input_str(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D ) -def test_input_follows_scheme(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D103 - assert dunder_naming_scheme.input_str_follows_scheme("listing__country") - assert dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__month") - assert dunder_naming_scheme.input_str_follows_scheme("booking__listing") - assert not dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month") - assert not dunder_naming_scheme.input_str_follows_scheme("123") - assert not dunder_naming_scheme.input_str_follows_scheme("TimeDimension('metric_time')") +def test_input_follows_scheme( # noqa: D103 + dunder_naming_scheme: DunderNamingScheme, + simple_semantic_manifest_lookup: SemanticManifestLookup, +) -> None: + assert dunder_naming_scheme.input_str_follows_scheme( + "listing__country", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert dunder_naming_scheme.input_str_follows_scheme( + "listing__creation_time__month", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert dunder_naming_scheme.input_str_follows_scheme( + "booking__listing", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert not dunder_naming_scheme.input_str_follows_scheme( + "listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert not dunder_naming_scheme.input_str_follows_scheme( + "123", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert not dunder_naming_scheme.input_str_follows_scheme( + "TimeDimension('metric_time')", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) def test_spec_pattern( # noqa: D103 diff --git a/metricflow-semantics/tests_metricflow_semantics/naming/test_metric_name_scheme.py b/metricflow-semantics/tests_metricflow_semantics/naming/test_metric_name_scheme.py index b8e057c491..24c456b5a6 100644 --- a/metricflow-semantics/tests_metricflow_semantics/naming/test_metric_name_scheme.py +++ b/metricflow-semantics/tests_metricflow_semantics/naming/test_metric_name_scheme.py @@ -19,8 +19,12 @@ def test_input_str(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D assert metric_naming_scheme.input_str(MetricSpec(element_name="bookings")) == "bookings" -def test_input_follows_scheme(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D103 - assert metric_naming_scheme.input_str_follows_scheme("listings") +def test_input_follows_scheme( # noqa: D103 + metric_naming_scheme: MetricNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup +) -> None: + assert metric_naming_scheme.input_str_follows_scheme( + "listings", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) def test_spec_pattern( # noqa: D103 diff --git a/metricflow-semantics/tests_metricflow_semantics/naming/test_object_builder_naming_scheme.py b/metricflow-semantics/tests_metricflow_semantics/naming/test_object_builder_naming_scheme.py index 1b8a058a32..5f471d0c47 100644 --- a/metricflow-semantics/tests_metricflow_semantics/naming/test_object_builder_naming_scheme.py +++ b/metricflow-semantics/tests_metricflow_semantics/naming/test_object_builder_naming_scheme.py @@ -64,20 +64,30 @@ def test_input_str(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> N ) -def test_input_follows_scheme(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> None: # noqa: D103 +def test_input_follows_scheme( # noqa: D103 + object_builder_naming_scheme: ObjectBuilderNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup +) -> None: assert object_builder_naming_scheme.input_str_follows_scheme( - "Dimension('listing__country', entity_path=['booking'])" + "Dimension('listing__country', entity_path=['booking'])", + semantic_manifest_lookup=simple_semantic_manifest_lookup, ) assert object_builder_naming_scheme.input_str_follows_scheme( "TimeDimension('listing__creation_time', time_granularity_name='month', date_part_name='day', " - "entity_path=['booking'])" + "entity_path=['booking'])", + semantic_manifest_lookup=simple_semantic_manifest_lookup, ) assert object_builder_naming_scheme.input_str_follows_scheme( - "Entity('user', entity_path=['booking', 'listing'])", + "Entity('user', entity_path=['booking', 'listing'])", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert not object_builder_naming_scheme.input_str_follows_scheme( + "listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert not object_builder_naming_scheme.input_str_follows_scheme( + "123", semantic_manifest_lookup=simple_semantic_manifest_lookup + ) + assert not object_builder_naming_scheme.input_str_follows_scheme( + "NotADimension('listing__country')", semantic_manifest_lookup=simple_semantic_manifest_lookup ) - assert not object_builder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month") - assert not object_builder_naming_scheme.input_str_follows_scheme("123") - assert not object_builder_naming_scheme.input_str_follows_scheme("NotADimension('listing__country')") def test_spec_pattern( # noqa: D103