Skip to content

Commit

Permalink
Fixed a bug with fetching the code context.
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Sep 21, 2023
1 parent 8aec5c6 commit 3556f9b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name="torchlens",
version="0.1.7",
version="0.1.8",
description="A package for extracting activations from PyTorch models",
long_description="A package for extracting activations from PyTorch models. Contains functionality for "
"extracting model activations, visualizing a model's computational graph, and "
Expand Down
2 changes: 1 addition & 1 deletion torchlens/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def extend_search_stack_from_item(
)

for attr_name in dir(item):
if attr_name.startswith("__"):
if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name):
continue
try:
attr = getattr(item, attr_name)
Expand Down
9 changes: 8 additions & 1 deletion torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
identity,
in_notebook,
int_list_to_compact_str,
is_iterable,
log_current_rng_states,
make_random_barcode,
make_short_barcode_from_input,
Expand Down Expand Up @@ -3022,10 +3023,16 @@ def _get_call_stack_dicts():
"call_linenum": caller.lineno,
"function": caller.function,
"code_context": caller.code_context,
"code_context_str": "".join(caller.code_context),
}
for caller in call_stack
]

for call_stack_dict in call_stack_dicts:
if is_iterable(call_stack_dict['code_context']):
call_stack_dict['code_context_str'] = ''.join(call_stack_dict['code_context'])
else:
call_stack_dict['code_context_str'] = str(call_stack_dict['code_context'])

# Only start at the level of that first forward pass, going from shallow to deep.
tracking = False
filtered_dicts = []
Expand Down

0 comments on commit 3556f9b

Please sign in to comment.