forked from harvitronix/five-video-classification-methods
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvalidate_rnn.py
More file actions
54 lines (44 loc) · 1.45 KB
/
validate_rnn.py
File metadata and controls
54 lines (44 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Validate our RNN. Basically just runs a validation generator on
about the same number of videos as we have in our test set.
"""
from keras.callbacks import TensorBoard, ModelCheckpoint, CSVLogger
from models import ResearchModels
from data import DataSet
def validate(data_type, model, seq_length=40, saved_model=None,
class_limit=None, image_shape=None):
batch_size = 32
# Get the data and process it.
if image_shape is None:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit
)
else:
data = DataSet(
seq_length=seq_length,
class_limit=class_limit,
image_shape=image_shape
)
val_generator = data.frame_generator(batch_size, 'test', data_type)
# Get the model.
rm = ResearchModels(len(data.classes), model, seq_length, saved_model)
# Evaluate!
results = rm.model.evaluate_generator(
generator=val_generator,
val_samples=3200)
print(results)
print(rm.model.metrics_names)
def main():
model = 'lstm'
saved_model = 'data/checkpoints/lstm-features.026-0.239.hdf5'
if model == 'conv_3d' or model == 'lrcn':
data_type = 'images'
image_shape = (80, 80, 3)
else:
data_type = 'features'
image_shape = None
validate(data_type, model, saved_model=saved_model,
image_shape=image_shape, class_limit=4)
if __name__ == '__main__':
main()