From 234e8a90c007a416c5100d976fd4e642c35fec72 Mon Sep 17 00:00:00 2001 From: JohnMark Taylor Date: Tue, 3 Oct 2023 16:43:10 -0400 Subject: [PATCH] Small fix with scalar parameters. --- setup.py | 2 +- tests/test_validation_and_visuals.py | 15 +-------------- torchlens/model_history.py | 13 ++++++++++++- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 4529bae..6a76251 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="torchlens", - version="0.1.9", + version="0.1.10", description="A package for extracting activations from PyTorch models", long_description="A package for extracting activations from PyTorch models. Contains functionality for " "extracting model activations, visualizing a model's computational graph, and " diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py index 2e38f39..696602d 100644 --- a/tests/test_validation_and_visuals.py +++ b/tests/test_validation_and_visuals.py @@ -2330,19 +2330,6 @@ def test_timm_beit_base_patch16_224(default_input1): assert validate_saved_activations(model, default_input1) -def test_timm_beit_base_patch16_224_in22k(default_input1): - model = timm.models.beit_base_patch16_224_in22k() - show_model_graph( - model, - default_input1, - vis_opt="unrolled", - vis_outpath=opj( - "visualization_outputs", "timm", "timm_beit_base_patch16_224_in22k" - ), - ) - assert validate_saved_activations(model, default_input1) - - def test_timm_cait_s24_224(default_input1): model = timm.models.cait_s24_224() show_model_graph( @@ -2377,7 +2364,7 @@ def test_timm_convit_base(default_input1): def test_timm_darknet21(): - model = timm.create_model("darknet21", pretrained=True) + model = timm.create_model("darknet21", pretrained=False) model_input = torch.randn(1, 3, 224, 224) show_model_graph( model, diff --git a/torchlens/model_history.py b/torchlens/model_history.py index 237ca5c..22bcb9f 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -5180,8 +5180,10 @@ def _make_param_label(node: Union[TensorLogEntry, RolledTensorLogEntry]) -> str: for param_shape in node.parent_param_shapes: if len(param_shape) > 1: each_param_shape.append("x".join([str(s) for s in param_shape])) - else: + elif len(param_shape) == 1: each_param_shape.append(f"x{param_shape[0]}") + else: + each_param_shape.append("x1") param_label = "
params: " + ", ".join( [param_shape for param_shape in each_param_shape] @@ -6009,6 +6011,15 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor( self[layers_to_perturb[0]].tensor_contents, layer_to_validate_parents_for.creation_args[1], ) + ): + return True + elif ( + perturb + and (layer_to_validate_parents_for.func_applied_name == "__getitem__") + and not torch.equal( + self[layers_to_perturb[0]].tensor_contents, + layer_to_validate_parents_for.creation_args[0], + ) ): return True elif layer_to_validate_parents_for.func_applied_name == 'empty_like':