diff --git a/tutorials/usage_testing.py b/tutorials/usage_testing.py index 47662217..c1fc60df 100644 --- a/tutorials/usage_testing.py +++ b/tutorials/usage_testing.py @@ -7,7 +7,7 @@ import requests import torch import torchvision - +import lightning as L # from torch import nn # from torch import optim from torch.utils.data import DataLoader @@ -16,19 +16,19 @@ from torchvision.utils import make_grid from tqdm import tqdm -from quanda.explainers.wrappers.captum_influence import ( +from quanda.explainers.wrappers import ( CaptumSimilarity, captum_similarity_explain, ) -from quanda.metrics.localization.class_detection import ClassDetectionMetric -from quanda.metrics.randomization.model_randomization import ( +from quanda.metrics.localization import ClassDetectionMetric +from quanda.metrics.randomization import ( ModelRandomizationMetric, ) -from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric -from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric -from quanda.toy_benchmarks.subclass_detection import SubclassDetection -from quanda.utils.training.base_pl_module import BasicLightningModule -from quanda.utils.training.trainer import Trainer +from quanda.metrics.unnamed import DatasetCleaningMetric +from quanda.metrics.unnamed import TopKOverlapMetric +from quanda.toy_benchmarks.localization import SubclassDetection +from quanda.utils.training import BasicLightningModule +from quanda.utils.training import Trainer DEVICE = "cuda:0" # "cuda" if torch.cuda.is_available() else "cpu" torch.set_float32_matmul_precision("medium") @@ -72,6 +72,7 @@ def main(): # load model with pre-trained weights model = resnet18(weights=None, num_classes=10) + init_model = resnet18(weights=None, num_classes=10) model.load_state_dict(weights_pretrained) model.to(DEVICE) model.eval() @@ -131,20 +132,42 @@ def accuracy(net, loader): top_k = TopKOverlapMetric(model=model, train_dataset=train_set, top_k=1, device=DEVICE) # dataset cleaning + max_epochs = 1 + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD + lr = 0.1 + optimizer_kwargs = {"momentum": 0.9, "weight_decay": 5e-4} + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR + scheduler_kwargs = {"T_max": max_epochs} + pl_module = BasicLightningModule( - model=copy.deepcopy(model), - optimizer=torch.optim.SGD, - lr=0.1, - criterion=torch.nn.CrossEntropyLoss(), + model=model, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + lr=lr, + criterion=criterion, ) - trainer = Trainer.from_lightning_module(model, pl_module) + + init_pl_module = BasicLightningModule( + model=init_model, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + lr=lr, + criterion=criterion, + ) + + trainer = L.Trainer(max_epochs=max_epochs) data_clean = DatasetCleaningMetric( - model=model, + model=pl_module, + init_model=copy.deepcopy(init_pl_module), train_dataset=train_set, global_method="sum_abs", trainer=trainer, - trainer_fit_kwargs={"max_epochs": 3}, top_k=50, device=DEVICE, ) @@ -179,33 +202,16 @@ def accuracy(net, loader): # Subclass Detection Benchmark Generation and Evaluation # ++++++++++++++++++++++++++++++++++++++++++ - max_epochs = 1 - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD - lr = 0.1 - optimizer_kwargs = {"momentum": 0.9, "weight_decay": 5e-4} - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR - scheduler_kwargs = {"T_max": max_epochs} - - trainer = BasicLightningModule( - model=model, - optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, - lr=lr, - criterion=criterion, - ) + trainer = L.Trainer(max_epochs=max_epochs) bench = SubclassDetection.generate( - model=model, + model=copy.deepcopy(init_pl_module), train_dataset=train_set, trainer=trainer, val_dataset=val_set, n_classes=10, n_groups=2, class_to_group="random", - trainer_fit_kwargs={"max_epochs": max_epochs}, seed=42, batch_size=100, device=DEVICE, @@ -214,7 +220,7 @@ def accuracy(net, loader): score = bench.evaluate( expl_dataset=test_set, explainer_cls=CaptumSimilarity, - expl_kwargs={"layers": "avgpool", "batch_size": 100}, + expl_kwargs={"layers": "model.avgpool", "batch_size": 100}, cache_dir="./cache", model_id="default_model_id", )