From 5682e27fc8a5162b659104297f3973c9bb19d28a Mon Sep 17 00:00:00 2001 From: JohnMark Taylor Date: Thu, 24 Oct 2024 15:30:38 -0400 Subject: [PATCH] Added fields indicating the argument name or position of each tensor in the forward pass of a submodule. --- setup.py | 2 +- torchlens/constants.py | 2 ++ torchlens/model_history.py | 17 +++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6bfe700..760e44b 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="torchlens", - version="0.1.22", + version="0.1.23", description="A package for extracting activations from PyTorch models", long_description="A package for extracting activations from PyTorch models. Contains functionality for " "extracting model activations, visualizing a model's computational graph, and " diff --git a/torchlens/constants.py b/torchlens/constants.py index 278e8fb..1bea589 100644 --- a/torchlens/constants.py +++ b/torchlens/constants.py @@ -95,6 +95,7 @@ "module_num_tensors", "module_pass_num_tensors", "module_layers", + "module_layer_argnames", "module_pass_layers", # Time elapsed "pass_start_time", @@ -231,6 +232,7 @@ "module_nesting_depth", "modules_entered", "module_passes_entered", + "modules_entered_argnames", "is_submodule_input", "modules_exited", "module_passes_exited", diff --git a/torchlens/model_history.py b/torchlens/model_history.py index 939e921..fa0ef07 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -223,6 +223,7 @@ def __init__(self, fields_dict: Dict): ] self.module_nesting_depth = fields_dict["module_nesting_depth"] self.modules_entered = fields_dict["modules_entered"] + self.modules_entered_argnames = fields_dict["modules_entered_argnames"] self.module_passes_entered = fields_dict["module_passes_entered"] self.is_submodule_input = fields_dict["is_submodule_input"] self.modules_exited = fields_dict["modules_exited"] @@ -853,6 +854,7 @@ def __init__( self.module_pass_num_tensors: Dict = defaultdict(lambda: 0) self.module_layers: Dict = defaultdict(list) self.module_pass_layers: Dict = defaultdict(list) + self.module_layer_argnames = defaultdict(list) # Time elapsed: self.pass_start_time: float = 0 @@ -1361,6 +1363,12 @@ def decorated_forward(*args, **kwargs): tensor_entry.modules_entered.append(module_address) tensor_entry.module_passes_entered.append(module_pass_label) tensor_entry.is_submodule_input = True + for arg_key, arg_val in list(enumerate(args)) + list(kwargs.items()): + if arg_val is t: + tensor_entry.modules_entered_argnames[ + f"{module_pass_label[0]}:{module_pass_label[1]}"].append(arg_key) + self.module_layer_argnames[(f"{module_pass_label[0]}:" + f"{module_pass_label[1]}")].append((t.tl_tensor_label_raw, arg_key)) tensor_entry.module_entry_exit_thread_output.append( ("+", module_pass_label[0], module_pass_label[1]) ) @@ -2131,6 +2139,7 @@ def log_source_tensor_exhaustive( "containing_modules_origin_nested": [], "module_nesting_depth": 0, "modules_entered": [], + "modules_entered_argnames": defaultdict(list), "module_passes_entered": [], "is_submodule_input": False, "modules_exited": [], @@ -2406,6 +2415,7 @@ def log_function_output_tensors_exhaustive( ] = containing_modules_origin_nested fields_dict["module_nesting_depth"] = len(containing_modules_origin_nested) fields_dict["modules_entered"] = [] + fields_dict["modules_entered_argnames"] = defaultdict(list) fields_dict["module_passes_entered"] = [] fields_dict["is_submodule_input"] = False fields_dict["modules_exited"] = [] @@ -4875,6 +4885,13 @@ def _rename_model_history_layer_names(self): ) self.conditional_branch_edges[t] = (new_child, new_parent) + for module_pass, arglist in self.module_layer_argnames.items(): + for a, arg in enumerate(arglist): + raw_name = self.module_layer_argnames[module_pass][a][0] + new_name = self.raw_to_final_layer_labels[raw_name] + argname = self.module_layer_argnames[module_pass][a][1] + self.module_layer_argnames[module_pass][a] = (new_name, argname) + def _trim_and_reorder_model_history_fields(self): """ Sorts the fields in ModelHistory into their desired order, and trims any