diff --git a/README.md b/README.md index af18a81..f937e1e 100644 --- a/README.md +++ b/README.md @@ -56,3 +56,9 @@ python daemon.py ``` and go to http://127.0.0.1:8000 to configure devicehive connection. Video stream is available on http://127.0.0.1:8000/events/ + +### Tested Tensorflow version +1. Tensorflow=1.4.0 +2. Tensorflow=1.9.0 +3. Tensorflow=1.12.0 (GPU version, cuda100=1.0, cudatoolkit=9.0) + diff --git a/eval.py b/eval.py index d66179b..4695355 100644 --- a/eval.py +++ b/eval.py @@ -18,6 +18,7 @@ import cv2 import pafy import tensorflow as tf +import os from models import yolo from log_config import LOGGING @@ -38,8 +39,12 @@ def evaluate(_): if is_url(video): videoPafy = pafy.new(video) video = videoPafy.getbest(preftype="mp4").url - - cam = cv2.VideoCapture(video) + cam = cv2.VideoCapture(video) + elif os.path.isfile(video): + cam = cv2.VideoCapture(video) + else: + cam = cv2.VideoCapture(0) + if not cam.isOpened(): raise IOError('Can\'t open "{}"'.format(FLAGS.video)) @@ -120,7 +125,7 @@ def evaluate(_): if __name__ == '__main__': - tf.flags.DEFINE_string('video', 0, 'Path to the video file.') + tf.flags.DEFINE_string('video', '0', 'Path to the video file.') tf.flags.DEFINE_string('model_name', 'Yolo2Model', 'Model name to use.') tf.app.run(main=evaluate) diff --git a/models/yolo.py b/models/yolo.py index 673f8f8..a5affba 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -14,6 +14,7 @@ import tensorflow as tf +from tensorflow import keras from models.base import BaseModel from utils import yolo, general @@ -98,9 +99,21 @@ def init(self): self._raw_inp = raw_inp self._raw_out = raw_out self._eval_inp = eval_inp + self._eval_out = eval_out self._sess.run(tf.global_variables_initializer()) + def save(self): + export_path = './2' + print('Exporting trained model to', export_path) + #builder = tf.saved_model.builder.SavedModelBuilder(export_path) + tf.saved_model.simple_save( + self._sess, + export_path, + inputs={'input_image': self._eval_inp}, + outputs={'output': self._eval_out}) + + def close(self): self._sess.close() diff --git a/save.py b/save.py new file mode 100644 index 0000000..6e7a39f --- /dev/null +++ b/save.py @@ -0,0 +1,39 @@ +# Copyright (C) 2017 DataArt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import tensorflow as tf + +from models import yolo +from log_config import LOGGING +from utils.general import format_predictions, find_class_by_name, is_url + + + + +def save(_): + + + model_cls = find_class_by_name('Yolo2Model', [yolo]) + model = model_cls(input_shape=(640, 480, 3)) + model.init() + + model.save() + model.close() + + +if __name__ == '__main__': + + + tf.app.run(main=save)