Skip to content

Commit

Permalink
fix usage_testing.py bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 9, 2024
1 parent dfdbf62 commit 5636102
Showing 1 changed file with 42 additions and 36 deletions.
78 changes: 42 additions & 36 deletions tutorials/usage_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import requests
import torch
import torchvision

import lightning as L
# from torch import nn
# from torch import optim
from torch.utils.data import DataLoader
Expand All @@ -16,19 +16,19 @@
from torchvision.utils import make_grid
from tqdm import tqdm

from quanda.explainers.wrappers.captum_influence import (
from quanda.explainers.wrappers import (
CaptumSimilarity,
captum_similarity_explain,
)
from quanda.metrics.localization.class_detection import ClassDetectionMetric
from quanda.metrics.randomization.model_randomization import (
from quanda.metrics.localization import ClassDetectionMetric
from quanda.metrics.randomization import (
ModelRandomizationMetric,
)
from quanda.metrics.unnamed.dataset_cleaning import DatasetCleaningMetric
from quanda.metrics.unnamed.top_k_overlap import TopKOverlapMetric
from quanda.toy_benchmarks.subclass_detection import SubclassDetection
from quanda.utils.training.base_pl_module import BasicLightningModule
from quanda.utils.training.trainer import Trainer
from quanda.metrics.unnamed import DatasetCleaningMetric
from quanda.metrics.unnamed import TopKOverlapMetric
from quanda.toy_benchmarks.localization import SubclassDetection
from quanda.utils.training import BasicLightningModule
from quanda.utils.training import Trainer

DEVICE = "cuda:0" # "cuda" if torch.cuda.is_available() else "cpu"
torch.set_float32_matmul_precision("medium")
Expand Down Expand Up @@ -72,6 +72,7 @@ def main():

# load model with pre-trained weights
model = resnet18(weights=None, num_classes=10)
init_model = resnet18(weights=None, num_classes=10)
model.load_state_dict(weights_pretrained)
model.to(DEVICE)
model.eval()
Expand Down Expand Up @@ -131,20 +132,42 @@ def accuracy(net, loader):
top_k = TopKOverlapMetric(model=model, train_dataset=train_set, top_k=1, device=DEVICE)

# dataset cleaning
max_epochs = 1
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD
lr = 0.1
optimizer_kwargs = {"momentum": 0.9, "weight_decay": 5e-4}
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
scheduler_kwargs = {"T_max": max_epochs}

pl_module = BasicLightningModule(
model=copy.deepcopy(model),
optimizer=torch.optim.SGD,
lr=0.1,
criterion=torch.nn.CrossEntropyLoss(),
model=model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
lr=lr,
criterion=criterion,
)
trainer = Trainer.from_lightning_module(model, pl_module)

init_pl_module = BasicLightningModule(
model=init_model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
lr=lr,
criterion=criterion,
)

trainer = L.Trainer(max_epochs=max_epochs)

data_clean = DatasetCleaningMetric(
model=model,
model=pl_module,
init_model=copy.deepcopy(init_pl_module),
train_dataset=train_set,
global_method="sum_abs",
trainer=trainer,
trainer_fit_kwargs={"max_epochs": 3},
top_k=50,
device=DEVICE,
)
Expand Down Expand Up @@ -179,33 +202,16 @@ def accuracy(net, loader):
# Subclass Detection Benchmark Generation and Evaluation
# ++++++++++++++++++++++++++++++++++++++++++

max_epochs = 1
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD
lr = 0.1
optimizer_kwargs = {"momentum": 0.9, "weight_decay": 5e-4}
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
scheduler_kwargs = {"T_max": max_epochs}

trainer = BasicLightningModule(
model=model,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
lr=lr,
criterion=criterion,
)
trainer = L.Trainer(max_epochs=max_epochs)

bench = SubclassDetection.generate(
model=model,
model=copy.deepcopy(init_pl_module),
train_dataset=train_set,
trainer=trainer,
val_dataset=val_set,
n_classes=10,
n_groups=2,
class_to_group="random",
trainer_fit_kwargs={"max_epochs": max_epochs},
seed=42,
batch_size=100,
device=DEVICE,
Expand All @@ -214,7 +220,7 @@ def accuracy(net, loader):
score = bench.evaluate(
expl_dataset=test_set,
explainer_cls=CaptumSimilarity,
expl_kwargs={"layers": "avgpool", "batch_size": 100},
expl_kwargs={"layers": "model.avgpool", "batch_size": 100},
cache_dir="./cache",
model_id="default_model_id",
)
Expand Down

0 comments on commit 5636102

Please sign in to comment.