Skip to content

Commit

Permalink
test: add extended lds benchmark test
Browse files Browse the repository at this point in the history
  • Loading branch information
aski02 committed Dec 9, 2024
1 parent c199ca0 commit e4c5e06
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 4 deletions.
124 changes: 121 additions & 3 deletions tests/benchmarks/ground_truth/test_linear_datamodeling.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from quanda.benchmarks.ground_truth.linear_datamodeling import (
LinearDatamodeling,
)
import math

import pytest
import torch

from quanda.benchmarks.ground_truth.linear_datamodeling import (
LinearDatamodeling,
)
from quanda.explainers.wrappers.captum_influence import CaptumSimilarity
from quanda.utils.functions.correlations import spearman_rank_corr
from quanda.utils.functions.similarities import cosine_similarity
Expand Down Expand Up @@ -187,3 +187,121 @@ def test_linear_datamodeling(
expected_score = spearman_rank_corr(outputs, counterfactual)
expected_score = expected_score.mean().item()
assert math.isclose(score, expected_score, abs_tol=0.00001)


@pytest.mark.benchmarks
@pytest.mark.parametrize(
"test_id, init_method, model, checkpoint, optimizer, lr, criterion, dataset, n_classes, seed, "
"batch_size, explainer_cls, expl_kwargs, use_pred, subset_indices, pretrained_models",
[
(
"mnist0",
"assemble",
"load_mnist_model",
"load_mnist_last_checkpoint",
"torch_sgd_optimizer",
0.01,
"torch_cross_entropy_loss_object",
"load_mnist_dataset",
10,
27,
8,
CaptumSimilarity,
{"layers": "fc_2", "similarity_metric": cosine_similarity},
False,
"load_subset_indices_lds",
"load_pretrained_models_lds",
),
],
)
def test_linear_datamodeling_benchmark_extended(
test_id,
init_method,
model,
checkpoint,
optimizer,
lr,
criterion,
dataset,
n_classes,
seed,
batch_size,
explainer_cls,
expl_kwargs,
use_pred,
subset_indices,
pretrained_models,
tmp_path,
request,
):
model = request.getfixturevalue(model)
checkpoint = request.getfixturevalue(checkpoint)
optimizer = request.getfixturevalue(optimizer)
criterion = request.getfixturevalue(criterion)
dataset = request.getfixturevalue(dataset)
subset_indices = request.getfixturevalue(subset_indices)
pretrained_models = request.getfixturevalue(pretrained_models)

expl_kwargs = {
**expl_kwargs,
"model_id": test_id,
"cache_dir": str(tmp_path),
}

trainer = Trainer(
max_epochs=0,
optimizer=optimizer,
lr=lr,
criterion=criterion,
)

if init_method == "generate":
benchmark = LinearDatamodeling.generate(
model=model,
checkpoints=checkpoint,
trainer=trainer,
train_dataset=dataset,
eval_dataset=dataset,
n_classes=n_classes,
seed=seed,
batch_size=batch_size,
use_predictions=use_pred,
cache_dir=str(tmp_path),
model_id=test_id,
m=len(subset_indices),
alpha=0.5,
correlation_fn="spearman",
device="cpu",
subset_ids=subset_indices,
pretrained_models=pretrained_models,
)
elif init_method == "assemble":
benchmark = LinearDatamodeling.assemble(
model=model,
checkpoints=checkpoint,
trainer=trainer,
train_dataset=dataset,
eval_dataset=dataset,
n_classes=n_classes,
seed=seed,
batch_size=batch_size,
use_predictions=use_pred,
cache_dir=str(tmp_path),
model_id=test_id,
m=len(subset_indices),
alpha=0.5,
correlation_fn="spearman",
device="cpu",
subset_ids=subset_indices,
pretrained_models=pretrained_models,
)
else:
raise ValueError(f"Invalid init_method: {init_method}")

score = benchmark.evaluate(
explainer_cls=explainer_cls,
expl_kwargs=expl_kwargs,
batch_size=batch_size,
)["score"]

assert isinstance(score, float), "Score should be a float."
2 changes: 1 addition & 1 deletion tests/metrics/test_ground_truth_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_linear_datamodeling(
),
],
)
def test_linear_datamodeling_with_pretrained_models_and_subsets(
def test_linear_datamodeling_extended(
test_id,
model,
dataset,
Expand Down

0 comments on commit e4c5e06

Please sign in to comment.