-
Notifications
You must be signed in to change notification settings - Fork 1
/
load_model.py
69 lines (62 loc) · 2.77 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import yaml
import torch
import torchvision.models as models
from robustbench.utils import load_model as load_clf
from score_sde.losses import get_optimizer
from score_sde.models import utils as mutils
from score_sde.models.ema import ExponentialMovingAverage
from score_sde import sde_lib
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
from svhn.diffusion import Model as SVHNDiffusion
from svhn.resnet import resnet, SVHNClf
from utils import dict2namespace, restore_checkpoint
def load_models(args, model_src, device):
if args.dataset == 'cifar10':
# Diffusion model
with open('./diffusion_configs/cifar10.yml', 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config = dict2namespace(config)
diffusion = mutils.create_model(config)
optimizer = get_optimizer(config, diffusion.parameters())
ema = ExponentialMovingAverage(
diffusion.parameters(), decay=config.model.ema_rate)
state = dict(step=0, optimizer=optimizer, model=diffusion, ema=ema)
restore_checkpoint(model_src, state, device)
ema.copy_to(diffusion.parameters())
diffusion.eval().to(device)
# Underlying classifier
clf = load_clf(model_name='Standard',
dataset='cifar10').to(device).eval()
elif args.dataset == 'imagenet':
with open('./diffusion_configs/imagenet.yml', 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config = dict2namespace(config)
model_config = model_and_diffusion_defaults()
model_config.update(vars(config.model))
diffusion, _ = create_model_and_diffusion(**model_config)
diffusion.load_state_dict(torch.load(model_src, map_location='cpu'))
diffusion.eval().to(device)
# Underlying classifier
clf = models.resnet50(pretrained=True).to(device).eval()
elif args.dataset == 'svhn':
with open('./svhn/config.yml', 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config = dict2namespace(config)
diffusion = SVHNDiffusion(config)
diffusion.to(device)
state = torch.load(model_src, map_location=device)[0]
for key in list(state.keys()):
state[key[7:]] = state.pop(key)
diffusion.load_state_dict(state)
diffusion.eval()
# Underlying classifier
from path import svhn_clf_path
clf_forward, params = resnet(28, 10, 10)
state_dict = torch.load(svhn_clf_path)
params_tensors = state_dict['params']
for k, v in params.items():
v.data.copy_(params_tensors[k])
for k, v in params.items():
v.data = v.data.to(device)
clf = SVHNClf(clf_forward, params)
return clf, diffusion