diff --git a/setup.py b/setup.py index 685f373..d71421a 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="torchlens", - version="0.1.7", + version="0.1.8", description="A package for extracting activations from PyTorch models", long_description="A package for extracting activations from PyTorch models. Contains functionality for " "extracting model activations, visualizing a model's computational graph, and " diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index 27889bf..89ac14f 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -353,7 +353,7 @@ def extend_search_stack_from_item( ) for attr_name in dir(item): - if attr_name.startswith("__"): + if (attr_name.startswith("__")) or (attr_name == 'T') or ('grad' in attr_name): continue try: attr = getattr(item, attr_name) diff --git a/torchlens/model_history.py b/torchlens/model_history.py index c8df2b0..08d729e 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -30,6 +30,7 @@ identity, in_notebook, int_list_to_compact_str, + is_iterable, log_current_rng_states, make_random_barcode, make_short_barcode_from_input, @@ -3022,10 +3023,16 @@ def _get_call_stack_dicts(): "call_linenum": caller.lineno, "function": caller.function, "code_context": caller.code_context, - "code_context_str": "".join(caller.code_context), } for caller in call_stack ] + + for call_stack_dict in call_stack_dicts: + if is_iterable(call_stack_dict['code_context']): + call_stack_dict['code_context_str'] = ''.join(call_stack_dict['code_context']) + else: + call_stack_dict['code_context_str'] = str(call_stack_dict['code_context']) + # Only start at the level of that first forward pass, going from shallow to deep. tracking = False filtered_dicts = []