Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trak wrapper #106

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ dependencies = [
"lightning>=1.4.0",
"torchmetrics>=1.4.0",
"tqdm>=4.0.0",
"traker>=0.3.2"
]
dynamic = ["version"]

[tool.isort]
profile = "black"
extend_skip = ["__init__.py"]
line_length = 79
multi_line_output = 3
include_trailing_comma = true
Expand All @@ -29,6 +31,10 @@ warn_unused_configs = true
check_untyped_defs = true
#ignore_errors = true # TODO: change this

[[tool.mypy.overrides]]
module = ["trak", "trak.projectors", "fast_jl"]
ignore_missing_imports = true
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved

# Black formatting
[tool.black]
line-length = 127
Expand Down
2 changes: 1 addition & 1 deletion quanda/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from quanda.explainers.base import BaseExplainer
from quanda.explainers import utils, wrappers
from quanda.explainers.aggregators import (
AbsSumAggregator,
BaseAggregator,
SumAggregator,
aggr_types,
)
from quanda.explainers.base import BaseExplainer
from quanda.explainers.functional import ExplainFunc, ExplainFuncMini
from quanda.explainers.random import RandomExplainer

Expand Down
5 changes: 5 additions & 0 deletions quanda/explainers/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
captum_tracincp_self_influence,
)

from quanda.explainers.wrappers.trak_wrapper import TRAK, trak_explain, trak_self_influence

__all__ = [
"CaptumInfluence",
"CaptumSimilarity",
Expand All @@ -22,4 +24,7 @@
"CaptumTracInCP",
"captum_tracincp_explain",
"captum_tracincp_self_influence",
"TRAK",
"trak_explain",
"trak_self_influence",
]
2 changes: 1 addition & 1 deletion quanda/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ArnoldiInfluenceFunction,
)

from quanda.explainers.base import BaseExplainer
from quanda.explainers import BaseExplainer
from quanda.explainers.utils import (
explain_fn_from_explainer,
self_influence_fn_from_explainer,
Expand Down
171 changes: 171 additions & 0 deletions quanda/explainers/wrappers/trak_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import warnings
from typing import Any, Iterable, List, Literal, Optional, Sized, Union

import torch
from trak import TRAKer
from trak.projectors import BasicProjector, CudaProjector, NoOpProjector

from quanda.explainers import BaseExplainer
from quanda.explainers.utils import (
explain_fn_from_explainer,
self_influence_fn_from_explainer,
)

TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"]
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
TRAKProjectionTypeLiteral = Literal["rademacher", "normal"]


class TRAK(BaseExplainer):
def __init__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
model_id: str,
cache_dir: Optional[str] = None,
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
device: Union[str, torch.device] = "cpu",
projector: TRAKProjectorLiteral = "check_cuda",
proj_dim: int = 128,
proj_type: TRAKProjectionTypeLiteral = "normal",
seed: int = 42,
batch_size: int = 32,
params_ldr: Optional[Iterable] = None,
):
super(TRAK, self).__init__(
model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device
)
self.dataset = train_dataset
self.proj_dim = proj_dim
self.batch_size = batch_size
self.cache_dir = cache_dir if cache_dir is not None else f"./trak_{model_id}_cache"

num_params_for_grad = 0
params_iter = params_ldr if params_ldr is not None else self.model.parameters()
for p in list(params_iter):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure this is a robust way to count parameters?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very nitpicky, but I'm sure there are way more elegant ways to do 43 - 47

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also count parameters of a given parameter loader, not necessarily all the parameters in the model. User may want to only consider final layer etc.

nn = 1
for s in list(p.size()):
nn = nn * s
num_params_for_grad += nn

# Check if traker was installer with the ["cuda"] option
if projector in ["cuda", "check_cuda"]:
try:
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
import fast_jl

test_gradient = torch.ones(1, num_params_for_grad).cuda()
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms)
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
projector = "cuda"
except (ImportError, RuntimeError, AttributeError) as e:
warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}")
warnings.warn("Defaulting to BasicProjector.")
projector = "basic"

projector_cls = {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not outside of the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, forgot to talk about this during our meeting. For this, I think along these lines: the only reason for that projector_cls dictionary to exist, is that it let's us quickly implement if statements. So if we were looking at it in terms of where it is used, why we have that dictionary, it's only a local thing for this explainer so I wouldn't define it outside the class..

of course, nothing changes practically if we define it outside so i let you decide if you definitely want to keep it outside, or are convinced by my explanation.

"cuda": CudaProjector,
"basic": BasicProjector,
"noop": NoOpProjector,
}

projector_kwargs = {
"grad_dim": num_params_for_grad,
"proj_dim": proj_dim,
"proj_type": proj_type,
"seed": seed,
"device": device,
}
if projector == "cuda":
projector_kwargs["max_batch_size"] = self.batch_size
projector_obj = projector_cls[projector](**projector_kwargs)
self.traker = TRAKer(
model=model,
task="image_classification",
train_set_size=self.dataset_length,
projector=projector_obj,
proj_dim=proj_dim,
projector_seed=seed,
save_dir=self.cache_dir,
device=device,
use_half_precision=False,
)

