From 3f0fefb351d09ccb680b0572d4bda9bf57fd9e6a Mon Sep 17 00:00:00 2001 From: JohnMark Taylor Date: Thu, 10 Aug 2023 20:22:46 -0400 Subject: [PATCH] Tidied up code. --- torchlens/helper_funcs.py | 1 + torchlens/model_history.py | 53 ++++++++++++++++++-------------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index d323cf4..1f381f7 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -357,6 +357,7 @@ def safe_to(x: Any, device: str): Args: x: The object. + device: which device to move to Returns: Object either moved to device if a tensor, same object if otherwise. diff --git a/torchlens/model_history.py b/torchlens/model_history.py index 2183c97..519aa6c 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -919,8 +919,8 @@ def decorate_pytorch(self, return decorated_func_mapper - def undecorate_pytorch(self, - torch_module, + @staticmethod + def undecorate_pytorch(torch_module, orig_func_defs: List[Tuple], input_tensors: List[torch.Tensor]): """ @@ -1086,7 +1086,8 @@ def decorated_forward(*args, **kwargs): out = orig_forward(*args, **kwargs) output_tensors = get_vars_of_type_from_obj(out, torch.Tensor) for t in output_tensors: - if module.tl_module_type.lower() == 'identity': # if identity module, run the function for bookkeeping + # if identity module, run the function for bookkeeping + if module.tl_module_type.lower() == 'identity': t = getattr(torch, 'identity')(t) return out @@ -1198,15 +1199,14 @@ def get_all_submodules(self, def cleanup_model(self, model: nn.Module, module_orig_forward_funcs: Dict[nn.Module, Callable], - model_device: str, decorated_func_mapper: Dict[Callable, Callable]): """Reverses all temporary changes to the model (namely, the forward hooks and added model attributes) that were added for PyTorch x-ray (scout's honor; leave no trace). Args: model: PyTorch model. - model_device: Device the model is stored on - module_orig_forward_funcs: Dict containing the original, undecorated forward pass functions for each submodule + module_orig_forward_funcs: Dict containing the original, undecorated forward pass functions for + each submodule decorated_func_mapper: Dict mapping between original and decorated PyTorch funcs Returns: @@ -1218,10 +1218,10 @@ def cleanup_model(self, continue submodule.forward = module_orig_forward_funcs[submodule] self.restore_model_attributes(model, decorated_func_mapper=decorated_func_mapper, attribute_keyword='tl') - self.undecorate_model_tensors(model, model_device) + self.undecorate_model_tensors(model) - def clear_hooks(self, - hook_handles: List): + @staticmethod + def clear_hooks(hook_handles: List): """Takes in a list of tuples (module, hook_handle), and clears the hook at that handle for each module. @@ -1234,8 +1234,8 @@ def clear_hooks(self, for hook_handle in hook_handles: hook_handle.remove() - def restore_module_attributes(self, - module: nn.Module, + @staticmethod + def restore_module_attributes(module: nn.Module, decorated_func_mapper: Dict[Callable, Callable], attribute_keyword: str = 'tl'): for attribute_name in dir(module): @@ -1269,14 +1269,13 @@ def restore_model_attributes(self, param.requires_grad = getattr(param, 'tl_requires_grad') delattr(param, 'tl_requires_grad') - def undecorate_model_tensors(self, model: nn.Module, model_device: str): + def undecorate_model_tensors(self, model: nn.Module): """Goes through a model and all its submodules, and unmutates any tensor attributes. Normally just clearing parameters would have done this, but some module types (e.g., batchnorm) contain attributes that are tensors, but not parameters. Args: model: PyTorch model - model_device: device the model is stored on Returns: PyTorch model with unmutated versions of all tensor attributes. @@ -1342,11 +1341,11 @@ def save_new_activations(self, layers_to_save: Union[str, List] = 'all', random_seed: Optional[int] = None): """Saves activations to a new input to the model, replacing existing saved activations. - This will be much faster than the initial call to log_forward_pass (since all of the metadata has already been - saved), so if you wish to save the activations to many different inputs for a given model this is the function - you should use. The one caveat is that this function assumes that the computational graph will be the same - for the new input; if the model involves a dynamic computational graph that can change across inputs, - and this graph changes for the new input, then this function will throw an error. In that case, + This will be much faster than the initial call to log_forward_pass (since all the of the metadata has + already been saved), so if you wish to save the activations to many different inputs for a given model + this is the function you should use. The one caveat is that this function assumes that the computational + graph will be the same for the new input; if the model involves a dynamic computational graph that can change + across inputs, and this graph changes for the new input, then this function will throw an error. In that case, you'll have to do a new call to log_forward_pass to log the new graph. Args: @@ -1458,13 +1457,13 @@ def run_and_log_inputs_through_model(self, self.raw_tensor_dict[t.tl_tensor_label_raw].is_output_parent = True tensors_to_undecorate = tensors_to_decorate + output_tensors self.undecorate_pytorch(torch, orig_func_defs, tensors_to_undecorate) - self.cleanup_model(model, module_orig_forward_funcs, model_device, decorated_func_mapper) + self.cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper) self.postprocess() decorated_func_mapper.clear() except Exception as e: # if anything fails, make sure everything gets cleaned up self.undecorate_pytorch(torch, orig_func_defs, input_tensors) - self.cleanup_model(model, module_orig_forward_funcs, model_device, decorated_func_mapper) + self.cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper) print("************\nFeature extraction failed; returning model and environment to normal\n*************") raise e @@ -2017,7 +2016,7 @@ def _add_backward_hook(self, t: torch.Tensor, tensor_label): """ # Define the decorator - def log_grad_to_model_history(g_in, g_out): + def log_grad_to_model_history(_, g_out): self._log_tensor_grad(g_out, tensor_label) if t.grad_fn is not None: @@ -3606,7 +3605,7 @@ def _add_lookup_keys_for_tensor_entry(self, # Allow indexing by modules exited as well: for module_pass in tensor_entry.module_passes_exited: - module_name, pass_num = module_pass.split(':') + module_name, _ = module_pass.split(':') lookup_keys_for_tensor.append(f"{module_pass}") if self.module_num_passes[module_name] == 1: lookup_keys_for_tensor.append(f"{module_name}") @@ -3688,9 +3687,6 @@ def _trim_and_reorder_model_history_fields(self): def _undecorate_all_saved_tensors(self): """Utility function to undecorate all saved tensors. - - Args: - decorated_func_mapper: Maps decorated functions to their original versions. """ tensors_to_undecorate = [] for layer_label in self.layer_labels: @@ -3980,7 +3976,7 @@ def _get_node_address_shape_color(self, if node.is_bottom_level_submodule_output: if type(node) == TensorLogEntry: module_pass_exited = node.bottom_level_submodule_pass_exited - module, pass_num = module_pass_exited.split(':') + module, _ = module_pass_exited.split(':') if self.module_num_passes[module] == 1: node_address = module else: @@ -4965,7 +4961,8 @@ def _str_after_pass(self) -> str: # Model tensors: s += "\n\tTensor info:" - s += f"\n\t\t- {self.num_tensors_total} total tensors ({self.tensor_fsize_total_nice}) computed in forward pass." + s += f"\n\t\t- {self.num_tensors_total} total tensors ({self.tensor_fsize_total_nice}) " \ + f"computed in forward pass." s += f"\n\t\t- {self.num_tensors_saved} tensors ({self.tensor_fsize_saved_nice}) with saved activations." # Model parameters: @@ -4993,7 +4990,7 @@ def _str_after_pass(self) -> str: else: pass_str = '' - if (self.layer_dict_main_keys[layer_barcode].has_saved_activations) and (not self.all_layers_saved): + if self.layer_dict_main_keys[layer_barcode].has_saved_activations and (not self.all_layers_saved): s += "\n\t\t* " else: s += "\n\t\t "