diff --git a/linkml_runtime/utils/schemaview.py b/linkml_runtime/utils/schemaview.py index e792366a..bbcd7271 100644 --- a/linkml_runtime/utils/schemaview.py +++ b/linkml_runtime/utils/schemaview.py @@ -1487,29 +1487,32 @@ def slot_range_as_union(self, slot: SlotDefinition) -> List[ElementName]: if x.range: range_union_of.append(x.range) return range_union_of - - def get_classes_by_slot(self, slot: SlotDefinition, include_induced: bool = False) -> List[ClassDefinitionName]: + + def get_classes_by_slot( + self, slot: SlotDefinition, include_induced: bool = False + ) -> List[ClassDefinitionName]: """Get all classes that use a given slot, either as a direct or induced slot. :param slot: slot in consideration :param include_induced: supplement all direct slots with induced slots, defaults to False :return: list of slots, either direct, or both direct and induced """ - slots_list = [] # list of all direct or induced slots + classes_set = set() # use set to avoid duplicates + all_classes = self.all_classes() - for c_name, c in self.all_classes().items(): - # check if slot is direct specification on class + for c_name, c in all_classes.items(): if slot.name in c.slots: - slots_list.append(c_name) + classes_set.add(c_name) - # include induced classes also if requested if include_induced: - for c_name, c in self.all_classes().items(): - for ind_slot in self.class_induced_slots(c_name): - if ind_slot.name == slot.name: - slots_list.append(c_name) - - return list(dict.fromkeys(slots_list)) + for c_name in all_classes: + induced_slot_names = [ + ind_slot.name for ind_slot in self.class_induced_slots(c_name) + ] + if slot.name in induced_slot_names: + classes_set.add(c_name) + + return list(classes_set) @lru_cache() def get_slots_by_enum(self, enum_name: ENUM_NAME = None) -> List[SlotDefinition]: diff --git a/tests/test_utils/test_schemaview.py b/tests/test_utils/test_schemaview.py index 53ea6f01..3e49e810 100644 --- a/tests/test_utils/test_schemaview.py +++ b/tests/test_utils/test_schemaview.py @@ -741,12 +741,12 @@ def test_get_classes_by_slot(self): actual_result = sv.get_classes_by_slot(slot) expected_result = ["Person"] - self.assertListEqual(expected_result, actual_result) + self.assertListEqual(sorted(expected_result), sorted(actual_result)) actual_result = sv.get_classes_by_slot(slot, include_induced=True) expected_result = ["Person", "Adult"] - self.assertListEqual(actual_result, expected_result) + self.assertListEqual(sorted(actual_result), sorted(expected_result)) def test_materialize_patterns(self): sv = SchemaView(SCHEMA_WITH_STRUCTURED_PATTERNS)