-
Notifications
You must be signed in to change notification settings - Fork 16
/
experiment.py
110 lines (95 loc) · 3.68 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from comet_ml import Experiment
import torch
from torchvision import transforms
from src.dataset import Places2
from src.model import PConvUNet
from src.loss import InpaintingLoss, VGG16FeatureExtractor
from src.train import Trainer
from src.utils import Config, load_ckpt, create_ckpt_dir
# set the config
config = Config("default_config.yml")
config.ckpt = create_ckpt_dir()
print("Check Point is '{}'".format(config.ckpt))
# Define the used device
device = torch.device("cuda:{}".format(config.cuda_id)
if torch.cuda.is_available() else "cpu")
# Define the model
print("Loading the Model...")
model = PConvUNet(finetune=config.finetune,
layer_size=config.layer_size)
if config.finetune:
model.load_state_dict(torch.load(config.finetune)['model'])
model.to(device)
# Data Transformation
img_tf = transforms.Compose([
transforms.ToTensor()
])
if config.mask_augment:
mask_tf = transforms.Compose([
transforms.RandomResizedCrop(256),
transforms.ToTensor()
])
else:
mask_tf = transforms.Compose([
transforms.ToTensor()
])
# Define the Validation set
print("Loading the Validation Dataset...")
dataset_val = Places2(config.data_root,
img_tf,
mask_tf,
data="val")
# Set the configuration for training
if config.mode == "train":
# set the comet-ml
if config.comet:
print("Connecting to Comet ML...")
experiment = Experiment(api_key=config.api_key,
project_name=config.project_name,
workspace=config.workspace)
experiment.log_parameters(config.__dict__)
else:
experiment = None
# Define the Places2 Dataset and Data Loader
print("Loading the Training Dataset...")
dataset_train = Places2(config.data_root,
img_tf,
mask_tf,
data="train")
# Define the Loss fucntion
criterion = InpaintingLoss(VGG16FeatureExtractor(),
tv_loss=config.tv_loss).to(device)
# Define the Optimizer
lr = config.finetune_lr if config.finetune else config.initial_lr
if config.optim == "Adam":
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
model.parameters()),
lr=lr,
weight_decay=config.weight_decay)
elif config.optim == "SGD":
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
model.parameters()),
lr=lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
start_iter = 0
if config.resume:
print("Loading the trained params and the state of optimizer...")
start_iter = load_ckpt(config.resume,
[("model", model)],
[("optimizer", optimizer)])
for param_group in optimizer.param_groups:
param_group["lr"] = lr
print("Starting from iter ", start_iter)
trainer = Trainer(start_iter, config, device, model, dataset_train,
dataset_val, criterion, optimizer, experiment=experiment)
if config.comet:
with experiment.train():
trainer.iterate()
else:
trainer.iterate()
# Set the configuration for testing
elif config.mode == "test":
pass
# <model load the trained weights>
# evaluate(model, dataset_val)