Skip to content

Commit

Permalink
fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Aug 12, 2024
1 parent de348f0 commit 67c09bd
Show file tree
Hide file tree
Showing 18 changed files with 111 additions and 57 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dynamic = ["version"]

[tool.isort]
profile = "black"
extend_skip = ["__init__.py"]
line_length = 79
multi_line_output = 3
include_trailing_comma = true
Expand All @@ -29,6 +30,10 @@ warn_unused_configs = true
check_untyped_defs = true
#ignore_errors = true # TODO: change this

[[tool.mypy.overrides]]
module = ["trak", "trak.projectors", "fast_jl"]
ignore_missing_imports = true

# Black formatting
[tool.black]
line-length = 127
Expand Down
2 changes: 1 addition & 1 deletion quanda/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from quanda.explainers.base import BaseExplainer
from quanda.explainers import utils, wrappers
from quanda.explainers.aggregators import (
AbsSumAggregator,
BaseAggregator,
SumAggregator,
aggr_types,
)
from quanda.explainers.base import BaseExplainer
from quanda.explainers.functional import ExplainFunc, ExplainFuncMini
from quanda.explainers.random import RandomExplainer

Expand Down
5 changes: 5 additions & 0 deletions quanda/explainers/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
captum_tracincp_self_influence,
)

from quanda.explainers.wrappers.trak import (
TRAK,
)

__all__ = [
"CaptumInfluence",
"CaptumSimilarity",
Expand All @@ -22,4 +26,5 @@
"CaptumTracInCP",
"captum_tracincp_explain",
"captum_tracincp_self_influence",
"TRAK",
]
2 changes: 1 addition & 1 deletion quanda/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ArnoldiInfluenceFunction,
)

from quanda.explainers.base import BaseExplainer
from quanda.explainers import BaseExplainer
from quanda.explainers.utils import (
explain_fn_from_explainer,
self_influence_fn_from_explainer,
Expand Down
128 changes: 86 additions & 42 deletions quanda/explainers/wrappers/trak.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from trak import TRAKer
from trak.projectors import BasicProjector, CudaProjector, NoOpProjector
from trak.projectors import ProjectionType
import warnings
from typing import Iterable, Literal, Optional, Sized, Union

from typing import Literal, Optional, Union
import os
import torch
from trak import TRAKer
from trak.projectors import (
BasicProjector,
CudaProjector,
NoOpProjector,
ProjectionType,
)

from quanda.explainers import BaseExplainer

TRAKProjectorLiteral=Literal["cuda", "noop", "basic"]
TRAKProjectionTypeLiteral=Literal["rademacher", "normal"]
# from quanda.explainers.utils import (
# explain_fn_from_explainer,
# self_influence_fn_from_explainer,
# )


TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"]
TRAKProjectionTypeLiteral = Literal["rademacher", "normal"]


class TRAK(BaseExplainer):
def __init__(
Expand All @@ -19,55 +30,88 @@ def __init__(
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
projector: TRAKProjectorLiteral="basic",
proj_dim: int=128,
proj_type: TRAKProjectionTypeLiteral="normal",
seed: int=42,
batch_size: int=32,
projector: TRAKProjectorLiteral = "check_cuda",
proj_dim: int = 128,
proj_type: TRAKProjectionTypeLiteral = "normal",
seed: int = 42,
batch_size: int = 32,
params_ldr: Optional[Iterable] = None,
):
super(TRAK, self).__init__(model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device)
self.dataset=train_dataset
self.batch_size=batch_size
proj_type=ProjectionType.normal if proj_type=="normal" else ProjectionType.rademacher

number_of_params=0
super(TRAK, self).__init__(
model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device
)
self.dataset = train_dataset
self.proj_dim = proj_dim
self.batch_size = batch_size
proj_type = ProjectionType.normal if proj_type == "normal" else ProjectionType.rademacher

num_params_for_grad = 0
for p in list(self.model.sim_parameters()):
nn = 1
for s in list(p.size()):
nn = nn * s
number_of_params += nn

num_params_for_grad += nn

# Check if traker was installer with the ["cuda"] option
try:
import fast_jl

test_gradient = torch.ones(1, num_params_for_grad).cuda()
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms)
projector = "cuda"
except (ImportError, RuntimeError, AttributeError) as e:
warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}")
warnings.warn("Defaulting to BasicProjector.")
projector = "basic"

