diff --git a/metricflow/model/semantics/linkable_element_set.py b/metricflow/model/semantics/linkable_element_set.py index 424a02d5c4..7ef277b61f 100644 --- a/metricflow/model/semantics/linkable_element_set.py +++ b/metricflow/model/semantics/linkable_element_set.py @@ -26,6 +26,35 @@ class LinkableElementSet: path_key_to_linkable_entities: Dict[ElementPathKey, Tuple[LinkableEntity, ...]] = field(default_factory=dict) path_key_to_linkable_metrics: Dict[ElementPathKey, Tuple[LinkableMetric, ...]] = field(default_factory=dict) + def __post_init__(self) -> None: + """Basic validation for ensuring consistency between path key type and value type.""" + mismatched_dimensions = tuple( + path_key + for path_key in self.path_key_to_linkable_dimensions.keys() + if not path_key.element_type.is_dimension_type + ) + mismatched_entities = tuple( + path_key + for path_key in self.path_key_to_linkable_entities + if path_key.element_type is not LinkableElementType.ENTITY + ) + mismatched_metrics = tuple( + path_key + for path_key in self.path_key_to_linkable_metrics + if path_key.element_type is not LinkableElementType.METRIC + ) + + mismatched_elements = { + "dimensions": mismatched_dimensions, + "entities": mismatched_entities, + "metrics": mismatched_metrics, + } + + assert all(len(mismatches) == 0 for mismatches in mismatched_elements.values()), ( + f"Found one or more elements where the element type defined in the path key does not match the value " + f"type! Mismatched elements: {mismatched_elements}" + ) + @staticmethod def merge_by_path_key(linkable_element_sets: Sequence[LinkableElementSet]) -> LinkableElementSet: """Combine multiple sets together by the path key. @@ -237,21 +266,21 @@ def as_spec_set(self) -> LinkableSpecSet: # noqa: D102 @property def only_unique_path_keys(self) -> LinkableElementSet: - """Returns a set that only includes path keys that map to a single element.""" + """Returns a set that only includes path keys that map to a single distinct element.""" return LinkableElementSet( path_key_to_linkable_dimensions={ - path_key: linkable_dimensions + path_key: tuple(set(linkable_dimensions)) for path_key, linkable_dimensions in self.path_key_to_linkable_dimensions.items() - if len(linkable_dimensions) <= 1 + if len(set(linkable_dimensions)) <= 1 }, path_key_to_linkable_entities={ - path_key: linkable_entities + path_key: tuple(set(linkable_entities)) for path_key, linkable_entities in self.path_key_to_linkable_entities.items() - if len(linkable_entities) <= 1 + if len(set(linkable_entities)) <= 1 }, path_key_to_linkable_metrics={ - path_key: linkable_metrics + path_key: tuple(set(linkable_metrics)) for path_key, linkable_metrics in self.path_key_to_linkable_metrics.items() - if len(linkable_metrics) <= 1 + if len(set(linkable_metrics)) <= 1 }, ) diff --git a/tests/model/semantics/test_linkable_element_set.py b/tests/model/semantics/test_linkable_element_set.py index a4842b5de4..578e5dcd8d 100644 --- a/tests/model/semantics/test_linkable_element_set.py +++ b/tests/model/semantics/test_linkable_element_set.py @@ -469,7 +469,8 @@ def test_only_unique_path_keys() -> None: unique_path_keys = base_set.only_unique_path_keys assert unique_path_keys.path_key_to_linkable_dimensions == { - _time_dimension.path_key: (_time_dimension,) + _categorical_dimension.path_key: (_categorical_dimension,), + _time_dimension.path_key: (_time_dimension,), }, "Got an unexpected value for unique dimensions sets!" assert unique_path_keys.path_key_to_linkable_entities == { _base_entity.path_key: (_base_entity,)