-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
121 lines (103 loc) · 4.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 21 18:44:46 2018
@author: wmy
"""
import numpy as np
import keras
import keras.backend as K
from keras.layers import Input, Lambda
from keras.models import Model
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
from model import preprocess_true_boxes, yolo_body, yolo_loss, get_random_data
train_path = 'info_gpu/train.txt'
val_path = 'info_gpu/val.txt'
test_path = 'info_gpu/test.txt'
save_path = 'my_gpu_model/'
class_path = 'infos/classes.txt'
anchor_path = 'infos/anchors.txt'
input_shape = (416,416)
def get_anchors(anchors_path):
with open(anchors_path) as f:
anchors = f.readline()
anchors = [float(x) for x in anchors.split(',')]
return np.array(anchors).reshape(-1, 2)
def get_classes(classes_path):
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
def create_model(input_shape, anchors, num_classes, load_pretrained=False, freeze_body=False,
weights_path=save_path+'weights.h5'):
K.clear_session() # get a new session
image_input = Input(shape=(None, None, 3))
h, w = input_shape
num_anchors = len(anchors)
y_true = [Input(shape=(h//{0:32, 1:16, 2:8}[l], w//{0:32, 1:16, 2:8}[l], \
num_anchors//3, num_classes+5)) for l in range(3)]
model_body = yolo_body(image_input, num_anchors//3, num_classes)
print('Create YOLOv3 model with {} anchors and {} classes.'.format(num_anchors, num_classes))
if load_pretrained:#如果载入权重,这里的权重就是原模型训练好的h5文件
model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
print('Load weights {}.'.format(weights_path))
if freeze_body:
# Do not freeze 3 output layers.
num = len(model_body.layers)-7
for i in range(num): model_body.layers[i].trainable = False
print('Freeze the first {} layers of total {} layers.'.format(num, len(model_body.layers)))
model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',
arguments={'anchors': anchors, 'num_classes': num_classes, 'ignore_thresh': 0.5})(
[*model_body.output, *y_true])
model = Model([model_body.input, *y_true], model_loss)
return model
def train(model, input_shape, anchors, num_classes, lr=0.001, epochs=20, log_dir=save_path):
model.compile(optimizer=keras.optimizers.Adam(lr=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8), \
loss={'yolo_loss': lambda y_true, y_pred: y_pred})#编译模型
logging = TensorBoard(log_dir=log_dir)
checkpoint = ModelCheckpoint(log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5",
monitor='val_loss', save_weights_only=True, save_best_only=True, period=1)
batch_size = 2
with open(train_path) as f:
train = f.readlines()
with open(val_path) as f:
val = f.readlines()
num_train = len(train)
num_val = len(val)
print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
model.fit_generator(data_generator_wrap(train, batch_size, input_shape, anchors, num_classes),
steps_per_epoch=max(1, num_train//batch_size),
validation_data=data_generator_wrap(val, batch_size, input_shape, anchors, num_classes),
validation_steps=max(1, num_val//batch_size),
epochs=epochs,
initial_epoch=0)
model.save_weights(log_dir + 'weights.h5')
model.save(log_dir + 'model.h5')
#数据整理
def data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes):
n = len(annotation_lines)
np.random.shuffle(annotation_lines)
i = 0
while True:
image_data = []
box_data = []
for b in range(batch_size):
i %= n
image, box = get_random_data(annotation_lines[i], input_shape, random=True)
image_data.append(image)
box_data.append(box)
i += 1
image_data = np.array(image_data)
box_data = np.array(box_data)
y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes)
yield [image_data, *y_true], np.zeros(batch_size)
def data_generator_wrap(annotation_lines, batch_size, input_shape, anchors, num_classes):
n = len(annotation_lines)
if n==0 or batch_size<=0:
return None
return data_generator(annotation_lines, batch_size, input_shape, anchors, num_classes)
if __name__ == '__main__':
classes = get_classes(class_path)#得到类,其实只有一个类,所以为1
anchors = get_anchors(anchor_path)#得到anchor数量
model = create_model(input_shape, anchors, len(classes), load_pretrained=True)#获取模型,可能是加载也可以是重新生成
train(model, input_shape, anchors, len(classes), lr=0.00075, epochs=20, log_dir=save_path)
pass