# Train the TRAK explainer: featurize the training data
ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size)
self.traker.load_checkpoint(self.model.state_dict(), model_id=0)
for i, (x, y) in enumerate(iter(ld)):
batch = x.to(self.device), y.to(self.device)
self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(x.shape[0])]))
self.traker.finalize_features()
if projector == "basic":
# finalize_features frees memory so projector.proj_matrix needs to be reconstructed
self.traker.projector = projector_cls[projector](**projector_kwargs)

@property
def dataset_length(self) -> int:
"""
By default, the Dataset class does not always have a __len__ method.
:return:
"""
if isinstance(self.dataset, Sized):
return len(self.dataset)
dl = torch.utils.data.DataLoader(self.dataset, batch_size=1)
return len(dl)

def explain(self, test, targets):
test = test.to(self.device)
self.traker.start_scoring_checkpoint(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make exp_namean argument to init? it can be test by default.

also I didn't look much into the trak library - but does it make more sense to do start_scoring_checkpoint in init as well? it might be doing something intense that we don't want to repeat every time we explain, for all I know

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming we decided to keep this as is. you can make the change very quickly if you decide to do it.

model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=test.shape[0]
)
self.traker.score(batch=(test, targets), num_samples=test.shape[0])
explanations = torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device)

# os.remove(os.path.join(self.cache_dir, "scores", "test.mmap"))
# os.removedirs(os.path.join(self.cache_dir, "scores"))

return explanations


def trak_explain(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
**kwargs: Any,
) -> torch.Tensor:
return explain_fn_from_explainer(
explainer_cls=TRAK,
model=model,
model_id=model_id,
cache_dir=cache_dir,
test_tensor=test_tensor,
targets=explanation_targets,
train_dataset=train_dataset,
device=device,
**kwargs,
)


def trak_self_influence(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
batch_size: Optional[int] = 32,
**kwargs: Any,
) -> torch.Tensor:
self_influence_kwargs = {
"batch_size": batch_size,
}
return self_influence_fn_from_explainer(
explainer_cls=TRAK,
model=model,
model_id=model_id,
cache_dir=cache_dir,
train_dataset=train_dataset,
device=device,
self_influence_kwargs=self_influence_kwargs,
**kwargs,
)
2 changes: 1 addition & 1 deletion quanda/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from quanda.metrics.base import GlobalMetric, Metric
from quanda.metrics import localization, randomization, unnamed
from quanda.metrics.aggr_strategies import (
GlobalAggrStrategy,
GlobalSelfInfluenceStrategy,
)
from quanda.metrics.base import GlobalMetric, Metric

__all__ = [
"Metric",
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/localization/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from quanda.metrics.base import Metric
from quanda.metrics import Metric


class ClassDetectionMetric(Metric):
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from quanda.metrics.base import GlobalMetric
from quanda.metrics import GlobalMetric


class MislabelingDetectionMetric(GlobalMetric):
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from quanda.metrics.base import Metric
from quanda.metrics import Metric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import CorrelationFnLiterals, correlation_functions

Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import lightning as L
import torch

from quanda.metrics.base import GlobalMetric
from quanda.metrics import GlobalMetric
from quanda.utils.common import class_accuracy
from quanda.utils.training import BaseTrainer

Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from quanda.metrics.base import Metric
from quanda.metrics import Metric


class TopKOverlapMetric(Metric):
Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from quanda.toy_benchmarks import localization, randomization, unnamed
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import localization, randomization, unnamed

__all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"]
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/localization/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm

from quanda.metrics.localization import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark


class ClassDetection(ToyBenchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from quanda.metrics.localization.mislabeling_detection import (
MislabelingDetectionMetric,
)
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.datasets.transformed.label_flipping import (
LabelFlippingDataset,
)
Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/localization/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.datasets.transformed.label_grouping import (
ClassToGroupLiterals,
LabelGroupingDataset,
Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from quanda.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.functions import CorrelationFnLiterals


Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.training.trainer import BaseTrainer


Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm

from quanda.metrics.unnamed import TopKOverlapMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark


class TopKOverlap(ToyBenchmark):
Expand Down
18 changes: 18 additions & 0 deletions src/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from quanda.explainers.base import BaseExplainer
from quanda.explainers import utils, wrappers
from quanda.explainers.functional import ExplainFunc, ExplainFuncMini
from quanda.explainers.random import RandomExplainer
from quanda.explainers.aggregators import BaseAggregator, SumAggregator


__all__ = [
"BaseExplainer",
"RandomExplainer",
"ExplainFunc",
"ExplainFuncMini",
"utils",
"wrappers",
"BaseAggregator",
"SumAggregator",
"AbsSumAggretor",
]
14 changes: 14 additions & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from quanda.metrics.base import GlobalMetric, Metric
from quanda.metrics import localization, randomization, unnamed
from quanda.metrics.aggr_strategies import GlobalAggrStrategy, GlobalSelfInfluenceStrategy


__all__ = [
"Metric",
"GlobalMetric",
"GlobalAggrStrategy",
"GlobalSelfInfluenceStrategy",
"randomization",
"localization",
"unnamed",
]
Loading
Loading