diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index 89ac14f..eb9c5b5 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -356,7 +356,9 @@ def extend_search_stack_from_item( if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name): continue try: - attr = getattr(item, attr_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attr = getattr(item, attr_name) except: continue attr_cls = type(attr) @@ -417,6 +419,9 @@ def nested_getattr(obj: Any, attr: str) -> Any: if a in [ "volatile", "T", + 'H', + 'mH', + 'mT' ]: # avoid annoying warning; if there's more, make a list with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -438,7 +443,9 @@ def nested_assign(obj, addr, val): if entry_type == "ind": obj = obj[entry_val] elif entry_type == "attr": - obj = getattr(obj, entry_val) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + obj = getattr(obj, entry_val) def remove_attributes_starting_with_str(obj: Any, s: str): diff --git a/torchlens/model_history.py b/torchlens/model_history.py index 0e21dd8..6d77bc4 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -9,6 +9,7 @@ import types from collections import OrderedDict, defaultdict from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +import warnings import graphviz import numpy as np @@ -1085,7 +1086,9 @@ def decorate_pytorch( continue new_func = self.torch_func_decorator(orig_func) try: - setattr(local_func_namespace, func_name, new_func) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(local_func_namespace, func_name, new_func) except (AttributeError, TypeError) as _: pass new_func.tl_is_decorated_function = True @@ -1117,10 +1120,14 @@ def undecorate_pytorch( for namespace_name, func_name, orig_func in orig_func_defs: namespace_name_notorch = namespace_name.replace("torch.", "") local_func_namespace = nested_getattr(torch_module, namespace_name_notorch) - decorated_func = getattr(local_func_namespace, func_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + decorated_func = getattr(local_func_namespace, func_name) del decorated_func try: - setattr(local_func_namespace, func_name, orig_func) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(local_func_namespace, func_name, orig_func) except (AttributeError, TypeError) as _: continue delattr(torch, "identity") @@ -1260,7 +1267,9 @@ def prepare_buffer_tensors(self, model: nn.Module): submodules = self.get_all_submodules(model) for submodule in submodules: for attribute_name in dir(submodule): - attribute = getattr(submodule, attribute_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attribute = getattr(submodule, attribute_name) if issubclass(type(attribute), torch.Tensor) and not issubclass( type(attribute), torch.nn.Parameter ): @@ -1484,13 +1493,17 @@ def restore_module_attributes( if attribute_name.startswith(attribute_keyword): delattr(module, attribute_name) continue - attr = getattr(module, attribute_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attr = getattr(module, attribute_name) if ( isinstance(attr, Callable) and (attr in decorated_func_mapper) and (attribute_name[0:2] != "__") ): - setattr(module, attribute_name, decorated_func_mapper[attr]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + setattr(module, attribute_name, decorated_func_mapper[attr]) def restore_model_attributes( self, @@ -1533,7 +1546,9 @@ def undecorate_model_tensors(self, model: nn.Module): submodules = self.get_all_submodules(model) for submodule in submodules: for attribute_name in dir(submodule): - attribute = getattr(submodule, attribute_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attribute = getattr(submodule, attribute_name) if issubclass(type(attribute), torch.Tensor): if not issubclass(type(attribute), torch.nn.Parameter) and hasattr( attribute, "tl_tensor_label_raw" @@ -3077,8 +3092,10 @@ def _remove_log_entry( else: tensor_label = log_entry.tensor_label_raw for attr in dir(log_entry): - if not attr.startswith("_") and not callable(getattr(log_entry, attr)): - delattr(log_entry, attr) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if not attr.startswith("_") and not callable(getattr(log_entry, attr)): + delattr(log_entry, attr) del log_entry if remove_references: self._remove_log_entry_references(tensor_label) @@ -4152,7 +4169,7 @@ def _fix_modules_for_single_internal_tensor( for enter_or_exit, module_address, module_pass in thread_modules[::step_val]: module_pass_label = (module_address, module_pass) if node_type_to_fix == "parent": - if enter_or_exit == "+": + if (enter_or_exit == "+") and (module_pass_label in node_to_fix.containing_modules_origin_nested): node_to_fix.containing_modules_origin_nested.remove( module_pass_label ) @@ -4664,7 +4681,9 @@ def _trim_and_reorder_model_history_fields(self): new_dir_dict[field] = getattr(self, field) for field in dir(self): if field.startswith("_"): - new_dir_dict[field] = getattr(self, field) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_dir_dict[field] = getattr(self, field) self.__dict__ = new_dir_dict def _undecorate_all_saved_tensors(self):