Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer Usability Improvements + Lightning Compatibility #99

Merged
merged 11 commits into from
Aug 6, 2024
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# .github/workflows/type-lint.yml
name: Type-lint
name: lint
on: push
jobs:
type-lint:
Expand All @@ -18,6 +18,3 @@ jobs:

- name: Run flake8
run: tox run -e lint

- name: Run mypy
run: tox run -e type
20 changes: 20 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# .github/workflows/type-lint.yml
name: mypy
on: push
jobs:
type-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Setup python 3.11
uses: actions/setup-python@v4
with:
cache: 'pip'
python-version: "3.11"

- name: Install tox-gh
run: pip install tox-gh

- name: Run mypy
run: tox run -e type
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ style:
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
find . | grep -E ".pytest_cache" | xargs rm -rf
find . | grep -E ".mypy_cache" | xargs rm -rf
find . | grep -E ".checkpoints" | xargs rm -rf
find . | grep -E "./checkpoints" | xargs rm -rf
find . | grep -E "*eff-info" | xargs rm -rf
find . | grep -E ".build" | xargs rm -rf
find . | grep -E ".htmlcov" | xargs rm -rf
Expand Down
1 change: 0 additions & 1 deletion src/explainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def self_influence_fn_from_explainer(
self_influence_kwargs: dict,
**kwargs: Any,
) -> torch.Tensor:

