From a01ce6fe52f3b1f8caebbec3939827d75c2315b4 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 10:35:59 +0200 Subject: [PATCH 01/11] make model a trainer.fit arg --- src/explainers/utils.py | 1 - src/metrics/unnamed/dataset_cleaning.py | 2 + .../localization/mislabeling_detection.py | 2 + .../localization/subclass_detection.py | 5 ++ src/utils/training/trainer.py | 60 +++++++++++-------- tests/metrics/test_unnamed_metrics.py | 6 +- .../test_mislabeling_detection.py | 2 +- .../localization/test_subclass_detection.py | 2 +- .../unnamed/test_dataset_cleaning.py | 2 +- tests/utils/test_training.py | 8 +-- 10 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/explainers/utils.py b/src/explainers/utils.py index 384afc72..135e8152 100644 --- a/src/explainers/utils.py +++ b/src/explainers/utils.py @@ -49,7 +49,6 @@ def self_influence_fn_from_explainer( self_influence_kwargs: dict, **kwargs: Any, ) -> torch.Tensor: - explainer = _init_explainer( explainer_cls=explainer_cls, model=model, diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index f4c7b453..bfa1c1a0 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -1,3 +1,4 @@ +import copy from typing import Optional, Union import torch @@ -128,6 +129,7 @@ def compute(self, *args, **kwargs): clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) self.clean_model = self.trainer.fit( + model=copy.deepcopy(self.model), train_loader=clean_dl, trainer_fit_kwargs=self.trainer_fit_kwargs, ) diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/src/toy_benchmarks/localization/mislabeling_detection.py index 5d0ed435..bdd2a95b 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/src/toy_benchmarks/localization/mislabeling_detection.py @@ -1,3 +1,4 @@ +import copy from typing import Callable, Dict, List, Optional, Union import torch @@ -120,6 +121,7 @@ def _generate( self.poisoned_val_dl = None self.model = self.trainer.fit( + model=copy.deepcopy(self.model), train_loader=self.poisoned_train_dl, val_loader=self.poisoned_val_dl, trainer_fit_kwargs=trainer_fit_kwargs, diff --git a/src/toy_benchmarks/localization/subclass_detection.py b/src/toy_benchmarks/localization/subclass_detection.py index e8c5a2a1..8cf4476d 100644 --- a/src/toy_benchmarks/localization/subclass_detection.py +++ b/src/toy_benchmarks/localization/subclass_detection.py @@ -1,3 +1,4 @@ +import copy from typing import Callable, Dict, Optional, Union import torch @@ -22,6 +23,7 @@ def __init__( super().__init__(device=device) self.trainer: Optional[BaseTrainer] = None + self.model: torch.nn.Module self.group_model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset self.dataset_transform: Optional[Callable] @@ -35,6 +37,7 @@ def __init__( @classmethod def generate( cls, + model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, trainer: Trainer, val_dataset: Optional[torch.utils.data.Dataset] = None, @@ -56,6 +59,7 @@ def generate( obj = cls(device=device) trainer_fit_kwargs = trainer_fit_kwargs or {"max_epochs": 5} + obj.model = model obj.trainer = trainer obj._generate( train_dataset=train_dataset, @@ -120,6 +124,7 @@ def _generate( self.grouped_val_dl = None self.group_model = self.trainer.fit( + model=copy.deepcopy(self.model), train_loader=self.grouped_train_dl, val_loader=self.grouped_val_dl, trainer_fit_kwargs=trainer_fit_kwargs, diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index 375efb75..a1374ee6 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -12,6 +12,7 @@ class BaseTrainer(metaclass=abc.ABCMeta): @abstractmethod def fit( self, + model: torch.nn.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, trainer_fit_kwargs: Optional[dict] = None, @@ -26,13 +27,17 @@ def get_model(self) -> torch.nn.Module: class Trainer(BaseTrainer): def __init__(self): - self.model: torch.nn.Module self.module: Optional[L.LightningModule] = None + self.optimizer: Optional[Callable] + self.lr: Optional[float] + self.criterion: Optional[torch.nn.modules.loss._Loss] + self.scheduler: Optional[Callable] + self.optimizer_kwargs: Optional[dict] + self.scheduler_kwargs: Optional[dict] @classmethod def from_arguments( cls, - model: torch.nn.Module, optimizer: Callable, lr: float, criterion: torch.nn.modules.loss._Loss, @@ -40,56 +45,61 @@ def from_arguments( optimizer_kwargs: Optional[dict] = None, scheduler_kwargs: Optional[dict] = None, ): + cls.optimizer = optimizer + cls.lr = lr + cls.criterion = criterion + cls.scheduler = scheduler + cls.optimizer_kwargs = optimizer_kwargs or {} + cls.scheduler_kwargs = scheduler_kwargs or {} + cls.module = None + obj = cls.__new__(cls) super(Trainer, obj).__init__() - obj.model = model - if optimizer_kwargs is None: - optimizer_kwargs = {} - obj.module = BasicLightningModule( - model=model, - optimizer=optimizer, - lr=lr, - criterion=criterion, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, - ) + return obj @classmethod def from_lightning_module( cls, - model: torch.nn.Module, pl_module: L.LightningModule, ): obj = cls.__new__(cls) super(Trainer, obj).__init__() - obj.model = model obj.module = pl_module return obj def fit( self, + model: torch.nn.Module, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, trainer_fit_kwargs: Optional[dict] = None, *args, **kwargs, ): - if self.model is None: - raise ValueError( - "Lightning module not initialized. Please initialize using from_arguments or from_lightning_module" - ) if self.module is None: - raise ValueError("Model not initialized. Please initialize using from_arguments or from_lightning_module") + if self.optimizer is None: + raise ValueError("Optimizer not initialized. Please initialize optimizer using from_arguments") + if self.lr is None: + raise ValueError("Learning rate not initialized. Please initialize lr using from_arguments") + if self.criterion is None: + raise ValueError("Criterion not initialized. Please initialize criterion using from_arguments") + + self.module = BasicLightningModule( + model=model, + optimizer=self.optimizer, + lr=self.lr, + criterion=self.criterion, + optimizer_kwargs=self.optimizer_kwargs, + scheduler=self.scheduler, + scheduler_kwargs=self.scheduler_kwargs, + ) if trainer_fit_kwargs is None: trainer_fit_kwargs = {} trainer = L.Trainer(**trainer_fit_kwargs) trainer.fit(self.module, train_loader, val_loader) - self.model.load_state_dict(self.module.model.state_dict()) - return self.model + model.load_state_dict(self.module.model.state_dict()) - def get_model(self) -> torch.nn.Module: - return self.model + return model diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index b439c303..625d489e 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -107,7 +107,7 @@ def test_dataset_cleaning( lr=lr, criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + trainer = Trainer.from_lightning_module(pl_module) if global_method != "self-influence": metric = DatasetCleaningMetric( @@ -189,7 +189,7 @@ def test_dataset_cleaning_self_influence_based( lr=lr, criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + trainer = Trainer.from_lightning_module(pl_module) expl_kwargs = expl_kwargs or {} @@ -253,7 +253,7 @@ def test_dataset_cleaning_aggr_based( lr=lr, criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + trainer = Trainer.from_lightning_module(pl_module) metric = DatasetCleaningMetric.aggr_based( model=model, diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index 67c6c0c3..864a9a54 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -132,7 +132,7 @@ def test_mislabeling_detection( criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + trainer = Trainer.from_lightning_module(pl_module) dst_eval = MislabelingDetection.generate( model=model, diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index 6de47fee..d158e922 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -120,7 +120,7 @@ def test_subclass_detection( criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + trainer = Trainer.from_lightning_module(pl_module) dst_eval = SubclassDetection.generate( model=model, diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index d3b27c47..bf48516c 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -136,7 +136,7 @@ def test_dataset_cleaning( criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + trainer = Trainer.from_lightning_module(pl_module) score = dst_eval.evaluate( expl_dataset=dataset, explainer_cls=explainer_cls, diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index 884c3feb..b7cf165c 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -69,7 +69,6 @@ def test_trainer( trainer = Trainer() if mode == "from_arguments": trainer = trainer.from_arguments( - model=model, optimizer=optimizer, lr=lr, criterion=criterion, @@ -83,10 +82,11 @@ def test_trainer( lr=lr, criterion=criterion, ) - trainer = trainer.from_lightning_module(model=model, pl_module=pl_module) + trainer = trainer.from_lightning_module(pl_module=pl_module) model = trainer.fit( - dataloader, - dataloader, + model=model, + train_loader=dataloader, + val_loader=dataloader, trainer_fit_kwargs={"max_epochs": max_epochs}, ) From d2d4c2c1a25ff445969acd4e0dda41f27a6b5225 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 11:58:38 +0200 Subject: [PATCH 02/11] allow lightning trainer instead of quanda base trainer --- src/metrics/unnamed/dataset_cleaning.py | 15 ++-- .../localization/mislabeling_detection.py | 19 +++-- .../localization/subclass_detection.py | 21 ++--- .../unnamed/dataset_cleaning.py | 9 ++- src/utils/training/trainer.py | 10 +-- tests/conftest.py | 17 ++++ .../localization/test_subclass_detection.py | 78 ++++++++++++++++++- tests/utils/test_training.py | 8 +- tutorials/usage_testing.py | 6 +- 9 files changed, 143 insertions(+), 40 deletions(-) diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index bfa1c1a0..defc2cd4 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -2,6 +2,7 @@ from typing import Optional, Union import torch +import lightning as L from src.metrics.base import GlobalMetric from src.utils.common import class_accuracy @@ -23,7 +24,7 @@ def __init__( self, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - trainer: BaseTrainer, + trainer: Union[L.Trainer, BaseTrainer], trainer_fit_kwargs: Optional[dict] = None, global_method: Union[str, type] = "self-influence", top_k: int = 50, @@ -59,7 +60,7 @@ def self_influence_based( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, explainer_cls: type, - trainer: BaseTrainer, + trainer: Union[L.Trainer, BaseTrainer], expl_kwargs: Optional[dict] = None, top_k: int = 50, trainer_fit_kwargs: Optional[dict] = None, @@ -84,7 +85,7 @@ def aggr_based( cls, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - trainer: BaseTrainer, + trainer: Union[L.Trainer, BaseTrainer], aggregator_cls: Union[str, type], top_k: int = 50, trainer_fit_kwargs: Optional[dict] = None, @@ -128,9 +129,11 @@ def compute(self, *args, **kwargs): clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) - self.clean_model = self.trainer.fit( - model=copy.deepcopy(self.model), - train_loader=clean_dl, + self.clean_model = copy.deepcopy(self.model) + + self.trainer.fit( + model=self.clean_model, + train_dataloaders=clean_dl, trainer_fit_kwargs=self.trainer_fit_kwargs, ) diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/src/toy_benchmarks/localization/mislabeling_detection.py index bdd2a95b..432a7acc 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/src/toy_benchmarks/localization/mislabeling_detection.py @@ -3,13 +3,14 @@ import torch from tqdm import tqdm +import lightning as L from src.metrics.localization.mislabeling_detection import ( MislabelingDetectionMetric, ) from src.toy_benchmarks.base import ToyBenchmark from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset -from src.utils.training.trainer import BaseTrainer, Trainer +from src.utils.training.trainer import BaseTrainer class MislabelingDetection(ToyBenchmark): @@ -21,7 +22,7 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[BaseTrainer] = None + self.trainer: Optional[L.Trainer, BaseTrainer] = None self.model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset self.poisoned_dataset: LabelFlippingDataset @@ -41,7 +42,7 @@ def generate( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, n_classes: int, - trainer: Trainer, + trainer: Union[L.Trainer, BaseTrainer], dataset_transform: Optional[Callable] = None, val_dataset: Optional[torch.utils.data.Dataset] = None, global_method: Union[str, type] = "self-influence", @@ -59,9 +60,9 @@ def generate( obj = cls(device=device) - obj.model = model.to(device) obj.trainer = trainer obj._generate( + model=model.to(device), train_dataset=train_dataset, val_dataset=val_dataset, p=p, @@ -76,6 +77,7 @@ def generate( def _generate( self, + model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, n_classes: int, dataset_transform: Optional[Callable], @@ -120,10 +122,11 @@ def _generate( else: self.poisoned_val_dl = None - self.model = self.trainer.fit( - model=copy.deepcopy(self.model), - train_loader=self.poisoned_train_dl, - val_loader=self.poisoned_val_dl, + self.model = copy.deepcopy(model) + self.trainer.fit( + model=self.model, + train_dataloaders=self.poisoned_train_dl, + val_dataloaders=self.poisoned_val_dl, trainer_fit_kwargs=trainer_fit_kwargs, ) diff --git a/src/toy_benchmarks/localization/subclass_detection.py b/src/toy_benchmarks/localization/subclass_detection.py index 8cf4476d..badd7dc6 100644 --- a/src/toy_benchmarks/localization/subclass_detection.py +++ b/src/toy_benchmarks/localization/subclass_detection.py @@ -3,6 +3,7 @@ import torch from tqdm import tqdm +import lightning as L from src.metrics.localization.class_detection import ClassDetectionMetric from src.toy_benchmarks.base import ToyBenchmark @@ -10,7 +11,7 @@ ClassToGroupLiterals, LabelGroupingDataset, ) -from src.utils.training.trainer import BaseTrainer, Trainer +from src.utils.training.trainer import BaseTrainer class SubclassDetection(ToyBenchmark): @@ -22,7 +23,7 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[BaseTrainer] = None + self.trainer: Optional[L.Trainer, BaseTrainer] = None self.model: torch.nn.Module self.group_model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset @@ -39,7 +40,7 @@ def generate( cls, model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, - trainer: Trainer, + trainer: Union[L.Trainer, BaseTrainer], val_dataset: Optional[torch.utils.data.Dataset] = None, dataset_transform: Optional[Callable] = None, n_classes: int = 10, @@ -57,7 +58,7 @@ def generate( """ obj = cls(device=device) - trainer_fit_kwargs = trainer_fit_kwargs or {"max_epochs": 5} + trainer_fit_kwargs = trainer_fit_kwargs or {} obj.model = model obj.trainer = trainer @@ -123,11 +124,13 @@ def _generate( else: self.grouped_val_dl = None - self.group_model = self.trainer.fit( - model=copy.deepcopy(self.model), - train_loader=self.grouped_train_dl, - val_loader=self.grouped_val_dl, - trainer_fit_kwargs=trainer_fit_kwargs, + self.group_model = copy.deepcopy(self.model) + + self.trainer.fit( + model=self.group_model, + train_dataloaders=self.grouped_train_dl, + val_dataloaders=self.grouped_val_dl, + **trainer_fit_kwargs, ) @classmethod diff --git a/src/toy_benchmarks/unnamed/dataset_cleaning.py b/src/toy_benchmarks/unnamed/dataset_cleaning.py index 93643f29..67b7d72e 100644 --- a/src/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/src/toy_benchmarks/unnamed/dataset_cleaning.py @@ -2,10 +2,11 @@ import torch from tqdm import tqdm +import lightning as L from src.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric from src.toy_benchmarks.base import ToyBenchmark -from src.utils.training.trainer import Trainer +from src.utils.training.trainer import BaseTrainer class DatasetCleaning(ToyBenchmark): @@ -17,8 +18,8 @@ def __init__( ): super().__init__(device=device) - self.model: torch.nn.Module - self.train_dataset: torch.utils.data.Dataset + self.model: Optional[torch.nn.Module] = None + self.train_dataset: Optional[torch.utils.data.Dataset] = None @classmethod def generate( @@ -85,7 +86,7 @@ def evaluate( self, expl_dataset: torch.utils.data.Dataset, explainer_cls: type, - trainer: Trainer, + trainer: Union[L.Trainer, BaseTrainer], use_predictions: bool = False, expl_kwargs: Optional[dict] = None, trainer_fit_kwargs: Optional[dict] = None, diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index a1374ee6..6fdebf95 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -13,8 +13,8 @@ class BaseTrainer(metaclass=abc.ABCMeta): def fit( self, model: torch.nn.Module, - train_loader: torch.utils.data.dataloader.DataLoader, - val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, + train_dataloaders: torch.utils.data.dataloader.DataLoader, + val_dataloaders: Optional[torch.utils.data.dataloader.DataLoader] = None, trainer_fit_kwargs: Optional[dict] = None, *args, **kwargs, @@ -71,8 +71,8 @@ def from_lightning_module( def fit( self, model: torch.nn.Module, - train_loader: torch.utils.data.dataloader.DataLoader, - val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, + train_dataloaders: torch.utils.data.dataloader.DataLoader, + val_dataloaders: Optional[torch.utils.data.dataloader.DataLoader] = None, trainer_fit_kwargs: Optional[dict] = None, *args, **kwargs, @@ -98,7 +98,7 @@ def fit( if trainer_fit_kwargs is None: trainer_fit_kwargs = {} trainer = L.Trainer(**trainer_fit_kwargs) - trainer.fit(self.module, train_loader, val_loader) + trainer.fit(self.module, train_dataloaders, val_dataloaders) model.load_state_dict(self.module.model.state_dict()) diff --git a/tests/conftest.py b/tests/conftest.py index c86ad489..e4e6439c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ from src.utils.datasets.transformed.label_flipping import LabelFlippingDataset from src.utils.datasets.transformed.label_grouping import LabelGroupingDataset +from src.utils.training.base_pl_module import BasicLightningModule from tests.models import LeNet MNIST_IMAGE_SIZE = 28 @@ -85,6 +86,22 @@ def load_mnist_model(): return model +@pytest.fixture +def load_mnist_pl_module(): + """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" + model = LeNet() + model.load_state_dict(torch.load("tests/assets/mnist", map_location="cpu", pickle_module=pickle)) + + pl_module = BasicLightningModule( + model=model, + optimizer=torch.optim.SGD, + lr=0.01, + criterion=torch.nn.CrossEntropyLoss(), + ) + + return pl_module + + @pytest.fixture def load_mnist_grouped_model(): """Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).""" diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index d158e922..365d65a3 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -1,5 +1,5 @@ import pytest - +import lightning as L from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.toy_benchmarks.localization.subclass_detection import ( SubclassDetection, @@ -157,3 +157,79 @@ def test_subclass_detection( ) assert score == expected_score + + +@pytest.mark.toy_benchmarks +@pytest.mark.parametrize( + "test_id, pl_module, max_epochs, dataset, n_classes, n_groups, seed, " + "class_to_group, batch_size, explainer_cls, expl_kwargs, use_pred, load_path, expected_score", + [ + ( + "mnist", + "load_mnist_pl_module", + 3, + "load_mnist_dataset", + 10, + 2, + 27, + {i: i % 2 for i in range(10)}, + 8, + CaptumSimilarity, + { + "layers": "model.fc_2", + "similarity_metric": cosine_similarity, + }, + False, + None, + 1.0, + ), + ], +) +def test_subclass_detection_generate_lightning_model( + test_id, + pl_module, + max_epochs, + dataset, + n_classes, + n_groups, + seed, + class_to_group, + batch_size, + explainer_cls, + expl_kwargs, + use_pred, + load_path, + expected_score, + tmp_path, + request, +): + pl_module = request.getfixturevalue(pl_module) + dataset = request.getfixturevalue(dataset) + + trainer = L.Trainer(max_epochs=max_epochs) + + dst_eval = SubclassDetection.generate( + model=pl_module, + trainer=trainer, + train_dataset=dataset, + n_classes=n_classes, + n_groups=n_groups, + class_to_group=class_to_group, + trainer_fit_kwargs={}, + seed=seed, + batch_size=batch_size, + device="cpu", + ) + + score = dst_eval.evaluate( + expl_dataset=dataset, + explainer_cls=explainer_cls, + expl_kwargs=expl_kwargs, + cache_dir=str(tmp_path), + model_id="default_model_id", + use_predictions=use_pred, + batch_size=batch_size, + device="cpu", + ) + + assert score == expected_score \ No newline at end of file diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index b7cf165c..344a6fc8 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -10,7 +10,7 @@ @pytest.mark.utils @pytest.mark.parametrize( "test_id, init_model, dataloader, optimizer, lr, criterion, scheduler, scheduler_kwargs, \ - max_epochs, val_loader, early_stopping, early_stopping_kwargs, mode", + max_epochs, val_dataloaders, early_stopping, early_stopping_kwargs, mode", [ ( "mnist", @@ -54,7 +54,7 @@ def test_trainer( scheduler, scheduler_kwargs, max_epochs, - val_loader, + val_dataloaders, early_stopping, early_stopping_kwargs, mode, @@ -85,8 +85,8 @@ def test_trainer( trainer = trainer.from_lightning_module(pl_module=pl_module) model = trainer.fit( model=model, - train_loader=dataloader, - val_loader=dataloader, + train_dataloaders=dataloader, + val_dataloaders=dataloader, trainer_fit_kwargs={"max_epochs": max_epochs}, ) diff --git a/tutorials/usage_testing.py b/tutorials/usage_testing.py index d3fa4c26..03134f05 100644 --- a/tutorials/usage_testing.py +++ b/tutorials/usage_testing.py @@ -54,13 +54,13 @@ def main(): ) train_set = torchvision.datasets.CIFAR10(root="./tutorials/data", train=True, download=True, transform=normalize) - train_loader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=8) + train_dataloader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=8) # we split held out data into test and validation set held_out = torchvision.datasets.CIFAR10(root="./tutorials/data", train=False, download=True, transform=normalize) test_set, val_set = torch.utils.data.random_split(held_out, [0.1, 0.9], generator=RNG) test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=8) - # val_loader = DataLoader(val_set, batch_size=100, shuffle=False, num_workers=8) + # val_dataloader = DataLoader(val_set, batch_size=100, shuffle=False, num_workers=8) # download pre-trained weights local_path = "./tutorials/model_weights_resnet18_cifar10.pth" @@ -103,7 +103,7 @@ def accuracy(net, loader): correct += predicted.eq(targets).sum().item() return correct / total - print(f"Train set accuracy: {100.0 * accuracy(model, train_loader):0.1f}%") + print(f"Train set accuracy: {100.0 * accuracy(model, train_dataloader):0.1f}%") print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%") # ++++++++++++++++++++++++++++++++++++++++++ From 86fb1e4bfac754bf909140cca49f2a8c2d5e4800 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 12:48:37 +0200 Subject: [PATCH 03/11] add pl module tests for benchmarks --- src/metrics/unnamed/dataset_cleaning.py | 2 +- .../localization/mislabeling_detection.py | 2 +- .../test_mislabeling_detection.py | 76 +++++++++++++++++++ .../localization/test_subclass_detection.py | 2 +- .../unnamed/test_dataset_cleaning.py | 75 ++++++++++++++++++ 5 files changed, 154 insertions(+), 3 deletions(-) diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index defc2cd4..63d3747f 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -134,7 +134,7 @@ def compute(self, *args, **kwargs): self.trainer.fit( model=self.clean_model, train_dataloaders=clean_dl, - trainer_fit_kwargs=self.trainer_fit_kwargs, + **self.trainer_fit_kwargs, ) self.clean_accuracy = class_accuracy(self.model, clean_dl, self.device) diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/src/toy_benchmarks/localization/mislabeling_detection.py index 432a7acc..c1cfbf83 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/src/toy_benchmarks/localization/mislabeling_detection.py @@ -127,7 +127,7 @@ def _generate( model=self.model, train_dataloaders=self.poisoned_train_dl, val_dataloaders=self.poisoned_val_dl, - trainer_fit_kwargs=trainer_fit_kwargs, + **trainer_fit_kwargs, ) @property diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index 864a9a54..6b09ae44 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -1,5 +1,7 @@ import pytest +import lightning as L + from src.explainers.aggregators import SumAggregator from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.toy_benchmarks.localization.mislabeling_detection import ( @@ -169,3 +171,77 @@ def test_mislabeling_detection( )["score"] assert score == expected_score + + +@pytest.mark.toy_benchmarks +@pytest.mark.parametrize( + "test_id, pl_module, max_epochs, dataset, n_classes, p, seed, " + "global_method, batch_size, explainer_cls, expl_kwargs, use_pred, load_path, expected_score", + [ + ( + "mnist", + "load_mnist_pl_module", + 3, + "load_mnist_dataset", + 10, + 1.0, + 27, + "self-influence", + 8, + CaptumSimilarity, + {"layers": "model.fc_2", "similarity_metric": cosine_similarity, "cache_dir": "cache", "model_id": "test"}, + False, + None, + 0.4921875, + ), + ], +) +def test_mislabeling_detection_generate_from_pl_module( + test_id, + pl_module, + max_epochs, + dataset, + n_classes, + p, + seed, + batch_size, + global_method, + explainer_cls, + expl_kwargs, + use_pred, + load_path, + expected_score, + tmp_path, + request, +): + pl_module = request.getfixturevalue(pl_module) + dataset = request.getfixturevalue(dataset) + + trainer = L.Trainer(max_epochs=max_epochs) + + dst_eval = MislabelingDetection.generate( + model=pl_module, + trainer=trainer, + train_dataset=dataset, + n_classes=n_classes, + p=p, + global_method=global_method, + class_to_group="random", + trainer_fit_kwargs={}, + seed=seed, + batch_size=batch_size, + device="cpu", + ) + + score = dst_eval.evaluate( + expl_dataset=dataset, + explainer_cls=explainer_cls, + expl_kwargs=expl_kwargs, + cache_dir=str(tmp_path), + model_id="default_model_id", + use_predictions=use_pred, + batch_size=batch_size, + device="cpu", + )["score"] + + assert score == expected_score diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index 365d65a3..a57483e4 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -232,4 +232,4 @@ def test_subclass_detection_generate_lightning_model( device="cpu", ) - assert score == expected_score \ No newline at end of file + assert score == expected_score diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index bf48516c..02ac8084 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -1,5 +1,7 @@ import pytest +import lightning as L + from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.toy_benchmarks.unnamed.dataset_cleaning import DatasetCleaning from src.utils.functions.similarities import cosine_similarity @@ -152,3 +154,76 @@ def test_dataset_cleaning( ) assert score == expected_score + + +@pytest.mark.toy_benchmarks +@pytest.mark.parametrize( + "test_id, pl_module, max_epochs, dataset, n_classes, n_groups, seed, " + "global_method, batch_size, explainer_cls, expl_kwargs, use_pred, load_path, expected_score", + [ + ( + "mnist1", + "load_mnist_pl_module", + 3, + "load_mnist_dataset", + 10, + 2, + 27, + "self-influence", + 8, + CaptumSimilarity, + { + "layers": "model.fc_2", + "similarity_metric": cosine_similarity, + }, + False, + None, + 0.0, + ), + ], +) +def test_dataset_cleaning_generate_from_pl_module( + test_id, + pl_module, + max_epochs, + dataset, + n_classes, + n_groups, + seed, + global_method, + batch_size, + explainer_cls, + expl_kwargs, + use_pred, + load_path, + expected_score, + tmp_path, + request, +): + pl_module = request.getfixturevalue(pl_module) + dataset = request.getfixturevalue(dataset) + + trainer = L.Trainer(max_epochs=max_epochs) + + dst_eval = DatasetCleaning.generate( + model=pl_module, + train_dataset=dataset, + device="cpu", + ) + dst_eval.save("tests/assets/mnist_dataset_cleaning_state_dict") + + score = dst_eval.evaluate( + expl_dataset=dataset, + explainer_cls=explainer_cls, + trainer=trainer, + expl_kwargs=expl_kwargs, + trainer_fit_kwargs={}, + cache_dir=str(tmp_path), + model_id="default_model_id", + use_predictions=use_pred, + global_method=global_method, + batch_size=batch_size, + device="cpu", + ) + + assert score == expected_score From 106bd9de90207ddacda84245a63a25c7b312118f Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 13:12:43 +0200 Subject: [PATCH 04/11] simplify trainer - no more pl initialization --- src/metrics/unnamed/dataset_cleaning.py | 1 - src/utils/training/trainer.py | 77 ++++++------------- tests/metrics/test_unnamed_metrics.py | 33 ++++---- .../test_mislabeling_detection.py | 6 +- .../localization/test_subclass_detection.py | 6 +- .../unnamed/test_dataset_cleaning.py | 18 ++--- tests/utils/test_training.py | 48 +++--------- 7 files changed, 65 insertions(+), 124 deletions(-) diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index 63d3747f..8c4c3948 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -134,7 +134,6 @@ def compute(self, *args, **kwargs): self.trainer.fit( model=self.clean_model, train_dataloaders=clean_dl, - **self.trainer_fit_kwargs, ) self.clean_accuracy = class_accuracy(self.model, clean_dl, self.device) diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index 6fdebf95..91b856af 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -26,80 +26,49 @@ def get_model(self) -> torch.nn.Module: class Trainer(BaseTrainer): - def __init__(self): - self.module: Optional[L.LightningModule] = None - self.optimizer: Optional[Callable] - self.lr: Optional[float] - self.criterion: Optional[torch.nn.modules.loss._Loss] - self.scheduler: Optional[Callable] - self.optimizer_kwargs: Optional[dict] - self.scheduler_kwargs: Optional[dict] - - @classmethod - def from_arguments( - cls, + def __init__( + self, optimizer: Callable, lr: float, + max_epochs: int, criterion: torch.nn.modules.loss._Loss, scheduler: Optional[Callable] = None, optimizer_kwargs: Optional[dict] = None, scheduler_kwargs: Optional[dict] = None, ): - cls.optimizer = optimizer - cls.lr = lr - cls.criterion = criterion - cls.scheduler = scheduler - cls.optimizer_kwargs = optimizer_kwargs or {} - cls.scheduler_kwargs = scheduler_kwargs or {} - cls.module = None - obj = cls.__new__(cls) - super(Trainer, obj).__init__() + self.optimizer = optimizer + self.lr = lr + self.max_epochs = max_epochs + self.criterion = criterion + self.scheduler = scheduler + self.optimizer_kwargs = optimizer_kwargs or {} + self.scheduler_kwargs = scheduler_kwargs or {} - return obj - - @classmethod - def from_lightning_module( - cls, - pl_module: L.LightningModule, - ): - obj = cls.__new__(cls) - super(Trainer, obj).__init__() - obj.module = pl_module - return obj + super(Trainer, self).__init__() def fit( self, model: torch.nn.Module, train_dataloaders: torch.utils.data.dataloader.DataLoader, val_dataloaders: Optional[torch.utils.data.dataloader.DataLoader] = None, - trainer_fit_kwargs: Optional[dict] = None, *args, **kwargs, ): - if self.module is None: - if self.optimizer is None: - raise ValueError("Optimizer not initialized. Please initialize optimizer using from_arguments") - if self.lr is None: - raise ValueError("Learning rate not initialized. Please initialize lr using from_arguments") - if self.criterion is None: - raise ValueError("Criterion not initialized. Please initialize criterion using from_arguments") - self.module = BasicLightningModule( - model=model, - optimizer=self.optimizer, - lr=self.lr, - criterion=self.criterion, - optimizer_kwargs=self.optimizer_kwargs, - scheduler=self.scheduler, - scheduler_kwargs=self.scheduler_kwargs, - ) + module = BasicLightningModule( + model=model, + optimizer=self.optimizer, + lr=self.lr, + criterion=self.criterion, + optimizer_kwargs=self.optimizer_kwargs, + scheduler=self.scheduler, + scheduler_kwargs=self.scheduler_kwargs, + ) - if trainer_fit_kwargs is None: - trainer_fit_kwargs = {} - trainer = L.Trainer(**trainer_fit_kwargs) - trainer.fit(self.module, train_dataloaders, val_dataloaders) + trainer = L.Trainer(max_epochs=self.max_epochs) + trainer.fit(module, train_dataloaders, val_dataloaders) - model.load_state_dict(self.module.model.state_dict()) + model.load_state_dict(module.model.state_dict()) return model diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index 625d489e..cb0627bb 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -101,13 +101,12 @@ def test_dataset_cleaning( optimizer = request.getfixturevalue(optimizer) criterion = request.getfixturevalue(criterion) - pl_module = BasicLightningModule( - model=model, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) - trainer = Trainer.from_lightning_module(pl_module) + trainer = Trainer( + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) if global_method != "self-influence": metric = DatasetCleaningMetric( @@ -189,7 +188,12 @@ def test_dataset_cleaning_self_influence_based( lr=lr, criterion=criterion, ) - trainer = Trainer.from_lightning_module(pl_module) + trainer = Trainer( + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) expl_kwargs = expl_kwargs or {} @@ -247,13 +251,12 @@ def test_dataset_cleaning_aggr_based( optimizer = request.getfixturevalue(optimizer) criterion = request.getfixturevalue(criterion) - pl_module = BasicLightningModule( - model=model, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) - trainer = Trainer.from_lightning_module(pl_module) + trainer = Trainer( + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) metric = DatasetCleaningMetric.aggr_based( model=model, diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index 6b09ae44..1068b853 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -127,15 +127,13 @@ def test_mislabeling_detection( dataset = request.getfixturevalue(dataset) if init_method == "generate": - pl_module = BasicLightningModule( - model=model, + trainer = Trainer( + max_epochs=max_epochs, optimizer=optimizer, lr=lr, criterion=criterion, ) - trainer = Trainer.from_lightning_module(pl_module) - dst_eval = MislabelingDetection.generate( model=model, trainer=trainer, diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index a57483e4..8ceb4aec 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -113,15 +113,13 @@ def test_subclass_detection( dataset = request.getfixturevalue(dataset) if init_method == "generate": - pl_module = BasicLightningModule( - model=model, + trainer = Trainer( + max_epochs=max_epochs, optimizer=optimizer, lr=lr, criterion=criterion, ) - trainer = Trainer.from_lightning_module(pl_module) - dst_eval = SubclassDetection.generate( model=model, trainer=trainer, diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index 02ac8084..ba822651 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -40,7 +40,7 @@ ( "mnist2", "assemble", - "load_mnist_grouped_model", + "load_mnist_model", "torch_sgd_optimizer", 0.01, "torch_cross_entropy_loss_object", @@ -58,7 +58,7 @@ }, False, None, - -0.875, + 0.0, ), ( "mnist3", @@ -131,14 +131,12 @@ def test_dataset_cleaning( else: raise ValueError(f"Invalid init_method: {init_method}") - pl_module = BasicLightningModule( - model=model, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) - - trainer = Trainer.from_lightning_module(pl_module) + trainer = Trainer( + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) score = dst_eval.evaluate( expl_dataset=dataset, explainer_cls=explainer_cls, diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index 344a6fc8..ed529dc3 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -10,7 +10,7 @@ @pytest.mark.utils @pytest.mark.parametrize( "test_id, init_model, dataloader, optimizer, lr, criterion, scheduler, scheduler_kwargs, \ - max_epochs, val_dataloaders, early_stopping, early_stopping_kwargs, mode", + max_epochs, val_dataloaders, early_stopping, early_stopping_kwargs", [ ( "mnist", @@ -25,22 +25,6 @@ "load_mnist_dataloader", False, {}, - "from_arguments", - ), - ( - "mnist", - "load_init_mnist_model", - "load_mnist_dataloader", - "torch_sgd_optimizer", - 0.01, - "torch_cross_entropy_loss_object", - "torch_constant_lr_scheduler_type", - {"last_epoch": -1}, - 3, - "load_mnist_dataloader", - False, - {}, - "from_pl_module", ), ], ) @@ -57,7 +41,6 @@ def test_trainer( val_dataloaders, early_stopping, early_stopping_kwargs, - mode, request, ): model = request.getfixturevalue(init_model) @@ -66,28 +49,21 @@ def test_trainer( scheduler = request.getfixturevalue(scheduler) criterion = request.getfixturevalue(criterion) old_model = copy.deepcopy(model) - trainer = Trainer() - if mode == "from_arguments": - trainer = trainer.from_arguments( - optimizer=optimizer, - lr=lr, - criterion=criterion, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, - ) - else: - pl_module = BasicLightningModule( - model=model, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) - trainer = trainer.from_lightning_module(pl_module=pl_module) + + trainer = Trainer( + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + ) + model = trainer.fit( model=model, train_dataloaders=dataloader, val_dataloaders=dataloader, - trainer_fit_kwargs={"max_epochs": max_epochs}, + max_epochs=max_epochs, ) for param1, param2 in zip(old_model.parameters(), model.parameters()): From 50bd0c0969b05abf0bce49bf53095583bd90b911 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 13:19:55 +0200 Subject: [PATCH 05/11] add an option to load freshly initialized model for dataset cleaning --- src/metrics/unnamed/dataset_cleaning.py | 20 ++++++++++--------- .../unnamed/dataset_cleaning.py | 7 +++++++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index 8c4c3948..f3dde619 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -25,6 +25,7 @@ def __init__( model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], + init_model: Optional[torch.nn.Module] = None, trainer_fit_kwargs: Optional[dict] = None, global_method: Union[str, type] = "self-influence", top_k: int = 50, @@ -50,9 +51,7 @@ def __init__( self.trainer = trainer self.trainer_fit_kwargs = trainer_fit_kwargs - self.clean_model: torch.nn.Module - self.clean_accuracy: int - self.original_accuracy: int + self.init_model = init_model or copy.deepcopy(model) @classmethod def self_influence_based( @@ -61,6 +60,7 @@ def self_influence_based( train_dataset: torch.utils.data.Dataset, explainer_cls: type, trainer: Union[L.Trainer, BaseTrainer], + init_model: Optional[torch.nn.Module] = None, expl_kwargs: Optional[dict] = None, top_k: int = 50, trainer_fit_kwargs: Optional[dict] = None, @@ -72,6 +72,7 @@ def self_influence_based( model=model, train_dataset=train_dataset, trainer=trainer, + init_model=init_model, trainer_fit_kwargs=trainer_fit_kwargs, global_method="self-influence", top_k=top_k, @@ -87,6 +88,7 @@ def aggr_based( train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], aggregator_cls: Union[str, type], + init_model: Optional[torch.nn.Module] = None, top_k: int = 50, trainer_fit_kwargs: Optional[dict] = None, device: str = "cpu", @@ -97,6 +99,7 @@ def aggr_based( model=model, train_dataset=train_dataset, trainer=trainer, + init_model=init_model, trainer_fit_kwargs=trainer_fit_kwargs, global_method=aggregator_cls, top_k=top_k, @@ -125,17 +128,16 @@ def compute(self, *args, **kwargs): clean_subset = torch.utils.data.Subset(self.train_dataset, clean_indices) train_dl = torch.utils.data.DataLoader(self.train_dataset, batch_size=32, shuffle=True) - self.original_accuracy = class_accuracy(self.model, train_dl, self.device) + original_accuracy = class_accuracy(self.model, train_dl, self.device) clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) - self.clean_model = copy.deepcopy(self.model) - self.trainer.fit( - model=self.clean_model, + model=self.init_model, train_dataloaders=clean_dl, + **self.trainer_fit_kwargs, ) - self.clean_accuracy = class_accuracy(self.model, clean_dl, self.device) + clean_accuracy = class_accuracy(self.model, clean_dl, self.device) - return self.original_accuracy - self.clean_accuracy + return original_accuracy - clean_accuracy diff --git a/src/toy_benchmarks/unnamed/dataset_cleaning.py b/src/toy_benchmarks/unnamed/dataset_cleaning.py index 67b7d72e..cf36bec8 100644 --- a/src/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/src/toy_benchmarks/unnamed/dataset_cleaning.py @@ -1,3 +1,4 @@ +import copy from typing import Optional, Union import torch @@ -87,6 +88,7 @@ def evaluate( expl_dataset: torch.utils.data.Dataset, explainer_cls: type, trainer: Union[L.Trainer, BaseTrainer], + init_model: Optional[torch.nn.Module] = None, use_predictions: bool = False, expl_kwargs: Optional[dict] = None, trainer_fit_kwargs: Optional[dict] = None, @@ -99,6 +101,9 @@ def evaluate( *args, **kwargs, ): + + init_model = init_model or copy.deepcopy(self.model) + expl_kwargs = expl_kwargs or {} explainer = explainer_cls( model=self.model, train_dataset=self.train_dataset, model_id=model_id, cache_dir=cache_dir, **expl_kwargs @@ -108,6 +113,7 @@ def evaluate( if global_method != "self-influence": metric = DatasetCleaningMetric.aggr_based( model=self.model, + init_model=init_model, train_dataset=self.train_dataset, aggregator_cls=global_method, trainer=trainer, @@ -139,6 +145,7 @@ def evaluate( else: metric = DatasetCleaningMetric.self_influence_based( model=self.model, + init_model=init_model, train_dataset=self.train_dataset, trainer=trainer, trainer_fit_kwargs=trainer_fit_kwargs, From c981dcbd96fe4aa94fa316b33641076a735eb301 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 13:23:13 +0200 Subject: [PATCH 06/11] run black --- .../unnamed/dataset_cleaning.py | 1 - src/utils/training/trainer.py | 2 -- tests/metrics/test_unnamed_metrics.py | 28 ++++++++----------- .../test_mislabeling_detection.py | 1 - .../localization/test_subclass_detection.py | 1 - .../unnamed/test_dataset_cleaning.py | 11 ++++---- tests/utils/test_training.py | 1 - 7 files changed, 16 insertions(+), 29 deletions(-) diff --git a/src/toy_benchmarks/unnamed/dataset_cleaning.py b/src/toy_benchmarks/unnamed/dataset_cleaning.py index cf36bec8..1a85a5a6 100644 --- a/src/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/src/toy_benchmarks/unnamed/dataset_cleaning.py @@ -101,7 +101,6 @@ def evaluate( *args, **kwargs, ): - init_model = init_model or copy.deepcopy(self.model) expl_kwargs = expl_kwargs or {} diff --git a/src/utils/training/trainer.py b/src/utils/training/trainer.py index 91b856af..a86a43dc 100644 --- a/src/utils/training/trainer.py +++ b/src/utils/training/trainer.py @@ -36,7 +36,6 @@ def __init__( optimizer_kwargs: Optional[dict] = None, scheduler_kwargs: Optional[dict] = None, ): - self.optimizer = optimizer self.lr = lr self.max_epochs = max_epochs @@ -55,7 +54,6 @@ def fit( *args, **kwargs, ): - module = BasicLightningModule( model=model, optimizer=self.optimizer, diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index cb0627bb..44cb15b3 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -102,11 +102,11 @@ def test_dataset_cleaning( criterion = request.getfixturevalue(criterion) trainer = Trainer( - max_epochs=max_epochs, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) if global_method != "self-influence": metric = DatasetCleaningMetric( @@ -182,18 +182,12 @@ def test_dataset_cleaning_self_influence_based( optimizer = request.getfixturevalue(optimizer) criterion = request.getfixturevalue(criterion) - pl_module = BasicLightningModule( - model=model, + trainer = Trainer( + max_epochs=max_epochs, optimizer=optimizer, lr=lr, criterion=criterion, ) - trainer = Trainer( - max_epochs=max_epochs, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) expl_kwargs = expl_kwargs or {} @@ -253,10 +247,10 @@ def test_dataset_cleaning_aggr_based( trainer = Trainer( max_epochs=max_epochs, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) metric = DatasetCleaningMetric.aggr_based( model=model, diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index 1068b853..00a445d5 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -8,7 +8,6 @@ MislabelingDetection, ) from src.utils.functions.similarities import cosine_similarity -from src.utils.training.base_pl_module import BasicLightningModule from src.utils.training.trainer import Trainer diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index 8ceb4aec..3a0ea49f 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -5,7 +5,6 @@ SubclassDetection, ) from src.utils.functions.similarities import cosine_similarity -from src.utils.training.base_pl_module import BasicLightningModule from src.utils.training.trainer import Trainer diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index ba822651..41c8f339 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -5,7 +5,6 @@ from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.toy_benchmarks.unnamed.dataset_cleaning import DatasetCleaning from src.utils.functions.similarities import cosine_similarity -from src.utils.training.base_pl_module import BasicLightningModule from src.utils.training.trainer import Trainer @@ -132,11 +131,11 @@ def test_dataset_cleaning( raise ValueError(f"Invalid init_method: {init_method}") trainer = Trainer( - max_epochs=max_epochs, - optimizer=optimizer, - lr=lr, - criterion=criterion, - ) + max_epochs=max_epochs, + optimizer=optimizer, + lr=lr, + criterion=criterion, + ) score = dst_eval.evaluate( expl_dataset=dataset, explainer_cls=explainer_cls, diff --git a/tests/utils/test_training.py b/tests/utils/test_training.py index ed529dc3..eb72b8f4 100644 --- a/tests/utils/test_training.py +++ b/tests/utils/test_training.py @@ -3,7 +3,6 @@ import pytest import torch -from src.utils.training.base_pl_module import BasicLightningModule from src.utils.training.trainer import Trainer From e80a9ac1ec2da39517498dca51912a75447f5ad2 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 13:23:32 +0200 Subject: [PATCH 07/11] mypy error fix --- tests/metrics/test_unnamed_metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/metrics/test_unnamed_metrics.py b/tests/metrics/test_unnamed_metrics.py index 44cb15b3..c7146514 100644 --- a/tests/metrics/test_unnamed_metrics.py +++ b/tests/metrics/test_unnamed_metrics.py @@ -4,7 +4,6 @@ from src.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric from src.metrics.unnamed.top_k_overlap import TopKOverlapMetric from src.utils.functions.similarities import cosine_similarity -from src.utils.training.base_pl_module import BasicLightningModule from src.utils.training.trainer import Trainer From 67da39e2bd7f385adaa77deeb332ef9dd076a743 Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 14:19:57 +0200 Subject: [PATCH 08/11] mypy error fix v2 --- src/metrics/unnamed/dataset_cleaning.py | 6 +++--- src/toy_benchmarks/localization/mislabeling_detection.py | 8 +++++--- src/toy_benchmarks/localization/subclass_detection.py | 7 ++++--- src/toy_benchmarks/unnamed/dataset_cleaning.py | 6 +++--- .../localization/test_mislabeling_detection.py | 3 +-- .../localization/test_subclass_detection.py | 3 ++- tests/toy_benchmarks/unnamed/test_dataset_cleaning.py | 3 +-- 7 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index f3dde619..8b548250 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -1,8 +1,8 @@ import copy from typing import Optional, Union -import torch import lightning as L +import torch from src.metrics.base import GlobalMetric from src.utils.common import class_accuracy @@ -49,7 +49,7 @@ def __init__( ) self.top_k = min(top_k, self.dataset_length - 1) self.trainer = trainer - self.trainer_fit_kwargs = trainer_fit_kwargs + self.trainer_fit_kwargs = trainer_fit_kwargs or {} self.init_model = init_model or copy.deepcopy(model) @@ -133,7 +133,7 @@ def compute(self, *args, **kwargs): clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) self.trainer.fit( - model=self.init_model, + model=self.init_model, # type: ignore train_dataloaders=clean_dl, **self.trainer_fit_kwargs, ) diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/src/toy_benchmarks/localization/mislabeling_detection.py index c1cfbf83..4be135e7 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/src/toy_benchmarks/localization/mislabeling_detection.py @@ -1,9 +1,9 @@ import copy from typing import Callable, Dict, List, Optional, Union +import lightning as L import torch from tqdm import tqdm -import lightning as L from src.metrics.localization.mislabeling_detection import ( MislabelingDetectionMetric, @@ -22,7 +22,7 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[L.Trainer, BaseTrainer] = None + self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None self.model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset self.poisoned_dataset: LabelFlippingDataset @@ -123,8 +123,10 @@ def _generate( self.poisoned_val_dl = None self.model = copy.deepcopy(model) + + trainer_fit_kwargs = trainer_fit_kwargs or {} self.trainer.fit( - model=self.model, + model=self.model, # type: ignore train_dataloaders=self.poisoned_train_dl, val_dataloaders=self.poisoned_val_dl, **trainer_fit_kwargs, diff --git a/src/toy_benchmarks/localization/subclass_detection.py b/src/toy_benchmarks/localization/subclass_detection.py index badd7dc6..edac9596 100644 --- a/src/toy_benchmarks/localization/subclass_detection.py +++ b/src/toy_benchmarks/localization/subclass_detection.py @@ -1,9 +1,9 @@ import copy from typing import Callable, Dict, Optional, Union +import lightning as L import torch from tqdm import tqdm -import lightning as L from src.metrics.localization.class_detection import ClassDetectionMetric from src.toy_benchmarks.base import ToyBenchmark @@ -23,7 +23,7 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[L.Trainer, BaseTrainer] = None + self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None self.model: torch.nn.Module self.group_model: torch.nn.Module self.train_dataset: torch.utils.data.Dataset @@ -126,8 +126,9 @@ def _generate( self.group_model = copy.deepcopy(self.model) + trainer_fit_kwargs = trainer_fit_kwargs or {} self.trainer.fit( - model=self.group_model, + model=self.group_model, # type: ignore train_dataloaders=self.grouped_train_dl, val_dataloaders=self.grouped_val_dl, **trainer_fit_kwargs, diff --git a/src/toy_benchmarks/unnamed/dataset_cleaning.py b/src/toy_benchmarks/unnamed/dataset_cleaning.py index 1a85a5a6..c89d9484 100644 --- a/src/toy_benchmarks/unnamed/dataset_cleaning.py +++ b/src/toy_benchmarks/unnamed/dataset_cleaning.py @@ -1,9 +1,9 @@ import copy from typing import Optional, Union +import lightning as L import torch from tqdm import tqdm -import lightning as L from src.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric from src.toy_benchmarks.base import ToyBenchmark @@ -19,8 +19,8 @@ def __init__( ): super().__init__(device=device) - self.model: Optional[torch.nn.Module] = None - self.train_dataset: Optional[torch.utils.data.Dataset] = None + self.model: torch.nn.Module + self.train_dataset: torch.utils.data.Dataset @classmethod def generate( diff --git a/tests/toy_benchmarks/localization/test_mislabeling_detection.py b/tests/toy_benchmarks/localization/test_mislabeling_detection.py index 00a445d5..de9719be 100644 --- a/tests/toy_benchmarks/localization/test_mislabeling_detection.py +++ b/tests/toy_benchmarks/localization/test_mislabeling_detection.py @@ -1,6 +1,5 @@ -import pytest - import lightning as L +import pytest from src.explainers.aggregators import SumAggregator from src.explainers.wrappers.captum_influence import CaptumSimilarity diff --git a/tests/toy_benchmarks/localization/test_subclass_detection.py b/tests/toy_benchmarks/localization/test_subclass_detection.py index 3a0ea49f..63e4f1fd 100644 --- a/tests/toy_benchmarks/localization/test_subclass_detection.py +++ b/tests/toy_benchmarks/localization/test_subclass_detection.py @@ -1,5 +1,6 @@ -import pytest import lightning as L +import pytest + from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.toy_benchmarks.localization.subclass_detection import ( SubclassDetection, diff --git a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py index 41c8f339..f3c8d5a1 100644 --- a/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py +++ b/tests/toy_benchmarks/unnamed/test_dataset_cleaning.py @@ -1,6 +1,5 @@ -import pytest - import lightning as L +import pytest from src.explainers.wrappers.captum_influence import CaptumSimilarity from src.toy_benchmarks.unnamed.dataset_cleaning import DatasetCleaning From baa05b6b27f9844beb36ef62074c6256d64d860c Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 16:05:46 +0200 Subject: [PATCH 09/11] unnecessarily duplicating code to make mypy happy --- Makefile | 2 +- src/metrics/unnamed/dataset_cleaning.py | 33 +++++++++--- .../localization/mislabeling_detection.py | 49 +++++++++++------- .../localization/subclass_detection.py | 50 ++++++++++++------- 4 files changed, 88 insertions(+), 46 deletions(-) diff --git a/Makefile b/Makefile index 3ae6aa8d..92349780 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ style: find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf find . | grep -E ".pytest_cache" | xargs rm -rf find . | grep -E ".mypy_cache" | xargs rm -rf - find . | grep -E ".checkpoints" | xargs rm -rf + find . | grep -E "./checkpoints" | xargs rm -rf find . | grep -E "*eff-info" | xargs rm -rf find . | grep -E ".build" | xargs rm -rf find . | grep -E ".htmlcov" | xargs rm -rf diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index 8b548250..d6d34e6b 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -22,7 +22,7 @@ class DatasetCleaningMetric(GlobalMetric): def __init__( self, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], init_model: Optional[torch.nn.Module] = None, @@ -56,7 +56,7 @@ def __init__( @classmethod def self_influence_based( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, explainer_cls: type, trainer: Union[L.Trainer, BaseTrainer], @@ -84,7 +84,7 @@ def self_influence_based( @classmethod def aggr_based( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], aggregator_cls: Union[str, type], @@ -132,11 +132,28 @@ def compute(self, *args, **kwargs): clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) - self.trainer.fit( - model=self.init_model, # type: ignore - train_dataloaders=clean_dl, - **self.trainer_fit_kwargs, - ) + if isinstance(self.trainer, L.Trainer): + if not isinstance(self.init_model, L.LightningModule): + raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer") + + self.trainer.fit( + model=self.init_model, + train_dataloaders=clean_dl, + **self.trainer_fit_kwargs, + ) + + elif isinstance(self.trainer, BaseTrainer): + if not isinstance(self.init_model, torch.nn.Module): + raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer") + + self.trainer.fit( + model=self.init_model, + train_dataloaders=clean_dl, + **self.trainer_fit_kwargs, + ) + + else: + raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") clean_accuracy = class_accuracy(self.model, clean_dl, self.device) diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/src/toy_benchmarks/localization/mislabeling_detection.py index 4be135e7..d2a16127 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/src/toy_benchmarks/localization/mislabeling_detection.py @@ -22,8 +22,7 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None - self.model: torch.nn.Module + self.model: Union[torch.nn.Module, L.LightningModule] self.train_dataset: torch.utils.data.Dataset self.poisoned_dataset: LabelFlippingDataset self.dataset_transform: Optional[Callable] @@ -39,7 +38,7 @@ def __init__( @classmethod def generate( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, trainer: Union[L.Trainer, BaseTrainer], @@ -60,7 +59,6 @@ def generate( obj = cls(device=device) - obj.trainer = trainer obj._generate( model=model.to(device), train_dataset=train_dataset, @@ -69,6 +67,7 @@ def generate( global_method=global_method, dataset_transform=dataset_transform, n_classes=n_classes, + trainer=trainer, trainer_fit_kwargs=trainer_fit_kwargs, seed=seed, batch_size=batch_size, @@ -77,9 +76,10 @@ def generate( def _generate( self, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, + trainer: Union[L.Trainer, BaseTrainer], dataset_transform: Optional[Callable], poisoned_indices: Optional[List[int]] = None, poisoned_labels: Optional[Dict[int, int]] = None, @@ -90,12 +90,6 @@ def _generate( seed: int = 27, batch_size: int = 8, ): - if self.trainer is None: - raise ValueError( - "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " - "init_trainer_from_train_arguments" - ) - self.train_dataset = train_dataset self.p = p self.global_method = global_method @@ -125,12 +119,31 @@ def _generate( self.model = copy.deepcopy(model) trainer_fit_kwargs = trainer_fit_kwargs or {} - self.trainer.fit( - model=self.model, # type: ignore - train_dataloaders=self.poisoned_train_dl, - val_dataloaders=self.poisoned_val_dl, - **trainer_fit_kwargs, - ) + + if isinstance(trainer, L.Trainer): + if not isinstance(self.model, L.LightningModule): + raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer") + + trainer.fit( + model=self.model, + train_dataloaders=self.poisoned_train_dl, + val_dataloaders=self.poisoned_val_dl, + **trainer_fit_kwargs, + ) + + elif isinstance(trainer, BaseTrainer): + if not isinstance(self.model, torch.nn.Module): + raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer") + + trainer.fit( + model=self.model, + train_dataloaders=self.poisoned_train_dl, + val_dataloaders=self.poisoned_val_dl, + **trainer_fit_kwargs, + ) + + else: + raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") @property def bench_state(self): @@ -168,7 +181,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar @classmethod def assemble( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, poisoned_indices: Optional[List[int]] = None, diff --git a/src/toy_benchmarks/localization/subclass_detection.py b/src/toy_benchmarks/localization/subclass_detection.py index edac9596..ffbe88f4 100644 --- a/src/toy_benchmarks/localization/subclass_detection.py +++ b/src/toy_benchmarks/localization/subclass_detection.py @@ -23,9 +23,8 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None - self.model: torch.nn.Module - self.group_model: torch.nn.Module + self.model: Union[torch.nn.Module, L.LightningModule] + self.group_model: Union[torch.nn.Module, L.LightningModule] self.train_dataset: torch.utils.data.Dataset self.dataset_transform: Optional[Callable] self.grouped_train_dl: torch.utils.data.DataLoader @@ -38,7 +37,7 @@ def __init__( @classmethod def generate( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], val_dataset: Optional[torch.utils.data.Dataset] = None, @@ -58,11 +57,10 @@ def generate( """ obj = cls(device=device) - trainer_fit_kwargs = trainer_fit_kwargs or {} obj.model = model - obj.trainer = trainer obj._generate( + trainer=trainer, train_dataset=train_dataset, dataset_transform=dataset_transform, val_dataset=val_dataset, @@ -77,6 +75,7 @@ def generate( def _generate( self, + trainer: Union[L.Trainer, BaseTrainer], train_dataset: torch.utils.data.Dataset, val_dataset: Optional[torch.utils.data.Dataset] = None, dataset_transform: Optional[Callable] = None, @@ -89,12 +88,6 @@ def _generate( *args, **kwargs, ): - if self.trainer is None: - raise ValueError( - "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " - "init_trainer_from_train_arguments" - ) - self.train_dataset = train_dataset self.grouped_dataset = LabelGroupingDataset( dataset=train_dataset, @@ -127,12 +120,31 @@ def _generate( self.group_model = copy.deepcopy(self.model) trainer_fit_kwargs = trainer_fit_kwargs or {} - self.trainer.fit( - model=self.group_model, # type: ignore - train_dataloaders=self.grouped_train_dl, - val_dataloaders=self.grouped_val_dl, - **trainer_fit_kwargs, - ) + + if isinstance(trainer, L.Trainer): + if not isinstance(self.group_model, L.LightningModule): + raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer") + + trainer.fit( + model=self.group_model, + train_dataloaders=self.grouped_train_dl, + val_dataloaders=self.grouped_val_dl, + **trainer_fit_kwargs, + ) + + elif isinstance(trainer, BaseTrainer): + if not isinstance(self.group_model, torch.nn.Module): + raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer") + + trainer.fit( + model=self.group_model, + train_dataloaders=self.grouped_train_dl, + val_dataloaders=self.grouped_val_dl, + **trainer_fit_kwargs, + ) + + else: + raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") @classmethod def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwargs): @@ -155,7 +167,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar @classmethod def assemble( cls, - group_model: torch.nn.Module, + group_model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, n_groups: int, From 7a4995d9dc6dcd778e05bb6bf6391632e746675c Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 16:26:44 +0200 Subject: [PATCH 10/11] break up actions jobs --- .github/workflows/{type-lint.yml => lint.yml} | 3 --- .github/workflows/mypy.yml | 20 +++++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) rename .github/workflows/{type-lint.yml => lint.yml} (88%) create mode 100644 .github/workflows/mypy.yml diff --git a/.github/workflows/type-lint.yml b/.github/workflows/lint.yml similarity index 88% rename from .github/workflows/type-lint.yml rename to .github/workflows/lint.yml index 52160d48..3691cdd4 100644 --- a/.github/workflows/type-lint.yml +++ b/.github/workflows/lint.yml @@ -18,6 +18,3 @@ jobs: - name: Run flake8 run: tox run -e lint - - - name: Run mypy - run: tox run -e type diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 00000000..32a88567 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,20 @@ +# .github/workflows/type-lint.yml +name: Type-lint +on: push +jobs: + type-lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Setup python 3.11 + uses: actions/setup-python@v4 + with: + cache: 'pip' + python-version: "3.11" + + - name: Install tox-gh + run: pip install tox-gh + + - name: Run mypy + run: tox run -e type From 657a0a1f9268df2266242d87dc8fc0b82d3fd0ca Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 16:31:28 +0200 Subject: [PATCH 11/11] rename actions jobs --- .github/workflows/lint.yml | 2 +- .github/workflows/mypy.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3691cdd4..b8453b73 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,5 +1,5 @@ # .github/workflows/type-lint.yml -name: Type-lint +name: lint on: push jobs: type-lint: diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 32a88567..784ae833 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -1,5 +1,5 @@ # .github/workflows/type-lint.yml -name: Type-lint +name: mypy on: push jobs: type-lint: