diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index eb9c5b5..0212d2e 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -130,9 +130,9 @@ def make_var_iterable(x): Returns: Iterable output """ - if type(x) in [tuple, list, set]: + if any([issubclass(type(x), cls) for cls in [list, tuple, set]]): return x - if issubclass(type(x), dict): + elif issubclass(type(x), dict): return list(x.values()) else: return [x] diff --git a/torchlens/model_history.py b/torchlens/model_history.py index bd8bf4d..4bed0a9 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -2424,7 +2424,7 @@ def log_function_output_tensors_exhaustive( } fields_dict["module_entry_exit_thread_output"] = [] - is_part_of_iterable_output = type(out_orig) in [list, tuple, dict, set] + is_part_of_iterable_output = any([issubclass(type(out_orig), cls) for cls in [list, tuple, dict, set]]) fields_dict["is_part_of_iterable_output"] = is_part_of_iterable_output out_iter = make_var_iterable(out_orig) # so we can iterate through it