Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Entity() parsing in filters.
Browse files Browse the repository at this point in the history
plypaul committed Jan 29, 2024
1 parent bee9aaf commit 05cb6d2
Showing 4 changed files with 56 additions and 24 deletions.
17 changes: 11 additions & 6 deletions dbt_semantic_interfaces/implementations/filters/where_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import textwrap
import traceback
from typing import Callable, Generator, List, Tuple

from typing_extensions import Self
@@ -16,7 +18,6 @@
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)
from dbt_semantic_interfaces.pretty_print import pformat_big_objects


class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel):
@@ -126,10 +127,14 @@ def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParamete
invalid_filter_expressions.append((where_filter.where_sql_template, e))

if invalid_filter_expressions:
raise ParseWhereFilterException(
f"Encountered one or more errors when parsing the set of filter expressions "
f"{pformat_big_objects(self.where_filters)}! Invalid expressions: \n "
f"{pformat_big_objects(invalid_filter_expressions)}"
)
lines = ["Encountered error(s) while parsing:\n"]
for where_sql_template, exception in invalid_filter_expressions:
lines.append("Filter:")
lines.append(textwrap.indent(where_sql_template, prefix=" "))
lines.append("Error Message:")
lines.append(textwrap.indent(str(exception), prefix=" "))
lines.append("Traceback:")
lines.append(textwrap.indent("".join(traceback.format_tb(exception.__traceback__)), prefix=" "))
raise ParseWhereFilterException("\n".join(lines))

return filter_parameter_sets
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ class ParameterSetFactory:
@staticmethod
def _exception_message_for_incorrect_format(element_name: str) -> str:
return (
f"Name is in an incorrect format: '{element_name}'. It should be of the form: "
f"Name is in an incorrect format: {repr(element_name)}. It should be of the form: "
f"<primary entity name>__<dimension_name>"
)

@@ -87,14 +87,17 @@ def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> Di
@staticmethod
def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet:
"""Gets called by Jinja when rendering {{ Entity(...) }}."""
group_by_item_name = DunderedNameFormatter.parse_name(entity_name)
if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None:
structured_dundered_name = DunderedNameFormatter.parse_name(entity_name)
if structured_dundered_name.time_granularity is not None:
raise ParseWhereFilterException(
f"Entity name is in an incorrect format: '{entity_name}'. "
f"It should not contain any dunders (double underscores, or __)."
f"Name is in an incorrect format: {repr(entity_name)}. " f"It should not contain a time grain suffix."
)

additional_entity_path_elements = tuple(
EntityReference(element_name=entity_path_item) for entity_path_item in entity_path
)

return EntityCallParameterSet(
entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path),
entity_reference=EntityReference(element_name=entity_name),
entity_path=additional_entity_path_elements + structured_dundered_name.entity_links,
entity_reference=EntityReference(element_name=structured_dundered_name.element_name),
)
21 changes: 10 additions & 11 deletions tests/implementations/where_filter/test_parse_calls.py
Original file line number Diff line number Diff line change
@@ -13,9 +13,6 @@
PydanticWhereFilter,
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
@@ -145,9 +142,9 @@ def test_metric_time_in_dimension_call_error() -> None: # noqa: D

def test_invalid_entity_name_error() -> None:
"""Test to ensure we throw an error if an entity name is invalid."""
bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order' )}}")
bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}")

with pytest.raises(ParseWhereFilterException, match="Entity name is in an incorrect format"):
with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"):
bad_entity_filter.call_parameter_sets


@@ -198,7 +195,7 @@ def test_where_filter_intersection_error_collection() -> None:
where_sql_template="{{ TimeDimension('order_id__order_time__month', 'week') }} > '2020-01-01'"
)
valid_dimension = PydanticWhereFilter(where_sql_template=" {Dimension('customer__has_delivery_address')} ")
entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order') }}")
entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order__day') }}")
filter_intersection = PydanticWhereFilterIntersection(
where_filters=[metric_time_in_dimension_error, valid_dimension, entity_format_error]
)
@@ -207,13 +204,15 @@ def test_where_filter_intersection_error_collection() -> None:
filter_intersection.filter_expression_parameter_sets

error_string = str(exc_info.value)
logger.error(f"Error string is:\n{error_string}")
# These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find.
assert ParameterSetFactory._exception_message_for_incorrect_format("order_id__order_time__month") in error_string
assert "Entity name is in an incorrect format: 'order_id__is_food_order'" in error_string
# We cannot simply scan for name because the error message contains the filter list, so we assert against the error
assert (
ParameterSetFactory._exception_message_for_incorrect_format("customer__has_delivery_address")
not in error_string
"Name is in an incorrect format: 'order_id__is_food_order__day'. It should not contain a time grain "
"suffix." in error_string
)
assert (
"Name is in an incorrect format: 'order_id__order_time__month'. It should be of the form: "
"<primary entity name>__<dimension_name>" in error_string
)


25 changes: 25 additions & 0 deletions tests/parsing/test_where_filter_parsing.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@

import pytest

from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet
from dbt_semantic_interfaces.implementations.base import HashableBaseModel
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
@@ -21,6 +22,7 @@
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.date_part import DatePart

__BOOLEAN_EXPRESSION__ = "1 > 0"
@@ -162,3 +164,26 @@ def test_dimension_date_part() -> None: # noqa
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.time_dimension_call_parameter_sets) == 1
assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR


def test_entity_without_primary_entity_prefix() -> None: # noqa
where = "{{ Entity('non_primary_entity') }} = '1'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.entity_call_parameter_sets) == 1
assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet(
entity_path=(),
entity_reference=EntityReference(element_name="non_primary_entity"),
)


def test_entity() -> None: # noqa
where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.entity_call_parameter_sets) == 1
assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet(
entity_path=(
EntityReference(element_name="entity_0"),
EntityReference(element_name="entity_1"),
),
entity_reference=EntityReference(element_name="entity_2"),
)

0 comments on commit 05cb6d2

Please sign in to comment.