Skip to content

Commit

Permalink
usage + device updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 28, 2024
1 parent d4b0cba commit 84e398d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
4 changes: 4 additions & 0 deletions src/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __init__(
self._layer: Optional[Union[List[str], str]] = None
self.layer = layers

if device != "cpu":
warnings.warn("CaptumSimilarity explainer only supports CPU devices. Setting device to 'cpu'.")
device = "cpu"

# TODO: validate SimilarityInfluence kwargs
explainer_kwargs.update(
{
Expand Down
5 changes: 4 additions & 1 deletion src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
test_labels.shape[0] == explanations.shape[0]
), f"Number of explanations ({explanations.shape[0]}) exceeds the number of test labels ({test_labels.shape[0]})."

test_labels = test_labels.to(self.device)
explanations = explanations.to(self.device)

top_one_xpl_indices = explanations.argmax(dim=1)
top_one_xpl_targets = torch.stack([self.train_dataset[i][1] for i in top_one_xpl_indices])
top_one_xpl_targets = torch.tensor([self.train_dataset[i][1] for i in top_one_xpl_indices]).to(self.device)

scores = (test_labels == top_one_xpl_targets) * 1.0
self.scores.append(scores)
Expand Down
5 changes: 4 additions & 1 deletion src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ def update(
explanations: torch.Tensor,
explanation_targets: Optional[torch.Tensor] = None,
):
explanations = explanations.to(self.device)

rand_explanations = self.explain_fn(
model=self.rand_model, test_tensor=test_data, explanation_targets=explanation_targets, device=self.device
)
).to(self.device)

corrs = self.corr_measure(explanations, rand_explanations)
self.results["scores"].append(corrs)

Expand Down
3 changes: 3 additions & 0 deletions src/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def update(
explanations: torch.Tensor,
**kwargs,
):

explanations = explanations.to(self.device)

top_k_indices = torch.topk(explanations, self.top_k).indices
self.all_top_k_examples = torch.concat((self.all_top_k_examples, top_k_indices), dim=0)

Expand Down
21 changes: 12 additions & 9 deletions tutorials/usage_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from src.metrics.unnamed.top_k_overlap import TopKOverlap

DEVICE = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = "cuda" # "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", DEVICE.upper())

# manual random seed is used for dataset partitioning
Expand All @@ -45,13 +45,13 @@ def main():
)

train_set = torchvision.datasets.CIFAR10(root="./tutorials/data", train=True, download=True, transform=normalize)
train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=2)

# we split held out data into test and validation set
held_out = torchvision.datasets.CIFAR10(root="./tutorials/data", train=False, download=True, transform=normalize)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)
# val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)
test_set, val_set = torch.utils.data.random_split(held_out, [0.1, 0.9], generator=RNG)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)
# val_loader = DataLoader(val_set, batch_size=100, shuffle=False, num_workers=2)

# download pre-trained weights
local_path = "./tutorials/model_weights_resnet18_cifar10.pth"
Expand Down Expand Up @@ -113,7 +113,7 @@ def accuracy(net, loader):
# ++++++++++++++++++++++++++++++++++++++++++

explain = captum_similarity_explain
explain_fn_kwargs = {"layers": "avgpool"}
explain_fn_kwargs = {"layers": "avgpool", "batch_size": 100}
model_id = "default_model_id"
cache_dir = "./cache"
model_rand = ModelRandomizationMetric(
Expand All @@ -130,7 +130,7 @@ def accuracy(net, loader):

id_class = IdenticalClass(model=model, train_dataset=train_set, device=DEVICE)

top_k = TopKOverlap(model=model, train_dataset=train_set, top_k=1, device="cpu")
top_k = TopKOverlap(model=model, train_dataset=train_set, top_k=1, device=DEVICE)

# iterate over test set and feed tensor batches first to explain, then to metric
for i, (data, target) in enumerate(tqdm(test_loader)):
Expand All @@ -146,9 +146,12 @@ def accuracy(net, loader):
)
model_rand.update(data, tda)
id_class.update(target, tda)
top_k.update(target)
top_k.update(tda)

print("Model randomization metric output:", model_rand.compute())
print("Identical class metric output:", id_class.compute())
print("Top-k overlap metric output:", top_k.compute())

print("Model randomization metric output:", model_rand.compute().item())
print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%")


Expand Down

0 comments on commit 84e398d

Please sign in to comment.