-
Notifications
You must be signed in to change notification settings - Fork 4
/
finetune.py
99 lines (79 loc) · 3.32 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torch.utils.data import DataLoader
import numpy as np
import time
import sys
import os
from models import utils, caption
from datasets import coco
from configuration import Config
from engine import train_one_epoch, evaluate
def finetune(config):
device = torch.device(config.device)
print(f'Initializing Device: {device}')
seed = config.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
model, criterion = caption.build_model(config)
checkpoint = torch.hub.load_state_dict_from_url(
url="https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth",
map_location=device
)
model.to(device)
model.load_state_dict(checkpoint['model'])
config.lr = 1e-5
config.epochs = 10
config.lr_drop = 8
n_parameters = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print(f"Number of params: {n_parameters}")
param_dicts = [
{"params": [p for n, p in model.named_parameters(
) if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
"lr": config.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
param_dicts, lr=config.lr, weight_decay=config.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)
dataset_train = coco.build_dataset(config, mode='training')
dataset_val = coco.build_dataset(config, mode='validation')
print(f"Train: {len(dataset_train)}")
print(f"Valid: {len(dataset_val)}")
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, config.batch_size, drop_last=True
)
data_loader_train = DataLoader(
dataset_train, batch_sampler=batch_sampler_train, num_workers=config.num_workers)
data_loader_val = DataLoader(dataset_val, config.batch_size,
sampler=sampler_val, drop_last=False, num_workers=config.num_workers)
if os.path.exists(config.checkpoint):
print("Loading Checkpoint...")
checkpoint = torch.load(config.checkpoint, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.start_epoch = checkpoint['epoch'] + 1
print("Start Training..")
for epoch in range(config.start_epoch, config.epochs):
print(f"Epoch: {epoch}")
epoch_loss = train_one_epoch(
model, criterion, data_loader_train, optimizer, device, epoch, config.clip_max_norm)
lr_scheduler.step()
print(f"Training Loss: {epoch_loss}")
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
}, config.checkpoint)
validation_loss = evaluate(model, criterion, data_loader_val, device)
print(f"Validation Loss: {validation_loss}")
print()
if __name__ == "__main__":
config = Config()
finetune(config)