diff --git a/.changes/unreleased/Fixes-20240129-110034.yaml b/.changes/unreleased/Fixes-20240129-110034.yaml new file mode 100644 index 00000000..705768a2 --- /dev/null +++ b/.changes/unreleased/Fixes-20240129-110034.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Referencing entities in where filter causes an error. +time: 2024-01-29T11:00:34.690723-08:00 +custom: + Author: plypaul + Issue: "256" diff --git a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py index af8fc085..e0ef5332 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py +++ b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py @@ -87,14 +87,18 @@ def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> Di @staticmethod def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet: """Gets called by Jinja when rendering {{ Entity(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(entity_name) - if len(group_by_item_name.entity_links) > 0 or group_by_item_name.time_granularity is not None: + structured_dundered_name = DunderedNameFormatter.parse_name(entity_name) + if structured_dundered_name.time_granularity is not None: raise ParseWhereFilterException( f"Entity name is in an incorrect format: '{entity_name}'. " f"It should not contain any dunders (double underscores, or __)." ) + additional_entity_path_elements = tuple( + EntityReference(element_name=entity_path_item) for entity_path_item in entity_path + ) + return EntityCallParameterSet( - entity_path=tuple(EntityReference(element_name=arg) for arg in entity_path), - entity_reference=EntityReference(element_name=entity_name), + entity_path=additional_entity_path_elements + structured_dundered_name.entity_links, + entity_reference=EntityReference(element_name=structured_dundered_name.element_name), ) diff --git a/pyproject.toml b/pyproject.toml index 70a8a86d..edcd5d34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dbt-semantic-interfaces" -version = "0.5.0a3" +version = "0.5.0a4" description = 'The shared semantic layer definitions that dbt-core and MetricFlow use' readme = "README.md" requires-python = ">=3.8" diff --git a/tests/parsing/test_where_filter_parsing.py b/tests/parsing/test_where_filter_parsing.py index cf741f32..553d9fa1 100644 --- a/tests/parsing/test_where_filter_parsing.py +++ b/tests/parsing/test_where_filter_parsing.py @@ -13,6 +13,7 @@ import pytest +from dbt_semantic_interfaces.call_parameter_sets import EntityCallParameterSet from dbt_semantic_interfaces.implementations.base import HashableBaseModel from dbt_semantic_interfaces.implementations.filters.where_filter import ( PydanticWhereFilter, @@ -21,6 +22,7 @@ from dbt_semantic_interfaces.parsing.where_filter.where_filter_parser import ( WhereFilterParser, ) +from dbt_semantic_interfaces.references import EntityReference from dbt_semantic_interfaces.type_enums.date_part import DatePart __BOOLEAN_EXPRESSION__ = "1 > 0" @@ -162,3 +164,26 @@ def test_dimension_date_part() -> None: # noqa param_sets = WhereFilterParser.parse_call_parameter_sets(where) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR + + +def test_entity_without_primary_entity_prefix() -> None: # noqa + where = "{{ Entity('non_primary_entity') }} = '1'" + param_sets = WhereFilterParser.parse_call_parameter_sets(where) + assert len(param_sets.entity_call_parameter_sets) == 1 + assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( + entity_path=(), + entity_reference=EntityReference(element_name="non_primary_entity"), + ) + + +def test_entity() -> None: # noqa + where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'" + param_sets = WhereFilterParser.parse_call_parameter_sets(where) + assert len(param_sets.entity_call_parameter_sets) == 1 + assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( + entity_path=( + EntityReference(element_name="entity_0"), + EntityReference(element_name="entity_1"), + ), + entity_reference=EntityReference(element_name="entity_2"), + )