From baa05b6b27f9844beb36ef62074c6256d64d860c Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Fri, 2 Aug 2024 16:05:46 +0200 Subject: [PATCH] unnecessarily duplicating code to make mypy happy --- Makefile | 2 +- src/metrics/unnamed/dataset_cleaning.py | 33 +++++++++--- .../localization/mislabeling_detection.py | 49 +++++++++++------- .../localization/subclass_detection.py | 50 ++++++++++++------- 4 files changed, 88 insertions(+), 46 deletions(-) diff --git a/Makefile b/Makefile index 3ae6aa8d..92349780 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ style: find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf find . | grep -E ".pytest_cache" | xargs rm -rf find . | grep -E ".mypy_cache" | xargs rm -rf - find . | grep -E ".checkpoints" | xargs rm -rf + find . | grep -E "./checkpoints" | xargs rm -rf find . | grep -E "*eff-info" | xargs rm -rf find . | grep -E ".build" | xargs rm -rf find . | grep -E ".htmlcov" | xargs rm -rf diff --git a/src/metrics/unnamed/dataset_cleaning.py b/src/metrics/unnamed/dataset_cleaning.py index 8b548250..d6d34e6b 100644 --- a/src/metrics/unnamed/dataset_cleaning.py +++ b/src/metrics/unnamed/dataset_cleaning.py @@ -22,7 +22,7 @@ class DatasetCleaningMetric(GlobalMetric): def __init__( self, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], init_model: Optional[torch.nn.Module] = None, @@ -56,7 +56,7 @@ def __init__( @classmethod def self_influence_based( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, explainer_cls: type, trainer: Union[L.Trainer, BaseTrainer], @@ -84,7 +84,7 @@ def self_influence_based( @classmethod def aggr_based( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], aggregator_cls: Union[str, type], @@ -132,11 +132,28 @@ def compute(self, *args, **kwargs): clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True) - self.trainer.fit( - model=self.init_model, # type: ignore - train_dataloaders=clean_dl, - **self.trainer_fit_kwargs, - ) + if isinstance(self.trainer, L.Trainer): + if not isinstance(self.init_model, L.LightningModule): + raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer") + + self.trainer.fit( + model=self.init_model, + train_dataloaders=clean_dl, + **self.trainer_fit_kwargs, + ) + + elif isinstance(self.trainer, BaseTrainer): + if not isinstance(self.init_model, torch.nn.Module): + raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer") + + self.trainer.fit( + model=self.init_model, + train_dataloaders=clean_dl, + **self.trainer_fit_kwargs, + ) + + else: + raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") clean_accuracy = class_accuracy(self.model, clean_dl, self.device) diff --git a/src/toy_benchmarks/localization/mislabeling_detection.py b/src/toy_benchmarks/localization/mislabeling_detection.py index 4be135e7..d2a16127 100644 --- a/src/toy_benchmarks/localization/mislabeling_detection.py +++ b/src/toy_benchmarks/localization/mislabeling_detection.py @@ -22,8 +22,7 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None - self.model: torch.nn.Module + self.model: Union[torch.nn.Module, L.LightningModule] self.train_dataset: torch.utils.data.Dataset self.poisoned_dataset: LabelFlippingDataset self.dataset_transform: Optional[Callable] @@ -39,7 +38,7 @@ def __init__( @classmethod def generate( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, trainer: Union[L.Trainer, BaseTrainer], @@ -60,7 +59,6 @@ def generate( obj = cls(device=device) - obj.trainer = trainer obj._generate( model=model.to(device), train_dataset=train_dataset, @@ -69,6 +67,7 @@ def generate( global_method=global_method, dataset_transform=dataset_transform, n_classes=n_classes, + trainer=trainer, trainer_fit_kwargs=trainer_fit_kwargs, seed=seed, batch_size=batch_size, @@ -77,9 +76,10 @@ def generate( def _generate( self, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, + trainer: Union[L.Trainer, BaseTrainer], dataset_transform: Optional[Callable], poisoned_indices: Optional[List[int]] = None, poisoned_labels: Optional[Dict[int, int]] = None, @@ -90,12 +90,6 @@ def _generate( seed: int = 27, batch_size: int = 8, ): - if self.trainer is None: - raise ValueError( - "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " - "init_trainer_from_train_arguments" - ) - self.train_dataset = train_dataset self.p = p self.global_method = global_method @@ -125,12 +119,31 @@ def _generate( self.model = copy.deepcopy(model) trainer_fit_kwargs = trainer_fit_kwargs or {} - self.trainer.fit( - model=self.model, # type: ignore - train_dataloaders=self.poisoned_train_dl, - val_dataloaders=self.poisoned_val_dl, - **trainer_fit_kwargs, - ) + + if isinstance(trainer, L.Trainer): + if not isinstance(self.model, L.LightningModule): + raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer") + + trainer.fit( + model=self.model, + train_dataloaders=self.poisoned_train_dl, + val_dataloaders=self.poisoned_val_dl, + **trainer_fit_kwargs, + ) + + elif isinstance(trainer, BaseTrainer): + if not isinstance(self.model, torch.nn.Module): + raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer") + + trainer.fit( + model=self.model, + train_dataloaders=self.poisoned_train_dl, + val_dataloaders=self.poisoned_val_dl, + **trainer_fit_kwargs, + ) + + else: + raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") @property def bench_state(self): @@ -168,7 +181,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar @classmethod def assemble( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, poisoned_indices: Optional[List[int]] = None, diff --git a/src/toy_benchmarks/localization/subclass_detection.py b/src/toy_benchmarks/localization/subclass_detection.py index edac9596..ffbe88f4 100644 --- a/src/toy_benchmarks/localization/subclass_detection.py +++ b/src/toy_benchmarks/localization/subclass_detection.py @@ -23,9 +23,8 @@ def __init__( ): super().__init__(device=device) - self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None - self.model: torch.nn.Module - self.group_model: torch.nn.Module + self.model: Union[torch.nn.Module, L.LightningModule] + self.group_model: Union[torch.nn.Module, L.LightningModule] self.train_dataset: torch.utils.data.Dataset self.dataset_transform: Optional[Callable] self.grouped_train_dl: torch.utils.data.DataLoader @@ -38,7 +37,7 @@ def __init__( @classmethod def generate( cls, - model: torch.nn.Module, + model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, trainer: Union[L.Trainer, BaseTrainer], val_dataset: Optional[torch.utils.data.Dataset] = None, @@ -58,11 +57,10 @@ def generate( """ obj = cls(device=device) - trainer_fit_kwargs = trainer_fit_kwargs or {} obj.model = model - obj.trainer = trainer obj._generate( + trainer=trainer, train_dataset=train_dataset, dataset_transform=dataset_transform, val_dataset=val_dataset, @@ -77,6 +75,7 @@ def generate( def _generate( self, + trainer: Union[L.Trainer, BaseTrainer], train_dataset: torch.utils.data.Dataset, val_dataset: Optional[torch.utils.data.Dataset] = None, dataset_transform: Optional[Callable] = None, @@ -89,12 +88,6 @@ def _generate( *args, **kwargs, ): - if self.trainer is None: - raise ValueError( - "Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or " - "init_trainer_from_train_arguments" - ) - self.train_dataset = train_dataset self.grouped_dataset = LabelGroupingDataset( dataset=train_dataset, @@ -127,12 +120,31 @@ def _generate( self.group_model = copy.deepcopy(self.model) trainer_fit_kwargs = trainer_fit_kwargs or {} - self.trainer.fit( - model=self.group_model, # type: ignore - train_dataloaders=self.grouped_train_dl, - val_dataloaders=self.grouped_val_dl, - **trainer_fit_kwargs, - ) + + if isinstance(trainer, L.Trainer): + if not isinstance(self.group_model, L.LightningModule): + raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer") + + trainer.fit( + model=self.group_model, + train_dataloaders=self.grouped_train_dl, + val_dataloaders=self.grouped_val_dl, + **trainer_fit_kwargs, + ) + + elif isinstance(trainer, BaseTrainer): + if not isinstance(self.group_model, torch.nn.Module): + raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer") + + trainer.fit( + model=self.group_model, + train_dataloaders=self.grouped_train_dl, + val_dataloaders=self.grouped_val_dl, + **trainer_fit_kwargs, + ) + + else: + raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer") @classmethod def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwargs): @@ -155,7 +167,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar @classmethod def assemble( cls, - group_model: torch.nn.Module, + group_model: Union[torch.nn.Module, L.LightningModule], train_dataset: torch.utils.data.Dataset, n_classes: int, n_groups: int,