diff --git a/quanda/benchmarks/heuristics/mixed_datasets.py b/quanda/benchmarks/heuristics/mixed_datasets.py index 5e186551..40dd25a9 100644 --- a/quanda/benchmarks/heuristics/mixed_datasets.py +++ b/quanda/benchmarks/heuristics/mixed_datasets.py @@ -346,7 +346,7 @@ def download(cls, name: str, cache_dir: str, device: str, *args, **kwargs): adversarial_dir=adversarial_dir, adversarial_label=bench_state["adversarial_label"], adversarial_transform=adversarial_transform, - adv_train_indices = adv_train_indices, + adv_train_indices=adv_train_indices, data_transform=dataset_transform, checkpoint_paths=checkpoint_paths, ) diff --git a/quanda/metrics/heuristics/model_randomization.py b/quanda/metrics/heuristics/model_randomization.py index 48d46b00..a8e7c163 100644 --- a/quanda/metrics/heuristics/model_randomization.py +++ b/quanda/metrics/heuristics/model_randomization.py @@ -193,19 +193,35 @@ def load_state_dict(self, state_dict: dict): self.results = state_dict["results_dict"] self.rand_model.load_state_dict(state_dict["rnd_model"]) + def _randomize_parameter(self, param, parent, param_name): + """Reset or randomize a parameter. + + Parameters + ---------- + param : torch.Tensor + The parameter tensor. + parent : torch.nn.Module + The parent module of the parameter. + param_name : str + The name of the parameter. + + """ + if hasattr(parent, "reset_parameters"): + torch.manual_seed(self.seed) + parent.reset_parameters() + else: + torch.nn.init.normal_(param, generator=self.generator) + parent.__setattr__(param_name, torch.nn.Parameter(param)) + def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]: """Randomize the model parameters. - Currently, only linear and convolutional layers are supported. - Returns ------- torch.nn.Module The randomized model. """ - # TODO: Add support for other layer types. - rand_model = copy.deepcopy(self.model) rand_checkpoints = [] @@ -214,24 +230,14 @@ def _randomize_model(self) -> Tuple[torch.nn.Module, List[str]]: for name, param in list(rand_model.named_parameters()): parent = get_parent_module_from_name(rand_model, name) - # TODO: currently only linear layer is randomized - if isinstance(parent, (torch.nn.Linear)): - random_param_tensor = torch.nn.init.normal_( - param, generator=self.generator - ) - parent.__setattr__( - name.split(".")[-1], - torch.nn.Parameter(random_param_tensor), - ) - - # save randomized checkpoint + param_name = name.split(".")[-1] + self._randomize_parameter(param, parent, param_name) + + # Save randomized checkpoint chckpt_path = os.path.join( self.cache_dir, f"{self.model_id}_rand_{i}.pth" ) - torch.save( - rand_model.state_dict(), - chckpt_path, - ) + torch.save(rand_model.state_dict(), chckpt_path) rand_checkpoints.append(chckpt_path) return rand_model, rand_checkpoints diff --git a/tests/benchmarks/heuristics/test_model_randomization.py b/tests/benchmarks/heuristics/test_model_randomization.py index 7dd98baa..2869fb55 100644 --- a/tests/benchmarks/heuristics/test_model_randomization.py +++ b/tests/benchmarks/heuristics/test_model_randomization.py @@ -30,7 +30,7 @@ "similarity_metric": cosine_similarity, }, None, - 0.5208332538604736, + 0.717261791229248, ), ( "mnist2", @@ -49,7 +49,7 @@ "similarity_metric": cosine_similarity, }, None, - 0.5208332538604736, + 0.717261791229248, ), ], ) @@ -120,7 +120,7 @@ def test_model_randomization( "similarity_metric": cosine_similarity, "load_from_disk": True, }, - 0.509926438331604, + 0.19356615841388702, ), ], ) diff --git a/tests/metrics/test_heuristics_metrics.py b/tests/metrics/test_heuristics_metrics.py index 22e69f63..a189fa6e 100644 --- a/tests/metrics/test_heuristics_metrics.py +++ b/tests/metrics/test_heuristics_metrics.py @@ -9,14 +9,16 @@ TopKCardinalityMetric, ) from quanda.metrics.heuristics.mixed_datasets import MixedDatasetsMetric -from quanda.utils.common import get_parent_module_from_name from quanda.utils.functions import correlation_functions, cosine_similarity +from quanda.utils.common import ( + get_parent_module_from_name, +) @pytest.mark.heuristic_metrics @pytest.mark.parametrize( - "test_id, model, checkpoint,dataset, test_data, batch_size, explainer_cls, \ - expl_kwargs, explanations, test_labels, correlation_fn", + "test_id, model, checkpoint, dataset, test_data, " + "explainer_cls, expl_kwargs, explanations, test_labels", [ ( "mnist_update_only_spearman", @@ -24,7 +26,6 @@ "load_mnist_last_checkpoint", "load_mnist_dataset", "load_mnist_test_samples_1", - 8, CaptumSimilarity, { "layers": "fc_2", @@ -32,7 +33,6 @@ }, "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", - "spearman", ), ( "mnist_update_only_kendall", @@ -40,7 +40,6 @@ "load_mnist_last_checkpoint", "load_mnist_dataset", "load_mnist_test_samples_1", - 8, CaptumSimilarity, { "layers": "fc_2", @@ -48,22 +47,19 @@ }, "load_mnist_explanations_similarity_1", "load_mnist_test_labels_1", - "kendall", ), ], ) -def test_randomization_metric( +def test_randomization_metric_score( test_id, model, checkpoint, dataset, test_data, - batch_size, explainer_cls, expl_kwargs, explanations, test_labels, - correlation_fn, tmp_path, request, ): @@ -74,6 +70,7 @@ def test_randomization_metric( test_labels = request.getfixturevalue(test_labels) tda = request.getfixturevalue(explanations) expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs} + metric = ModelRandomizationMetric( model=model, model_id=0, @@ -89,7 +86,213 @@ def test_randomization_metric( ) out = metric.compute()["score"] - assert (out >= -1.0) & (out <= 1.0), "Test failed." + assert (out >= -1.0) & ( + out <= 1.0 + ), "Metric score is out of expected range." + + +@pytest.mark.heuristic_metrics +@pytest.mark.parametrize( + "test_id, model, checkpoint, dataset, test_data, batch_size, " + "explainer_cls, expl_kwargs, explanations, test_labels", + [ + ( + "mnist_update_only_spearman", + "load_mnist_model", + "load_mnist_last_checkpoint", + "load_mnist_dataset", + "load_mnist_test_samples_1", + 8, + CaptumSimilarity, + { + "layers": "fc_2", + "similarity_metric": cosine_similarity, + }, + "load_mnist_explanations_similarity_1", + "load_mnist_test_labels_1", + ), + ( + "mnist_update_only_kendall", + "load_mnist_model", + "load_mnist_last_checkpoint", + "load_mnist_dataset", + "load_mnist_test_samples_1", + 8, + CaptumSimilarity, + { + "layers": "fc_2", + "similarity_metric": cosine_similarity, + }, + "load_mnist_explanations_similarity_1", + "load_mnist_test_labels_1", + ), + ], +) +def test_randomization_metric_randomization( + test_id, + model, + checkpoint, + dataset, + test_data, + batch_size, + explainer_cls, + expl_kwargs, + explanations, + test_labels, + tmp_path, + request, +): + model = request.getfixturevalue(model) + checkpoint = request.getfixturevalue(checkpoint) + test_data = request.getfixturevalue(test_data) + dataset = request.getfixturevalue(dataset) + test_labels = request.getfixturevalue(test_labels) + expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs} + + metric = ModelRandomizationMetric( + model=model, + model_id=0, + checkpoints=checkpoint, + train_dataset=dataset, + explainer_cls=explainer_cls, + expl_kwargs=expl_kwargs, + cache_dir=str(tmp_path), + seed=42, + ) + + # Generate a random batch of data + batch_size = 2 + input_shape = test_data[0].shape + random_tensor = torch.randn((batch_size, *input_shape), device="cpu") + + # Randomize model + rand_model = metric._randomize_model()[0] + rand_model.eval() + model.eval() + + # Check if the outputs differ after randomization + with torch.no_grad(): + original_out = model(random_tensor) + randomized_out = rand_model(random_tensor) + + assert not torch.allclose( + original_out, randomized_out + ), "Outputs do not differ after randomization." + assert not torch.isnan( + randomized_out + ).any(), "Randomized model output contains NaNs." + + +@pytest.mark.heuristic_metrics +@pytest.mark.parametrize( + "test_id, model, checkpoint, dataset, test_data, batch_size, " + "explainer_cls, expl_kwargs, explanations, test_labels", + [ + ( + "mnist_update_only_spearman", + "load_mnist_model", + "load_mnist_last_checkpoint", + "load_mnist_dataset", + "load_mnist_test_samples_1", + 8, + CaptumSimilarity, + { + "layers": "fc_2", + "similarity_metric": cosine_similarity, + }, + "load_mnist_explanations_similarity_1", + "load_mnist_test_labels_1", + ), + ( + "mnist_update_only_kendall", + "load_mnist_model", + "load_mnist_last_checkpoint", + "load_mnist_dataset", + "load_mnist_test_samples_1", + 8, + CaptumSimilarity, + { + "layers": "fc_2", + "similarity_metric": cosine_similarity, + }, + "load_mnist_explanations_similarity_1", + "load_mnist_test_labels_1", + ), + ], +) +def test_randomization_metric_custom_param( + test_id, + model, + checkpoint, + dataset, + test_data, + batch_size, + explainer_cls, + expl_kwargs, + explanations, + test_labels, + tmp_path, + request, +): + model = request.getfixturevalue(model) + checkpoint = request.getfixturevalue(checkpoint) + test_data = request.getfixturevalue(test_data) + dataset = request.getfixturevalue(dataset) + test_labels = request.getfixturevalue(test_labels) + expl_kwargs = {"model_id": "0", "cache_dir": str(tmp_path), **expl_kwargs} + + def _load_flexible_state_dict(model: torch.nn.Module, path: str): + checkpoint = torch.load(path, map_location="cpu") + model.load_state_dict(checkpoint, strict=False) + return model + + metric = ModelRandomizationMetric( + model=model, + model_id=0, + checkpoints=checkpoint, + checkpoints_load_func=_load_flexible_state_dict, + train_dataset=dataset, + explainer_cls=explainer_cls, + expl_kwargs=expl_kwargs, + cache_dir=str(tmp_path), + seed=42, + ) + + # Add a custom parameter to the model + model.custom_param = torch.nn.Parameter(torch.randn(4)) + model.eval() + + # Save the original custom parameter + original_custom_param = model.custom_param.data.clone() + + # Randomize model + rand_model = metric._randomize_model()[0] + rand_model.eval() + + # Save the randomized custom parameter + randomized_custom_param = rand_model.custom_param.data.clone() + + # Generate a random batch of data + batch_size = 2 + input_shape = test_data[0].shape + random_tensor = torch.randn((batch_size, *input_shape), device="cpu") + + # Check if both outputs and custom params differ after randomization + with torch.no_grad(): + original_out = model(random_tensor) + randomized_out = rand_model(random_tensor) + + assert not torch.allclose( + original_out, randomized_out + ), "Outputs do not differ after randomization." + + assert not torch.allclose( + original_custom_param, randomized_custom_param + ), "Custom parameter did not change after randomization." + + assert not torch.isnan( + randomized_out + ).any(), "Randomized model output contains NaNs." @pytest.mark.heuristic_metrics