-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_video_cls.py
114 lines (105 loc) · 4.22 KB
/
train_video_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Extracting images according to annotations from given videos."""
from __future__ import absolute_import, division, print_function
import os
import pathlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from smile import app, flags, logging
from torch.autograd import Variable
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from dataset import get_segment_dense_loader
from metrics import m2cai_map
from models import get_video_model
from train_utils import evaluate, final_evaluate, train
from utils import get_gt_from_files
flags.DEFINE_integer("num_classes", 7, "Number of classes.")
flags.DEFINE_string(
"train_list",
"/mnt/data/m2cai/m2cai_tool/video_features_cropped_long/train_video.pkl",
"Train data set list.")
flags.DEFINE_string(
"valid_list",
"/mnt/data/m2cai/m2cai_tool/video_features_cropped_long/valid_video.pkl",
"Valid data set list.")
flags.DEFINE_string(
"test_list",
"/mnt/data/m2cai/m2cai_tool/video_features_cropped_long/test_video.pkl",
"Test data set list.")
flags.DEFINE_integer("num_gpu", 1, "Number of gpus to use.")
flags.DEFINE_integer("batch_size", 64, "Batch size.")
flags.DEFINE_float("lr", 0.000001, "Optimizer Learning Rate.")
flags.DEFINE_float("momentum", 0.5, "Optimizer momentum.")
flags.DEFINE_integer("epoch_num", 200, "Epoch numbers to train.")
flags.DEFINE_string("save_model_path", "saved/saved_video_new_avg_11_contrast",
"Path to save models.")
flags.DEFINE_string("load_model_path", "", "Save model parameters.")
flags.DEFINE_string("pool_type", "avg", "Temporal pooling type.")
flags.DEFINE_string("gt_path", "gt", "Groud truth path.")
flags.DEFINE_boolean("moe", False, "If use MOE in the model.")
flags.DEFINE_string("loss_type", "bce", "Either BCE or MultiSoftMargin")
flags.DEFINE_integer("frame_num", 11, "Segment length.")
FLAGS = flags.FLAGS
def main(_):
"""Main function for video classification.
"""
# Get model.
logging.info("Creating model.")
model = get_video_model(
num_gpus=FLAGS.num_gpu,
load_model_path=FLAGS.load_model_path,
num_classes=FLAGS.num_classes,
pool_type=FLAGS.pool_type,
moe=FLAGS.moe,
frame_num=FLAGS.frame_num)
logging.info("Model is ready.")
# Get data loaders.
logging.info("Creating data loaders.")
train_loader, _ = get_segment_dense_loader(
FLAGS.train_list, batch_size=FLAGS.batch_size, shuffle=True)
valid_loader, _ = get_segment_dense_loader(
FLAGS.valid_list, batch_size=FLAGS.batch_size, shuffle=True)
test_loader, test_data_len = get_segment_dense_loader(
FLAGS.test_list, batch_size=FLAGS.batch_size, shuffle=False)
logging.info("Data loaders are ready.")
# Get optimizer.
optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.lr)
# Get criterion.
if FLAGS.loss_type is "bce":
criterion_train = nn.BCEWithLogitsLoss()
criterion_val = nn.BCELoss()
need_sigmoid = True
else:
criterion_train = criterion_val = nn.MultiLabelSoftMarginLoss()
need_sigmoid = False
# Scheduler.
scheduler = ReduceLROnPlateau(optimizer, factor=0.9, patience=5, mode="min")
# Start training.
if not os.path.isdir(FLAGS.save_model_path):
pathlib.Path(FLAGS.save_model_path).mkdir(parents=True, exist_ok=True)
for i in range(FLAGS.epoch_num):
train(
i,
model,
train_loader,
optimizer,
criterion_train,
scheduler=scheduler)
evaluate(model, valid_loader, criterion_val, need_sigmoid=need_sigmoid)
if i % 10 == 0:
path_to_save = os.path.join(FLAGS.save_model_path,
"params_epoch_%04d.pkl" % i)
torch.save(model.state_dict(), path_to_save)
# Final evaluation after training finishes.
pred = final_evaluate(model, test_loader, test_data_len,
FLAGS.num_classes, FLAGS.batch_size)
gt = get_gt_from_files(FLAGS.gt_path)
ap = m2cai_map(pred, gt)
logging.info("Average precision:")
logging.info(ap)
logging.info(sum(ap) / len(ap))
if __name__ == "__main__":
app.run()