Skip to content

Commit

Permalink
Merge pull request #100 from dilyabareeva/import_strategy
Browse files Browse the repository at this point in the history
Import strategy
  • Loading branch information
dilyabareeva authored Aug 6, 2024
2 parents 6ddd4a8 + a1f56ff commit ca2b0dc
Show file tree
Hide file tree
Showing 94 changed files with 302 additions and 159 deletions.
16 changes: 8 additions & 8 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
[run]
source = src
source = quanda
omit =
/tests/*
src/utils/explanations.py
src/utils/transforms.py
src/utils/datasets/transformed/sample.py
src/utils/cache.py
src/utils/datasets/activation_dataset.py
src/utils/datasets/indexed_subset.py
src/explainers/functional.py
quanda/utils/explanations.py
quanda/utils/transforms.py
quanda/utils/datasets/transformed/sample.py
quanda/utils/cache.py
quanda/utils/datasets/activation_dataset.py
quanda/utils/datasets/indexed_subset.py
quanda/explainers/functional.py
[report]
ignore_errors = True
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ style:
black .
python -m flake8 . --pytest-parametrize-names-type=csv
python -m isort .
python -m mypy src --check-untyped-defs
python -m mypy quanda --check-untyped-defs
rm -f .coverage
rm -f .coverage.*
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
Expand Down
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ Excerpts from `tutorials/usage_testing.py`:
<summary><b><big>Step 1. Import library components</big></b></summary>

```python
from src.explainers.wrappers.captum_influence import captum_similarity_explain, CaptumSimilarity
from src.metrics.localization.class_detection import ClassDetectionMetric
from src.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)
from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric
from quanda.explainers.wrappers import captum_similarity_explain, CaptumSimilarity
from quanda.metrics.localization import ClassDetectionMetric
from quanda.metrics.randomization import ModelRandomizationMetric
from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric
```
</details>

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ exclude = '''
testpaths = ["tests"]
python_files = "test_*.py"

[tool.setuptools]
py-modules = []

[project.optional-dependencies]
dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh
"coverage>=7.2.3",
Expand Down
3 changes: 3 additions & 0 deletions quanda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from quanda import explainers, metrics, toy_benchmarks, utils

__all__ = ["explainers", "metrics", "toy_benchmarks", "utils"]
23 changes: 23 additions & 0 deletions quanda/explainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from quanda.explainers import utils, wrappers
from quanda.explainers.aggregators import (
AbsSumAggregator,
BaseAggregator,
SumAggregator,
aggr_types,
)
from quanda.explainers.base import BaseExplainer
from quanda.explainers.functional import ExplainFunc, ExplainFuncMini
from quanda.explainers.random import RandomExplainer

__all__ = [
"BaseExplainer",
"RandomExplainer",
"ExplainFunc",
"ExplainFuncMini",
"utils",
"wrappers",
"BaseAggregator",
"SumAggregator",
"AbsSumAggregator",
"aggr_types",
]
File renamed without changes.
4 changes: 2 additions & 2 deletions src/explainers/base.py → quanda/explainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from src.utils.common import cache_result
from src.utils.validation import validate_1d_tensor_or_int_list
from quanda.utils.common import cache_result
from quanda.utils.validation import validate_1d_tensor_or_int_list


class BaseExplainer(ABC):
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions src/explainers/random.py → quanda/explainers/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch

from src.explainers.base import BaseExplainer
from src.utils.common import cache_result
from quanda.explainers import BaseExplainer
from quanda.utils.common import cache_result


class RandomExplainer(BaseExplainer):
Expand Down
File renamed without changes.
25 changes: 25 additions & 0 deletions quanda/explainers/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from quanda.explainers.wrappers.captum_influence import (
CaptumArnoldi,
CaptumInfluence,
CaptumSimilarity,
CaptumTracInCP,
captum_arnoldi_explain,
captum_arnoldi_self_influence,
captum_similarity_explain,
captum_similarity_self_influence,
captum_tracincp_explain,
captum_tracincp_self_influence,
)

__all__ = [
"CaptumInfluence",
"CaptumSimilarity",
"captum_similarity_explain",
"captum_similarity_self_influence",
"CaptumArnoldi",
"captum_arnoldi_explain",
"captum_arnoldi_self_influence",
"CaptumTracInCP",
"captum_tracincp_explain",
"captum_tracincp_self_influence",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
ArnoldiInfluenceFunction,
)

from src.explainers.base import BaseExplainer
from src.explainers.utils import (
from quanda.explainers.base import BaseExplainer
from quanda.explainers.utils import (
explain_fn_from_explainer,
self_influence_fn_from_explainer,
)
from src.utils.common import get_load_state_dict_func
from src.utils.functions.similarities import cosine_similarity
from src.utils.validation import validate_checkpoints_load_func
from quanda.utils.common import get_load_state_dict_func
from quanda.utils.functions import cosine_similarity
from quanda.utils.validation import validate_checkpoints_load_func


class CaptumInfluence(BaseExplainer, ABC):
Expand Down
16 changes: 16 additions & 0 deletions quanda/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from quanda.metrics import localization, randomization, unnamed
from quanda.metrics.aggr_strategies import (
GlobalAggrStrategy,
GlobalSelfInfluenceStrategy,
)
from quanda.metrics.base import GlobalMetric, Metric

__all__ = [
"Metric",
"GlobalMetric",
"GlobalAggrStrategy",
"GlobalSelfInfluenceStrategy",
"randomization",
"localization",
"unnamed",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import torch

from src.explainers.aggregators import BaseAggregator
from src.explainers.base import BaseExplainer
from quanda.explainers import BaseAggregator, BaseExplainer


class GlobalSelfInfluenceStrategy:
Expand Down
4 changes: 2 additions & 2 deletions src/metrics/base.py → quanda/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch

from src.explainers.aggregators import aggr_types
from src.metrics.aggr_strategies import (
from quanda.explainers import aggr_types
from quanda.metrics.aggr_strategies import (
GlobalAggrStrategy,
GlobalSelfInfluenceStrategy,
)
Expand Down
9 changes: 9 additions & 0 deletions quanda/metrics/localization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.metrics.localization.mislabeling_detection import (
MislabelingDetectionMetric,
)
from quanda.metrics.localization.subclass_detection import (
SubclassDetectionMetric,
)

__all__ = ["ClassDetectionMetric", "SubclassDetectionMetric", "MislabelingDetectionMetric"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from src.metrics.base import Metric
from quanda.metrics.base import Metric


class ClassDetectionMetric(Metric):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from src.metrics.base import GlobalMetric
from quanda.metrics.base import GlobalMetric


class MislabelingDetectionMetric(GlobalMetric):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from src.metrics.localization.class_detection import ClassDetectionMetric
from quanda.metrics.localization import ClassDetectionMetric


class SubclassDetectionMetric(ClassDetectionMetric):
Expand Down
5 changes: 5 additions & 0 deletions quanda/metrics/randomization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from quanda.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)

__all__ = ["ModelRandomizationMetric"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

import torch

from src.metrics.base import Metric
from src.utils.common import get_parent_module_from_name
from src.utils.functions.correlations import (
CorrelationFnLiterals,
correlation_functions,
)
from quanda.metrics.base import Metric
from quanda.utils.common import get_parent_module_from_name
from quanda.utils.functions import CorrelationFnLiterals, correlation_functions


class ModelRandomizationMetric(Metric):
Expand Down
4 changes: 4 additions & 0 deletions quanda/metrics/unnamed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric
from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric

__all__ = ["DatasetCleaningMetric", "TopKOverlapMetric"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import lightning as L
import torch

from src.metrics.base import GlobalMetric
from src.utils.common import class_accuracy
from src.utils.training.trainer import BaseTrainer
from quanda.metrics.base import GlobalMetric
from quanda.utils.common import class_accuracy
from quanda.utils.training import BaseTrainer


class DatasetCleaningMetric(GlobalMetric):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from src.metrics.base import Metric
from quanda.metrics.base import Metric


class TopKOverlapMetric(Metric):
Expand Down
4 changes: 4 additions & 0 deletions quanda/toy_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from quanda.toy_benchmarks import localization, randomization, unnamed
from quanda.toy_benchmarks.base import ToyBenchmark

__all__ = ["ToyBenchmark", "randomization", "localization", "unnamed"]
File renamed without changes.
9 changes: 9 additions & 0 deletions quanda/toy_benchmarks/localization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from quanda.toy_benchmarks.localization.class_detection import ClassDetection
from quanda.toy_benchmarks.localization.mislabeling_detection import (
MislabelingDetection,
)
from quanda.toy_benchmarks.localization.subclass_detection import (
SubclassDetection,
)

__all__ = ["ClassDetection", "SubclassDetection", "MislabelingDetection"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
from tqdm import tqdm

from src.metrics.localization.class_detection import ClassDetectionMetric
from src.toy_benchmarks.base import ToyBenchmark
from quanda.metrics.localization import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark


class ClassDetection(ToyBenchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import torch
from tqdm import tqdm

from src.metrics.localization.mislabeling_detection import (
from quanda.metrics.localization.mislabeling_detection import (
MislabelingDetectionMetric,
)
from src.toy_benchmarks.base import ToyBenchmark
from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset
from src.utils.training.trainer import BaseTrainer
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.utils.datasets.transformed.label_flipping import (
LabelFlippingDataset,
)
from quanda.utils.training.trainer import BaseTrainer


class MislabelingDetection(ToyBenchmark):
Expand Down Expand Up @@ -256,16 +258,16 @@ def evaluate(
pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (input, labels) in enumerate(pbar):
for i, (inputs, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

input, labels = input.to(device), labels.to(device)
inputs, labels = inputs.to(device), labels.to(device)
if use_predictions:
with torch.no_grad():
targets = self.model(input).argmax(dim=-1)
targets = self.model(inputs).argmax(dim=-1)
else:
targets = labels
explanations = explainer.explain(test=input, targets=targets)
explanations = explainer.explain(test=inputs, targets=targets)
metric.update(explanations)
else:
metric = MislabelingDetectionMetric.self_influence_based(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import torch
from tqdm import tqdm

from src.metrics.localization.class_detection import ClassDetectionMetric
from src.toy_benchmarks.base import ToyBenchmark
from src.utils.datasets.transformed.label_grouping import (
from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.toy_benchmarks.base import ToyBenchmark
from quanda.utils.datasets.transformed.label_grouping import (
ClassToGroupLiterals,
LabelGroupingDataset,
)
from src.utils.training.trainer import BaseTrainer
from quanda.utils.training.trainer import BaseTrainer


class SubclassDetection(ToyBenchmark):
Expand Down Expand Up @@ -231,19 +231,19 @@ def evaluate(
pbar = tqdm(expl_dl)
n_batches = len(expl_dl)

for i, (input, labels) in enumerate(pbar):
for i, (inputs, labels) in enumerate(pbar):
pbar.set_description("Metric evaluation, batch %d/%d" % (i + 1, n_batches))

input, labels = input.to(device), labels.to(device)
inputs, labels = inputs.to(device), labels.to(device)
grouped_labels = torch.tensor([self.class_to_group[i.item()] for i in labels], device=labels.device)
if use_predictions:
with torch.no_grad():
output = self.group_model(input)
output = self.group_model(inputs)
targets = output.argmax(dim=-1)
else:
targets = grouped_labels
explanations = explainer.explain(
test=input,
test=inputs,
targets=targets,
)

Expand Down
5 changes: 5 additions & 0 deletions quanda/toy_benchmarks/randomization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from quanda.toy_benchmarks.randomization.model_randomization import (
ModelRandomization,
)

__all__ = ["ModelRandomization"]
Loading

0 comments on commit ca2b0dc

Please sign in to comment.