Skip to content

Commit

Permalink
Updated filter call set parameters to fix breaking change
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Nov 11, 2024
1 parent bec5255 commit 3a40926
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:

@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> EntityLinkPattern:
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise ValueError(f"{repr(input_str)} does not follow this scheme.")

input_str = input_str.lower()
Expand Down Expand Up @@ -119,7 +119,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
)

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
# This naming scheme is case-insensitive.
input_str = input_str.lower()
if DunderNamingScheme._INPUT_REGEX.match(input_str) is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:
@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> MetricSpecPattern:
input_str = input_str.lower()
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise RuntimeError(f"{repr(input_str)} does not follow this scheme.")
return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str))

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
# TODO: Use regex.
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
pass

@abstractmethod
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
"""Returns true if the given input string follows this naming scheme.
Consider adding a structured result that indicates why it does not match the scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ def input_str(self, instance_spec: InstanceSpec) -> Optional[str]:

@override
def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> SpecPattern:
if not self.input_str_follows_scheme(input_str):
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise ValueError(
f"The specified input {repr(input_str)} does not match the input described by the object builder "
f"pattern."
)
try:
# TODO: Update when more appropriate parsing libraries are available.
call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets
call_parameter_sets = PydanticWhereFilter(where_sql_template="{{ " + input_str + " }}").call_parameter_sets(
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names
)
except ParseWhereFilterException as e:
raise ValueError(f"A spec pattern can't be generated from the input string {repr(input_str)}") from e

Expand Down Expand Up @@ -121,11 +123,14 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
raise RuntimeError("There should have been a return associated with one of the CallParameterSets.")

@override
def input_str_follows_scheme(self, input_str: str) -> bool:
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
if ObjectBuilderNamingScheme._NAME_REGEX.match(input_str) is None:
return False
try:
call_parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{ " + input_str + " }}")
call_parameter_sets = WhereFilterParser.parse_call_parameter_sets(
where_sql_template="{{ " + input_str + " }}",
custom_granularity_names=semantic_manifest_lookup.semantic_model_lookup.custom_granularity_names,
)
return_value = (
len(call_parameter_sets.dimension_call_parameter_sets)
+ len(call_parameter_sets.time_dimension_call_parameter_sets)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def _resolve_specs_for_where_filters(
for location, where_filters in where_filters_and_locations.items():
for where_filter in where_filters:
try:
filter_call_parameter_sets = where_filter.call_parameter_sets
filter_call_parameter_sets = where_filter.call_parameter_sets(
custom_granularity_names=self._manifest_lookup.semantic_model_lookup.custom_granularity_names
)
except Exception as e:
non_parsable_resolutions.append(
NonParsableFilterResolution(
Expand Down
16 changes: 12 additions & 4 deletions metricflow-semantics/metricflow_semantics/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _parse_order_by_names(
order_by_name_without_prefix = order_by_name

for group_by_item_naming_scheme in self._group_by_item_naming_schemes:
if group_by_item_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix):
if group_by_item_naming_scheme.input_str_follows_scheme(
order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup
):
possible_inputs.append(
ResolverInputForGroupByItem(
input_obj=order_by_name,
Expand All @@ -223,7 +225,9 @@ def _parse_order_by_names(
break

for metric_naming_scheme in self._metric_naming_schemes:
if metric_naming_scheme.input_str_follows_scheme(order_by_name_without_prefix):
if metric_naming_scheme.input_str_follows_scheme(
order_by_name_without_prefix, semantic_manifest_lookup=self._manifest_lookup
):
possible_inputs.append(
ResolverInputForMetric(
input_obj=order_by_name,
Expand Down Expand Up @@ -373,7 +377,9 @@ def _parse_and_validate_query(
for metric_name in metric_names:
resolver_input_for_metric: Optional[MetricFlowQueryResolverInput] = None
for metric_naming_scheme in self._metric_naming_schemes:
if metric_naming_scheme.input_str_follows_scheme(metric_name):
if metric_naming_scheme.input_str_follows_scheme(
metric_name, semantic_manifest_lookup=self._manifest_lookup
):
resolver_input_for_metric = ResolverInputForMetric(
input_obj=metric_name,
naming_scheme=metric_naming_scheme,
Expand Down Expand Up @@ -405,7 +411,9 @@ def _parse_and_validate_query(
for group_by_name in group_by_names:
resolver_input_for_group_by_item: Optional[MetricFlowQueryResolverInput] = None
for group_by_item_naming_scheme in self._group_by_item_naming_schemes:
if group_by_item_naming_scheme.input_str_follows_scheme(group_by_name):
if group_by_item_naming_scheme.input_str_follows_scheme(
group_by_name, semantic_manifest_lookup=self._manifest_lookup
):
spec_pattern = group_by_item_naming_scheme.spec_pattern(
group_by_name, semantic_manifest_lookup=self._manifest_lookup
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,28 @@ def test_input_str(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D
)


def test_input_follows_scheme(dunder_naming_scheme: DunderNamingScheme) -> None: # noqa: D103
assert dunder_naming_scheme.input_str_follows_scheme("listing__country")
assert dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__month")
assert dunder_naming_scheme.input_str_follows_scheme("booking__listing")
assert not dunder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month")
assert not dunder_naming_scheme.input_str_follows_scheme("123")
assert not dunder_naming_scheme.input_str_follows_scheme("TimeDimension('metric_time')")
def test_input_follows_scheme( # noqa: D103
dunder_naming_scheme: DunderNamingScheme,
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> None:
assert dunder_naming_scheme.input_str_follows_scheme(
"listing__country", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert dunder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert dunder_naming_scheme.input_str_follows_scheme(
"booking__listing", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"123", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not dunder_naming_scheme.input_str_follows_scheme(
"TimeDimension('metric_time')", semantic_manifest_lookup=simple_semantic_manifest_lookup
)


def test_spec_pattern( # noqa: D103
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def test_input_str(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D
assert metric_naming_scheme.input_str(MetricSpec(element_name="bookings")) == "bookings"


def test_input_follows_scheme(metric_naming_scheme: MetricNamingScheme) -> None: # noqa: D103
assert metric_naming_scheme.input_str_follows_scheme("listings")
def test_input_follows_scheme( # noqa: D103
metric_naming_scheme: MetricNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup
) -> None:
assert metric_naming_scheme.input_str_follows_scheme(
"listings", semantic_manifest_lookup=simple_semantic_manifest_lookup
)


def test_spec_pattern( # noqa: D103
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,30 @@ def test_input_str(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> N
)


def test_input_follows_scheme(object_builder_naming_scheme: ObjectBuilderNamingScheme) -> None: # noqa: D103
def test_input_follows_scheme( # noqa: D103
object_builder_naming_scheme: ObjectBuilderNamingScheme, simple_semantic_manifest_lookup: SemanticManifestLookup
) -> None:
assert object_builder_naming_scheme.input_str_follows_scheme(
"Dimension('listing__country', entity_path=['booking'])"
"Dimension('listing__country', entity_path=['booking'])",
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
assert object_builder_naming_scheme.input_str_follows_scheme(
"TimeDimension('listing__creation_time', time_granularity_name='month', date_part_name='day', "
"entity_path=['booking'])"
"entity_path=['booking'])",
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
assert object_builder_naming_scheme.input_str_follows_scheme(
"Entity('user', entity_path=['booking', 'listing'])",
"Entity('user', entity_path=['booking', 'listing'])", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme(
"listing__creation_time__extract_month", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme(
"123", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme(
"NotADimension('listing__country')", semantic_manifest_lookup=simple_semantic_manifest_lookup
)
assert not object_builder_naming_scheme.input_str_follows_scheme("listing__creation_time__extract_month")
assert not object_builder_naming_scheme.input_str_follows_scheme("123")
assert not object_builder_naming_scheme.input_str_follows_scheme("NotADimension('listing__country')")


def test_spec_pattern( # noqa: D103
Expand Down

0 comments on commit 3a40926

Please sign in to comment.