explainer = _init_explainer(
explainer_cls=explainer_cls,
model=model,
Expand Down
57 changes: 40 additions & 17 deletions src/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
from typing import Optional, Union

import lightning as L
import torch

from src.metrics.base import GlobalMetric
Expand All @@ -20,9 +22,10 @@ class DatasetCleaningMetric(GlobalMetric):

def __init__(
self,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
trainer: BaseTrainer,
trainer: Union[L.Trainer, BaseTrainer],
init_model: Optional[torch.nn.Module] = None,
trainer_fit_kwargs: Optional[dict] = None,
global_method: Union[str, type] = "self-influence",
top_k: int = 50,
Expand All @@ -46,19 +49,18 @@ def __init__(
)
self.top_k = min(top_k, self.dataset_length - 1)
self.trainer = trainer
self.trainer_fit_kwargs = trainer_fit_kwargs
self.trainer_fit_kwargs = trainer_fit_kwargs or {}

self.clean_model: torch.nn.Module
self.clean_accuracy: int
self.original_accuracy: int
self.init_model = init_model or copy.deepcopy(model)

@classmethod
def self_influence_based(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
explainer_cls: type,
trainer: BaseTrainer,
trainer: Union[L.Trainer, BaseTrainer],
init_model: Optional[torch.nn.Module] = None,
expl_kwargs: Optional[dict] = None,
top_k: int = 50,
trainer_fit_kwargs: Optional[dict] = None,
Expand All @@ -70,6 +72,7 @@ def self_influence_based(
model=model,
train_dataset=train_dataset,
trainer=trainer,
init_model=init_model,
trainer_fit_kwargs=trainer_fit_kwargs,
global_method="self-influence",
top_k=top_k,
Expand All @@ -81,10 +84,11 @@ def self_influence_based(
@classmethod
def aggr_based(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
trainer: BaseTrainer,
trainer: Union[L.Trainer, BaseTrainer],
aggregator_cls: Union[str, type],
init_model: Optional[torch.nn.Module] = None,
top_k: int = 50,
trainer_fit_kwargs: Optional[dict] = None,
device: str = "cpu",
Expand All @@ -95,6 +99,7 @@ def aggr_based(
model=model,
train_dataset=train_dataset,
trainer=trainer,
init_model=init_model,
trainer_fit_kwargs=trainer_fit_kwargs,
global_method=aggregator_cls,
top_k=top_k,
Expand Down Expand Up @@ -123,15 +128,33 @@ def compute(self, *args, **kwargs):
clean_subset = torch.utils.data.Subset(self.train_dataset, clean_indices)

train_dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=32, shuffle=True)
self.original_accuracy = class_accuracy(self.model, train_dl, self.device)
original_accuracy = class_accuracy(self.model, train_dl, self.device)

clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True)

self.clean_model = self.trainer.fit(
train_loader=clean_dl,
trainer_fit_kwargs=self.trainer_fit_kwargs,
)
if isinstance(self.trainer, L.Trainer):
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(self.init_model, L.LightningModule):
raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer")

self.trainer.fit(
model=self.init_model,
train_dataloaders=clean_dl,
**self.trainer_fit_kwargs,
)

elif isinstance(self.trainer, BaseTrainer):
if not isinstance(self.init_model, torch.nn.Module):
raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer")

self.trainer.fit(
model=self.init_model,
train_dataloaders=clean_dl,
**self.trainer_fit_kwargs,
)

else:
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

self.clean_accuracy = class_accuracy(self.model, clean_dl, self.device)
clean_accuracy = class_accuracy(self.model, clean_dl, self.device)

return self.original_accuracy - self.clean_accuracy
return original_accuracy - clean_accuracy
58 changes: 39 additions & 19 deletions src/toy_benchmarks/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
from typing import Callable, Dict, List, Optional, Union

import lightning as L
import torch
from tqdm import tqdm

Expand All @@ -8,7 +10,7 @@
)
from src.toy_benchmarks.base import ToyBenchmark
from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset
from src.utils.training.trainer import BaseTrainer, Trainer
from src.utils.training.trainer import BaseTrainer


class MislabelingDetection(ToyBenchmark):
Expand All @@ -20,8 +22,7 @@ def __init__(
):
super().__init__(device=device)

self.trainer: Optional[BaseTrainer] = None
self.model: torch.nn.Module
self.model: Union[torch.nn.Module, L.LightningModule]
self.train_dataset: torch.utils.data.Dataset
self.poisoned_dataset: LabelFlippingDataset
self.dataset_transform: Optional[Callable]
Expand All @@ -37,10 +38,10 @@ def __init__(
@classmethod
def generate(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
trainer: Trainer,
trainer: Union[L.Trainer, BaseTrainer],
dataset_transform: Optional[Callable] = None,
val_dataset: Optional[torch.utils.data.Dataset] = None,
global_method: Union[str, type] = "self-influence",
Expand All @@ -58,15 +59,15 @@ def generate(

obj = cls(device=device)

obj.model = model.to(device)
obj.trainer = trainer
obj._generate(
model=model.to(device),
train_dataset=train_dataset,
val_dataset=val_dataset,
p=p,
global_method=global_method,
dataset_transform=dataset_transform,
n_classes=n_classes,
trainer=trainer,
trainer_fit_kwargs=trainer_fit_kwargs,
seed=seed,
batch_size=batch_size,
Expand All @@ -75,8 +76,10 @@ def generate(

def _generate(
self,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
trainer: Union[L.Trainer, BaseTrainer],
dataset_transform: Optional[Callable],
poisoned_indices: Optional[List[int]] = None,
poisoned_labels: Optional[Dict[int, int]] = None,
Expand All @@ -87,12 +90,6 @@ def _generate(
seed: int = 27,
batch_size: int = 8,
):
if self.trainer is None:
raise ValueError(
"Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or "
"init_trainer_from_train_arguments"
)

self.train_dataset = train_dataset
self.p = p
self.global_method = global_method
Expand All @@ -119,11 +116,34 @@ def _generate(
else:
self.poisoned_val_dl = None

self.model = self.trainer.fit(
train_loader=self.poisoned_train_dl,
val_loader=self.poisoned_val_dl,
trainer_fit_kwargs=trainer_fit_kwargs,
)
self.model = copy.deepcopy(model)

trainer_fit_kwargs = trainer_fit_kwargs or {}

if isinstance(trainer, L.Trainer):
if not isinstance(self.model, L.LightningModule):
raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer")

trainer.fit(
model=self.model,
train_dataloaders=self.poisoned_train_dl,
val_dataloaders=self.poisoned_val_dl,
**trainer_fit_kwargs,
)

elif isinstance(trainer, BaseTrainer):
if not isinstance(self.model, torch.nn.Module):
raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer")

trainer.fit(
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
model=self.model,
train_dataloaders=self.poisoned_train_dl,
val_dataloaders=self.poisoned_val_dl,
**trainer_fit_kwargs,
)

else:
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

@property
def bench_state(self):
Expand Down Expand Up @@ -161,7 +181,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar
@classmethod
def assemble(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
poisoned_indices: Optional[List[int]] = None,
Expand Down
Loading
Loading