Skip to content

Commit

Permalink
unnecessarily duplicating code to make mypy happy
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 2, 2024
1 parent 67da39e commit baa05b6
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ style:
find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
find . | grep -E ".pytest_cache" | xargs rm -rf
find . | grep -E ".mypy_cache" | xargs rm -rf
find . | grep -E ".checkpoints" | xargs rm -rf
find . | grep -E "./checkpoints" | xargs rm -rf
find . | grep -E "*eff-info" | xargs rm -rf
find . | grep -E ".build" | xargs rm -rf
find . | grep -E ".htmlcov" | xargs rm -rf
Expand Down
33 changes: 25 additions & 8 deletions src/metrics/unnamed/dataset_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DatasetCleaningMetric(GlobalMetric):

def __init__(
self,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
trainer: Union[L.Trainer, BaseTrainer],
init_model: Optional[torch.nn.Module] = None,
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
@classmethod
def self_influence_based(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
explainer_cls: type,
trainer: Union[L.Trainer, BaseTrainer],
Expand Down Expand Up @@ -84,7 +84,7 @@ def self_influence_based(
@classmethod
def aggr_based(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
trainer: Union[L.Trainer, BaseTrainer],
aggregator_cls: Union[str, type],
Expand Down Expand Up @@ -132,11 +132,28 @@ def compute(self, *args, **kwargs):

clean_dl = torch.utils.data.DataLoader(clean_subset, batch_size=32, shuffle=True)

self.trainer.fit(
model=self.init_model, # type: ignore
train_dataloaders=clean_dl,
**self.trainer_fit_kwargs,
)
if isinstance(self.trainer, L.Trainer):
if not isinstance(self.init_model, L.LightningModule):
raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer")

self.trainer.fit(
model=self.init_model,
train_dataloaders=clean_dl,
**self.trainer_fit_kwargs,
)

elif isinstance(self.trainer, BaseTrainer):
if not isinstance(self.init_model, torch.nn.Module):
raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer")

self.trainer.fit(
model=self.init_model,
train_dataloaders=clean_dl,
**self.trainer_fit_kwargs,
)

else:
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

clean_accuracy = class_accuracy(self.model, clean_dl, self.device)

Expand Down
49 changes: 31 additions & 18 deletions src/toy_benchmarks/localization/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def __init__(
):
super().__init__(device=device)

self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None
self.model: torch.nn.Module
self.model: Union[torch.nn.Module, L.LightningModule]
self.train_dataset: torch.utils.data.Dataset
self.poisoned_dataset: LabelFlippingDataset
self.dataset_transform: Optional[Callable]
Expand All @@ -39,7 +38,7 @@ def __init__(
@classmethod
def generate(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
trainer: Union[L.Trainer, BaseTrainer],
Expand All @@ -60,7 +59,6 @@ def generate(

obj = cls(device=device)

obj.trainer = trainer
obj._generate(
model=model.to(device),
train_dataset=train_dataset,
Expand All @@ -69,6 +67,7 @@ def generate(
global_method=global_method,
dataset_transform=dataset_transform,
n_classes=n_classes,
trainer=trainer,
trainer_fit_kwargs=trainer_fit_kwargs,
seed=seed,
batch_size=batch_size,
Expand All @@ -77,9 +76,10 @@ def generate(

def _generate(
self,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
trainer: Union[L.Trainer, BaseTrainer],
dataset_transform: Optional[Callable],
poisoned_indices: Optional[List[int]] = None,
poisoned_labels: Optional[Dict[int, int]] = None,
Expand All @@ -90,12 +90,6 @@ def _generate(
seed: int = 27,
batch_size: int = 8,
):
if self.trainer is None:
raise ValueError(
"Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or "
"init_trainer_from_train_arguments"
)

self.train_dataset = train_dataset
self.p = p
self.global_method = global_method
Expand Down Expand Up @@ -125,12 +119,31 @@ def _generate(
self.model = copy.deepcopy(model)

trainer_fit_kwargs = trainer_fit_kwargs or {}
self.trainer.fit(
model=self.model, # type: ignore
train_dataloaders=self.poisoned_train_dl,
val_dataloaders=self.poisoned_val_dl,
**trainer_fit_kwargs,
)

if isinstance(trainer, L.Trainer):
if not isinstance(self.model, L.LightningModule):
raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer")

trainer.fit(
model=self.model,
train_dataloaders=self.poisoned_train_dl,
val_dataloaders=self.poisoned_val_dl,
**trainer_fit_kwargs,
)

elif isinstance(trainer, BaseTrainer):
if not isinstance(self.model, torch.nn.Module):
raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer")

trainer.fit(
model=self.model,
train_dataloaders=self.poisoned_train_dl,
val_dataloaders=self.poisoned_val_dl,
**trainer_fit_kwargs,
)

else:
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

@property
def bench_state(self):
Expand Down Expand Up @@ -168,7 +181,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar
@classmethod
def assemble(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
poisoned_indices: Optional[List[int]] = None,
Expand Down
50 changes: 31 additions & 19 deletions src/toy_benchmarks/localization/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ def __init__(
):
super().__init__(device=device)

self.trainer: Optional[Union[L.Trainer, BaseTrainer]] = None
self.model: torch.nn.Module
self.group_model: torch.nn.Module
self.model: Union[torch.nn.Module, L.LightningModule]
self.group_model: Union[torch.nn.Module, L.LightningModule]
self.train_dataset: torch.utils.data.Dataset
self.dataset_transform: Optional[Callable]
self.grouped_train_dl: torch.utils.data.DataLoader
Expand All @@ -38,7 +37,7 @@ def __init__(
@classmethod
def generate(
cls,
model: torch.nn.Module,
model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
trainer: Union[L.Trainer, BaseTrainer],
val_dataset: Optional[torch.utils.data.Dataset] = None,
Expand All @@ -58,11 +57,10 @@ def generate(
"""

obj = cls(device=device)
trainer_fit_kwargs = trainer_fit_kwargs or {}

obj.model = model
obj.trainer = trainer
obj._generate(
trainer=trainer,
train_dataset=train_dataset,
dataset_transform=dataset_transform,
val_dataset=val_dataset,
Expand All @@ -77,6 +75,7 @@ def generate(

def _generate(
self,
trainer: Union[L.Trainer, BaseTrainer],
train_dataset: torch.utils.data.Dataset,
val_dataset: Optional[torch.utils.data.Dataset] = None,
dataset_transform: Optional[Callable] = None,
Expand All @@ -89,12 +88,6 @@ def _generate(
*args,
**kwargs,
):
if self.trainer is None:
raise ValueError(
"Trainer not initialized. Please initialize trainer using init_trainer_from_lightning_module or "
"init_trainer_from_train_arguments"
)

self.train_dataset = train_dataset
self.grouped_dataset = LabelGroupingDataset(
dataset=train_dataset,
Expand Down Expand Up @@ -127,12 +120,31 @@ def _generate(
self.group_model = copy.deepcopy(self.model)

trainer_fit_kwargs = trainer_fit_kwargs or {}
self.trainer.fit(
model=self.group_model, # type: ignore
train_dataloaders=self.grouped_train_dl,
val_dataloaders=self.grouped_val_dl,
**trainer_fit_kwargs,
)

if isinstance(trainer, L.Trainer):
if not isinstance(self.group_model, L.LightningModule):
raise ValueError("Model should be a LightningModule if Trainer is a Lightning Trainer")

trainer.fit(
model=self.group_model,
train_dataloaders=self.grouped_train_dl,
val_dataloaders=self.grouped_val_dl,
**trainer_fit_kwargs,
)

elif isinstance(trainer, BaseTrainer):
if not isinstance(self.group_model, torch.nn.Module):
raise ValueError("Model should be a torch.nn.Module if Trainer is a BaseTrainer")

trainer.fit(
model=self.group_model,
train_dataloaders=self.grouped_train_dl,
val_dataloaders=self.grouped_val_dl,
**trainer_fit_kwargs,
)

else:
raise ValueError("Trainer should be a Lightning Trainer or a BaseTrainer")

@classmethod
def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwargs):
Expand All @@ -155,7 +167,7 @@ def load(cls, path: str, device: str = "cpu", batch_size: int = 8, *args, **kwar
@classmethod
def assemble(
cls,
group_model: torch.nn.Module,
group_model: Union[torch.nn.Module, L.LightningModule],
train_dataset: torch.utils.data.Dataset,
n_classes: int,
n_groups: int,
Expand Down

0 comments on commit baa05b6

Please sign in to comment.