Skip to content

Commit

Permalink
Merge pull request #46 from dilyabareeva/small_trainer_changes_mypy_f…
Browse files Browse the repository at this point in the history
…ixes

Small trainer changes & MyPy fixes
  • Loading branch information
dilyabareeva authored Jun 24, 2024
2 parents 42e8ac5 + 4ec56f9 commit d6e549a
Show file tree
Hide file tree
Showing 27 changed files with 339 additions and 141 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,10 @@ coverage.xml
.pytest_cache/
cover/
/scratch.py

# Lightning
lightning_logs/
checkpoints/

# data_attribution_evaluation
cache/
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,9 @@ style:
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
find . | grep -E ".pytest_cache" | xargs rm -rf
find . | grep -E ".mypy_cache" | xargs rm -rf
find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
find . | grep -E ".checkpoints" | xargs rm -rf
find . | grep -E "*eff-info" | xargs rm -rf
find . | grep -E ".build" | xargs rm -rf
find . | grep -E ".htmlcov" | xargs rm -rf
find . | grep -E ".lightning_logs" | xargs rm -rf
find . -name '*~' -exec rm -f {} +
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ include_trailing_comma = true
python_version = "3.11"
warn_return_any = false
warn_unused_configs = true
ignore_errors = true # TODO: change this
check_untyped_defs = true
#ignore_errors = true # TODO: change this

# Black formatting
[tool.black]
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion src/downstream_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, device: str = "cpu", *args, **kwargs):
def evaluate(
self,
model: torch.nn.Module,
dataset: torch.utils.data.dataset,
dataset: torch.utils.data.Dataset,
*args,
**kwargs,
):
Expand Down
52 changes: 29 additions & 23 deletions src/downstream_tasks/subclass_identification.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,48 @@
import os
from typing import Callable, Dict, Optional, Union

import lightning as L
import torch

