Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: randomize all layer types #252

Merged
merged 17 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quanda/benchmarks/heuristics/mixed_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs):
adversarial_dir=adversarial_dir,
adversarial_label=bench_state["adversarial_label"],
adversarial_transform=adversarial_transform,
adv_train_indices = adv_train_indices,
adv_train_indices=adv_train_indices,
data_transform=dataset_transform,
checkpoint_paths=checkpoint_paths,
)
Expand Down
44 changes: 25 additions & 19 deletions quanda/metrics/heuristics/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,35 @@ def load_state_dict(self, state_dict: dict):
self.results = state_dict["results_dict"]
self.rand_model.load_state_dict(state_dict["rnd_model"])

def _randomize_parameter(self, param, parent, param_name):
"""Reset or randomize a parameter.

Parameters
----------
param : torch.Tensor
The parameter tensor.
parent : torch.nn.Module
The parent module of the parameter.
param_name : str
The name of the parameter.

"""
if hasattr(parent, "reset_parameters"):
torch.manual_seed(self.seed)
parent.reset_parameters()
else:
torch.nn.init.normal_(param, generator=self.generator)
parent.__setattr__(param_name, torch.nn.Parameter(param))

def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:
"""Randomize the model parameters.

Currently, only linear and convolutional layers are supported.

Returns
-------
torch.nn.Module
The randomized model.

"""
# TODO: Add support for other layer types.

rand_model = copy.deepcopy(self.model)
rand_checkpoints = []

Expand All @@ -214,24 +230,14 @@ def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:

for name, param in list(rand_model.named_parameters()):
parent = get_parent_module_from_name(rand_model, name)
# TODO: currently only linear layer is randomized
if isinstance(parent, (torch.nn.Linear)):
random_param_tensor = torch.nn.init.normal_(
param, generator=self.generator
)
parent.__setattr__(
name.split(".")[-1],
torch.nn.Parameter(random_param_tensor),
)

# save randomized checkpoint
param_name = name.split(".")[-1]
self._randomize_parameter(param, parent, param_name)

# Save randomized checkpoint
chckpt_path = os.path.join(
self.cache_dir, f"{self.model_id}_rand_{i}.pth"
)
torch.save(
rand_model.state_dict(),
chckpt_path,
)
torch.save(rand_model.state_dict(), chckpt_path)
rand_checkpoints.append(chckpt_path)

return rand_model, rand_checkpoints
6 changes: 3 additions & 3 deletions tests/benchmarks/heuristics/test_model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"similarity_metric": cosine_similarity,
},
None,
0.5208332538604736,
0.717261791229248,
),
(
"mnist2",
Expand All @@ -49,7 +49,7 @@
"similarity_metric": cosine_similarity,
},
None,
0.5208332538604736,
0.717261791229248,
),
],
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_model_randomization(
"similarity_metric": cosine_similarity,
"load_from_disk": True,
},
0.509926438331604,
0.19356615841388702,
),
],
)
Expand Down
49 changes: 48 additions & 1 deletion tests/metrics/test_heuristics_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
TopKCardinalityMetric,
)
from quanda.metrics.heuristics.mixed_datasets import MixedDatasetsMetric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import correlation_functions, cosine_similarity
from quanda.utils.common import (
get_parent_module_from_name,
)


@pytest.mark.heuristic_metrics
Expand Down Expand Up @@ -67,17 +69,25 @@ def test_randomization_metric(
tmp_path,
request,
):
aski02 marked this conversation as resolved.
Show resolved Hide resolved
# 1) Check if the metric works correctly
model = request.getfixturevalue(model)
checkpoint = request.getfixturevalue(checkpoint)
test_data = request.getfixturevalue(test_data)
dataset = request.getfixturevalue(dataset)
test_labels = request.getfixturevalue(test_labels)
tda = request.getfixturevalue(explanations)
expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs}

def _load_flexible_state_dict(model: torch.nn.Module, path: str):
checkpoint = torch.load(path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
return model

metric = ModelRandomizationMetric(
model=model,
model_id=0,
checkpoints=checkpoint,
checkpoints_load_func=_load_flexible_state_dict,
train_dataset=dataset,
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,
Expand All @@ -91,6 +101,43 @@ def test_randomization_metric(
out = metric.compute()["score"]
assert (out >= -1.0) & (out <= 1.0), "Test failed."

# 2) Check if the randomization works correctly
batch_size = 2
input_shape = test_data[0].shape
random_tensor = torch.randn((batch_size, *input_shape), device="cpu")

rand_model = metric._randomize_model()[0]
rand_model.eval()
model.eval()

with torch.no_grad():
original_out = model(random_tensor)
randomized_out = rand_model(random_tensor)

assert not torch.allclose(
original_out, randomized_out
), "Outputs do not differ after randomization"
assert not torch.isnan(
randomized_out
).any(), "Randomized model output contains NaNs."
aski02 marked this conversation as resolved.
Show resolved Hide resolved

# 3) Check if the randomization works correctly for custom parameters
model.custom_param = torch.nn.Parameter(torch.randn(4))
model.eval()
rand_model = metric._randomize_model()[0]
rand_model.eval()

with torch.no_grad():
original_out = model(random_tensor)
randomized_out = rand_model(random_tensor)

assert not torch.allclose(
original_out, randomized_out
), "Outputs do not differ after randomization"
assert not torch.isnan(
randomized_out
).any(), "Randomized model output contains NaNs."


@pytest.mark.heuristic_metrics
@pytest.mark.parametrize(
Expand Down
Loading