Skip to content

Commit

Permalink
Fixed an issue where tl_requires_grad wasn't being cleaned up properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Aug 30, 2024
1 parent d01a3d1 commit 4da3e31
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,8 +1591,9 @@ def restore_model_attributes(
)

for param in model.parameters():
param.requires_grad = getattr(param, "tl_requires_grad")
delattr(param, "tl_requires_grad")
if hasattr(param, "tl_requires_grad"):
param.requires_grad = getattr(param, "tl_requires_grad")
delattr(param, "tl_requires_grad")

def undecorate_model_tensors(self, model: nn.Module):
"""Goes through a model and all its submodules, and unmutates any tensor attributes. Normally just clearing
Expand Down

0 comments on commit 4da3e31

Please sign in to comment.