diff --git a/bmt/toolkit.py b/bmt/toolkit.py index 79bf629..b61da7a 100644 --- a/bmt/toolkit.py +++ b/bmt/toolkit.py @@ -441,22 +441,24 @@ def get_associations( A list of elements """ + filtered_elements: List[str] = list() + inverse_predicates: Optional[List[str]] = None subject_categories_formatted = [] - for sc in subject_categories: - sc_formatted = format_element(self.get_element(sc)) - subject_categories_formatted.append(sc_formatted) object_categories_formatted = [] - for oc in object_categories: - oc_formatted = format_element(self.get_element(oc)) - object_categories_formatted.append(oc_formatted) predicates_formatted = [] - for pred in predicates: - pred_formatted = format_element(self.get_element(pred)) - predicates_formatted.append(pred_formatted) association_elements = self.get_descendants("association") - filtered_elements: List[str] = list() - inverse_predicates: Optional[List[str]] = None - if predicates_formatted: + if subject_categories: + for sc in subject_categories: + sc_formatted = format_element(self.get_element(sc)) + subject_categories_formatted.append(sc_formatted) + if object_categories: + for oc in object_categories: + oc_formatted = format_element(self.get_element(oc)) + object_categories_formatted.append(oc_formatted) + if predicates: + for pred in predicates: + pred_formatted = format_element(self.get_element(pred)) + predicates_formatted.append(pred_formatted) inverse_predicates = list() for pred_curie in predicates_formatted: predicate = self.get_element(pred_curie) @@ -466,6 +468,8 @@ def get_associations( inverse_predicates.append(inverse_p) inverse_predicates = self._format_all_elements(elements=inverse_predicates, formatted=True) + + if subject_categories_formatted or predicates_formatted or object_categories_formatted: # This feels like a bit of a brute force approach as an implementation, # but we just use the list of all association names to retrieve each