forked from harvitronix/five-video-classification-methods
-
Notifications
You must be signed in to change notification settings - Fork 0
/
validate_rnn.py
59 lines (48 loc) · 1.55 KB
/
validate_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, CSVLogger
from models import ResearchModels
from data import DataSet
def validate(data_type, model, seq_length=40, saved_model=None,
concat=False, class_limit=None, image_shape=None):
batch_size = 32
# Get the data and process it.
if image_shape is None:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit
)
else:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit,
image_shape=image_shape
)
val_generator = data.frame_generator(batch_size, 'test', data_type, concat)
# Get the model.
rm = ResearchModels(len(data.classes), model, seq_length, saved_model)
# Evaluate!
results = rm.model.evaluate_generator(
generator=val_generator,
val_samples=3200)
print(results)
print(rm.model.metrics_names)
def main():
model = 'mlp'
saved_model = 'data/ucf101/checkpoints/mlp-features.023-0.926.hdf5'
if model == 'conv_3d' or model == 'crnn':
data_type = 'images'
image_shape = (80, 80, 3)
else:
data_type = 'features'
image_shape = None
if model == 'mlp':
concat = True
else:
concat = False
validate(data_type, model, saved_model=saved_model,
concat=concat, image_shape=image_shape)
if __name__ == '__main__':
main()