diff --git a/predictor.py b/predictor.py index b1c1eba..a907990 100644 --- a/predictor.py +++ b/predictor.py @@ -97,6 +97,7 @@ def restore(self, model_path): def predict(self, image): pos = self.sess.run(self.x_op, feed_dict = {self.x: image[np.newaxis, :,:,:]}) + pos = np.squeeze(pos) return pos*self.MaxPos def predict_batch(self, images):