Skip to content

Commit

Permalink
Tensorflow 2 Support for Detection Models (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
objorkman authored Sep 26, 2021
1 parent fdba024 commit 288bcf9
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 103 deletions.
2 changes: 1 addition & 1 deletion configs/detection.conf
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--obstacle_detection
--nosimulator_obstacle_detection
--obstacle_detection_model_paths=dependencies/models/obstacle_detection/faster-rcnn/frozen_inference_graph.pb
--obstacle_detection_model_paths=dependencies/models/obstacle_detection/faster-rcnn/
--obstacle_detection_model_names=faster-rcnn
--obstacle_detection_min_score_threshold=0.3
--obstacle_detection_gpu_memory_fraction=0.6
Expand Down
4 changes: 2 additions & 2 deletions install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ mkdir -p dependencies/models
###### Download CARLA-trained object detection models ######
echo "[x] Downloading the obstacle detection models..."
cd $PYLOT_HOME/dependencies/models
~/.local/bin/gdown https://drive.google.com/uc?id=1KL4jD1TNlWLz4199wzAAw-olquBCOGXe
~/.local/bin/gdown https://drive.google.com/uc?id=1aT0q-HCz3KutvNGcc0Tleh88nK05unSe
unzip obstacle_detection.zip ; rm obstacle_detection.zip

###### Download the traffic light model ######
echo "[x] Downloading the traffic light detection models..."
cd $PYLOT_HOME/dependencies/models
mkdir -p traffic_light_detection/faster-rcnn ; cd traffic_light_detection/faster-rcnn
~/.local/bin/gdown https://drive.google.com/uc?id=1LVLb_0R7LwM_pSY4dw7e2_06LO0tGtl-
~/.local/bin/gdown https://drive.google.com/uc?id=1MbTIkh4KJubJN66-SurH1x725D9S-w50

###### Download the Lanenet lane detection model ######
echo "[x] Downloading the lane detection models..."
Expand Down
59 changes: 23 additions & 36 deletions pylot/perception/detection/detection_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,42 +40,27 @@ def __init__(self, camera_stream: erdos.ReadStream,
self._logger = erdos.utils.setup_logging(self.config.name,
self.config.log_file_name)
self._obstacles_stream = obstacles_stream
self._detection_graph = tf.Graph()
# Load the model from the model file.

pylot.utils.set_tf_loglevel(logging.ERROR)
with self._detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(model_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

self._gpu_options = tf.GPUOptions(
allow_growth=True,
visible_device_list=str(self._flags.obstacle_detection_gpu_index),
per_process_gpu_memory_fraction=flags.
obstacle_detection_gpu_memory_fraction)
# Create a TensorFlow session.
self._tf_session = tf.Session(
graph=self._detection_graph,
config=tf.ConfigProto(gpu_options=self._gpu_options))
# Get the tensors we're interested in.
self._image_tensor = self._detection_graph.get_tensor_by_name(
'image_tensor:0')
self._detection_boxes = self._detection_graph.get_tensor_by_name(
'detection_boxes:0')
self._detection_scores = self._detection_graph.get_tensor_by_name(
'detection_scores:0')
self._detection_classes = self._detection_graph.get_tensor_by_name(
'detection_classes:0')
self._num_detections = self._detection_graph.get_tensor_by_name(
'num_detections:0')

# Only sets memory growth for flagged GPU
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(
[physical_devices[self._flags.obstacle_detection_gpu_index]],
'GPU')
tf.config.experimental.set_memory_growth(
physical_devices[self._flags.obstacle_detection_gpu_index], True)

# Load the model from the saved_model format file.
self._model = tf.saved_model.load(model_path)

self._coco_labels = load_coco_labels(self._flags.path_coco_labels)
self._bbox_colors = load_coco_bbox_colors(self._coco_labels)
# Unique bounding box id. Incremented for each bounding box.
self._unique_id = 0

# Serve some junk image to load up the model.
self.__run_model(np.zeros((108, 192, 3)))
self.__run_model(np.zeros((108, 192, 3), dtype='uint8'))

@staticmethod
def connect(camera_stream: erdos.ReadStream,
Expand Down Expand Up @@ -172,12 +157,14 @@ def __run_model(self, image_np):
# Expand dimensions since the model expects images to have
# shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num_detections) = self._tf_session.run(
[
self._detection_boxes, self._detection_scores,
self._detection_classes, self._num_detections
],
feed_dict={self._image_tensor: image_np_expanded})

infer = self._model.signatures['serving_default']
result = infer(tf.convert_to_tensor(value=image_np_expanded))

boxes = result['boxes']
scores = result['scores']
classes = result['classes']
num_detections = result['detections']

num_detections = int(num_detections[0])
res_classes = [int(cls) for cls in classes[0][:num_detections]]
Expand Down
11 changes: 6 additions & 5 deletions pylot/perception/detection/efficientdet_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tensorflow as tf


# TODO: Remove once transition to TF2 is complete
class EfficientDetOperator(erdos.Operator):
""" Detects obstacles using the EfficientDet set of models.
Expand Down Expand Up @@ -78,17 +79,17 @@ def load_serving_model(self, model_name, model_path, gpu_memory_fraction):
detection_graph = tf.Graph()
with detection_graph.as_default():
# Load a frozen graph.
graph_def = tf.GraphDef()
with tf.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile(model_path, 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
gpu_options = tf.GPUOptions(
gpu_options = tf.compat.v1.GPUOptions(
allow_growth=True,
visible_device_list=str(self._flags.obstacle_detection_gpu_index),
per_process_gpu_memory_fraction=gpu_memory_fraction)
return model_name, tf.Session(
return model_name, tf.compat.v1.Session(
graph=detection_graph,
config=tf.ConfigProto(gpu_options=gpu_options))
config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))

@staticmethod
def connect(camera_stream: erdos.ReadStream,
Expand Down
18 changes: 10 additions & 8 deletions pylot/perception/detection/lanenet_detection_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,29 @@ def __init__(self, camera_stream: erdos.ReadStream,
self._flags = flags
self._logger = erdos.utils.setup_logging(self.config.name,
self.config.log_file_name)
tf.compat.v1.disable_eager_execution()
pylot.utils.set_tf_loglevel(logging.ERROR)
self._input_tensor = tf.placeholder(dtype=tf.float32,
shape=[1, 256, 512, 3],
name='input_tensor')
self._input_tensor = tf.compat.v1.placeholder(dtype=tf.float32,
shape=[1, 256, 512, 3],
name='input_tensor')
net = lanenet.LaneNet(phase='test')
self._binary_seg_ret, self._instance_seg_ret = net.inference(
input_tensor=self._input_tensor, name='LaneNet')
self._gpu_options = tf.GPUOptions(
self._gpu_options = tf.compat.v1.GPUOptions(
allow_growth=True,
visible_device_list=str(self._flags.lane_detection_gpu_index),
per_process_gpu_memory_fraction=flags.
lane_detection_gpu_memory_fraction,
allocator_type='BFC')
self._tf_session = tf.Session(config=tf.ConfigProto(
gpu_options=self._gpu_options, allow_soft_placement=True))
with tf.variable_scope(name_or_scope='moving_avg'):
self._tf_session = tf.compat.v1.Session(
config=tf.compat.v1.ConfigProto(gpu_options=self._gpu_options,
allow_soft_placement=True))
with tf.compat.v1.variable_scope(name_or_scope='moving_avg'):
variable_averages = tf.train.ExponentialMovingAverage(0.9995)
variables_to_restore = variable_averages.variables_to_restore()

self._postprocessor = lanenet_postprocess.LaneNetPostProcessor()
saver = tf.train.Saver(variables_to_restore)
saver = tf.compat.v1.train.Saver(variables_to_restore)
with self._tf_session.as_default():
saver.restore(sess=self._tf_session,
save_path=flags.lanenet_detection_model_path)
Expand Down
71 changes: 29 additions & 42 deletions pylot/perception/detection/traffic_light_det_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,20 @@ def __init__(self, camera_stream: erdos.ReadStream,
self.config.log_file_name)
self._flags = flags
self._traffic_lights_stream = traffic_lights_stream
self._detection_graph = tf.Graph()
# Load the model from the model file.
pylot.utils.set_tf_loglevel(logging.ERROR)
with self._detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self._flags.traffic_light_det_model_path,
'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

self._gpu_options = tf.GPUOptions(
allow_growth=True,
visible_device_list=str(self._flags.traffic_light_det_gpu_index),
per_process_gpu_memory_fraction=flags.
traffic_light_det_gpu_memory_fraction)
# Create a TensorFlow session.
self._tf_session = tf.Session(
graph=self._detection_graph,
config=tf.ConfigProto(gpu_options=self._gpu_options))
# Get the tensors we're interested in.
self._image_tensor = self._detection_graph.get_tensor_by_name(
'image_tensor:0')
self._detection_boxes = self._detection_graph.get_tensor_by_name(
'detection_boxes:0')
self._detection_scores = self._detection_graph.get_tensor_by_name(
'detection_scores:0')
self._detection_classes = self._detection_graph.get_tensor_by_name(
'detection_classes:0')
self._num_detections = self._detection_graph.get_tensor_by_name(
'num_detections:0')

# Only sets memory growth for flagged GPU
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(
[physical_devices[self._flags.traffic_light_det_gpu_index]], 'GPU')
tf.config.experimental.set_memory_growth(
physical_devices[self._flags.traffic_light_det_gpu_index], True)

# Load the model from the saved_model format file.
self._model = tf.saved_model.load(
self._flags.traffic_light_det_model_path)

self._labels = {
1: TrafficLightColor.GREEN,
2: TrafficLightColor.YELLOW,
Expand All @@ -79,7 +62,7 @@ def __init__(self, camera_stream: erdos.ReadStream,
# Unique bounding box id. Incremented for each bounding box.
self._unique_id = 0
# Serve some junk image to load up the model.
self.__run_model(np.zeros((108, 192, 3)))
self.__run_model(np.zeros((108, 192, 3), dtype='uint8'))

@staticmethod
def connect(camera_stream: erdos.ReadStream,
Expand Down Expand Up @@ -146,18 +129,22 @@ def __run_model(self, image_np):
# Expand dimensions since the model expects images to have
# shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
(boxes, scores, classes, num) = self._tf_session.run(
[
self._detection_boxes, self._detection_scores,
self._detection_classes, self._num_detections
],
feed_dict={self._image_tensor: image_np_expanded})

num_detections = int(num[0])
labels = [self._labels[label] for label in classes[0][:num_detections]]
boxes = boxes[0][:num_detections]
scores = scores[0][:num_detections]
return boxes, scores, labels

infer = self._model.signatures['serving_default']
result = infer(tf.convert_to_tensor(value=image_np_expanded))

boxes = result['boxes']
scores = result['scores']
classes = result['classes']
num_detections = result['detections']

num_detections = int(num_detections[0])
res_labels = [
self._labels[int(label)] for label in classes[0][:num_detections]
]
res_boxes = boxes[0][:num_detections]
res_scores = scores[0][:num_detections]
return res_boxes, res_scores, res_labels

def __convert_to_detected_tl(self, boxes, scores, labels, height, width):
traffic_lights = []
Expand Down
4 changes: 2 additions & 2 deletions pylot/perception/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Detection flags.
flags.DEFINE_list(
'obstacle_detection_model_paths',
'dependencies/models/obstacle_detection/faster-rcnn/frozen_inference_graph.pb', # noqa: E501
'dependencies/models/obstacle_detection/faster-rcnn/', # noqa: E501
'Comma-separated list of model paths')
flags.DEFINE_list('obstacle_detection_model_names', 'faster-rcnn',
'Comma-separated list of model names')
Expand All @@ -25,7 +25,7 @@
# Traffic light detector flags.
flags.DEFINE_string(
'traffic_light_det_model_path',
'dependencies/models/traffic_light_detection/faster-rcnn/frozen_inference_graph.pb', # noqa: E501
'dependencies/models/traffic_light_detection/faster-rcnn/', # noqa: E501
'Path to the traffic light model protobuf')
flags.DEFINE_float('traffic_light_det_min_score_threshold', 0.3,
'Min score threshold for bounding box')
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pytest
scikit-image<0.15
scipy==1.2.2
shapely==1.6.4
tensorflow-gpu==1.15.4
tensorflow-gpu==2.0.0
torch==1.4.0
torchvision==0.5.0
##### Tracking dependencies #####
Expand Down
2 changes: 1 addition & 1 deletion scripts/set_pythonpath.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ if [ -z "$CARLA_HOME" ]; then
fi

CARLA_EGG=$(ls $CARLA_HOME/PythonAPI/carla/dist/carla*py3*egg)
export PYTHONPATH=$PYTHONPATH:$PYLOT_HOME:/$PYLOT_HOME/dependencies/:$CARLA_EGG:$CARLA_HOME/PythonAPI/carla/
export PYTHONPATH=$PYTHONPATH:$PYLOT_HOME:/$PYLOT_HOME/dependencies/:$CARLA_EGG:$CARLA_HOME/PythonAPI/carla/:$PYLOT_HOME/dependencies/lanenet/
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"scikit-image<0.15",
"scipy==1.2.2",
"shapely==1.6.4",
"tensorflow-gpu==1.15.4",
"tensorflow-gpu==2.0.0",
"torch==1.4.0",
"torchvision==0.5.0",
##### Tracking dependencies #####
Expand Down
8 changes: 4 additions & 4 deletions tests/check_lanenet_lane_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, weights, config):
weights: The path of the weights to be used in the prediction.
config: The config to be used for tensorflow.
"""
self.input_tensor = tf.placeholder(dtype=tf.float32,
self.input_tensor = tf.compat.v1.placeholder(dtype=tf.float32,
shape=[1, 256, 512, 3],
name='input_tensor')
self.net = lanenet.LaneNet(phase='test', net_flag='vgg')
Expand All @@ -36,13 +36,13 @@ def __init__(self, weights, config):
ipm_remap_file_path=
'./dependencies/lanenet-lane-detection/data/tusimple_ipm_remap.yml'
)
sess_config = tf.ConfigProto()
sess_config = tf.compat.v1.ConfigProto()
sess_config.gpu_options.per_process_gpu_memory_fraction = \
config.TEST.GPU_MEMORY_FRACTION
sess_config.gpu_options.allow_growth = config.TRAIN.TF_ALLOW_GROWTH
sess_config.gpu_options.allocator_type = 'BFC'
self.sess = tf.Session(config=sess_config).__enter__()
saver = tf.train.Saver()
self.sess = tf.compat.v1.Session(config=sess_config).__enter__()
saver = tf.compat.v1.train.Saver()
saver.restore(sess=self.sess, save_path=weights)

def process_images(self, msg):
Expand Down

0 comments on commit 288bcf9

Please sign in to comment.