projector_cls = {
"cuda": CudaProjector,
"basic": BasicProjector,
"noop": NoOpProjector
"noop": NoOpProjector,
}
projector_kwargs={
"grad_dim": number_of_params,

projector_kwargs = {
"grad_dim": num_params_for_grad,
"proj_dim": proj_dim,
"proj_type": proj_type,
"seed": seed,
"device": device
"device": device,
}
projector=projector_cls[projector](**projector_kwargs)
self.traker = TRAKer(model=model, task='image_classification', train_set_size=len(train_dataset),
projector=projector, proj_dim=proj_dim, projector_seed=seed, save_dir=cache_dir)
projector = projector_cls[projector](**projector_kwargs)
self.traker = TRAKer(
model=model,
task="image_classification",
train_set_size=self.dataset_length,
projector=projector,
proj_dim=proj_dim,
projector_seed=seed,
save_dir=cache_dir,
)

#Train the TRAK explainer: featurize the training data
ld=torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size)
self.traker.load_checkpoint(self.model.state_dict(),model_id=0)
for (i,(x,y)) in enumerate(iter(ld)):
batch=x.to(self.device), y.to(self.device)
self.traker.featurize(batch=batch,inds=torch.tensor([i*self.batch_size+j for j in range(self.batch_size)]))
# Train the TRAK explainer: featurize the training data
ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size)
self.traker.load_checkpoint(self.model.state_dict(), model_id=0)
for i, (x, y) in enumerate(iter(ld)):
batch = x.to(self.device), y.to(self.device)
self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(self.batch_size)]))
self.traker.finalize_features()

def explain(self, x, targets):
x=x.to(self.device)
self.traker.start_scoring_checkpoint(model_id=0,
checkpoint=self.model.state_dict(),
exp_name='test',
num_targets=x.shape[0])
self.traker.score(batch=(x,targets), num_samples=x.shape[0])
return torch.from_numpy(self.traker.finalize_scores(exp_name='test')).T.to(self.device)
@property
def dataset_length(self) -> int:
"""
By default, the Dataset class does not always have a __len__ method.
:return:
"""
if isinstance(self.dataset, Sized):
return len(self.dataset)
dl = torch.utils.data.DataLoader(self.dataset, batch_size=1)
return len(dl)

def explain(self, x, targets):
x = x.to(self.device)
self.traker.start_scoring_checkpoint(
model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=x.shape[0]
)
self.traker.score(batch=(x, targets), num_samples=x.shape[0])
return torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device)
2 changes: 1 addition & 1 deletion quanda/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from quanda.metrics.base import GlobalMetric, Metric
from quanda.metrics import localization, randomization, unnamed
from quanda.metrics.aggr_strategies import (
GlobalAggrStrategy,
GlobalSelfInfluenceStrategy,
)
from quanda.metrics.base import GlobalMetric, Metric

__all__ = [
"Metric",
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/localization/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from quanda.metrics.base import Metric
from quanda.metrics import Metric


class ClassDetectionMetric(Metric):
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from quanda.metrics.base import GlobalMetric
from quanda.metrics import GlobalMetric


class MislabelingDetectionMetric(GlobalMetric):
Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from quanda.metrics.base import Metric
from quanda.metrics import Metric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import CorrelationFnLiterals, correlation_functions

Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import lightning as L
import torch

from quanda.metrics.base import GlobalMetric
from quanda.metrics import GlobalMetric
from quanda.utils.common import class_accuracy
from quanda.utils.training import BaseTrainer

Expand Down
2 changes: 1 addition & 1 deletion quanda/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from quanda.metrics.base import Metric
from quanda.metrics import Metric


class TopKOverlapMetric(Metric):
Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from quanda.toy_benchmarks import localization, randomization, unnamed
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import localization, randomization, unnamed

__all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"]
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/localization/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm

from quanda.metrics.localization import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark


class ClassDetection(ToyBenchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from quanda.metrics.localization.mislabeling_detection import (
MislabelingDetectionMetric,
)
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.datasets.transformed.label_flipping import (
LabelFlippingDataset,
)
Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/localization/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.datasets.transformed.label_grouping import (
ClassToGroupLiterals,
LabelGroupingDataset,
Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from quanda.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.functions import CorrelationFnLiterals


Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark
from quanda.utils.training.trainer import BaseTrainer


Expand Down
2 changes: 1 addition & 1 deletion quanda/toy_benchmarks/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tqdm import tqdm

from quanda.metrics.unnamed import TopKOverlapMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.toy_benchmarks import ToyBenchmark


class TopKOverlap(ToyBenchmark):
Expand Down

0 comments on commit 67c09bd

Please sign in to comment.