diff --git a/quantus/metrics/randomisation/efficient_mprt.py b/quantus/metrics/randomisation/efficient_mprt.py index eecbb1f4..3c4afc98 100644 --- a/quantus/metrics/randomisation/efficient_mprt.py +++ b/quantus/metrics/randomisation/efficient_mprt.py @@ -400,42 +400,42 @@ def __call__( if self.skip_layers and (l_ix + 1) < n_layers: continue - # Generate explanations on perturbed model in batches. - a_perturbed_generator = self.generate_explanations( - random_layer_model, x_full_dataset, y_full_dataset, batch_size - ) + # Generate explanations on perturbed model in batches. + a_perturbed_generator = self.generate_explanations( + random_layer_model, x_full_dataset, y_full_dataset, batch_size + ) - # Compute the complexity of explanations of the perturbed model. - self.explanation_scores_by_layer[layer_name] = [] - for a_batch, a_batch_perturbed in zip( - self.generate_a_batches(a_full_dataset), a_perturbed_generator - ): - for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): - score = self.evaluate_instance( - model=random_layer_model, - x=None, - y=None, - s=None, - a=a_instance_perturbed, - ) - self.explanation_scores_by_layer[layer_name].append(score) - pbar.update(1) - - # Wrap the model. - random_layer_model_wrapped = utils.get_wrapped_model( - model=random_layer_model, - channel_first=channel_first, - softmax=softmax, - device=device, - model_predict_kwargs=model_predict_kwargs, - ) + # Compute the complexity of explanations of the perturbed model. + self.explanation_scores_by_layer[layer_name] = [] + for a_batch, a_batch_perturbed in zip( + self.generate_a_batches(a_full_dataset), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance_perturbed, + ) + self.explanation_scores_by_layer[layer_name].append(score) + pbar.update(1) + + # Wrap the model. + random_layer_model_wrapped = utils.get_wrapped_model( + model=random_layer_model, + channel_first=channel_first, + softmax=softmax, + device=device, + model_predict_kwargs=model_predict_kwargs, + ) - # Predict and save complexity scores of the perturbed model outputs. - self.model_scores_by_layer[layer_name] = [] - y_preds = random_layer_model_wrapped.predict(x_full_dataset) - for y_ix, y_pred in enumerate(y_preds): - score = entropy(a=y_pred, x=y_pred) - self.model_scores_by_layer[layer_name].append(score) + # Predict and save complexity scores of the perturbed model outputs. + self.model_scores_by_layer[layer_name] = [] + y_preds = random_layer_model_wrapped.predict(x_full_dataset) + for y_ix, y_pred in enumerate(y_preds): + score = entropy(a=y_pred, x=y_pred) + self.model_scores_by_layer[layer_name].append(score) # Save evaluation scores as the relative rise in complexity. explanation_scores = list(self.explanation_scores_by_layer.values()) diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/mprt.py index cd686ea2..cca0b43d 100644 --- a/quantus/metrics/randomisation/mprt.py +++ b/quantus/metrics/randomisation/mprt.py @@ -323,10 +323,6 @@ def __call__( ): pbar.desc = layer_name - # Skip layers if computing delta. - if self.skip_layers and (l_ix + 1) < n_layers: - continue - if l_ix == 0: # Generate explanations on original model in batches. @@ -354,28 +350,32 @@ def __call__( self.evaluation_scores["original"].append(score) pbar.update(1) - self.evaluation_scores[layer_name] = [] - - # Generate explanations on perturbed model in batches. - a_perturbed_generator = self.generate_explanations( - random_layer_model, x_full_dataset, y_full_dataset, batch_size - ) + # Skip layers if computing delta. + if self.skip_layers and (l_ix + 1) < n_layers: + continue - # Compute the similarity of explanations of the perturbed model. - for a_batch, a_batch_perturbed in zip( - self.generate_a_batches(a_full_dataset), a_perturbed_generator - ): - for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): - score = self.evaluate_instance( - model=random_layer_model, - x=None, - y=None, - s=None, - a=a_instance, - a_perturbed=a_instance_perturbed, - ) - self.evaluation_scores[layer_name].append(score) - pbar.update(1) + self.evaluation_scores[layer_name] = [] + + # Generate explanations on perturbed model in batches. + a_perturbed_generator = self.generate_explanations( + random_layer_model, x_full_dataset, y_full_dataset, batch_size + ) + + # Compute the similarity of explanations of the perturbed model. + for a_batch, a_batch_perturbed in zip( + self.generate_a_batches(a_full_dataset), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_perturbed, + ) + self.evaluation_scores[layer_name].append(score) + pbar.update(1) if self.return_average_correlation: self.evaluation_scores = self.recompute_average_correlation_per_sample() diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index a0f308b4..cb62a6aa 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -343,10 +343,6 @@ def __call__( ): pbar.desc = layer_name - # Skip layers if computing delta. - if self.skip_layers and (l_ix + 1) < n_layers: - continue - if l_ix == 0: # Generate explanations on original model in batches. @@ -377,31 +373,35 @@ def __call__( self.evaluation_scores["original"].append(score) pbar.update(1) - self.evaluation_scores[layer_name] = [] - - # Generate explanations on perturbed model in batches. - a_perturbed_generator = self.generate_explanations( - random_layer_model, - x_full_dataset, - y_full_dataset, - **kwargs, - ) + # Skip layers if computing delta. + if self.skip_layers and (l_ix + 1) < n_layers: + continue - # Compute the similarity of explanations of the perturbed model. - for a_batch, a_batch_perturbed in zip( - self.generate_a_batches(a_full_dataset), a_perturbed_generator - ): - for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): - score = self.evaluate_instance( - model=random_layer_model, - x=None, - y=None, - s=None, - a=a_instance, - a_perturbed=a_instance_perturbed, - ) - self.evaluation_scores[layer_name].append(score) - pbar.update(1) + self.evaluation_scores[layer_name] = [] + + # Generate explanations on perturbed model in batches. + a_perturbed_generator = self.generate_explanations( + random_layer_model, + x_full_dataset, + y_full_dataset, + **kwargs, + ) + + # Compute the similarity of explanations of the perturbed model. + for a_batch, a_batch_perturbed in zip( + self.generate_a_batches(a_full_dataset), a_perturbed_generator + ): + for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + score = self.evaluate_instance( + model=random_layer_model, + x=None, + y=None, + s=None, + a=a_instance, + a_perturbed=a_instance_perturbed, + ) + self.evaluation_scores[layer_name].append(score) + pbar.update(1) if self.return_average_correlation: self.evaluation_scores = self.recompute_average_correlation_per_sample()