Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Oct 16, 2024
2 parents 34cbb9a + 3fd8a31 commit 25b061e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 16 deletions.
13 changes: 13 additions & 0 deletions requirements.test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
black
cornet @ git+https://github.com/dicarlolab/CORnet
lightning
numpy
pillow
pytest
requests
timm
torch
torchaudio
torch_geometric
torchlens
torchvision
transformers
visualpriors
51 changes: 49 additions & 2 deletions tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
import requests
import timm.models
import torch
import torch.nn as nn
import torchaudio.models
import torchvision
import visualpriors
from PIL import Image
from StyleTTS.models import TextEncoder
from model.Unet import UNet

from transformers import (
BertForNextSentencePrediction,
Expand Down Expand Up @@ -837,6 +836,7 @@ def test_varying_loop_noparam1(default_input1):
)


@pytest.mark.xfail
def test_varying_loop_noparam2(default_input1):
model = example_models.VaryingLoopNoParam2()
assert validate_saved_activations(model, default_input1)
Expand Down Expand Up @@ -3025,6 +3025,11 @@ def test_clip(): # for some reason CLIP breaks the PyCharm debugger


def test_stable_diffusion():
try:
from model.Unet import UNet
except ModuleNotFoundError:
pytest.skip()

model = UNet(3, 16, 10)
model_inputs = (torch.rand(6, 3, 224, 224), torch.tensor([1]), torch.tensor([1.]), torch.tensor([3.]))
show_model_graph(
Expand All @@ -3039,7 +3044,13 @@ def test_stable_diffusion():


# Text to speech

def test_styletts():
try:
from StyleTTS.models import TextEncoder
except ModuleNotFoundError:
pytest.skip()

model = TextEncoder(3, 3, 3, 100)
tokens = torch.tensor([[3, 0, 1, 2, 0, 2, 2, 3, 1, 4]])
input_lengths = torch.ones(1, dtype=torch.long) * 10
Expand Down Expand Up @@ -3081,3 +3092,39 @@ def test_dimenet():
vis_outpath=opj("visualization_outputs", "graph-neural-networks", "dimenet"),
)
assert validate_saved_activations(model, model_inputs)


# Lightning modules

def test_lightning():
try:
import lightning as L
except ModuleNotFoundError:
pytest.skip()

class OneHotAutoEncoder(L.LightningModule):

def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(16, 4),
nn.ReLU(),
nn.Linear(4, 16),
)

def forward(self, x):
x_hat = self.model(x)
return nn.functional.mse_loss(x_hat, x)

model = OneHotAutoEncoder()
model_inputs = [torch.randn(2, 16)]

show_model_graph(
model,
model_inputs,
save_only=True,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "lightning", "one-hot-autoencoder"),
)
assert validate_saved_activations(model, model_inputs)

22 changes: 21 additions & 1 deletion torchlens/helper_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import string
import warnings
from sys import getsizeof
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Callable

import numpy as np
import torch
Expand Down Expand Up @@ -448,6 +448,26 @@ def nested_assign(obj, addr, val):
obj = getattr(obj, entry_val)


def iter_accessible_attributes(obj: Any, *, short_circuit: Optional[Callable[[Any, str], bool]] = None):
for attr_name in dir(obj):
if short_circuit and short_circuit(obj, attr_name):
continue

# Attribute access can fail for any number of reasons, especially when
# working with objects that we don't know anything about. This
# function makes a best-effort attempt to access every attribute, but
# gracefully skips any that cause problems.

with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
attr = getattr(obj, attr_name)
except Exception:
continue

yield attr_name, attr


def remove_attributes_starting_with_str(obj: Any, s: str):
"""Given an object removes, any attributes for that object beginning with a given
substring.
Expand Down
20 changes: 7 additions & 13 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
set_random_seed,
set_rng_from_saved_states,
tuple_tolerant_assign,
iter_accessible_attributes,
remove_attributes_starting_with_str,
tensor_nanequal,
tensor_all_nan,
Expand Down Expand Up @@ -1314,10 +1315,7 @@ def prepare_buffer_tensors(self, model: nn.Module):
"""
submodules = self.get_all_submodules(model)
for submodule in submodules:
for attribute_name in dir(submodule):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attribute = getattr(submodule, attribute_name)
for attribute_name, attribute in iter_accessible_attributes(submodule):
if issubclass(type(attribute), torch.Tensor) and not issubclass(
type(attribute), torch.nn.Parameter
):
Expand Down Expand Up @@ -1551,13 +1549,12 @@ def restore_module_attributes(
decorated_func_mapper: Dict[Callable, Callable],
attribute_keyword: str = "tl",
):
for attribute_name in dir(module):
def del_attrs_with_prefix(module, attribute_name):
if attribute_name.startswith(attribute_keyword):
delattr(module, attribute_name)
continue
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attr = getattr(module, attribute_name)
return True

for attribute_name, attr in iter_accessible_attributes(module, short_circuit=del_attrs_with_prefix):
if (
isinstance(attr, Callable)
and (attr in decorated_func_mapper)
Expand Down Expand Up @@ -1608,10 +1605,7 @@ def undecorate_model_tensors(self, model: nn.Module):
"""
submodules = self.get_all_submodules(model)
for submodule in submodules:
for attribute_name in dir(submodule):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attribute = getattr(submodule, attribute_name)
for attribute_name, attribute in iter_accessible_attributes(submodule):
if issubclass(type(attribute), torch.Tensor):
if not issubclass(type(attribute), torch.nn.Parameter) and hasattr(
attribute, "tl_tensor_label_raw"
Expand Down

0 comments on commit 25b061e

Please sign in to comment.