Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: randomize all layer types #252

Merged
merged 17 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions quanda/metrics/heuristics/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,12 @@ def load_state_dict(self, state_dict: dict):
def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:
"""Randomize the model parameters.

Currently, only linear and convolutional layers are supported.

Returns
-------
torch.nn.Module
The randomized model.

"""
# TODO: Add support for other layer types.

rand_model = copy.deepcopy(self.model)
rand_checkpoints = []

Expand All @@ -214,15 +210,14 @@ def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:

for name, param in list(rand_model.named_parameters()):
parent = get_parent_module_from_name(rand_model, name)
# TODO: currently only linear layer is randomized
if isinstance(parent, (torch.nn.Linear)):
random_param_tensor = torch.nn.init.normal_(
param, generator=self.generator
)
parent.__setattr__(
name.split(".")[-1],
torch.nn.Parameter(random_param_tensor),
)
param_name = name.split(".")[-1]

random_param_tensor = torch.nn.init.normal_(
param, generator=self.generator
)
parent.__setattr__(
param_name, torch.nn.Parameter(random_param_tensor)
)

# save randomized checkpoint
chckpt_path = os.path.join(
Expand Down
6 changes: 3 additions & 3 deletions tests/benchmarks/heuristics/test_model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"similarity_metric": cosine_similarity,
},
None,
0.5208332538604736,
0.43154752254486084,
),
(
"mnist2",
Expand All @@ -49,7 +49,7 @@
"similarity_metric": cosine_similarity,
},
None,
0.5208332538604736,
0.43154752254486084,
),
],
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_model_randomization(
"similarity_metric": cosine_similarity,
"load_from_disk": True,
},
0.509926438331604,
0.4639705419540405,
),
],
)
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
LabelGroupingDataset,
)
from quanda.utils.training.base_pl_module import BasicLightningModule
from tests.models import LeNet
from tests.models import LeNet, BasicTransformer

MNIST_IMAGE_SIZE = 28
BATCH_SIZE = 124
Expand Down Expand Up @@ -433,3 +433,8 @@ def get_lds_score():
with open("tests/assets/lds_score.json", "r") as f:
score_data = json.load(f)
return score_data["lds_score"]


@pytest.fixture
def transformer_model():
return BasicTransformer()
82 changes: 81 additions & 1 deletion tests/metrics/test_heuristics_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import copy
import math

import pytest
Expand All @@ -9,8 +11,11 @@
TopKCardinalityMetric,
)
from quanda.metrics.heuristics.mixed_datasets import MixedDatasetsMetric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import correlation_functions, cosine_similarity
from quanda.utils.common import (
get_parent_module_from_name,
get_load_state_dict_func,
)


@pytest.mark.heuristic_metrics
Expand Down Expand Up @@ -92,6 +97,81 @@ def test_randomization_metric(
assert (out >= -1.0) & (out <= 1.0), "Test failed."


@pytest.mark.heuristic_metrics
@pytest.mark.parametrize(
"test_id, model, seed",
[
("transformer_randomization", "transformer_model", 42),
],
)
def test_randomization_metric_transformer(
test_id, model, seed, tmp_path, request
):
model = request.getfixturevalue(model)
model_id = "transformer"
cache_dir = str(tmp_path)
device = "cpu"
checkpoints = [os.path.join(cache_dir, "dummy.ckpt")]
checkpoints_load_func = get_load_state_dict_func(device)
generator = torch.Generator(device=device)
generator.manual_seed(seed)

original_params = {
name: p.clone().detach() for name, p in model.named_parameters()
}

def _randomize_model(
model,
model_id,
cache_dir,
checkpoints,
checkpoints_load_func,
generator,
):
rand_model = copy.deepcopy(model)
rand_checkpoints = []

for i, _ in enumerate(checkpoints):
for name, param in list(rand_model.named_parameters()):
parent = get_parent_module_from_name(rand_model, name)
param_name = name.split(".")[-1]

random_param_tensor = torch.nn.init.normal_(
param, generator=generator
)
parent.__setattr__(
param_name, torch.nn.Parameter(random_param_tensor)
)

chckpt_path = os.path.join(cache_dir, f"{model_id}_rand_{i}.pth")
torch.save(
rand_model.state_dict(),
chckpt_path,
)
rand_checkpoints.append(chckpt_path)

return rand_model, rand_checkpoints

rand_model, _ = _randomize_model(
model,
model_id,
cache_dir,
checkpoints,
checkpoints_load_func,
generator,
)

for (name, original_param), (_, rand_param) in zip(
original_params.items(), rand_model.named_parameters()
):
assert not torch.allclose(
original_param, rand_param
), f"Parameter {name} was not randomized."
assert not torch.isnan(
rand_param
).any(), f"Parameter {name} contains NaN values."


@pytest.mark.heuristic_metrics
@pytest.mark.parametrize(
"test_id, model, checkpoint,dataset, explainer_cls, expl_kwargs, corr_fn",
Expand Down
27 changes: 27 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,30 @@ def forward(self, x):
x = self.relu_4(self.fc_2(x))
x = self.fc_3(x)
return x


class BasicTransformer(torch.nn.Module):
def __init__(
self, embed_dim=16, nhead=4, dim_feedforward=32, num_layers=1
):
super().__init__()
self.embedding = torch.nn.Embedding(
num_embeddings=100, embedding_dim=embed_dim
)
encoder_layer = torch.nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=0.0,
activation="relu",
batch_first=True,
)
self.encoder = torch.nn.TransformerEncoder(
encoder_layer, num_layers=num_layers
)
self.fc_out = torch.nn.Linear(embed_dim, 10)

def forward(self, x):
emb = self.embedding(x)
enc_out = self.encoder(emb)
return self.fc_out(enc_out[:, 0, :])
Loading