From 9d67bcb3890b7dee58c5e01829f1e12ad7b63b43 Mon Sep 17 00:00:00 2001 From: Kale Kundert Date: Wed, 9 Oct 2024 11:40:57 -0400 Subject: [PATCH 1/2] More carefully avoid problematic attribute accesses --- requirements.test.txt | 13 ++++++++ tests/test_validation_and_visuals.py | 50 ++++++++++++++++++++++++++-- torchlens/helper_funcs.py | 22 +++++++++++- torchlens/model_history.py | 20 ++++------- 4 files changed, 89 insertions(+), 16 deletions(-) diff --git a/requirements.test.txt b/requirements.test.txt index 6e9e9dd..b039d93 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,2 +1,15 @@ black +cornet @ git+https://github.com/dicarlolab/CORnet +lightning +numpy +pillow pytest +requests +timm +torch +torchaudio +torch_geometric +torchlens +torchvision +transformers +visualpriors diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py index 98de015..8757b58 100644 --- a/tests/test_validation_and_visuals.py +++ b/tests/test_validation_and_visuals.py @@ -9,12 +9,11 @@ import requests import timm.models import torch +import torch.nn as nn import torchaudio.models import torchvision import visualpriors from PIL import Image -from StyleTTS.models import TextEncoder -from model.Unet import UNet from transformers import ( BertForNextSentencePrediction, @@ -3025,6 +3024,11 @@ def test_clip(): # for some reason CLIP breaks the PyCharm debugger def test_stable_diffusion(): + try: + from model.Unet import UNet + except ModuleNotFoundError: + pytest.xfail() + model = UNet(3, 16, 10) model_inputs = (torch.rand(6, 3, 224, 224), torch.tensor([1]), torch.tensor([1.]), torch.tensor([3.])) show_model_graph( @@ -3039,7 +3043,13 @@ def test_stable_diffusion(): # Text to speech + def test_styletts(): + try: + from StyleTTS.models import TextEncoder + except ModuleNotFoundError: + pytest.xfail() + model = TextEncoder(3, 3, 3, 100) tokens = torch.tensor([[3, 0, 1, 2, 0, 2, 2, 3, 1, 4]]) input_lengths = torch.ones(1, dtype=torch.long) * 10 @@ -3081,3 +3091,39 @@ def test_dimenet(): vis_outpath=opj("visualization_outputs", "graph-neural-networks", "dimenet"), ) assert validate_saved_activations(model, model_inputs) + + +# Lightning modules + +def test_lightning(): + try: + import lightning as L + except ModuleNotFoundError: + pytest.xfail() + + class OneHotAutoEncoder(L.LightningModule): + + def __init__(self): + super().__init__() + self.model = nn.Sequential( + nn.Linear(16, 4), + nn.ReLU(), + nn.Linear(4, 16), + ) + + def forward(self, x): + x_hat = self.model(x) + return nn.functional.mse_loss(x_hat, x) + + model = OneHotAutoEncoder() + model_inputs = [torch.randn(2, 16)] + + show_model_graph( + model, + model_inputs, + save_only=True, + vis_opt="unrolled", + vis_outpath=opj("visualization_outputs", "lightning", "one-hot-autoencoder"), + ) + assert validate_saved_activations(model, model_inputs) + diff --git a/torchlens/helper_funcs.py b/torchlens/helper_funcs.py index 327afeb..701131a 100644 --- a/torchlens/helper_funcs.py +++ b/torchlens/helper_funcs.py @@ -6,7 +6,7 @@ import string import warnings from sys import getsizeof -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Callable import numpy as np import torch @@ -448,6 +448,26 @@ def nested_assign(obj, addr, val): obj = getattr(obj, entry_val) +def iter_accessible_attributes(obj: Any, *, short_circuit: Optional[Callable[[Any, str], bool]] = None): + for attr_name in dir(obj): + if short_circuit and short_circuit(obj, attr_name): + continue + + # Attribute access can fail for any number of reasons, especially when + # working with objects that we don't know anything about. This + # function makes a best-effort attempt to access every attribute, but + # gracefully skips any that cause problems. + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + attr = getattr(obj, attr_name) + except Exception: + continue + + yield attr_name, attr + + def remove_attributes_starting_with_str(obj: Any, s: str): """Given an object removes, any attributes for that object beginning with a given substring. diff --git a/torchlens/model_history.py b/torchlens/model_history.py index 09a6224..939e921 100644 --- a/torchlens/model_history.py +++ b/torchlens/model_history.py @@ -44,6 +44,7 @@ set_random_seed, set_rng_from_saved_states, tuple_tolerant_assign, + iter_accessible_attributes, remove_attributes_starting_with_str, tensor_nanequal, tensor_all_nan, @@ -1314,10 +1315,7 @@ def prepare_buffer_tensors(self, model: nn.Module): """ submodules = self.get_all_submodules(model) for submodule in submodules: - for attribute_name in dir(submodule): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - attribute = getattr(submodule, attribute_name) + for attribute_name, attribute in iter_accessible_attributes(submodule): if issubclass(type(attribute), torch.Tensor) and not issubclass( type(attribute), torch.nn.Parameter ): @@ -1551,13 +1549,12 @@ def restore_module_attributes( decorated_func_mapper: Dict[Callable, Callable], attribute_keyword: str = "tl", ): - for attribute_name in dir(module): + def del_attrs_with_prefix(module, attribute_name): if attribute_name.startswith(attribute_keyword): delattr(module, attribute_name) - continue - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - attr = getattr(module, attribute_name) + return True + + for attribute_name, attr in iter_accessible_attributes(module, short_circuit=del_attrs_with_prefix): if ( isinstance(attr, Callable) and (attr in decorated_func_mapper) @@ -1608,10 +1605,7 @@ def undecorate_model_tensors(self, model: nn.Module): """ submodules = self.get_all_submodules(model) for submodule in submodules: - for attribute_name in dir(submodule): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - attribute = getattr(submodule, attribute_name) + for attribute_name, attribute in iter_accessible_attributes(submodule): if issubclass(type(attribute), torch.Tensor): if not issubclass(type(attribute), torch.nn.Parameter) and hasattr( attribute, "tl_tensor_label_raw" From 40aab43991457d1e49a9166810cb621a3f5c47fa Mon Sep 17 00:00:00 2001 From: Kale Kundert Date: Wed, 9 Oct 2024 14:18:54 -0400 Subject: [PATCH 2/2] Indicate that 'varying_loop_noparam2' is currently expected to fail --- tests/test_validation_and_visuals.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_validation_and_visuals.py b/tests/test_validation_and_visuals.py index 8757b58..9ac2ebf 100644 --- a/tests/test_validation_and_visuals.py +++ b/tests/test_validation_and_visuals.py @@ -836,6 +836,7 @@ def test_varying_loop_noparam1(default_input1): ) +@pytest.mark.xfail def test_varying_loop_noparam2(default_input1): model = example_models.VaryingLoopNoParam2() assert validate_saved_activations(model, default_input1) @@ -3027,7 +3028,7 @@ def test_stable_diffusion(): try: from model.Unet import UNet except ModuleNotFoundError: - pytest.xfail() + pytest.skip() model = UNet(3, 16, 10) model_inputs = (torch.rand(6, 3, 224, 224), torch.tensor([1]), torch.tensor([1.]), torch.tensor([3.])) @@ -3048,7 +3049,7 @@ def test_styletts(): try: from StyleTTS.models import TextEncoder except ModuleNotFoundError: - pytest.xfail() + pytest.skip() model = TextEncoder(3, 3, 3, 100) tokens = torch.tensor([[3, 0, 1, 2, 0, 2, 2, 3, 1, 4]]) @@ -3099,7 +3100,7 @@ def test_lightning(): try: import lightning as L except ModuleNotFoundError: - pytest.xfail() + pytest.skip() class OneHotAutoEncoder(L.LightningModule):