Skip to content

Commit

Permalink
changes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Aug 15, 2024
1 parent 25280fe commit 09ad3a5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 21 deletions.
27 changes: 9 additions & 18 deletions quanda/explainers/wrappers/trak_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from importlib.util import find_spec
from typing import Any, Iterable, List, Literal, Optional, Sized, Union

import torch
Expand All @@ -11,7 +12,7 @@
self_influence_fn_from_explainer,
)

TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"]
TRAKProjectorLiteral = Literal["cuda", "noop", "basic"]
TRAKProjectionTypeLiteral = Literal["rademacher", "normal"]


Expand All @@ -21,9 +22,9 @@ def __init__(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
model_id: str,
cache_dir: Optional[str] = None,
cache_dir: str,
projector: TRAKProjectorLiteral,
device: Union[str, torch.device] = "cpu",
projector: TRAKProjectorLiteral = "check_cuda",
proj_dim: int = 128,
proj_type: TRAKProjectionTypeLiteral = "normal",
seed: int = 42,
Expand All @@ -41,23 +42,13 @@ def __init__(
num_params_for_grad = 0
params_iter = params_ldr if params_ldr is not None else self.model.parameters()
for p in list(params_iter):
nn = 1
for s in list(p.size()):
nn = nn * s
num_params_for_grad += nn

num_params_for_grad = num_params_for_grad + p.numel()
# Check if traker was installer with the ["cuda"] option
if projector in ["cuda", "check_cuda"]:
try:
import fast_jl

test_gradient = torch.ones(1, num_params_for_grad).cuda()
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms)
if projector == "cuda":
if find_spec("fast_jl"):
projector = "cuda"
except (ImportError, RuntimeError, AttributeError) as e:
warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}")
warnings.warn("Defaulting to BasicProjector.")
else:
warnings.warn("Could not find cuda installation of TRAK. Defaulting to BasicProjector.")
projector = "basic"

projector_cls = {
Expand Down
6 changes: 3 additions & 3 deletions tests/explainers/wrappers/test_trak_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"load_mnist_explanations_trak_1",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10},
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"},
),
],
)
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_trak_wrapper_explain_stateful(
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10},
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"},
"load_mnist_explanations_trak_1",
),
],
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_trak_wrapper_explain_functional(
"load_mnist_dataset",
"load_mnist_test_samples_1",
"load_mnist_test_labels_1",
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10},
{"model_id": "0", "batch_size": 8, "seed": 42, "proj_dim": 10, "projector": "basic"},
"load_mnist_explanations_trak_si_1",
),
],
Expand Down

0 comments on commit 09ad3a5

Please sign in to comment.