From 67c09bd5d7a2a3b4589be5b76c54404dea6bdc6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Galip=20=C3=9Cmit=20Yolcu?= Date: Mon, 12 Aug 2024 20:32:21 +0200 Subject: [PATCH] fix imports --- pyproject.toml | 5 + quanda/explainers/__init__.py | 2 +- quanda/explainers/wrappers/__init__.py | 5 + .../explainers/wrappers/captum_influence.py | 2 +- quanda/explainers/wrappers/trak.py | 128 ++++++++++++------ quanda/metrics/__init__.py | 2 +- .../metrics/localization/class_detection.py | 2 +- .../localization/mislabeling_detection.py | 2 +- .../randomization/model_randomization.py | 2 +- quanda/metrics/unnamed/dataset_cleaning.py | 2 +- quanda/metrics/unnamed/top_k_overlap.py | 2 +- quanda/toy_benchmarks/__init__.py | 2 +- .../localization/class_detection.py | 2 +- .../localization/mislabeling_detection.py | 2 +- .../localization/subclass_detection.py | 2 +- .../randomization/model_randomization.py | 2 +- .../unnamed/dataset_cleaning.py | 2 +- .../toy_benchmarks/unnamed/top_k_overlap.py | 2 +- 18 files changed, 111 insertions(+), 57 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 445834f3..bd802d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dynamic = ["version"] [tool.isort] profile = "black" +extend_skip = ["__init__.py"] line_length = 79 multi_line_output = 3 include_trailing_comma = true @@ -29,6 +30,10 @@ warn_unused_configs = true check_untyped_defs = true #ignore_errors = true # TODO: change this +[[tool.mypy.overrides]] +module = ["trak", "trak.projectors", "fast_jl"] +ignore_missing_imports = true + # Black formatting [tool.black] line-length = 127 diff --git a/quanda/explainers/__init__.py b/quanda/explainers/__init__.py index c682a366..cf2439bc 100644 --- a/quanda/explainers/__init__.py +++ b/quanda/explainers/__init__.py @@ -1,3 +1,4 @@ +from quanda.explainers.base import BaseExplainer from quanda.explainers import utils, wrappers from quanda.explainers.aggregators import ( AbsSumAggregator, @@ -5,7 +6,6 @@ SumAggregator, aggr_types, ) -from quanda.explainers.base import BaseExplainer from quanda.explainers.functional import ExplainFunc, ExplainFuncMini from quanda.explainers.random import RandomExplainer diff --git a/quanda/explainers/wrappers/__init__.py b/quanda/explainers/wrappers/__init__.py index a69d04f4..ba508a03 100644 --- a/quanda/explainers/wrappers/__init__.py +++ b/quanda/explainers/wrappers/__init__.py @@ -11,6 +11,10 @@ captum_tracincp_self_influence, ) +from quanda.explainers.wrappers.trak import ( + TRAK, +) + __all__ = [ "CaptumInfluence", "CaptumSimilarity", @@ -22,4 +26,5 @@ "CaptumTracInCP", "captum_tracincp_explain", "captum_tracincp_self_influence", + "TRAK", ] diff --git a/quanda/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py index e1bdda55..7c52d486 100644 --- a/quanda/explainers/wrappers/captum_influence.py +++ b/quanda/explainers/wrappers/captum_influence.py @@ -11,7 +11,7 @@ ArnoldiInfluenceFunction, ) -from quanda.explainers.base import BaseExplainer +from quanda.explainers import BaseExplainer from quanda.explainers.utils import ( explain_fn_from_explainer, self_influence_fn_from_explainer, diff --git a/quanda/explainers/wrappers/trak.py b/quanda/explainers/wrappers/trak.py index e500fc23..18c4243a 100644 --- a/quanda/explainers/wrappers/trak.py +++ b/quanda/explainers/wrappers/trak.py @@ -1,15 +1,26 @@ -from trak import TRAKer -from trak.projectors import BasicProjector, CudaProjector, NoOpProjector -from trak.projectors import ProjectionType +import warnings +from typing import Iterable, Literal, Optional, Sized, Union -from typing import Literal, Optional, Union -import os import torch +from trak import TRAKer +from trak.projectors import ( + BasicProjector, + CudaProjector, + NoOpProjector, + ProjectionType, +) from quanda.explainers import BaseExplainer -TRAKProjectorLiteral=Literal["cuda", "noop", "basic"] -TRAKProjectionTypeLiteral=Literal["rademacher", "normal"] +# from quanda.explainers.utils import ( +# explain_fn_from_explainer, +# self_influence_fn_from_explainer, +# ) + + +TRAKProjectorLiteral = Literal["cuda", "noop", "basic", "check_cuda"] +TRAKProjectionTypeLiteral = Literal["rademacher", "normal"] + class TRAK(BaseExplainer): def __init__( @@ -19,55 +30,88 @@ def __init__( cache_dir: Optional[str], train_dataset: torch.utils.data.Dataset, device: Union[str, torch.device], - projector: TRAKProjectorLiteral="basic", - proj_dim: int=128, - proj_type: TRAKProjectionTypeLiteral="normal", - seed: int=42, - batch_size: int=32, + projector: TRAKProjectorLiteral = "check_cuda", + proj_dim: int = 128, + proj_type: TRAKProjectionTypeLiteral = "normal", + seed: int = 42, + batch_size: int = 32, + params_ldr: Optional[Iterable] = None, ): - super(TRAK, self).__init__(model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device) - self.dataset=train_dataset - self.batch_size=batch_size - proj_type=ProjectionType.normal if proj_type=="normal" else ProjectionType.rademacher - - number_of_params=0 + super(TRAK, self).__init__( + model=model, train_dataset=train_dataset, model_id=model_id, cache_dir=cache_dir, device=device + ) + self.dataset = train_dataset + self.proj_dim = proj_dim + self.batch_size = batch_size + proj_type = ProjectionType.normal if proj_type == "normal" else ProjectionType.rademacher + + num_params_for_grad = 0 for p in list(self.model.sim_parameters()): nn = 1 for s in list(p.size()): nn = nn * s - number_of_params += nn - + num_params_for_grad += nn + + # Check if traker was installer with the ["cuda"] option + try: + import fast_jl + + test_gradient = torch.ones(1, num_params_for_grad).cuda() + num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count + fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) + projector = "cuda" + except (ImportError, RuntimeError, AttributeError) as e: + warnings.warn(f"Could not use CudaProjector.\nReason: {str(e)}") + warnings.warn("Defaulting to BasicProjector.") + projector = "basic" + projector_cls = { "cuda": CudaProjector, "basic": BasicProjector, - "noop": NoOpProjector + "noop": NoOpProjector, } - - projector_kwargs={ - "grad_dim": number_of_params, + + projector_kwargs = { + "grad_dim": num_params_for_grad, "proj_dim": proj_dim, "proj_type": proj_type, "seed": seed, - "device": device + "device": device, } - projector=projector_cls[projector](**projector_kwargs) - self.traker = TRAKer(model=model, task='image_classification', train_set_size=len(train_dataset), - projector=projector, proj_dim=proj_dim, projector_seed=seed, save_dir=cache_dir) + projector = projector_cls[projector](**projector_kwargs) + self.traker = TRAKer( + model=model, + task="image_classification", + train_set_size=self.dataset_length, + projector=projector, + proj_dim=proj_dim, + projector_seed=seed, + save_dir=cache_dir, + ) - #Train the TRAK explainer: featurize the training data - ld=torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) - self.traker.load_checkpoint(self.model.state_dict(),model_id=0) - for (i,(x,y)) in enumerate(iter(ld)): - batch=x.to(self.device), y.to(self.device) - self.traker.featurize(batch=batch,inds=torch.tensor([i*self.batch_size+j for j in range(self.batch_size)])) + # Train the TRAK explainer: featurize the training data + ld = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size) + self.traker.load_checkpoint(self.model.state_dict(), model_id=0) + for i, (x, y) in enumerate(iter(ld)): + batch = x.to(self.device), y.to(self.device) + self.traker.featurize(batch=batch, inds=torch.tensor([i * self.batch_size + j for j in range(self.batch_size)])) self.traker.finalize_features() - def explain(self, x, targets): - x=x.to(self.device) - self.traker.start_scoring_checkpoint(model_id=0, - checkpoint=self.model.state_dict(), - exp_name='test', - num_targets=x.shape[0]) - self.traker.score(batch=(x,targets), num_samples=x.shape[0]) - return torch.from_numpy(self.traker.finalize_scores(exp_name='test')).T.to(self.device) + @property + def dataset_length(self) -> int: + """ + By default, the Dataset class does not always have a __len__ method. + :return: + """ + if isinstance(self.dataset, Sized): + return len(self.dataset) + dl = torch.utils.data.DataLoader(self.dataset, batch_size=1) + return len(dl) + def explain(self, x, targets): + x = x.to(self.device) + self.traker.start_scoring_checkpoint( + model_id=0, checkpoint=self.model.state_dict(), exp_name="test", num_targets=x.shape[0] + ) + self.traker.score(batch=(x, targets), num_samples=x.shape[0]) + return torch.from_numpy(self.traker.finalize_scores(exp_name="test")).T.to(self.device) diff --git a/quanda/metrics/__init__.py b/quanda/metrics/__init__.py index 40b3aa8e..8e5a4005 100644 --- a/quanda/metrics/__init__.py +++ b/quanda/metrics/__init__.py @@ -1,9 +1,9 @@ +from quanda.metrics.base import GlobalMetric, Metric from quanda.metrics import localization, randomization, unnamed from quanda.metrics.aggr_strategies import ( GlobalAggrStrategy, GlobalSelfInfluenceStrategy, ) -from quanda.metrics.base import GlobalMetric, Metric __all__ = [ "Metric", diff --git a/quanda/metrics/localization/class_detection.py b/quanda/metrics/localization/class_detection.py index 828b6778..a3234e80 100644 --- a/quanda/metrics/localization/class_detection.py +++ b/quanda/metrics/localization/class_detection.py @@ -2,7 +2,7 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric class ClassDetectionMetric(Metric): diff --git a/quanda/metrics/localization/mislabeling_detection.py b/quanda/metrics/localization/mislabeling_detection.py index f8055692..9266ed73 100644 --- a/quanda/metrics/localization/mislabeling_detection.py +++ b/quanda/metrics/localization/mislabeling_detection.py @@ -2,7 +2,7 @@ import torch -from quanda.metrics.base import GlobalMetric +from quanda.metrics import GlobalMetric class MislabelingDetectionMetric(GlobalMetric): diff --git a/quanda/metrics/randomization/model_randomization.py b/quanda/metrics/randomization/model_randomization.py index 2a584d83..97e5616f 100644 --- a/quanda/metrics/randomization/model_randomization.py +++ b/quanda/metrics/randomization/model_randomization.py @@ -3,7 +3,7 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric from quanda.utils.common import get_parent_module_from_name from quanda.utils.functions import CorrelationFnLiterals, correlation_functions diff --git a/quanda/metrics/unnamed/dataset_cleaning.py b/quanda/metrics/unnamed/dataset_cleaning.py index dfd10f5c..d1d95748 100644 --- a/quanda/metrics/unnamed/dataset_cleaning.py +++ b/quanda/metrics/unnamed/dataset_cleaning.py @@ -4,7 +4,7 @@ import lightning as L import torch -from quanda.metrics.base import GlobalMetric +from quanda.metrics import GlobalMetric from quanda.utils.common import class_accuracy from quanda.utils.training import BaseTrainer diff --git a/quanda/metrics/unnamed/top_k_overlap.py b/quanda/metrics/unnamed/top_k_overlap.py index e9ef3e6f..b86353dc 100644 --- a/quanda/metrics/unnamed/top_k_overlap.py +++ b/quanda/metrics/unnamed/top_k_overlap.py @@ -1,6 +1,6 @@ import torch -from quanda.metrics.base import Metric +from quanda.metrics import Metric class TopKOverlapMetric(Metric): diff --git a/quanda/toy_benchmarks/__init__.py b/quanda/toy_benchmarks/__init__.py index 7dc47e53..b275c546 100644 --- a/quanda/toy_benchmarks/__init__.py +++ b/quanda/toy_benchmarks/__init__.py @@ -1,4 +1,4 @@ -from quanda.toy_benchmarks import localization, randomization, unnamed from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import localization, randomization, unnamed __all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"] diff --git a/quanda/toy_benchmarks/localization/class_detection.py b/quanda/toy_benchmarks/localization/class_detection.py index 93dc3f0f..3e2d843d 100644 --- a/quanda/toy_benchmarks/localization/class_detection.py +++ b/quanda/toy_benchmarks/localization/class_detection.py @@ -4,7 +4,7 @@ from tqdm import tqdm from quanda.metrics.localization import ClassDetectionMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark class ClassDetection(ToyBenchmark): diff --git a/quanda/toy_benchmarks/localization/mislabeling_detection.py b/quanda/toy_benchmarks/localization/mislabeling_detection.py index 0f5313e3..e31f8638 100644 --- a/quanda/toy_benchmarks/localization/mislabeling_detection.py +++ b/quanda/toy_benchmarks/localization/mislabeling_detection.py @@ -8,7 +8,7 @@ from quanda.metrics.localization.mislabeling_detection import ( MislabelingDetectionMetric, ) -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.datasets.transformed.label_flipping import ( LabelFlippingDataset, ) diff --git a/quanda/toy_benchmarks/localization/subclass_detection.py b/quanda/toy_benchmarks/localization/subclass_detection.py index f3dbe4f7..a2ced0b9 100644 --- a/quanda/toy_benchmarks/localization/subclass_detection.py +++ b/quanda/toy_benchmarks/localization/subclass_detection.py @@ -6,7 +6,7 @@ from tqdm import tqdm from quanda.metrics.localization.class_detection import ClassDetectionMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.datasets.transformed.label_grouping import ( ClassToGroupLiterals, LabelGroupingDataset, diff --git a/quanda/toy_benchmarks/randomization/model_randomization.py b/quanda/toy_benchmarks/randomization/model_randomization.py index c6d810fb..589f224b 100644 --- a/quanda/toy_benchmarks/randomization/model_randomization.py +++ b/quanda/toy_benchmarks/randomization/model_randomization.py @@ -6,7 +6,7 @@ from quanda.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.functions import CorrelationFnLiterals diff --git a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py index 99a17364..999a51f5 100644 --- a/quanda/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py @@ -6,7 +6,7 @@ from tqdm import tqdm from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark from quanda.utils.training.trainer import BaseTrainer diff --git a/quanda/toy_benchmarks/unnamed/top_k_overlap.py b/quanda/toy_benchmarks/unnamed/top_k_overlap.py index 4ac0f83f..bbd86a29 100644 --- a/quanda/toy_benchmarks/unnamed/top_k_overlap.py +++ b/quanda/toy_benchmarks/unnamed/top_k_overlap.py @@ -4,7 +4,7 @@ from tqdm import tqdm from quanda.metrics.unnamed import TopKOverlapMetric -from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.toy_benchmarks import ToyBenchmark class TopKOverlap(ToyBenchmark):