Skip to content

Commit

Permalink
Fixed issue with layer renaming.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Nov 14, 2024
1 parent 5682e27 commit 86dafc0
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4886,11 +4886,17 @@ def _rename_model_history_layer_names(self):
self.conditional_branch_edges[t] = (new_child, new_parent)

for module_pass, arglist in self.module_layer_argnames.items():
inds_to_remove = []
for a, arg in enumerate(arglist):
raw_name = self.module_layer_argnames[module_pass][a][0]
if raw_name not in self.raw_to_final_layer_labels:
inds_to_remove.append(a)
continue
new_name = self.raw_to_final_layer_labels[raw_name]
argname = self.module_layer_argnames[module_pass][a][1]
self.module_layer_argnames[module_pass][a] = (new_name, argname)
self.module_layer_argnames[module_pass] = [self.module_layer_argnames[module_pass][i]
for i in range(len(arglist)) if i not in inds_to_remove]

def _trim_and_reorder_model_history_fields(self):
"""
Expand Down Expand Up @@ -6441,7 +6447,7 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor(
and ((('scale_factor' in layer_to_validate_parents_for.creation_kwargs)
and torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_kwargs['scale_factor']))
torch.tensor(layer_to_validate_parents_for.creation_kwargs['scale_factor'])))
or ((len(layer_to_validate_parents_for.creation_args) >= 3)
and torch.equal(
self[layers_to_perturb[0]].tensor_contents,
Expand Down Expand Up @@ -6472,7 +6478,7 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor(
]: # TODO: fix this
recomputed_output = input_args["args"][0]

if type(recomputed_output) in [list, tuple]:
if any([issubclass(type(recomputed_output), which_type) for which_type in [list, tuple]]):
recomputed_output = recomputed_output[
layer_to_validate_parents_for.iterable_output_index
]
Expand Down

0 comments on commit 86dafc0

Please sign in to comment.