From 19d6556e67ab53e8ba74d9e9ec6346553965c04a Mon Sep 17 00:00:00 2001 From: JohnMark Taylor Date: Thu, 22 Aug 2024 14:24:56 -0700 Subject: [PATCH] Fixed a bug where TorchLens wasn't accounting for the dim argument in torch.max and torch.min (they return a tuple subclass with the values and the indices). --- torchlens/helper_funcs.py | 4 ++-- torchlens/model_history.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index eb9c5b5..0212d2e 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -130,9 +130,9 @@ def make_var_iterable(x): Returns: Iterable output """ - if type(x) in [tuple, list, set]: + if any([issubclass(type(x), cls) for cls in [list, tuple, set]]): return x - if issubclass(type(x), dict): + elif issubclass(type(x), dict): return list(x.values()) else: return [x] diff --git a/torchlens/model_history.py b/torchlens/model_history.py index bd8bf4d..4bed0a9 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -2424,7 +2424,7 @@ def log_function_output_tensors_exhaustive( } fields_dict["module_entry_exit_thread_output"] = [] - is_part_of_iterable_output = type(out_orig) in [list, tuple, dict, set] + is_part_of_iterable_output = any([issubclass(type(out_orig), cls) for cls in [list, tuple, dict, set]]) fields_dict["is_part_of_iterable_output"] = is_part_of_iterable_output out_iter = make_var_iterable(out_orig) # so we can iterate through it