Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: randomize all layer types #252

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion quanda/benchmarks/heuristics/mixed_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs):
adversarial_dir=adversarial_dir,
adversarial_label=bench_state["adversarial_label"],
adversarial_transform=adversarial_transform,
adv_train_indices = adv_train_indices,
adv_train_indices=adv_train_indices,
data_transform=dataset_transform,
checkpoint_paths=checkpoint_paths,
)
Expand Down
44 changes: 25 additions & 19 deletions quanda/metrics/heuristics/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,35 @@ def load_state_dict(self, state_dict: dict):
self.results = state_dict["results_dict"]
self.rand_model.load_state_dict(state_dict["rnd_model"])

def _randomize_parameter(self, param, parent, param_name):
"""Reset or randomize a parameter.

Parameters
----------
param : torch.Tensor
The parameter tensor.
parent : torch.nn.Module
The parent module of the parameter.
param_name : str
The name of the parameter.

"""
if hasattr(parent, "reset_parameters"):
torch.manual_seed(self.seed)
parent.reset_parameters()
else:
torch.nn.init.normal_(param, generator=self.generator)
parent.__setattr__(param_name, torch.nn.Parameter(param))

def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:
"""Randomize the model parameters.

Currently, only linear and convolutional layers are supported.

Returns
-------
torch.nn.Module
The randomized model.

"""
# TODO: Add support for other layer types.

rand_model = copy.deepcopy(self.model)
rand_checkpoints = []

Expand All @@ -214,24 +230,14 @@ def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]:

for name, param in list(rand_model.named_parameters()):
parent = get_parent_module_from_name(rand_model, name)
# TODO: currently only linear layer is randomized
if isinstance(parent, (torch.nn.Linear)):
random_param_tensor = torch.nn.init.normal_(
param, generator=self.generator
)
parent.__setattr__(
name.split(".")[-1],
torch.nn.Parameter(random_param_tensor),
)

# save randomized checkpoint
param_name = name.split(".")[-1]
self._randomize_parameter(param, parent, param_name)

# Save randomized checkpoint
chckpt_path = os.path.join(
self.cache_dir, f"{self.model_id}_rand_{i}.pth"
)
torch.save(
rand_model.state_dict(),
chckpt_path,
)
torch.save(rand_model.state_dict(), chckpt_path)
rand_checkpoints.append(chckpt_path)

return rand_model, rand_checkpoints
6 changes: 3 additions & 3 deletions tests/benchmarks/heuristics/test_model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"similarity_metric": cosine_similarity,
},
None,
0.5208332538604736,
0.717261791229248,
),
(
"mnist2",
Expand All @@ -49,7 +49,7 @@
"similarity_metric": cosine_similarity,
},
None,
0.5208332538604736,
0.717261791229248,
),
],
)
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_model_randomization(
"similarity_metric": cosine_similarity,
"load_from_disk": True,
},
0.509926438331604,
0.19356615841388702,
),
],
)
Expand Down
225 changes: 214 additions & 11 deletions tests/metrics/test_heuristics_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,61 +9,57 @@
TopKCardinalityMetric,
)
from quanda.metrics.heuristics.mixed_datasets import MixedDatasetsMetric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import correlation_functions, cosine_similarity
from quanda.utils.common import (
get_parent_module_from_name,
)


