diff --git a/.coveragerc b/.coveragerc index a3f969d1..67e57da7 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,13 +1,13 @@ [run] -source = src +source = quanda omit = /tests/* - src/utils/explanations.py - src/utils/transforms.py - src/utils/datasets/transformed/sample.py - src/utils/cache.py - src/utils/datasets/activation_dataset.py - src/utils/datasets/indexed_subset.py - src/explainers/functional.py + quanda/utils/explanations.py + quanda/utils/transforms.py + quanda/utils/datasets/transformed/sample.py + quanda/utils/cache.py + quanda/utils/datasets/activation_dataset.py + quanda/utils/datasets/indexed_subset.py + quanda/explainers/functional.py [report] ignore_errors = True diff --git a/Makefile b/Makefile index 92349780..2b5381d5 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ style: black . python -m flake8 . --pytest-parametrize-names-type=csv python -m isort . - python -m mypy src --check-untyped-defs + python -m mypy quanda --check-untyped-defs rm -f .coverage rm -f .coverage.* find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf diff --git a/README.md b/README.md index 1da1d2eb..ff604fad 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,10 @@ Excerpts from `tutorials/usage_testing.py`: Step 1. Import library components ```python -from src.explainers.wrappers.captum_influence import captum_similarity_explain, CaptumSimilarity -from src.metrics.localization.class_detection import ClassDetectionMetric -from src.metrics.randomization.model_randomization import ( - ModelRandomizationMetric, -) -from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric +from quanda.explainers.wrappers import captum_similarity_explain, CaptumSimilarity +from quanda.metrics.localization import ClassDetectionMetric +from quanda.metrics.randomization import ModelRandomizationMetric +from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric ``` diff --git a/pyproject.toml b/pyproject.toml index 21d75252..445834f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,9 @@ exclude = ''' testpaths = ["tests"] python_files = "test_*.py" +[tool.setuptools] +py-modules = [] + [project.optional-dependencies] dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh "coverage>=7.2.3", diff --git a/quanda/__init__.py b/quanda/__init__.py new file mode 100644 index 00000000..6e54f639 --- /dev/null +++ b/quanda/__init__.py @@ -0,0 +1,3 @@ +from quanda import explainers, metrics, toy_benchmarks, utils + +__all__ = ["explainers", "metrics", "toy_benchmarks", "utils"] diff --git a/quanda/explainers/__init__.py b/quanda/explainers/__init__.py new file mode 100644 index 00000000..c682a366 --- /dev/null +++ b/quanda/explainers/__init__.py @@ -0,0 +1,23 @@ +from quanda.explainers import utils, wrappers +from quanda.explainers.aggregators import ( + AbsSumAggregator, + BaseAggregator, + SumAggregator, + aggr_types, +) +from quanda.explainers.base import BaseExplainer +from quanda.explainers.functional import ExplainFunc, ExplainFuncMini +from quanda.explainers.random import RandomExplainer + +__all__ = [ + "BaseExplainer", + "RandomExplainer", + "ExplainFunc", + "ExplainFuncMini", + "utils", + "wrappers", + "BaseAggregator", + "SumAggregator", + "AbsSumAggregator", + "aggr_types", +] diff --git a/src/explainers/aggregators.py b/quanda/explainers/aggregators.py similarity index 100% rename from src/explainers/aggregators.py rename to quanda/explainers/aggregators.py diff --git a/src/explainers/base.py b/quanda/explainers/base.py similarity index 95% rename from src/explainers/base.py rename to quanda/explainers/base.py index 197986af..ddc3c888 100644 --- a/src/explainers/base.py +++ b/quanda/explainers/base.py @@ -3,8 +3,8 @@ import torch -from src.utils.common import cache_result -from src.utils.validation import validate_1d_tensor_or_int_list +from quanda.utils.common import cache_result +from quanda.utils.validation import validate_1d_tensor_or_int_list class BaseExplainer(ABC): diff --git a/src/explainers/functional.py b/quanda/explainers/functional.py similarity index 100% rename from src/explainers/functional.py rename to quanda/explainers/functional.py diff --git a/src/explainers/random.py b/quanda/explainers/random.py similarity index 94% rename from src/explainers/random.py rename to quanda/explainers/random.py index afa65500..f4470daf 100644 --- a/src/explainers/random.py +++ b/quanda/explainers/random.py @@ -2,8 +2,8 @@ import torch -from src.explainers.base import BaseExplainer -from src.utils.common import cache_result +from quanda.explainers import BaseExplainer +from quanda.utils.common import cache_result class RandomExplainer(BaseExplainer): diff --git a/src/explainers/utils.py b/quanda/explainers/utils.py similarity index 100% rename from src/explainers/utils.py rename to quanda/explainers/utils.py diff --git a/quanda/explainers/wrappers/__init__.py b/quanda/explainers/wrappers/__init__.py new file mode 100644 index 00000000..a69d04f4 --- /dev/null +++ b/quanda/explainers/wrappers/__init__.py @@ -0,0 +1,25 @@ +from quanda.explainers.wrappers.captum_influence import ( + CaptumArnoldi, + CaptumInfluence, + CaptumSimilarity, + CaptumTracInCP, + captum_arnoldi_explain, + captum_arnoldi_self_influence, + captum_similarity_explain, + captum_similarity_self_influence, + captum_tracincp_explain, + captum_tracincp_self_influence, +) + +__all__ = [ + "CaptumInfluence", + "CaptumSimilarity", + "captum_similarity_explain", + "captum_similarity_self_influence", + "CaptumArnoldi", + "captum_arnoldi_explain", + "captum_arnoldi_self_influence", + "CaptumTracInCP", + "captum_tracincp_explain", + "captum_tracincp_self_influence", +] diff --git a/src/explainers/wrappers/captum_influence.py b/quanda/explainers/wrappers/captum_influence.py similarity index 98% rename from src/explainers/wrappers/captum_influence.py rename to quanda/explainers/wrappers/captum_influence.py index 9575f547..e1bdda55 100644 --- a/src/explainers/wrappers/captum_influence.py +++ b/quanda/explainers/wrappers/captum_influence.py @@ -11,14 +11,14 @@ ArnoldiInfluenceFunction, ) -from src.explainers.base import BaseExplainer -from src.explainers.utils import ( +from quanda.explainers.base import BaseExplainer +from quanda.explainers.utils import ( explain_fn_from_explainer, self_influence_fn_from_explainer, ) -from src.utils.common import get_load_state_dict_func -from src.utils.functions.similarities import cosine_similarity -from src.utils.validation import validate_checkpoints_load_func +from quanda.utils.common import get_load_state_dict_func +from quanda.utils.functions import cosine_similarity +from quanda.utils.validation import validate_checkpoints_load_func class CaptumInfluence(BaseExplainer, ABC): diff --git a/quanda/metrics/__init__.py b/quanda/metrics/__init__.py new file mode 100644 index 00000000..40b3aa8e --- /dev/null +++ b/quanda/metrics/__init__.py @@ -0,0 +1,16 @@ +from quanda.metrics import localization, randomization, unnamed +from quanda.metrics.aggr_strategies import ( + GlobalAggrStrategy, + GlobalSelfInfluenceStrategy, +) +from quanda.metrics.base import GlobalMetric, Metric + +__all__ = [ + "Metric", + "GlobalMetric", + "GlobalAggrStrategy", + "GlobalSelfInfluenceStrategy", + "randomization", + "localization", + "unnamed", +] diff --git a/src/metrics/aggr_strategies.py b/quanda/metrics/aggr_strategies.py similarity index 95% rename from src/metrics/aggr_strategies.py rename to quanda/metrics/aggr_strategies.py index ba12074a..c82ff115 100644 --- a/src/metrics/aggr_strategies.py +++ b/quanda/metrics/aggr_strategies.py @@ -4,8 +4,7 @@ import torch -from src.explainers.aggregators import BaseAggregator -from src.explainers.base import BaseExplainer +from quanda.explainers import BaseAggregator, BaseExplainer class GlobalSelfInfluenceStrategy: diff --git a/src/metrics/base.py b/quanda/metrics/base.py similarity index 98% rename from src/metrics/base.py rename to quanda/metrics/base.py index 6f4770c3..79dfc8d0 100644 --- a/src/metrics/base.py +++ b/quanda/metrics/base.py @@ -3,8 +3,8 @@ import torch -from src.explainers.aggregators import aggr_types -from src.metrics.aggr_strategies import ( +from quanda.explainers import aggr_types +from quanda.metrics.aggr_strategies import ( GlobalAggrStrategy, GlobalSelfInfluenceStrategy, ) diff --git a/quanda/metrics/localization/__init__.py b/quanda/metrics/localization/__init__.py new file mode 100644 index 00000000..0cc6dbe4 --- /dev/null +++ b/quanda/metrics/localization/__init__.py @@ -0,0 +1,9 @@ +from quanda.metrics.localization.class_detection import ClassDetectionMetric +from quanda.metrics.localization.mislabeling_detection import ( + MislabelingDetectionMetric, +) +from quanda.metrics.localization.subclass_detection import ( + SubclassDetectionMetric, +) + +__all__ = ["ClassDetectionMetric", "SubclassDetectionMetric", "MislabelingDetectionMetric"] diff --git a/src/metrics/localization/class_detection.py b/quanda/metrics/localization/class_detection.py similarity index 97% rename from src/metrics/localization/class_detection.py rename to quanda/metrics/localization/class_detection.py index e2a6206b..828b6778 100644 --- a/src/metrics/localization/class_detection.py +++ b/quanda/metrics/localization/class_detection.py @@ -2,7 +2,7 @@ import torch -from src.metrics.base import Metric +from quanda.metrics.base import Metric class ClassDetectionMetric(Metric): diff --git a/src/metrics/localization/mislabeling_detection.py b/quanda/metrics/localization/mislabeling_detection.py similarity index 98% rename from src/metrics/localization/mislabeling_detection.py rename to quanda/metrics/localization/mislabeling_detection.py index d6533179..f8055692 100644 --- a/src/metrics/localization/mislabeling_detection.py +++ b/quanda/metrics/localization/mislabeling_detection.py @@ -2,7 +2,7 @@ import torch -from src.metrics.base import GlobalMetric +from quanda.metrics.base import GlobalMetric class MislabelingDetectionMetric(GlobalMetric): diff --git a/src/metrics/localization/subclass_detection.py b/quanda/metrics/localization/subclass_detection.py similarity index 94% rename from src/metrics/localization/subclass_detection.py rename to quanda/metrics/localization/subclass_detection.py index 2c5cd768..bc79c84b 100644 --- a/src/metrics/localization/subclass_detection.py +++ b/quanda/metrics/localization/subclass_detection.py @@ -1,6 +1,6 @@ import torch -from src.metrics.localization.class_detection import ClassDetectionMetric +from quanda.metrics.localization import ClassDetectionMetric class SubclassDetectionMetric(ClassDetectionMetric): diff --git a/quanda/metrics/randomization/__init__.py b/quanda/metrics/randomization/__init__.py new file mode 100644 index 00000000..03486024 --- /dev/null +++ b/quanda/metrics/randomization/__init__.py @@ -0,0 +1,5 @@ +from quanda.metrics.randomization.model_randomization import ( + ModelRandomizationMetric, +) + +__all__ = ["ModelRandomizationMetric"] diff --git a/src/metrics/randomization/model_randomization.py b/quanda/metrics/randomization/model_randomization.py similarity index 95% rename from src/metrics/randomization/model_randomization.py rename to quanda/metrics/randomization/model_randomization.py index 69177c85..2a584d83 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/quanda/metrics/randomization/model_randomization.py @@ -3,12 +3,9 @@ import torch -from src.metrics.base import Metric -from src.utils.common import get_parent_module_from_name -from src.utils.functions.correlations import ( - CorrelationFnLiterals, - correlation_functions, -) +from quanda.metrics.base import Metric +from quanda.utils.common import get_parent_module_from_name +from quanda.utils.functions import CorrelationFnLiterals, correlation_functions class ModelRandomizationMetric(Metric): diff --git a/quanda/metrics/unnamed/__init__.py b/quanda/metrics/unnamed/__init__.py new file mode 100644 index 00000000..c4abf4a2 --- /dev/null +++ b/quanda/metrics/unnamed/__init__.py @@ -0,0 +1,4 @@ +from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric +from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric + +__all__ = ["DatasetCleaningMetric", "TopKOverlapMetric"] diff --git a/src/metrics/unnamed/dataset_cleaning.py b/quanda/metrics/unnamed/dataset_cleaning.py similarity index 97% rename from src/metrics/unnamed/dataset_cleaning.py rename to quanda/metrics/unnamed/dataset_cleaning.py index d6d34e6b..dfd10f5c 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/quanda/metrics/unnamed/dataset_cleaning.py @@ -4,9 +4,9 @@ import lightning as L import torch -from src.metrics.base import GlobalMetric -from src.utils.common import class_accuracy -from src.utils.training.trainer import BaseTrainer +from quanda.metrics.base import GlobalMetric +from quanda.utils.common import class_accuracy +from quanda.utils.training import BaseTrainer class DatasetCleaningMetric(GlobalMetric): diff --git a/src/metrics/unnamed/top_k_overlap.py b/quanda/metrics/unnamed/top_k_overlap.py similarity index 96% rename from src/metrics/unnamed/top_k_overlap.py rename to quanda/metrics/unnamed/top_k_overlap.py index 99b534f6..e9ef3e6f 100644 --- a/src/metrics/unnamed/top_k_overlap.py +++ b/quanda/metrics/unnamed/top_k_overlap.py @@ -1,6 +1,6 @@ import torch -from src.metrics.base import Metric +from quanda.metrics.base import Metric class TopKOverlapMetric(Metric): diff --git a/quanda/toy_benchmarks/__init__.py b/quanda/toy_benchmarks/__init__.py new file mode 100644 index 00000000..7dc47e53 --- /dev/null +++ b/quanda/toy_benchmarks/__init__.py @@ -0,0 +1,4 @@ +from quanda.toy_benchmarks import localization, randomization, unnamed +from quanda.toy_benchmarks.base import ToyBenchmark + +__all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"] diff --git a/src/toy_benchmarks/base.py b/quanda/toy_benchmarks/base.py similarity index 100% rename from src/toy_benchmarks/base.py rename to quanda/toy_benchmarks/base.py diff --git a/quanda/toy_benchmarks/localization/__init__.py b/quanda/toy_benchmarks/localization/__init__.py new file mode 100644 index 00000000..95fd7e71 --- /dev/null +++ b/quanda/toy_benchmarks/localization/__init__.py @@ -0,0 +1,9 @@ +from quanda.toy_benchmarks.localization.class_detection import ClassDetection +from quanda.toy_benchmarks.localization.mislabeling_detection import ( + MislabelingDetection, +) +from quanda.toy_benchmarks.localization.subclass_detection import ( + SubclassDetection, +) + +__all__ = ["ClassDetection", "SubclassDetection", "MislabelingDetection"] diff --git a/src/toy_benchmarks/localization/class_detection.py b/quanda/toy_benchmarks/localization/class_detection.py similarity index 96% rename from src/toy_benchmarks/localization/class_detection.py rename to quanda/toy_benchmarks/localization/class_detection.py index aa8df7a1..93dc3f0f 100644 --- a/src/toy_benchmarks/localization/class_detection.py +++ b/quanda/toy_benchmarks/localization/class_detection.py @@ -3,8 +3,8 @@ import torch from tqdm import tqdm -from src.metrics.localization.class_detection import ClassDetectionMetric -from src.toy_benchmarks.base import ToyBenchmark +from quanda.metrics.localization import ClassDetectionMetric +from quanda.toy_benchmarks.base import ToyBenchmark class ClassDetection(ToyBenchmark): diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/quanda/toy_benchmarks/localization/mislabeling_detection.py similarity index 94% rename from src/toy_benchmarks/localization/mislabeling_detection.py rename to quanda/toy_benchmarks/localization/mislabeling_detection.py index d2a16127..0f5313e3 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/quanda/toy_benchmarks/localization/mislabeling_detection.py @@ -5,12 +5,14 @@ import torch from tqdm import tqdm -from src.metrics.localization.mislabeling_detection import ( +from quanda.metrics.localization.mislabeling_detection import ( MislabelingDetectionMetric, ) -from src.toy_benchmarks.base import ToyBenchmark -from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset -from src.utils.training.trainer import BaseTrainer +from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.utils.datasets.transformed.label_flipping import ( + LabelFlippingDataset, +) +from quanda.utils.training.trainer import BaseTrainer class MislabelingDetection(ToyBenchmark): @@ -256,16 +258,16 @@ def evaluate( pbar = tqdm(expl_dl) n_batches = len(expl_dl) - for i, (input, labels) in enumerate(pbar): + for i, (inputs, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - input, labels = input.to(device), labels.to(device) + inputs, labels = inputs.to(device), labels.to(device) if use_predictions: with torch.no_grad(): - targets = self.model(input).argmax(dim=-1) + targets = self.model(inputs).argmax(dim=-1) else: targets = labels - explanations = explainer.explain(test=input, targets=targets) + explanations = explainer.explain(test=inputs, targets=targets) metric.update(explanations) else: metric = MislabelingDetectionMetric.self_influence_based( diff --git a/src/toy_benchmarks/localization/subclass_detection.py b/quanda/toy_benchmarks/localization/subclass_detection.py similarity index 95% rename from src/toy_benchmarks/localization/subclass_detection.py rename to quanda/toy_benchmarks/localization/subclass_detection.py index ffbe88f4..f3dbe4f7 100644 --- a/src/toy_benchmarks/localization/subclass_detection.py +++ b/quanda/toy_benchmarks/localization/subclass_detection.py @@ -5,13 +5,13 @@ import torch from tqdm import tqdm -from src.metrics.localization.class_detection import ClassDetectionMetric -from src.toy_benchmarks.base import ToyBenchmark -from src.utils.datasets.transformed.label_grouping import ( +from quanda.metrics.localization.class_detection import ClassDetectionMetric +from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.utils.datasets.transformed.label_grouping import ( ClassToGroupLiterals, LabelGroupingDataset, ) -from src.utils.training.trainer import BaseTrainer +from quanda.utils.training.trainer import BaseTrainer class SubclassDetection(ToyBenchmark): @@ -231,19 +231,19 @@ def evaluate( pbar = tqdm(expl_dl) n_batches = len(expl_dl) - for i, (input, labels) in enumerate(pbar): + for i, (inputs, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - input, labels = input.to(device), labels.to(device) + inputs, labels = inputs.to(device), labels.to(device) grouped_labels = torch.tensor([self.class_to_group[i.item()] for i in labels], device=labels.device) if use_predictions: with torch.no_grad(): - output = self.group_model(input) + output = self.group_model(inputs) targets = output.argmax(dim=-1) else: targets = grouped_labels explanations = explainer.explain( - test=input, + test=inputs, targets=targets, ) diff --git a/quanda/toy_benchmarks/randomization/__init__.py b/quanda/toy_benchmarks/randomization/__init__.py new file mode 100644 index 00000000..b2c2a079 --- /dev/null +++ b/quanda/toy_benchmarks/randomization/__init__.py @@ -0,0 +1,5 @@ +from quanda.toy_benchmarks.randomization.model_randomization import ( + ModelRandomization, +) + +__all__ = ["ModelRandomization"] diff --git a/src/toy_benchmarks/randomization/model_randomization.py b/quanda/toy_benchmarks/randomization/model_randomization.py similarity index 95% rename from src/toy_benchmarks/randomization/model_randomization.py rename to quanda/toy_benchmarks/randomization/model_randomization.py index a449f3d2..c6d810fb 100644 --- a/src/toy_benchmarks/randomization/model_randomization.py +++ b/quanda/toy_benchmarks/randomization/model_randomization.py @@ -3,11 +3,11 @@ import torch from tqdm import tqdm -from src.metrics.randomization.model_randomization import ( +from quanda.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) -from src.toy_benchmarks.base import ToyBenchmark -from src.utils.functions.correlations import CorrelationFnLiterals +from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.utils.functions import CorrelationFnLiterals class ModelRandomization(ToyBenchmark): diff --git a/quanda/toy_benchmarks/unnamed/__init__.py b/quanda/toy_benchmarks/unnamed/__init__.py new file mode 100644 index 00000000..a4e795dd --- /dev/null +++ b/quanda/toy_benchmarks/unnamed/__init__.py @@ -0,0 +1,4 @@ +from quanda.toy_benchmarks.unnamed.dataset_cleaning import DatasetCleaning +from quanda.toy_benchmarks.unnamed.top_k_overlap import TopKOverlap + +__all__ = ["DatasetCleaning", "TopKOverlap"] diff --git a/src/toy_benchmarks/unnamed/dataset_cleaning.py b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py similarity index 92% rename from src/toy_benchmarks/unnamed/dataset_cleaning.py rename to quanda/toy_benchmarks/unnamed/dataset_cleaning.py index c89d9484..99a17364 100644 --- a/src/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/quanda/toy_benchmarks/unnamed/dataset_cleaning.py @@ -5,9 +5,9 @@ import torch from tqdm import tqdm -from src.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from src.toy_benchmarks.base import ToyBenchmark -from src.utils.training.trainer import BaseTrainer +from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric +from quanda.toy_benchmarks.base import ToyBenchmark +from quanda.utils.training.trainer import BaseTrainer class DatasetCleaning(ToyBenchmark): @@ -123,20 +123,20 @@ def evaluate( pbar = tqdm(expl_dl) n_batches = len(expl_dl) - for i, (input, labels) in enumerate(pbar): + for i, (inputs, labels) in enumerate(pbar): pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches)) - input, labels = input.to(device), labels.to(device) + inputs, labels = inputs.to(device), labels.to(device) if use_predictions: with torch.no_grad(): - output = self.model(input) + output = self.model(inputs) targets = output.argmax(dim=-1) else: targets = labels explanations = explainer.explain( - test=input, + test=inputs, targets=targets, ) metric.update(explanations) diff --git a/src/toy_benchmarks/unnamed/top_k_overlap.py b/quanda/toy_benchmarks/unnamed/top_k_overlap.py similarity index 96% rename from src/toy_benchmarks/unnamed/top_k_overlap.py rename to quanda/toy_benchmarks/unnamed/top_k_overlap.py index 28fc6fe4..4ac0f83f 100644 --- a/src/toy_benchmarks/unnamed/top_k_overlap.py +++ b/quanda/toy_benchmarks/unnamed/top_k_overlap.py @@ -3,8 +3,8 @@ import torch from tqdm import tqdm -from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric -from src.toy_benchmarks.base import ToyBenchmark +from quanda.metrics.unnamed import TopKOverlapMetric +from quanda.toy_benchmarks.base import ToyBenchmark class TopKOverlap(ToyBenchmark): diff --git a/quanda/utils/__init__.py b/quanda/utils/__init__.py new file mode 100644 index 00000000..b2db06cf --- /dev/null +++ b/quanda/utils/__init__.py @@ -0,0 +1,3 @@ +from quanda.utils import common, datasets, functions, training, validation + +__all__ = ["common", "validation", "datasets", "functions", "training"] diff --git a/src/utils/cache.py b/quanda/utils/cache.py similarity index 97% rename from src/utils/cache.py rename to quanda/utils/cache.py index 3054c8cf..188637cf 100644 --- a/src/utils/cache.py +++ b/quanda/utils/cache.py @@ -8,9 +8,9 @@ from torch import Tensor from torch.utils.data import DataLoader -from src.utils.common import _get_module_from_name -from src.utils.datasets.activation_dataset import ActivationDataset -from src.utils.explanations import BatchedCachedExplanations +from quanda.utils.common import _get_module_from_name +from quanda.utils.datasets import ActivationDataset +from quanda.utils.explanations import BatchedCachedExplanations class Cache: diff --git a/src/utils/common.py b/quanda/utils/common.py similarity index 100% rename from src/utils/common.py rename to quanda/utils/common.py diff --git a/quanda/utils/datasets/__init__.py b/quanda/utils/datasets/__init__.py new file mode 100644 index 00000000..3c563cde --- /dev/null +++ b/quanda/utils/datasets/__init__.py @@ -0,0 +1,4 @@ +from quanda.utils.datasets.activation_dataset import ActivationDataset +from quanda.utils.datasets.indexed_subset import IndexedSubset + +__all__ = ["ActivationDataset", "IndexedSubset"] diff --git a/src/utils/datasets/activation_dataset.py b/quanda/utils/datasets/activation_dataset.py similarity index 100% rename from src/utils/datasets/activation_dataset.py rename to quanda/utils/datasets/activation_dataset.py diff --git a/src/utils/datasets/indexed_subset.py b/quanda/utils/datasets/indexed_subset.py similarity index 100% rename from src/utils/datasets/indexed_subset.py rename to quanda/utils/datasets/indexed_subset.py diff --git a/quanda/utils/datasets/transformed/__init__.py b/quanda/utils/datasets/transformed/__init__.py new file mode 100644 index 00000000..3fb4d3e9 --- /dev/null +++ b/quanda/utils/datasets/transformed/__init__.py @@ -0,0 +1,19 @@ +from quanda.utils.datasets.transformed.base import TransformedDataset +from quanda.utils.datasets.transformed.label_flipping import ( + LabelFlippingDataset, +) +from quanda.utils.datasets.transformed.label_grouping import ( + ClassToGroupLiterals, + LabelGroupingDataset, +) +from quanda.utils.datasets.transformed.sample import ( + SampleTransformationDataset, +) + +__all__ = [ + "TransformedDataset", + "SampleTransformationDataset", + "LabelFlippingDataset", + "LabelGroupingDataset", + "ClassToGroupLiterals", +] diff --git a/src/utils/datasets/transformed/base.py b/quanda/utils/datasets/transformed/base.py similarity index 100% rename from src/utils/datasets/transformed/base.py rename to quanda/utils/datasets/transformed/base.py diff --git a/src/utils/datasets/transformed/label_flipping.py b/quanda/utils/datasets/transformed/label_flipping.py similarity index 96% rename from src/utils/datasets/transformed/label_flipping.py rename to quanda/utils/datasets/transformed/label_flipping.py index 1a68c202..c612d60d 100644 --- a/src/utils/datasets/transformed/label_flipping.py +++ b/quanda/utils/datasets/transformed/label_flipping.py @@ -2,7 +2,7 @@ import torch -from src.utils.datasets.transformed.base import TransformedDataset +from quanda.utils.datasets.transformed import TransformedDataset class LabelFlippingDataset(TransformedDataset): diff --git a/src/utils/datasets/transformed/label_grouping.py b/quanda/utils/datasets/transformed/label_grouping.py similarity index 97% rename from src/utils/datasets/transformed/label_grouping.py rename to quanda/utils/datasets/transformed/label_grouping.py index 539a2ac6..7a4ad7a2 100644 --- a/src/utils/datasets/transformed/label_grouping.py +++ b/quanda/utils/datasets/transformed/label_grouping.py @@ -3,7 +3,7 @@ import torch -from src.utils.datasets.transformed.base import TransformedDataset +from quanda.utils.datasets.transformed import TransformedDataset ClassToGroupLiterals = Literal["random"] diff --git a/src/utils/datasets/transformed/sample.py b/quanda/utils/datasets/transformed/sample.py similarity index 92% rename from src/utils/datasets/transformed/sample.py rename to quanda/utils/datasets/transformed/sample.py index 33360b3a..c4709657 100644 --- a/src/utils/datasets/transformed/sample.py +++ b/quanda/utils/datasets/transformed/sample.py @@ -2,7 +2,7 @@ import torch -from src.utils.datasets.transformed.base import TransformedDataset +from quanda.utils.datasets.transformed import TransformedDataset ClassToGroupLiterals = Literal["random"] diff --git a/src/utils/explanations.py b/quanda/utils/explanations.py similarity index 100% rename from src/utils/explanations.py rename to quanda/utils/explanations.py diff --git a/quanda/utils/functions/__init__.py b/quanda/utils/functions/__init__.py new file mode 100644 index 00000000..90956339 --- /dev/null +++ b/quanda/utils/functions/__init__.py @@ -0,0 +1,19 @@ +from quanda.utils.functions.correlations import ( + CorrelationFnLiterals, + correlation_functions, + kendall_rank_corr, + spearman_rank_corr, +) +from quanda.utils.functions.similarities import ( + cosine_similarity, + dot_product_similarity, +) + +__all__ = [ + "kendall_rank_corr", + "spearman_rank_corr", + "correlation_functions", + "CorrelationFnLiterals", + "dot_product_similarity", + "cosine_similarity", +] diff --git a/src/utils/functions/correlations.py b/quanda/utils/functions/correlations.py similarity index 100% rename from src/utils/functions/correlations.py rename to quanda/utils/functions/correlations.py diff --git a/src/utils/functions/similarities.py b/quanda/utils/functions/similarities.py similarity index 100% rename from src/utils/functions/similarities.py rename to quanda/utils/functions/similarities.py diff --git a/quanda/utils/training/__init__.py b/quanda/utils/training/__init__.py new file mode 100644 index 00000000..afcf6016 --- /dev/null +++ b/quanda/utils/training/__init__.py @@ -0,0 +1,4 @@ +from quanda.utils.training.base_pl_module import BasicLightningModule +from quanda.utils.training.trainer import BaseTrainer, Trainer + +__all__ = ["BasicLightningModule", "BaseTrainer", "Trainer"] diff --git a/src/utils/training/base_pl_module.py b/quanda/utils/training/base_pl_module.py similarity index 100% rename from src/utils/training/base_pl_module.py rename to quanda/utils/training/base_pl_module.py diff --git a/src/utils/training/trainer.py b/quanda/utils/training/trainer.py similarity index 96% rename from src/utils/training/trainer.py rename to quanda/utils/training/trainer.py index a86a43dc..33149944 100644 --- a/src/utils/training/trainer.py +++ b/quanda/utils/training/trainer.py @@ -5,7 +5,7 @@ import lightning as L import torch -from src.utils.training.base_pl_module import BasicLightningModule +from quanda.utils.training import BasicLightningModule class BaseTrainer(metaclass=abc.ABCMeta): diff --git a/src/utils/transforms.py b/quanda/utils/transforms.py similarity index 100% rename from src/utils/transforms.py rename to quanda/utils/transforms.py diff --git a/src/utils/validation.py b/quanda/utils/validation.py similarity index 100% rename from src/utils/validation.py rename to quanda/utils/validation.py diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/explainers/__init__.py b/src/explainers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/explainers/wrappers/__init__.py b/src/explainers/wrappers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/metrics/localization/__init__.py b/src/metrics/localization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/metrics/randomization/__init__.py b/src/metrics/randomization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/metrics/unnamed/__init__.py b/src/metrics/unnamed/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/toy_benchmarks/__init__.py b/src/toy_benchmarks/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/toy_benchmarks/localization/__init__.py b/src/toy_benchmarks/localization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/toy_benchmarks/randomization/__init__.py b/src/toy_benchmarks/randomization/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/toy_benchmarks/unnamed/__init__.py b/src/toy_benchmarks/unnamed/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/__init__.py b/src/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/datasets/__init__.py b/src/utils/datasets/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/datasets/transformed/__init__.py b/src/utils/datasets/transformed/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/functions/__init__.py b/src/utils/functions/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/utils/training/__init__.py b/src/utils/training/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/assets/mnist_dataset_cleaning_state_dict b/tests/assets/mnist_dataset_cleaning_state_dict new file mode 100644 index 00000000..74b2a298 Binary files /dev/null and b/tests/assets/mnist_dataset_cleaning_state_dict differ diff --git a/tests/conftest.py b/tests/conftest.py index e4e6439c..3c200ccd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,9 +6,13 @@ import torch from torch.utils.data import TensorDataset -from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset -from src.utils.datasets.transformed.label_grouping import LabelGroupingDataset -from src.utils.training.base_pl_module import BasicLightningModule +from quanda.utils.datasets.transformed.label_flipping import ( + LabelFlippingDataset, +) +from quanda.utils.datasets.transformed.label_grouping import ( + LabelGroupingDataset, +) +from quanda.utils.training.base_pl_module import BasicLightningModule from tests.models import LeNet MNIST_IMAGE_SIZE = 28 diff --git a/tests/explainers/test_aggregators.py b/tests/explainers/test_aggregators.py index a83f6432..b9b3dc67 100644 --- a/tests/explainers/test_aggregators.py +++ b/tests/explainers/test_aggregators.py @@ -1,7 +1,7 @@ import pytest import torch -from src.explainers.aggregators import AbsSumAggregator, SumAggregator +from quanda.explainers import AbsSumAggregator, SumAggregator @pytest.mark.aggregators diff --git a/tests/explainers/test_base_explainer.py b/tests/explainers/test_base_explainer.py index 45701285..2fc4cd34 100644 --- a/tests/explainers/test_base_explainer.py +++ b/tests/explainers/test_base_explainer.py @@ -4,8 +4,8 @@ import pytest import torch -from src.explainers.base import BaseExplainer -from src.utils.functions.similarities import cosine_similarity +from quanda.explainers import BaseExplainer +from quanda.utils.functions import cosine_similarity @pytest.mark.explainers diff --git a/tests/explainers/wrappers/test_captum_influence.py b/tests/explainers/wrappers/test_captum_influence.py index e3c8b6a2..79e70f53 100644 --- a/tests/explainers/wrappers/test_captum_influence.py +++ b/tests/explainers/wrappers/test_captum_influence.py @@ -10,7 +10,7 @@ ) from torch.utils.data import TensorDataset -from src.explainers.wrappers.captum_influence import ( +from quanda.explainers.wrappers import ( CaptumArnoldi, CaptumSimilarity, CaptumTracInCP, @@ -21,11 +21,8 @@ captum_tracincp_explain, captum_tracincp_self_influence, ) -from src.utils.common import get_load_state_dict_func -from src.utils.functions.similarities import ( - cosine_similarity, - dot_product_similarity, -) +from quanda.utils.common import get_load_state_dict_func +from quanda.utils.functions import cosine_similarity, dot_product_similarity @pytest.mark.self_influence diff --git a/tests/metrics/test_aggr_strategies.py b/tests/metrics/test_aggr_strategies.py index b7518bf0..2af37012 100644 --- a/tests/metrics/test_aggr_strategies.py +++ b/tests/metrics/test_aggr_strategies.py @@ -1,13 +1,10 @@ import pytest import torch -from src.explainers.aggregators import AbsSumAggregator -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.metrics.aggr_strategies import ( - GlobalAggrStrategy, - GlobalSelfInfluenceStrategy, -) -from src.utils.functions.similarities import cosine_similarity +from quanda.explainers import AbsSumAggregator +from quanda.explainers.wrappers import CaptumSimilarity +from quanda.metrics import GlobalAggrStrategy, GlobalSelfInfluenceStrategy +from quanda.utils.functions import cosine_similarity @pytest.mark.aggr_strategies diff --git a/tests/metrics/test_localization_metrics.py b/tests/metrics/test_localization_metrics.py index 0dd0515c..6e8e24ad 100644 --- a/tests/metrics/test_localization_metrics.py +++ b/tests/metrics/test_localization_metrics.py @@ -1,13 +1,13 @@ import pytest -from src.explainers.aggregators import SumAggregator -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.metrics.localization.class_detection import ClassDetectionMetric -from src.metrics.localization.mislabeling_detection import ( +from quanda.explainers import SumAggregator +from quanda.explainers.wrappers import CaptumSimilarity +from quanda.metrics.localization import ( + ClassDetectionMetric, MislabelingDetectionMetric, + SubclassDetectionMetric, ) -from src.metrics.localization.subclass_detection import SubclassDetectionMetric -from src.utils.functions.similarities import cosine_similarity +from quanda.utils.functions import cosine_similarity @pytest.mark.localization_metrics diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 690a0982..3fe97766 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -1,18 +1,15 @@ import pytest import torch -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.metrics.randomization.model_randomization import ( - ModelRandomizationMetric, -) -from src.utils.functions.correlations import correlation_functions -from src.utils.functions.similarities import cosine_similarity +from quanda.explainers.wrappers import CaptumSimilarity +from quanda.metrics.randomization import ModelRandomizationMetric +from quanda.utils.functions import correlation_functions, cosine_similarity @pytest.mark.randomization_metrics @pytest.mark.parametrize( "test_id, model, dataset, test_data, batch_size, explainer_cls, \ - expl_kwargs, explanations, test_labels, correlation_fn, tmp_path", + expl_kwargs, explanations, test_labels, correlation_fn", [ ( "mnist_update_only_spearman", @@ -28,7 +25,6 @@ "load_mnist_explanations_1", "load_mnist_test_labels_1", "spearman", - "tmp_path", ), ( "mnist_update_only_kendall", @@ -44,7 +40,6 @@ "load_mnist_explanations_1", "load_mnist_test_labels_1", "kendall", - "tmp_path", ), ( "mnist_explain_update", @@ -60,7 +55,6 @@ "load_mnist_explanations_1", "load_mnist_test_labels_1", "spearman", - "tmp_path", ), ], ) diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index c7146514..9e6e8c25 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -1,10 +1,10 @@ import pytest -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric -from src.utils.functions.similarities import cosine_similarity -from src.utils.training.trainer import Trainer +from quanda.explainers.wrappers.captum_influence import CaptumSimilarity +from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric +from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric +from quanda.utils.functions.similarities import cosine_similarity +from quanda.utils.training.trainer import Trainer @pytest.mark.unnamed_metrics diff --git a/tests/toy_benchmarks/localization/test_class_detection.py b/tests/toy_benchmarks/localization/test_class_detection.py index a44868db..7686f7a6 100644 --- a/tests/toy_benchmarks/localization/test_class_detection.py +++ b/tests/toy_benchmarks/localization/test_class_detection.py @@ -1,8 +1,8 @@ import pytest -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.toy_benchmarks.localization.class_detection import ClassDetection -from src.utils.functions.similarities import cosine_similarity +from quanda.explainers.wrappers import CaptumSimilarity +from quanda.toy_benchmarks.localization import ClassDetection +from quanda.utils.functions import cosine_similarity @pytest.mark.toy_benchmarks diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index de9719be..b2417955 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -1,13 +1,13 @@ import lightning as L import pytest -from src.explainers.aggregators import SumAggregator -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.toy_benchmarks.localization.mislabeling_detection import ( +from quanda.explainers.aggregators import SumAggregator +from quanda.explainers.wrappers.captum_influence import CaptumSimilarity +from quanda.toy_benchmarks.localization.mislabeling_detection import ( MislabelingDetection, ) -from src.utils.functions.similarities import cosine_similarity -from src.utils.training.trainer import Trainer +from quanda.utils.functions.similarities import cosine_similarity +from quanda.utils.training.trainer import Trainer @pytest.mark.toy_benchmarks diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index 63e4f1fd..d5a6a7cc 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -1,12 +1,12 @@ import lightning as L import pytest -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.toy_benchmarks.localization.subclass_detection import ( +from quanda.explainers.wrappers.captum_influence import CaptumSimilarity +from quanda.toy_benchmarks.localization.subclass_detection import ( SubclassDetection, ) -from src.utils.functions.similarities import cosine_similarity -from src.utils.training.trainer import Trainer +from quanda.utils.functions.similarities import cosine_similarity +from quanda.utils.training.trainer import Trainer @pytest.mark.toy_benchmarks diff --git a/tests/toy_benchmarks/randomization/test_model_randomization.py b/tests/toy_benchmarks/randomization/test_model_randomization.py index 01ee1959..540c40db 100644 --- a/tests/toy_benchmarks/randomization/test_model_randomization.py +++ b/tests/toy_benchmarks/randomization/test_model_randomization.py @@ -1,10 +1,8 @@ import pytest -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.toy_benchmarks.randomization.model_randomization import ( - ModelRandomization, -) -from src.utils.functions.similarities import cosine_similarity +from quanda.explainers.wrappers import CaptumSimilarity +from quanda.toy_benchmarks.randomization import ModelRandomization +from quanda.utils.functions import cosine_similarity @pytest.mark.toy_benchmarks diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index f3c8d5a1..b7fa1ce4 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -1,10 +1,10 @@ import lightning as L import pytest -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.toy_benchmarks.unnamed.dataset_cleaning import DatasetCleaning -from src.utils.functions.similarities import cosine_similarity -from src.utils.training.trainer import Trainer +from quanda.explainers.wrappers.captum_influence import CaptumSimilarity +from quanda.toy_benchmarks.unnamed.dataset_cleaning import DatasetCleaning +from quanda.utils.functions.similarities import cosine_similarity +from quanda.utils.training.trainer import Trainer @pytest.mark.toy_benchmarks @@ -116,7 +116,6 @@ def test_dataset_cleaning( train_dataset=dataset, device="cpu", ) - dst_eval.save("tests/assets/mnist_dataset_cleaning_state_dict") elif init_method == "load": dst_eval = DatasetCleaning.load(path=load_path) @@ -206,7 +205,6 @@ def test_dataset_cleaning_generate_from_pl_module( train_dataset=dataset, device="cpu", ) - dst_eval.save("tests/assets/mnist_dataset_cleaning_state_dict") score = dst_eval.evaluate( expl_dataset=dataset, diff --git a/tests/toy_benchmarks/unnamed/test_top_k_overlap.py b/tests/toy_benchmarks/unnamed/test_top_k_overlap.py index 1462342a..bd417a1c 100644 --- a/tests/toy_benchmarks/unnamed/test_top_k_overlap.py +++ b/tests/toy_benchmarks/unnamed/test_top_k_overlap.py @@ -1,8 +1,8 @@ import pytest -from src.explainers.wrappers.captum_influence import CaptumSimilarity -from src.toy_benchmarks.unnamed.top_k_overlap import TopKOverlap -from src.utils.functions.similarities import cosine_similarity +from quanda.explainers.wrappers import CaptumSimilarity +from quanda.toy_benchmarks.unnamed import TopKOverlap +from quanda.utils.functions import cosine_similarity @pytest.mark.toy_benchmarks diff --git a/tests/utils/datasets/transformed/test_base.py b/tests/utils/datasets/transformed/test_base.py index 0930c13d..d95124cf 100644 --- a/tests/utils/datasets/transformed/test_base.py +++ b/tests/utils/datasets/transformed/test_base.py @@ -2,7 +2,7 @@ import torch from torch.utils.data import Dataset, TensorDataset -from src.utils.datasets.transformed.base import TransformedDataset +from quanda.utils.datasets.transformed import TransformedDataset class UnsizedTensorDataset(Dataset): diff --git a/tests/utils/datasets/transformed/test_label_flipping.py b/tests/utils/datasets/transformed/test_label_flipping.py index 9a62eb42..89782913 100644 --- a/tests/utils/datasets/transformed/test_label_flipping.py +++ b/tests/utils/datasets/transformed/test_label_flipping.py @@ -1,6 +1,6 @@ import pytest -from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset +from quanda.utils.datasets.transformed import LabelFlippingDataset @pytest.mark.utils diff --git a/tests/utils/datasets/transformed/test_label_grouping.py b/tests/utils/datasets/transformed/test_label_grouping.py index 13ccb874..823b25b0 100644 --- a/tests/utils/datasets/transformed/test_label_grouping.py +++ b/tests/utils/datasets/transformed/test_label_grouping.py @@ -1,6 +1,6 @@ import pytest -from src.utils.datasets.transformed.label_grouping import LabelGroupingDataset +from quanda.utils.datasets.transformed import LabelGroupingDataset @pytest.mark.utils diff --git a/tests/utils/test_common.py b/tests/utils/test_common.py index dfe37a5b..a19863c5 100644 --- a/tests/utils/test_common.py +++ b/tests/utils/test_common.py @@ -1,6 +1,6 @@ import pytest -from src.utils.common import make_func +from quanda.utils.common import make_func @pytest.mark.utils diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index eb72b8f4..eb9dcc7d 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -3,7 +3,7 @@ import pytest import torch -from src.utils.training.trainer import Trainer +from quanda.utils.training.trainer import Trainer @pytest.mark.utils diff --git a/tox.ini b/tox.ini index 4626a96a..9fe29164 100644 --- a/tox.ini +++ b/tox.ini @@ -24,7 +24,7 @@ deps = coverage pytest_cov commands = - python3 -m pytest --cov=src --cov-report=term-missing --cov-fail-under 57 --cov-report html:htmlcov --cov-report xml {posargs} + python3 -m pytest --cov=quanda --cov-report=term-missing --cov-fail-under 57 --cov-report html:htmlcov --cov-report xml {posargs} [testenv:lint] @@ -41,7 +41,7 @@ deps = {[testenv]deps} mypy==1.9.0 commands = - python3 -m mypy src --check-untyped-defs + python3 -m mypy quanda --check-untyped-defs [flake8] max-line-length = 127 diff --git a/tutorials/usage_testing.py b/tutorials/usage_testing.py index 03134f05..47662217 100644 --- a/tutorials/usage_testing.py +++ b/tutorials/usage_testing.py @@ -16,19 +16,19 @@ from torchvision.utils import make_grid from tqdm import tqdm -from src.explainers.wrappers.captum_influence import ( +from quanda.explainers.wrappers.captum_influence import ( CaptumSimilarity, captum_similarity_explain, ) -from src.metrics.localization.class_detection import ClassDetectionMetric -from src.metrics.randomization.model_randomization import ( +from quanda.metrics.localization.class_detection import ClassDetectionMetric +from quanda.metrics.randomization.model_randomization import ( ModelRandomizationMetric, ) -from src.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric -from src.toy_benchmarks.subclass_detection import SubclassDetection -from src.utils.training.base_pl_module import BasicLightningModule -from src.utils.training.trainer import Trainer +from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric +from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric +from quanda.toy_benchmarks.subclass_detection import SubclassDetection +from quanda.utils.training.base_pl_module import BasicLightningModule +from quanda.utils.training.trainer import Trainer DEVICE = "cuda:0" # "cuda" if torch.cuda.is_available() else "cpu" torch.set_float32_matmul_precision("medium")