Skip to content

Commit

Permalink
Added some additional models to test suite (LSTM, RNN, stable diffusion)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnMark Taylor committed Oct 6, 2023
1 parent 234e8a9 commit 98e9fd1
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 23 deletions.
65 changes: 48 additions & 17 deletions tests/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(x):
y.zero_()
y = y + 1
y = torch.log(y)
x = x**2
x = x ** 2
return x


Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(self):
def forward(x, y, z):
a = x + y
b = torch.log(z)
x = a**b
x = a ** b
return x


Expand All @@ -326,7 +326,7 @@ def forward(input_list):
x, y, z = input_list
a = x + y
b = torch.log(z)
x = a**b
x = a ** b
return x


Expand All @@ -339,7 +339,7 @@ def forward(input_dict):
x, y, z = input_dict["x"], input_dict["y"], input_dict["z"]
a = x + y
b = torch.log(z)
x = a**b
x = a ** b
return x


Expand Down Expand Up @@ -437,7 +437,7 @@ def forward(x):
y = x * 2
y = y + 3
y = torch.log(y)
z = x**2
z = x ** 2
z = torch.sin(z)
x = x + y + z
return x
Expand All @@ -457,7 +457,7 @@ def forward(x):
x = torch.cos(x)
x = x + 1
x = x * 4
x = x**2
x = x ** 2
return x


Expand Down Expand Up @@ -505,7 +505,7 @@ def __init__(self):

@staticmethod
def forward(x):
x = x**3 + torch.ones(x.shape)
x = x ** 3 + torch.ones(x.shape)
x = x / 5
return x

Expand Down Expand Up @@ -567,9 +567,9 @@ def forward(x):
z = torch.ones(5, 5)
z = z + 1
a = z * 2
b = z**2
b = z ** 2
c = a + b
x = x**2
x = x ** 2
return x


Expand Down Expand Up @@ -633,7 +633,7 @@ def forward(x):
if i % 2 == 0:
y = x + 3
y = torch.sin(y)
y = y**2
y = y ** 2
x = torch.sin(x)
if i % 2 == 1:
z = x + 3
Expand All @@ -656,7 +656,7 @@ def forward(x):
if i in [0, 3, 4]:
y = x + 3
y = torch.sin(y)
y = y**2
y = y ** 2
else:
y = x * torch.rand(x.shape)
y = torch.cos(y)
Expand All @@ -681,7 +681,7 @@ def forward(self, x):
x = self.fc(x)
x = x + 1
x = x * 2
x = x**3
x = x ** 3
x = self.fc(x)
return x

Expand Down Expand Up @@ -872,7 +872,7 @@ def forward(self, x):
x = self.relu(x)
x = torch.log(x)
for _ in range(
4
4
): # this tests clashes between what counts as "same"--module-based or looping-based
x = self.relu(x)
x = x + 1
Expand Down Expand Up @@ -922,8 +922,39 @@ def forward(self, x):


# ****************************
# **** Uber Architectures ****
# **** RNN/LSTM Architectures ****
# ****************************
class LSTMModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(5, 10)
self.label = nn.Linear(10, 5)
self.hidden_size = 10

def forward(self, x):
batch_size = x.shape[0]
h_0 = torch.zeros(1, batch_size, self.hidden_size)
c_0 = torch.zeros(1, batch_size, self.hidden_size)

output, (final_hidden_state, final_cell_state) = self.lstm(x, (h_0, c_0))

return self.label(final_hidden_state[-1])


class RNNModel(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(5, 10)
self.label = nn.Linear(10, 5)
self.hidden_size = 10

def forward(self, x):
batch_size = x.shape[0]
h_0 = torch.zeros(1, batch_size, self.hidden_size)

output, final_hidden_state = self.rnn(x, h_0)

return self.label(final_hidden_state[-1])


class UberModel1(nn.Module):
Expand All @@ -939,7 +970,7 @@ def forward(x):
x, y, z = x
x = x + 1
y = y * 2
y = y**3
y = y ** 3
w = torch.rand(5, 5)
w = w * 2
w = w + 4
Expand Down Expand Up @@ -1007,7 +1038,7 @@ def forward(self, x):
x = self.fc(x)
x = x + 1
x = x * 2
x = x**3
x = x ** 3
x = self.fc(x)
x = x + 2
x = x * 3
Expand Down Expand Up @@ -1171,6 +1202,6 @@ def forward(self, x):
a = torch.ones(3, 3)
b = torch.zeros(3, 3)
z3 = z2 + a + b
w1 = x**3
w1 = x ** 3
x = torch.sum(torch.stack([y4, z3, w1]))
return x
51 changes: 50 additions & 1 deletion tests/test_validation_and_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import visualpriors
from PIL import Image
from StyleTTS.models import TextEncoder
from model.Unet import UNet

from transformers import (
BertForNextSentencePrediction,
BertTokenizer,
GPT2Model,
GPT2Tokenizer,
)
from torch_geometric.datasets import QM9
from torch_geometric.nn import DimeNet

import example_models
Expand Down Expand Up @@ -2625,6 +2626,41 @@ def test_deepspeech():

# Language models

def test_lstm():
model = example_models.LSTMModel()
model_input = torch.rand(5, 5, 5)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "language-models", "language_lstm_unrolled"),
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_outpath=opj("visualization_outputs", "language-models", "language_lstm_rolled"),
)
assert validate_saved_activations(model, model_input)


def test_rnn():
model = example_models.RNNModel()
model_input = torch.rand(5, 5, 5)
show_model_graph(
model,
model_input,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "language-models", "language_rnn_unrolled"),
)
show_model_graph(
model,
model_input,
vis_opt="rolled",
vis_outpath=opj("visualization_outputs", "language-models", "language_rnn_rolled"),
)
assert validate_saved_activations(model, model_input)


def test_gpt2():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
Expand Down Expand Up @@ -2684,6 +2720,19 @@ def test_clip(): # for some reason CLIP breaks the PyCharm debugger
assert validate_saved_activations(model, [], model_inputs, random_seed=1)


def test_stable_diffusion():
model = UNet(3, 16, 10)
model_inputs = (torch.rand(6, 3, 224, 224), torch.tensor([1]), torch.tensor([1.]), torch.tensor([3.]))
show_model_graph(
model,
model_inputs,
random_seed=1,
vis_opt="unrolled",
vis_outpath=opj("visualization_outputs", "multimodal-models", "stable_diffusion"),
)
assert validate_saved_activations(model, model_inputs, random_seed=1)


# Text to speech
def test_styletts():
model = TextEncoder(3, 3, 3, 100)
Expand Down
11 changes: 6 additions & 5 deletions torchlens/model_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -6076,16 +6076,17 @@ def _check_whether_func_on_saved_parents_yields_saved_tensor(
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1][1]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[2][0]) or
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[2][1])
)):
layer_to_validate_parents_for.creation_args[2][1]) or
((type(layer_to_validate_parents_for.creation_args[1]) == torch.Tensor) and
torch.equal(
self[layers_to_perturb[0]].tensor_contents,
layer_to_validate_parents_for.creation_args[1])
))):
return True
elif (
perturb
Expand Down

0 comments on commit 98e9fd1

Please sign in to comment.