Skip to content

Commit

Permalink
Merge pull request #261 from dbt-labs/plypaul--56--backport-256-to-0.4
Browse files Browse the repository at this point in the history
Backport #256 to `0.4.latest`
  • Loading branch information
plypaul authored Feb 15, 2024
2 parents 48e3fb7 + 35f92b2 commit da42d9e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 24 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240129-110034.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Referencing entities in where filter causes an error.
time: 2024-01-29T11:00:34.690723-08:00
custom:
Author: plypaul
Issue: "256"
17 changes: 11 additions & 6 deletions dbt_semantic_interfaces/implementations/filters/where_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import textwrap
import traceback
from typing import Callable, Generator, List, Tuple

from typing_extensions import Self
Expand All @@ -16,7 +18,6 @@
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)
from dbt_semantic_interfaces.pretty_print import pformat_big_objects


class PydanticWhereFilter(PydanticCustomInputParser, HashableBaseModel):
Expand Down Expand Up @@ -126,10 +127,14 @@ def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParamete
invalid_filter_expressions.append((where_filter.where_sql_template, e))

if invalid_filter_expressions:
raise ParseWhereFilterException(
f"Encountered one or more errors when parsing the set of filter expressions "
f"{pformat_big_objects(self.where_filters)}! Invalid expressions: \n "
f"{pformat_big_objects(invalid_filter_expressions)}"
)
lines = ["Encountered error(s) while parsing:\n"]
for where_sql_template, exception in invalid_filter_expressions:
lines.append("Filter:")
lines.append(textwrap.indent(where_sql_template, prefix=" "))
lines.append("Error Message:")
lines.append(textwrap.indent(str(exception), prefix=" "))
lines.append("Traceback:")
lines.append(textwrap.indent("".join(traceback.format_tb(exception.__traceback__)), prefix=" "))
raise ParseWhereFilterException("\n".join(lines))

return filter_parameter_sets
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ParameterSetFactory:
@staticmethod
def _exception_message_for_incorrect_format(element_name: str) -> str:
return (
f"Name is in an incorrect format: '{element_name}'. It should be of the form: "
f"Name is in an incorrect format: {repr(element_name)}. It should be of the form: "
f"<primary entity name>__<dimension_name>"
)

Expand Down Expand Up @@ -87,14 +87,17 @@ def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> Di
@staticmethod
def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet:
"""Gets called by Jinja when rendering {{ Entity(...) }}."""
group_by_item_name = DunderedNameFormatter.parse_name(entity_name)
if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None:
structured_dundered_name = DunderedNameFormatter.parse_name(entity_name)
if structured_dundered_name.time_granularity is not None:
raise ParseWhereFilterException(
f"Entity name is in an incorrect format: '{entity_name}'. "
f"It should not contain any dunders (double underscores, or __)."
f"Name is in an incorrect format: {repr(entity_name)}. " f"It should not contain a time grain suffix."
)

additional_entity_path_elements = tuple(
EntityReference(element_name=entity_path_item) for entity_path_item in entity_path
)

return EntityCallParameterSet(
entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path),
entity_reference=EntityReference(element_name=entity_name),
entity_path=additional_entity_path_elements + structured_dundered_name.entity_links,
entity_reference=EntityReference(element_name=structured_dundered_name.element_name),
)
20 changes: 9 additions & 11 deletions tests/implementations/where_filter/test_parse_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
PydanticWhereFilter,
PydanticWhereFilterIntersection,
)
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.references import (
DimensionReference,
EntityReference,
Expand Down Expand Up @@ -145,9 +142,9 @@ def test_metric_time_in_dimension_call_error() -> None: # noqa: D

def test_invalid_entity_name_error() -> None:
"""Test to ensure we throw an error if an entity name is invalid."""
bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order' )}}")
bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}")

with pytest.raises(ParseWhereFilterException, match="Entity name is in an incorrect format"):
with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"):
bad_entity_filter.call_parameter_sets


Expand Down Expand Up @@ -198,7 +195,7 @@ def test_where_filter_intersection_error_collection() -> None:
where_sql_template="{{ TimeDimension('order_id__order_time__month', 'week') }} > '2020-01-01'"
)
valid_dimension = PydanticWhereFilter(where_sql_template=" {Dimension('customer__has_delivery_address')} ")
entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order') }}")
entity_format_error = PydanticWhereFilter(where_sql_template="{{ Entity('order_id__is_food_order__day') }}")
filter_intersection = PydanticWhereFilterIntersection(
where_filters=[metric_time_in_dimension_error, valid_dimension, entity_format_error]
)
Expand All @@ -208,12 +205,13 @@ def test_where_filter_intersection_error_collection() -> None:

error_string = str(exc_info.value)
# These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find.
assert ParameterSetFactory._exception_message_for_incorrect_format("order_id__order_time__month") in error_string
assert "Entity name is in an incorrect format: 'order_id__is_food_order'" in error_string
# We cannot simply scan for name because the error message contains the filter list, so we assert against the error
assert (
ParameterSetFactory._exception_message_for_incorrect_format("customer__has_delivery_address")
not in error_string
"Name is in an incorrect format: 'order_id__is_food_order__day'. It should not contain a time grain "
"suffix." in error_string
)
assert (
"Name is in an incorrect format: 'order_id__order_time__month'. It should be of the form: "
"<primary entity name>__<dimension_name>" in error_string
)


Expand Down
25 changes: 25 additions & 0 deletions tests/parsing/test_where_filter_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import pytest

from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet
from dbt_semantic_interfaces.implementations.base import HashableBaseModel
from dbt_semantic_interfaces.implementations.filters.where_filter import (
PydanticWhereFilter,
Expand All @@ -21,6 +22,7 @@
from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import (
WhereFilterParser,
)
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.date_part import DatePart

__BOOLEAN_EXPRESSION__ = "1 > 0"
Expand Down Expand Up @@ -162,3 +164,26 @@ def test_dimension_date_part() -> None: # noqa
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.time_dimension_call_parameter_sets) == 1
assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR


def test_entity_without_primary_entity_prefix() -> None: # noqa
where = "{{ Entity('non_primary_entity') }} = '1'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.entity_call_parameter_sets) == 1
assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet(
entity_path=(),
entity_reference=EntityReference(element_name="non_primary_entity"),
)


def test_entity() -> None: # noqa
where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'"
param_sets = WhereFilterParser.parse_call_parameter_sets(where)
assert len(param_sets.entity_call_parameter_sets) == 1
assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet(
entity_path=(
EntityReference(element_name="entity_0"),
EntityReference(element_name="entity_1"),
),
entity_reference=EntityReference(element_name="entity_2"),
)

0 comments on commit da42d9e

Please sign in to comment.