from explainers.functional import ExplainFunc
from explainers.wrappers.captum_influence import captum_similarity_explain
from metrics.localization.identical_class import IdenticalClass
from utils.datasets.group_label_dataset import (
from src.explainers.functional import ExplainFunc
from src.explainers.wrappers.captum_influence import captum_similarity_explain
from src.metrics.localization.identical_class import IdenticalClass
from src.utils.datasets.group_label_dataset import (
ClassToGroupLiterals,
GroupLabelDataset,
)
from utils.training.trainer import BaseTrainer, Trainer
from src.utils.training.trainer import BaseTrainer, Trainer


class SubclassIdentification:
def __init__(self, device: str = "cpu", *args, **kwargs):
self.device = device
self.trainer: Optional[BaseTrainer] = None

def init_trainer_from_lightning_module(self, pl_module):
trainer = Trainer()
trainer.from_lightning_module(pl_module)
self.trainer = trainer

def init_trainer_from_train_arguments(
def __init__(
self,
model: torch.nn.Module,
optimizer: Callable,
lr: float,
criterion: torch.nn.modules.loss._Loss,
optimizer_kwargs: Optional[dict] = None,
device: str = "cpu",
*args,
**kwargs,
):
trainer = Trainer()
trainer.from_train_arguments(model, optimizer, lr, criterion, optimizer_kwargs)
self.trainer = trainer
self.device = device
self.trainer: Optional[BaseTrainer] = Trainer.from_arguments(
model=model, optimizer=optimizer, lr=lr, criterion=criterion, optimizer_kwargs=optimizer_kwargs
)

@classmethod
def from_pl_module(cls, model: torch.nn.Module, pl_module: L.LightningModule, device: str = "cpu", *args, **kwargs):
obj = cls.__new__(cls)
super(SubclassIdentification, obj).__init__()
obj.device = device
obj.trainer = Trainer.from_lightning_module(model, pl_module)
return obj

def evaluate(
self,
train_dataset: torch.utils.data.dataset,
val_dataset: Optional[torch.utils.data.dataset] = None,
train_dataset: torch.utils.data.Dataset,
val_dataset: Optional[torch.utils.data.Dataset] = None,
n_classes: int = 10,
n_groups: int = 2,
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random",
Expand Down Expand Up @@ -81,13 +85,15 @@ def evaluate(
class_to_group=grouped_dataset.class_to_group,
seed=seed,
)
val_loader = torch.utils.data.DataLoader(grouped_val_dataset, batch_size=batch_size)
val_loader: Optional[torch.utils.data.DataLoader] = torch.utils.data.DataLoader(
grouped_val_dataset, batch_size=batch_size
)
else:
val_loader = None

model = self.trainer.fit(
grouped_train_loader,
val_loader,
train_loader=grouped_train_loader,
val_loader=val_loader,
trainer_kwargs=trainer_kwargs,
)
metric = IdenticalClass(model=model, train_dataset=train_dataset, device="cpu")
Expand Down
6 changes: 4 additions & 2 deletions src/explainers/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def reset(self, *args, **kwargs):
"""
Used to reset the aggregator state.
"""
self.scores: torch.Tensor = None
self.scores = None

def load_state_dict(self, state_dict: dict, *args, **kwargs):
"""
Expand All @@ -38,7 +38,9 @@ def state_dict(self, *args, **kwargs):
return {"scores": self.scores}

def compute(self) -> torch.Tensor:
return self.scores.argsort()
if self.scores is None:
raise ValueError("No scores to aggregate.")
return self.scores


class SumAggregator(BaseAggregator):
Expand Down
19 changes: 15 additions & 4 deletions src/explainers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Sized, Union

import torch

Expand Down Expand Up @@ -28,8 +28,19 @@ def __init__(
def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any):
raise NotImplementedError

@property
def dataset_length(self) -> int:
"""
By default, the Dataset class does not always have a __len__ method.
:return:
"""
if isinstance(self.train_dataset, Sized):
return len(self.train_dataset)
dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=1)
return len(dl)

@cache_result
def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch.Tensor:
def self_influence(self, batch_size: int = 32, **kwargs: Any) -> torch.Tensor:
"""
Base class implements computing self influences by explaining the train dataset one by one
Expand All @@ -39,10 +50,10 @@ def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch
"""

# Pre-allcate memory for influences, because torch.cat is slow
influences = torch.empty((len(self.train_dataset),), device=self.device)
influences = torch.empty((self.dataset_length,), device=self.device)
ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size)

for i, (x, y) in zip(range(0, len(self.train_dataset), batch_size), ldr):
for i, (x, y) in zip(range(0, self.dataset_length, batch_size), ldr):
explanations = self.explain(test=x.to(self.device), **kwargs)
influences[i : i + batch_size] = explanations.diag(diagonal=i)

Expand Down
6 changes: 3 additions & 3 deletions src/explainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ def __call__(
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
explain_kwargs: Dict,
init_kwargs: Dict,
device: Union[str, torch.device],
explanation_targets: Optional[Union[List[int], torch.Tensor]] = None,
init_kwargs: Optional[Dict] = None,
explain_kwargs: Optional[Dict] = None,
) -> torch.Tensor:
pass
4 changes: 2 additions & 2 deletions src/explainers/wrappers/captum_influence.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Optional, Union

import torch
from captum.influence import SimilarityInfluence
from captum.influence import SimilarityInfluence # type: ignore

from src.explainers.base import BaseExplainer
from src.explainers.utils import (
Expand Down Expand Up @@ -117,7 +117,7 @@ def layer(self, layers: Any):

def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any):
# We might want to pass the top_k as an argument in some scenarios
top_k = kwargs.get("top_k", len(self.train_dataset))
top_k = kwargs.get("top_k", self.dataset_length)

topk_idx, topk_val = super().explain(test=test, top_k=top_k, **kwargs)[self.layer]
inverted_idx = topk_idx.argsort()
Expand Down
20 changes: 16 additions & 4 deletions src/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Sized

import torch

Expand All @@ -7,14 +8,14 @@ class Metric(ABC):
def __init__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.dataset,
train_dataset: torch.utils.data.Dataset,
device: str = "cpu",
*args,
**kwargs,
):
self.model = model.to(device)
self.train_dataset = train_dataset
self.device = device
self.model: torch.nn.Module = model.to(device)
self.train_dataset: torch.utils.data.Dataset = train_dataset
self.device: str = device

@abstractmethod
def update(
Expand Down Expand Up @@ -59,3 +60,14 @@ def state_dict(self, *args, **kwargs):
"""

raise NotImplementedError

@property
def dataset_length(self) -> int:
"""
By default, the Dataset class does not always have a __len__ method.
:return:
"""
if isinstance(self.train_dataset, Sized):
return len(self.train_dataset)
dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=1)
return len(dl)
12 changes: 7 additions & 5 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import torch

from src.metrics.base import Metric
Expand All @@ -12,8 +14,8 @@ def __init__(
*args,
**kwargs,
):
super().__init__(model=model, train_dataset=train_dataset, device=device, *args, **kwargs)
self.scores = []
super().__init__(model=model, train_dataset=train_dataset, device=device)
self.scores: List[torch.Tensor] = []

def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
"""
Expand Down Expand Up @@ -65,11 +67,11 @@ def __init__(
*args,
**kwargs,
):
assert len(subclass_labels) == len(train_dataset), (
super().__init__(model, train_dataset, device, *args, **kwargs)
assert len(subclass_labels) == self.dataset_length, (
f"Number of subclass labels ({len(subclass_labels)}) "
f"does not match the number of train dataset samples ({len(train_dataset)})."
f"does not match the number of train dataset samples ({self.dataset_length})."
)
super().__init__(model, train_dataset, device, *args, **kwargs)
self.subclass_labels = subclass_labels

def update(self, test_subclasses: torch.Tensor, explanations: torch.Tensor):
Expand Down
14 changes: 7 additions & 7 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Callable, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch

Expand All @@ -22,8 +22,8 @@ def __init__(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
explain_fn: ExplainFunc,
explain_init_kwargs: Optional[dict] = {},
explain_fn_kwargs: Optional[dict] = {},
explain_init_kwargs: Optional[dict] = None,
explain_fn_kwargs: Optional[dict] = None,
correlation_fn: Union[Callable, CorrelationFnLiterals] = "spearman",
seed: int = 42,
model_id: str = "0",
Expand All @@ -39,8 +39,8 @@ def __init__(
)
self.model = model
self.train_dataset = train_dataset
self.explain_fn_kwargs = explain_fn_kwargs
self.explain_init_kwargs = explain_init_kwargs
self.explain_fn_kwargs = explain_fn_kwargs or {}
self.explain_init_kwargs = explain_init_kwargs or {}
self.seed = seed
self.model_id = model_id
self.cache_dir = cache_dir
Expand All @@ -64,11 +64,11 @@ def __init__(
train_dataset=self.train_dataset,
)

self.results = {"scores": []}
self.results: Dict[str, List] = {"scores": []}

# TODO: create a validation utility function
if isinstance(correlation_fn, str) and correlation_fn in correlation_functions:
self.corr_measure = correlation_functions.get(correlation_fn)
self.corr_measure = correlation_functions[correlation_fn]
elif callable(correlation_fn):
self.corr_measure = correlation_fn
else:
Expand Down
Loading

0 comments on commit d6e549a

Please sign in to comment.