Skip to content

Commit

Permalink
optimize get_classes_by_slot() in schemaview.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sujaypatil96 committed Feb 20, 2024
1 parent c9adcae commit 54d9515
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
29 changes: 16 additions & 13 deletions linkml_runtime/utils/schemaview.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,29 +1487,32 @@ def slot_range_as_union(self, slot: SlotDefinition) -> List[ElementName]:
if x.range:
range_union_of.append(x.range)
return range_union_of

def get_classes_by_slot(self, slot: SlotDefinition, include_induced: bool = False) -> List[ClassDefinitionName]:

def get_classes_by_slot(
self, slot: SlotDefinition, include_induced: bool = False
) -> List[ClassDefinitionName]:
"""Get all classes that use a given slot, either as a direct or induced slot.
:param slot: slot in consideration
:param include_induced: supplement all direct slots with induced slots, defaults to False
:return: list of slots, either direct, or both direct and induced
"""
slots_list = [] # list of all direct or induced slots
classes_set = set() # use set to avoid duplicates
all_classes = self.all_classes()

for c_name, c in self.all_classes().items():
# check if slot is direct specification on class
for c_name, c in all_classes.items():
if slot.name in c.slots:
slots_list.append(c_name)
classes_set.add(c_name)

# include induced classes also if requested
if include_induced:
for c_name, c in self.all_classes().items():
for ind_slot in self.class_induced_slots(c_name):
if ind_slot.name == slot.name:
slots_list.append(c_name)

return list(dict.fromkeys(slots_list))
for c_name in all_classes:
induced_slot_names = [
ind_slot.name for ind_slot in self.class_induced_slots(c_name)
]
if slot.name in induced_slot_names:
classes_set.add(c_name)

return list(classes_set)

@lru_cache()
def get_slots_by_enum(self, enum_name: ENUM_NAME = None) -> List[SlotDefinition]:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils/test_schemaview.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,12 +741,12 @@ def test_get_classes_by_slot(self):
actual_result = sv.get_classes_by_slot(slot)
expected_result = ["Person"]

self.assertListEqual(expected_result, actual_result)
self.assertListEqual(sorted(expected_result), sorted(actual_result))

actual_result = sv.get_classes_by_slot(slot, include_induced=True)
expected_result = ["Person", "Adult"]

self.assertListEqual(actual_result, expected_result)
self.assertListEqual(sorted(actual_result), sorted(expected_result))

def test_materialize_patterns(self):
sv = SchemaView(SCHEMA_WITH_STRUCTURED_PATTERNS)
Expand Down

0 comments on commit 54d9515

Please sign in to comment.