Skip to content

Commit

Permalink
Added fields indicating the argument name or position of each tensor …
Browse files Browse the repository at this point in the history
…in the forward pass of a submodule.
  • Loading branch information
JohnMark Taylor committed Oct 24, 2024
1 parent 07ddc54 commit 5682e27
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name="torchlens",
version="0.1.22",
version="0.1.23",
description="A package for extracting activations from PyTorch models",
long_description="A package for extracting activations from PyTorch models. Contains functionality for "
"extracting model activations, visualizing a model's computational graph, and "
Expand Down
2 changes: 2 additions & 0 deletions torchlens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
"module_num_tensors",
"module_pass_num_tensors",
"module_layers",
"module_layer_argnames",
"module_pass_layers",
# Time elapsed
"pass_start_time",
Expand Down Expand Up @@ -231,6 +232,7 @@
"module_nesting_depth",
"modules_entered",
"module_passes_entered",
"modules_entered_argnames",
"is_submodule_input",
"modules_exited",
"module_passes_exited",
Expand Down
17 changes: 17 additions & 0 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(self, fields_dict: Dict):
]
self.module_nesting_depth = fields_dict["module_nesting_depth"]
self.modules_entered = fields_dict["modules_entered"]
self.modules_entered_argnames = fields_dict["modules_entered_argnames"]
self.module_passes_entered = fields_dict["module_passes_entered"]
self.is_submodule_input = fields_dict["is_submodule_input"]
self.modules_exited = fields_dict["modules_exited"]
Expand Down Expand Up @@ -853,6 +854,7 @@ def __init__(
self.module_pass_num_tensors: Dict = defaultdict(lambda: 0)
self.module_layers: Dict = defaultdict(list)
self.module_pass_layers: Dict = defaultdict(list)
self.module_layer_argnames = defaultdict(list)

# Time elapsed:
self.pass_start_time: float = 0
Expand Down Expand Up @@ -1361,6 +1363,12 @@ def decorated_forward(*args, **kwargs):
tensor_entry.modules_entered.append(module_address)
tensor_entry.module_passes_entered.append(module_pass_label)
tensor_entry.is_submodule_input = True
for arg_key, arg_val in list(enumerate(args)) + list(kwargs.items()):
if arg_val is t:
tensor_entry.modules_entered_argnames[
f"{module_pass_label[0]}:{module_pass_label[1]}"].append(arg_key)
self.module_layer_argnames[(f"{module_pass_label[0]}:"
f"{module_pass_label[1]}")].append((t.tl_tensor_label_raw, arg_key))
tensor_entry.module_entry_exit_thread_output.append(
("+", module_pass_label[0], module_pass_label[1])
)
Expand Down Expand Up @@ -2131,6 +2139,7 @@ def log_source_tensor_exhaustive(
"containing_modules_origin_nested": [],
"module_nesting_depth": 0,
"modules_entered": [],
"modules_entered_argnames": defaultdict(list),
"module_passes_entered": [],
"is_submodule_input": False,
"modules_exited": [],
Expand Down Expand Up @@ -2406,6 +2415,7 @@ def log_function_output_tensors_exhaustive(
] = containing_modules_origin_nested
fields_dict["module_nesting_depth"] = len(containing_modules_origin_nested)
fields_dict["modules_entered"] = []
fields_dict["modules_entered_argnames"] = defaultdict(list)
fields_dict["module_passes_entered"] = []
fields_dict["is_submodule_input"] = False
fields_dict["modules_exited"] = []
Expand Down Expand Up @@ -4875,6 +4885,13 @@ def _rename_model_history_layer_names(self):
)
self.conditional_branch_edges[t] = (new_child, new_parent)

for module_pass, arglist in self.module_layer_argnames.items():
for a, arg in enumerate(arglist):
raw_name = self.module_layer_argnames[module_pass][a][0]
new_name = self.raw_to_final_layer_labels[raw_name]
argname = self.module_layer_argnames[module_pass][a][1]
self.module_layer_argnames[module_pass][a] = (new_name, argname)

def _trim_and_reorder_model_history_fields(self):
"""
Sorts the fields in ModelHistory into their desired order, and trims any
Expand Down

0 comments on commit 5682e27

Please sign in to comment.