@pytest.mark.heuristic_metrics
@pytest.mark.parametrize(
"test_id, model, checkpoint,dataset, test_data, batch_size, explainer_cls, \
expl_kwargs, explanations, test_labels, correlation_fn",
"test_id, model, checkpoint, dataset, test_data, "
"explainer_cls, expl_kwargs, explanations, test_labels",
[
(
"mnist_update_only_spearman",
"load_mnist_model",
"load_mnist_last_checkpoint",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
CaptumSimilarity,
{
"layers": "fc_2",
"similarity_metric": cosine_similarity,
},
"load_mnist_explanations_similarity_1",
"load_mnist_test_labels_1",
"spearman",
),
(
"mnist_update_only_kendall",
"load_mnist_model",
"load_mnist_last_checkpoint",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
CaptumSimilarity,
{
"layers": "fc_2",
"similarity_metric": cosine_similarity,
},
"load_mnist_explanations_similarity_1",
"load_mnist_test_labels_1",
"kendall",
),
],
)
def test_randomization_metric(
def test_randomization_metric_score(
test_id,
model,
checkpoint,
dataset,
test_data,
batch_size,
explainer_cls,
expl_kwargs,
explanations,
test_labels,
correlation_fn,
tmp_path,
request,
):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some additional (tiny) architectures to conftest.py, and add a couple of parameterizations (extra items in the parametrization list) with these architectures? Maybe the easiest would be to create those tiny architectures manually, and not train them.

Expand All @@ -74,6 +70,7 @@ def test_randomization_metric(
test_labels = request.getfixturevalue(test_labels)
tda = request.getfixturevalue(explanations)
expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs}

metric = ModelRandomizationMetric(
model=model,
model_id=0,
Expand All @@ -89,7 +86,213 @@ def test_randomization_metric(
)

out = metric.compute()["score"]
assert (out >= -1.0) & (out <= 1.0), "Test failed."
assert (out >= -1.0) & (
out <= 1.0
), "Metric score is out of expected range."


@pytest.mark.heuristic_metrics
@pytest.mark.parametrize(
"test_id, model, checkpoint, dataset, test_data, batch_size, "
"explainer_cls, expl_kwargs, explanations, test_labels",
[
(
"mnist_update_only_spearman",
"load_mnist_model",
"load_mnist_last_checkpoint",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
CaptumSimilarity,
{
"layers": "fc_2",
"similarity_metric": cosine_similarity,
},
"load_mnist_explanations_similarity_1",
"load_mnist_test_labels_1",
),
(
"mnist_update_only_kendall",
"load_mnist_model",
"load_mnist_last_checkpoint",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
CaptumSimilarity,
{
"layers": "fc_2",
"similarity_metric": cosine_similarity,
},
"load_mnist_explanations_similarity_1",
"load_mnist_test_labels_1",
),
],
)
def test_randomization_metric_randomization(
test_id,
model,
checkpoint,
dataset,
test_data,
batch_size,
explainer_cls,
expl_kwargs,
explanations,
test_labels,
tmp_path,
request,
):
model = request.getfixturevalue(model)
checkpoint = request.getfixturevalue(checkpoint)
test_data = request.getfixturevalue(test_data)
dataset = request.getfixturevalue(dataset)
test_labels = request.getfixturevalue(test_labels)
expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs}

metric = ModelRandomizationMetric(
model=model,
model_id=0,
checkpoints=checkpoint,
train_dataset=dataset,
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,
cache_dir=str(tmp_path),
seed=42,
)

# Generate a random batch of data
batch_size = 2
input_shape = test_data[0].shape
random_tensor = torch.randn((batch_size, *input_shape), device="cpu")

# Randomize model
rand_model = metric._randomize_model()[0]
rand_model.eval()
model.eval()

# Check if the outputs differ after randomization
with torch.no_grad():
original_out = model(random_tensor)
randomized_out = rand_model(random_tensor)

assert not torch.allclose(
original_out, randomized_out
), "Outputs do not differ after randomization."
assert not torch.isnan(
randomized_out
).any(), "Randomized model output contains NaNs."
aski02 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.heuristic_metrics
@pytest.mark.parametrize(
"test_id, model, checkpoint, dataset, test_data, batch_size, "
"explainer_cls, expl_kwargs, explanations, test_labels",
[
(
"mnist_update_only_spearman",
"load_mnist_model",
"load_mnist_last_checkpoint",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
CaptumSimilarity,
{
"layers": "fc_2",
"similarity_metric": cosine_similarity,
},
"load_mnist_explanations_similarity_1",
"load_mnist_test_labels_1",
),
(
"mnist_update_only_kendall",
"load_mnist_model",
"load_mnist_last_checkpoint",
"load_mnist_dataset",
"load_mnist_test_samples_1",
8,
CaptumSimilarity,
{
"layers": "fc_2",
"similarity_metric": cosine_similarity,
},
"load_mnist_explanations_similarity_1",
"load_mnist_test_labels_1",
),
],
)
def test_randomization_metric_custom_param(
test_id,
model,
checkpoint,
dataset,
test_data,
batch_size,
explainer_cls,
expl_kwargs,
explanations,
test_labels,
tmp_path,
request,
):
model = request.getfixturevalue(model)
checkpoint = request.getfixturevalue(checkpoint)
test_data = request.getfixturevalue(test_data)
dataset = request.getfixturevalue(dataset)
test_labels = request.getfixturevalue(test_labels)
expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs}

def _load_flexible_state_dict(model: torch.nn.Module, path: str):
checkpoint = torch.load(path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False)
return model

metric = ModelRandomizationMetric(
model=model,
model_id=0,
checkpoints=checkpoint,
checkpoints_load_func=_load_flexible_state_dict,
train_dataset=dataset,
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,
cache_dir=str(tmp_path),
seed=42,
)

# Add a custom parameter to the model
model.custom_param = torch.nn.Parameter(torch.randn(4))
model.eval()

# Save the original custom parameter
original_custom_param = model.custom_param.data.clone()

# Randomize model
rand_model = metric._randomize_model()[0]
rand_model.eval()

# Save the randomized custom parameter
randomized_custom_param = rand_model.custom_param.data.clone()

# Generate a random batch of data
batch_size = 2
input_shape = test_data[0].shape
random_tensor = torch.randn((batch_size, *input_shape), device="cpu")

# Check if both outputs and custom params differ after randomization
with torch.no_grad():
original_out = model(random_tensor)
randomized_out = rand_model(random_tensor)

assert not torch.allclose(
original_out, randomized_out
), "Outputs do not differ after randomization."

assert not torch.allclose(
original_custom_param, randomized_custom_param
), "Custom parameter did not change after randomization."

assert not torch.isnan(
randomized_out
).any(), "Randomized model output contains NaNs."


@pytest.mark.heuristic_metrics
Expand Down
Loading