Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import strategy #100

Merged
merged 22 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[run]
source = src
source = quanda
omit =
/tests/*
src/utils/explanations.py
src/utils/transforms.py
src/utils/datasets/transformed/sample.py
src/utils/cache.py
src/utils/datasets/activation_dataset.py
src/utils/datasets/indexed_subset.py
src/explainers/functional.py
quanda/utils/explanations.py
quanda/utils/transforms.py
quanda/utils/datasets/transformed/sample.py
quanda/utils/cache.py
quanda/utils/datasets/activation_dataset.py
quanda/utils/datasets/indexed_subset.py
quanda/explainers/functional.py
[report]
ignore_errors = True
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ style:
black .
python -m flake8 . --pytest-parametrize-names-type=csv
python -m isort .
python -m mypy src --check-untyped-defs
python -m mypy quanda --check-untyped-defs
rm -f .coverage
rm -f .coverage.*
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
Expand Down
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ Excerpts from `tutorials/usage_testing.py`:
<summary><b><big>Step 1. Import library components</big></b></summary>

```python
from src.explainers.wrappers.captum_influence import captum_similarity_explain, CaptumSimilarity
from src.metrics.localization.class_detection import ClassDetectionMetric
from src.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)
from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric
from quanda.explainers.wrappers import captum_similarity_explain, CaptumSimilarity
from quanda.metrics.localization import ClassDetectionMetric
from quanda.metrics.randomization import ModelRandomizationMetric
from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric
```
</details>

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ exclude = '''
testpaths = ["tests"]
python_files = "test_*.py"

[tool.setuptools]
py-modules = []

[project.optional-dependencies]
dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh
"coverage>=7.2.3",
Expand Down
3 changes: 3 additions & 0 deletions quanda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from quanda import explainers, metrics, toy_benchmarks, utils

__all__ = ["explainers", "metrics", "toy_benchmarks", "utils"]
23 changes: 23 additions & 0 deletions quanda/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from quanda.explainers import utils, wrappers
from quanda.explainers.aggregators import (
AbsSumAggregator,
BaseAggregator,
SumAggregator,
aggr_types,
)
from quanda.explainers.base import BaseExplainer
from quanda.explainers.functional import ExplainFunc, ExplainFuncMini
from quanda.explainers.random import RandomExplainer

__all__ = [
"BaseExplainer",
"RandomExplainer",
"ExplainFunc",
"ExplainFuncMini",
"utils",
"wrappers",
"BaseAggregator",
"SumAggregator",
"AbsSumAggregator",
"aggr_types",
]
File renamed without changes.
4 changes: 2 additions & 2 deletions src/explainers/base.py → quanda/explainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from src.utils.common import cache_result
from src.utils.validation import validate_1d_tensor_or_int_list
from quanda.utils.common import cache_result
from quanda.utils.validation import validate_1d_tensor_or_int_list


class BaseExplainer(ABC):
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions src/explainers/random.py → quanda/explainers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch

from src.explainers.base import BaseExplainer
from src.utils.common import cache_result
from quanda.explainers import BaseExplainer
from quanda.utils.common import cache_result


class RandomExplainer(BaseExplainer):
Expand Down
File renamed without changes.
25 changes: 25 additions & 0 deletions quanda/explainers/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from quanda.explainers.wrappers.captum_influence import (
CaptumArnoldi,
CaptumInfluence,
CaptumSimilarity,
CaptumTracInCP,
captum_arnoldi_explain,
captum_arnoldi_self_influence,
captum_similarity_explain,
captum_similarity_self_influence,
captum_tracincp_explain,
captum_tracincp_self_influence,
)

__all__ = [
"CaptumInfluence",
"CaptumSimilarity",
"captum_similarity_explain",
"captum_similarity_self_influence",
"CaptumArnoldi",
"captum_arnoldi_explain",
"captum_arnoldi_self_influence",
"CaptumTracInCP",
"captum_tracincp_explain",
"captum_tracincp_self_influence",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
ArnoldiInfluenceFunction,
)

from src.explainers.base import BaseExplainer
from src.explainers.utils import (
from quanda.explainers.base import BaseExplainer
from quanda.explainers.utils import (
explain_fn_from_explainer,
self_influence_fn_from_explainer,
)
from src.utils.common import get_load_state_dict_func
from src.utils.functions.similarities import cosine_similarity
from src.utils.validation import validate_checkpoints_load_func
from quanda.utils.common import get_load_state_dict_func
from quanda.utils.functions import cosine_similarity
from quanda.utils.validation import validate_checkpoints_load_func


class CaptumInfluence(BaseExplainer, ABC):
Expand Down
16 changes: 16 additions & 0 deletions quanda/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from quanda.metrics import localization, randomization, unnamed
from quanda.metrics.aggr_strategies import (
GlobalAggrStrategy,
GlobalSelfInfluenceStrategy,
)
from quanda.metrics.base import GlobalMetric, Metric

__all__ = [
"Metric",
"GlobalMetric",
"GlobalAggrStrategy",
"GlobalSelfInfluenceStrategy",
"randomization",
"localization",
"unnamed",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import torch

from src.explainers.aggregators import BaseAggregator
from src.explainers.base import BaseExplainer
from quanda.explainers import BaseAggregator, BaseExplainer


class GlobalSelfInfluenceStrategy:
Expand Down
4 changes: 2 additions & 2 deletions src/metrics/base.py → quanda/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from src.explainers.aggregators import aggr_types
from src.metrics.aggr_strategies import (
from quanda.explainers import aggr_types
from quanda.metrics.aggr_strategies import (
GlobalAggrStrategy,
GlobalSelfInfluenceStrategy,
)
Expand Down
9 changes: 9 additions & 0 deletions quanda/metrics/localization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.metrics.localization.mislabeling_detection import (
MislabelingDetectionMetric,
)
from quanda.metrics.localization.subclass_detection import (
SubclassDetectionMetric,
)

__all__ = ["ClassDetectionMetric", "SubclassDetectionMetric", "MislabelingDetectionMetric"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from src.metrics.base import Metric
from quanda.metrics.base import Metric


class ClassDetectionMetric(Metric):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from src.metrics.base import GlobalMetric
from quanda.metrics.base import GlobalMetric


class MislabelingDetectionMetric(GlobalMetric):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from src.metrics.localization.class_detection import ClassDetectionMetric
from quanda.metrics.localization import ClassDetectionMetric


class SubclassDetectionMetric(ClassDetectionMetric):
Expand Down
5 changes: 5 additions & 0 deletions quanda/metrics/randomization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from quanda.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)

__all__ = ["ModelRandomizationMetric"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

import torch

from src.metrics.base import Metric
from src.utils.common import get_parent_module_from_name
from src.utils.functions.correlations import (
CorrelationFnLiterals,
correlation_functions,
)
from quanda.metrics.base import Metric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import CorrelationFnLiterals, correlation_functions


class ModelRandomizationMetric(Metric):
Expand Down
4 changes: 4 additions & 0 deletions quanda/metrics/unnamed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric
from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric

__all__ = ["DatasetCleaningMetric", "TopKOverlapMetric"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import lightning as L
import torch

from src.metrics.base import GlobalMetric
from src.utils.common import class_accuracy
from src.utils.training.trainer import BaseTrainer
from quanda.metrics.base import GlobalMetric
from quanda.utils.common import class_accuracy
from quanda.utils.training import BaseTrainer


class DatasetCleaningMetric(GlobalMetric):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from src.metrics.base import Metric
from quanda.metrics.base import Metric


class TopKOverlapMetric(Metric):
Expand Down
4 changes: 4 additions & 0 deletions quanda/toy_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from quanda.toy_benchmarks import localization, randomization, unnamed
from quanda.toy_benchmarks.base import ToyBenchmark

__all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"]
File renamed without changes.
9 changes: 9 additions & 0 deletions quanda/toy_benchmarks/localization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from quanda.toy_benchmarks.localization.class_detection import ClassDetection
from quanda.toy_benchmarks.localization.mislabeling_detection import (
MislabelingDetection,
)
from quanda.toy_benchmarks.localization.subclass_detection import (
SubclassDetection,
)

__all__ = ["ClassDetection", "SubclassDetection", "MislabelingDetection"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from tqdm import tqdm

from src.metrics.localization.class_detection import ClassDetectionMetric
from src.toy_benchmarks.base import ToyBenchmark
from quanda.metrics.localization import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark


class ClassDetection(ToyBenchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import torch
from tqdm import tqdm

from src.metrics.localization.mislabeling_detection import (
from quanda.metrics.localization.mislabeling_detection import (
MislabelingDetectionMetric,
)
from src.toy_benchmarks.base import ToyBenchmark
from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset
from src.utils.training.trainer import BaseTrainer
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.utils.datasets.transformed.label_flipping import (
LabelFlippingDataset,
)
from quanda.utils.training.trainer import BaseTrainer


class MislabelingDetection(ToyBenchmark):
Expand Down Expand Up @@ -256,16 +258,16 @@ def evaluate(
pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (input, labels) in enumerate(pbar):
for i, (inputs, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

input, labels = input.to(device), labels.to(device)
inputs, labels = inputs.to(device), labels.to(device)
if use_predictions:
with torch.no_grad():
targets = self.model(input).argmax(dim=-1)
targets = self.model(inputs).argmax(dim=-1)
else:
targets = labels
explanations = explainer.explain(test=input, targets=targets)
explanations = explainer.explain(test=inputs, targets=targets)
metric.update(explanations)
else:
metric = MislabelingDetectionMetric.self_influence_based(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import torch
from tqdm import tqdm

from src.metrics.localization.class_detection import ClassDetectionMetric
from src.toy_benchmarks.base import ToyBenchmark
from src.utils.datasets.transformed.label_grouping import (
from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.utils.datasets.transformed.label_grouping import (
ClassToGroupLiterals,
LabelGroupingDataset,
)
from src.utils.training.trainer import BaseTrainer
from quanda.utils.training.trainer import BaseTrainer


class SubclassDetection(ToyBenchmark):
Expand Down Expand Up @@ -231,19 +231,19 @@ def evaluate(
pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (input, labels) in enumerate(pbar):
for i, (inputs, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

input, labels = input.to(device), labels.to(device)
inputs, labels = inputs.to(device), labels.to(device)
grouped_labels = torch.tensor([self.class_to_group[i.item()] for i in labels], device=labels.device)
if use_predictions:
with torch.no_grad():
output = self.group_model(input)
output = self.group_model(inputs)
targets = output.argmax(dim=-1)
else:
targets = grouped_labels
explanations = explainer.explain(
test=input,
test=inputs,
targets=targets,
)

Expand Down
5 changes: 5 additions & 0 deletions quanda/toy_benchmarks/randomization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from quanda.toy_benchmarks.randomization.model_randomization import (
ModelRandomization,
)

__all__ = ["ModelRandomization"]
Loading
Loading