Skip to content

Commit

Permalink
Tidied up code.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Aug 11, 2023
1 parent 9336f45 commit 3f0fefb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 28 deletions.
1 change: 1 addition & 0 deletions torchlens/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def safe_to(x: Any, device: str):
Args:
x: The object.
device: which device to move to
Returns:
Object either moved to device if a tensor, same object if otherwise.
Expand Down
53 changes: 25 additions & 28 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,8 +919,8 @@ def decorate_pytorch(self,

return decorated_func_mapper

def undecorate_pytorch(self,
torch_module,
@staticmethod
def undecorate_pytorch(torch_module,
orig_func_defs: List[Tuple],
input_tensors: List[torch.Tensor]):
"""
Expand Down Expand Up @@ -1086,7 +1086,8 @@ def decorated_forward(*args, **kwargs):
out = orig_forward(*args, **kwargs)
output_tensors = get_vars_of_type_from_obj(out, torch.Tensor)
for t in output_tensors:
if module.tl_module_type.lower() == 'identity': # if identity module, run the function for bookkeeping
# if identity module, run the function for bookkeeping
if module.tl_module_type.lower() == 'identity':
t = getattr(torch, 'identity')(t)
return out

Expand Down Expand Up @@ -1198,15 +1199,14 @@ def get_all_submodules(self,
def cleanup_model(self,
model: nn.Module,
module_orig_forward_funcs: Dict[nn.Module, Callable],
model_device: str,
decorated_func_mapper: Dict[Callable, Callable]):
"""Reverses all temporary changes to the model (namely, the forward hooks and added
model attributes) that were added for PyTorch x-ray (scout's honor; leave no trace).
Args:
model: PyTorch model.
model_device: Device the model is stored on
module_orig_forward_funcs: Dict containing the original, undecorated forward pass functions for each submodule
module_orig_forward_funcs: Dict containing the original, undecorated forward pass functions for
each submodule
decorated_func_mapper: Dict mapping between original and decorated PyTorch funcs
Returns:
Expand All @@ -1218,10 +1218,10 @@ def cleanup_model(self,
continue
submodule.forward = module_orig_forward_funcs[submodule]
self.restore_model_attributes(model, decorated_func_mapper=decorated_func_mapper, attribute_keyword='tl')
self.undecorate_model_tensors(model, model_device)
self.undecorate_model_tensors(model)

def clear_hooks(self,
hook_handles: List):
@staticmethod
def clear_hooks(hook_handles: List):
"""Takes in a list of tuples (module, hook_handle), and clears the hook at that
handle for each module.
Expand All @@ -1234,8 +1234,8 @@ def clear_hooks(self,
for hook_handle in hook_handles:
hook_handle.remove()

def restore_module_attributes(self,
module: nn.Module,
@staticmethod
def restore_module_attributes(module: nn.Module,
decorated_func_mapper: Dict[Callable, Callable],
attribute_keyword: str = 'tl'):
for attribute_name in dir(module):
Expand Down Expand Up @@ -1269,14 +1269,13 @@ def restore_model_attributes(self,
param.requires_grad = getattr(param, 'tl_requires_grad')
delattr(param, 'tl_requires_grad')

def undecorate_model_tensors(self, model: nn.Module, model_device: str):
def undecorate_model_tensors(self, model: nn.Module):
"""Goes through a model and all its submodules, and unmutates any tensor attributes. Normally just clearing
parameters would have done this, but some module types (e.g., batchnorm) contain attributes that are tensors,
but not parameters.
Args:
model: PyTorch model
model_device: device the model is stored on
Returns:
PyTorch model with unmutated versions of all tensor attributes.
Expand Down Expand Up @@ -1342,11 +1341,11 @@ def save_new_activations(self,
layers_to_save: Union[str, List] = 'all',
random_seed: Optional[int] = None):
"""Saves activations to a new input to the model, replacing existing saved activations.
This will be much faster than the initial call to log_forward_pass (since all of the metadata has already been
saved), so if you wish to save the activations to many different inputs for a given model this is the function
you should use. The one caveat is that this function assumes that the computational graph will be the same
for the new input; if the model involves a dynamic computational graph that can change across inputs,
and this graph changes for the new input, then this function will throw an error. In that case,
This will be much faster than the initial call to log_forward_pass (since all the of the metadata has
already been saved), so if you wish to save the activations to many different inputs for a given model
this is the function you should use. The one caveat is that this function assumes that the computational
graph will be the same for the new input; if the model involves a dynamic computational graph that can change
across inputs, and this graph changes for the new input, then this function will throw an error. In that case,
you'll have to do a new call to log_forward_pass to log the new graph.
Args:
Expand Down Expand Up @@ -1458,13 +1457,13 @@ def run_and_log_inputs_through_model(self,
self.raw_tensor_dict[t.tl_tensor_label_raw].is_output_parent = True
tensors_to_undecorate = tensors_to_decorate + output_tensors
self.undecorate_pytorch(torch, orig_func_defs, tensors_to_undecorate)
self.cleanup_model(model, module_orig_forward_funcs, model_device, decorated_func_mapper)
self.cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper)
self.postprocess()
decorated_func_mapper.clear()

except Exception as e: # if anything fails, make sure everything gets cleaned up
self.undecorate_pytorch(torch, orig_func_defs, input_tensors)
self.cleanup_model(model, module_orig_forward_funcs, model_device, decorated_func_mapper)
self.cleanup_model(model, module_orig_forward_funcs, decorated_func_mapper)
print("************\nFeature extraction failed; returning model and environment to normal\n*************")
raise e

Expand Down Expand Up @@ -2017,7 +2016,7 @@ def _add_backward_hook(self, t: torch.Tensor, tensor_label):
"""

# Define the decorator
def log_grad_to_model_history(g_in, g_out):
def log_grad_to_model_history(_, g_out):
self._log_tensor_grad(g_out, tensor_label)

if t.grad_fn is not None:
Expand Down Expand Up @@ -3606,7 +3605,7 @@ def _add_lookup_keys_for_tensor_entry(self,

# Allow indexing by modules exited as well:
for module_pass in tensor_entry.module_passes_exited:
module_name, pass_num = module_pass.split(':')
module_name, _ = module_pass.split(':')
lookup_keys_for_tensor.append(f"{module_pass}")
if self.module_num_passes[module_name] == 1:
lookup_keys_for_tensor.append(f"{module_name}")
Expand Down Expand Up @@ -3688,9 +3687,6 @@ def _trim_and_reorder_model_history_fields(self):

def _undecorate_all_saved_tensors(self):
"""Utility function to undecorate all saved tensors.
Args:
decorated_func_mapper: Maps decorated functions to their original versions.
"""
tensors_to_undecorate = []
for layer_label in self.layer_labels:
Expand Down Expand Up @@ -3980,7 +3976,7 @@ def _get_node_address_shape_color(self,
if node.is_bottom_level_submodule_output:
if type(node) == TensorLogEntry:
module_pass_exited = node.bottom_level_submodule_pass_exited
module, pass_num = module_pass_exited.split(':')
module, _ = module_pass_exited.split(':')
if self.module_num_passes[module] == 1:
node_address = module
else:
Expand Down Expand Up @@ -4965,7 +4961,8 @@ def _str_after_pass(self) -> str:
# Model tensors:

s += "\n\tTensor info:"
s += f"\n\t\t- {self.num_tensors_total} total tensors ({self.tensor_fsize_total_nice}) computed in forward pass."
s += f"\n\t\t- {self.num_tensors_total} total tensors ({self.tensor_fsize_total_nice}) " \
f"computed in forward pass."
s += f"\n\t\t- {self.num_tensors_saved} tensors ({self.tensor_fsize_saved_nice}) with saved activations."

# Model parameters:
Expand Down Expand Up @@ -4993,7 +4990,7 @@ def _str_after_pass(self) -> str:
else:
pass_str = ''

if (self.layer_dict_main_keys[layer_barcode].has_saved_activations) and (not self.all_layers_saved):
if self.layer_dict_main_keys[layer_barcode].has_saved_activations and (not self.all_layers_saved):
s += "\n\t\t* "
else:
s += "\n\t\t "
Expand Down

0 comments on commit 3f0fefb

Please sign in to comment.