diff --git a/TF_object_detection/Drone_Net.py b/Drone_Net.py similarity index 100% rename from TF_object_detection/Drone_Net.py rename to Drone_Net.py diff --git a/TF_object_detection/Drone_Net_Test.py b/Drone_Net_Test.py similarity index 96% rename from TF_object_detection/Drone_Net_Test.py rename to Drone_Net_Test.py index d5c4387..eae7705 100644 --- a/TF_object_detection/Drone_Net_Test.py +++ b/Drone_Net_Test.py @@ -130,14 +130,6 @@ def run(self): self.process.set_rotation(0) self.process.set_tilt(0) - # # Display output - # cv2.imshow('object detection', img) - # - # if cv2.waitKey(10) & 0xFF == ord('q'): - # cv2.destroyAllWindows() - # break - - def compute_boxes(self, coordinates): coordinates = np.array(coordinates) if (not coordinates.sum() == 0): diff --git a/TF_object_detection/Manager.py b/Manager.py similarity index 68% rename from TF_object_detection/Manager.py rename to Manager.py index e0cf0b7..cc99f53 100644 --- a/TF_object_detection/Manager.py +++ b/Manager.py @@ -7,25 +7,42 @@ from pyparrot_modified.pyparrot.Bebop import Bebop from pyparrot_modified.pyparrot.DroneVisionGUI import DroneVisionGUI +while True: + test_mode = input("Would you like to use the test mode? (Y/N): ").capitalize() + if(test_mode == "Y"): + test_mode = True + break + + elif(test_mode == "N"): + test_mode = False + break + else: + print("Invalid input, enter either 'Y' or 'N'") + + # this object computes the motion and then feeds it to # 'Move_drone' which then actually moves the drone accordingly + process = Movement_processing() # creates the bebop object and connects to it bebop = Bebop() -bebop.connect(5) +success = bebop.connect(5) # creates the object that moves the drone -move = Move_drone(bebop, process) +move = Move_drone(bebop, process, success) # creates the GUI that will initiate the video stream # if testing is true video is streamed from the webcam (this has to be used # in combination with Drone_Net_Test as the network vision = DroneVisionGUI(bebop, move=move, process=process, is_bebop=True, - user_args=(bebop,), testing=True) + user_args=(bebop,), testing=test_mode) # initialises neural net -net = Drone_Net_Test(vision=vision, process=process) +if(test_mode): + net = Drone_Net_Test(vision=vision, process=process) +else: + net = Drone_Net(vision=vision, process=process) vision.feed_net(net) move.feed_net(net) @@ -42,7 +59,8 @@ print("Starting Vision") vision.start() +move.start() +print('move started') + # starts feeding movement information to the drone once everything else # is up and running -move.start() -print('move started') \ No newline at end of file diff --git a/TF_object_detection/Movement_Processing.py b/Movement_Processing.py similarity index 91% rename from TF_object_detection/Movement_Processing.py rename to Movement_Processing.py index fa7d5ce..522accf 100644 --- a/TF_object_detection/Movement_Processing.py +++ b/Movement_Processing.py @@ -75,12 +75,13 @@ def set_min_box_size(self, min): class Move_drone(threading.Thread): - def __init__(self, bebop, process): + def __init__(self, bebop, process, success): """ :param user_function: user code to run (presumably flies the drone) :param user_args: optional arguments to the user function """ threading.Thread.__init__(self) + self.success = success self.net = None self.bebop = bebop self.process = process @@ -168,15 +169,15 @@ def drone_hover(self): # the function that feeds the movement; runs as a separate thread def run(self): # self.bebop.safe_takeoff(5) - self.bebop.set_video_resolutions('rec1080_stream420') - self.bebop.set_video_recording('time') - self.bebop.set_video_stream_mode('low_latency') - #makes the drone patroll if no target is detected - + if(self.success): + self.bebop.set_video_resolutions('rec1080_stream480') + self.bebop.set_video_recording('time') + self.bebop.set_video_stream_mode('low_latency') + # makes the drone patroll if no target is detected # moves the drone as long as it wasn't killed - while(self.killed == False): - if((not self.camera_angle ==-60 or not self.camera_angle == 60) and self.hovering == False): + while (self.killed == False): + if ((not self.camera_angle == -60 or not self.camera_angle == 60) and self.hovering == False): self.yaw = self.process.get_rotation() self.tilt = self.process.get_tilt() self.pitch = self.process.get_pitch() @@ -185,14 +186,14 @@ def run(self): # print("Drone Discombobulated") # # rotates to track object's left/right movement # else: - if(self.rotate == True): + if (self.rotate == True): self.bebop.fly_direct(roll=0, yaw=self.yaw, pitch=self.pitch, vertical_movement=0, duration=0.1) # moves left/right itself to track object's left/right movement - elif(self.rotate == False): + elif (self.rotate == False): self.bebop.fly_direct(roll=self.yaw, yaw=0, pitch=self.pitch, vertical_movement=0, duration=0.1) # if 60 >= self.camera_angle >= -60: self.bebop.pan_tilt_camera_velocity(tilt_velocity=self.tilt, pan_velocity=0, duration=0.1) - self.camera_angle += self.tilt*0.1 + self.camera_angle += self.tilt * 0.1 # all the getter functions: def get_pitch(self): diff --git a/TF_object_detection/README.md b/TF_object_detection/README.md deleted file mode 100644 index 6258443..0000000 --- a/TF_object_detection/README.md +++ /dev/null @@ -1 +0,0 @@ -This folder contains all files related to the tensorflow object detection network diff --git a/TF_object_detection/make_TFrecord.py b/make_TFrecord.py similarity index 100% rename from TF_object_detection/make_TFrecord.py rename to make_TFrecord.py diff --git a/object_detection/core/__pycache__/standard_fields.cpython-36.pyc b/object_detection/core/__pycache__/standard_fields.cpython-36.pyc new file mode 100644 index 0000000..468c880 Binary files /dev/null and b/object_detection/core/__pycache__/standard_fields.cpython-36.pyc differ diff --git a/object_detection/core/standard_fields.py b/object_detection/core/standard_fields.py new file mode 100644 index 0000000..de11848 --- /dev/null +++ b/object_detection/core/standard_fields.py @@ -0,0 +1,242 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contains classes specifying naming conventions used for object detection. + + +Specifies: + InputDataFields: standard fields used by reader/preprocessor/batcher. + DetectionResultFields: standard fields returned by object detector. + BoxListFields: standard field used by BoxList + TfExampleFields: standard fields for tf-example data format (go/tf-example). +""" + + +class InputDataFields(object): + """Names for the input tensors. + + Holds the standard data field names to use for identifying input tensors. This + should be used by the decoder to identify keys for the returned tensor_dict + containing input tensors. And it should be used by the model to identify the + tensors it needs. + + Attributes: + image: image. + image_additional_channels: additional channels. + original_image: image in the original input size. + original_image_spatial_shape: image in the original input size. + key: unique key corresponding to image. + source_id: source of the original image. + filename: original filename of the dataset (without common path). + groundtruth_image_classes: image-level class labels. + groundtruth_image_confidences: image-level class confidences. + groundtruth_boxes: coordinates of the ground truth boxes in the image. + groundtruth_classes: box-level class labels. + groundtruth_confidences: box-level class confidences. The shape should be + the same as the shape of groundtruth_classes. + groundtruth_label_types: box-level label types (e.g. explicit negative). + groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] + is the groundtruth a single object or a crowd. + groundtruth_area: area of a groundtruth segment. + groundtruth_difficult: is a `difficult` object + groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the + same class, forming a connected group, where instances are heavily + occluding each other. + proposal_boxes: coordinates of object proposal boxes. + proposal_objectness: objectness score of each proposal. + groundtruth_instance_masks: ground truth instance masks. + groundtruth_instance_boundaries: ground truth instance boundaries. + groundtruth_instance_classes: instance mask-level class labels. + groundtruth_keypoints: ground truth keypoints. + groundtruth_keypoint_visibilities: ground truth keypoint visibilities. + groundtruth_label_weights: groundtruth label weights. + groundtruth_weights: groundtruth weight factor for bounding boxes. + num_groundtruth_boxes: number of groundtruth boxes. + is_annotated: whether an image has been labeled or not. + true_image_shapes: true shapes of images in the resized images, as resized + images can be padded with zeros. + multiclass_scores: the label score per class for each box. + """ + image = 'image' + image_additional_channels = 'image_additional_channels' + original_image = 'original_image' + original_image_spatial_shape = 'original_image_spatial_shape' + key = 'key' + source_id = 'source_id' + filename = 'filename' + groundtruth_image_classes = 'groundtruth_image_classes' + groundtruth_image_confidences = 'groundtruth_image_confidences' + groundtruth_boxes = 'groundtruth_boxes' + groundtruth_classes = 'groundtruth_classes' + groundtruth_confidences = 'groundtruth_confidences' + groundtruth_label_types = 'groundtruth_label_types' + groundtruth_is_crowd = 'groundtruth_is_crowd' + groundtruth_area = 'groundtruth_area' + groundtruth_difficult = 'groundtruth_difficult' + groundtruth_group_of = 'groundtruth_group_of' + proposal_boxes = 'proposal_boxes' + proposal_objectness = 'proposal_objectness' + groundtruth_instance_masks = 'groundtruth_instance_masks' + groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' + groundtruth_instance_classes = 'groundtruth_instance_classes' + groundtruth_keypoints = 'groundtruth_keypoints' + groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' + groundtruth_label_weights = 'groundtruth_label_weights' + groundtruth_weights = 'groundtruth_weights' + num_groundtruth_boxes = 'num_groundtruth_boxes' + is_annotated = 'is_annotated' + true_image_shape = 'true_image_shape' + multiclass_scores = 'multiclass_scores' + + +class DetectionResultFields(object): + """Naming conventions for storing the output of the detector. + + Attributes: + source_id: source of the original image. + key: unique key corresponding to image. + detection_boxes: coordinates of the detection boxes in the image. + detection_scores: detection scores for the detection boxes in the image. + detection_multiclass_scores: class score distribution (including background) + for detection boxes in the image including background class. + detection_classes: detection-level class labels. + detection_masks: contains a segmentation mask for each detection box. + detection_boundaries: contains an object boundary for each detection box. + detection_keypoints: contains detection keypoints for each detection box. + num_detections: number of detections in the batch. + raw_detection_boxes: contains decoded detection boxes without Non-Max + suppression. + raw_detection_scores: contains class score logits for raw detection boxes. + """ + + source_id = 'source_id' + key = 'key' + detection_boxes = 'detection_boxes' + detection_scores = 'detection_scores' + detection_multiclass_scores = 'detection_multiclass_scores' + detection_classes = 'detection_classes' + detection_masks = 'detection_masks' + detection_boundaries = 'detection_boundaries' + detection_keypoints = 'detection_keypoints' + num_detections = 'num_detections' + raw_detection_boxes = 'raw_detection_boxes' + raw_detection_scores = 'raw_detection_scores' + + +class BoxListFields(object): + """Naming conventions for BoxLists. + + Attributes: + boxes: bounding box coordinates. + classes: classes per bounding box. + scores: scores per bounding box. + weights: sample weights per bounding box. + objectness: objectness score per bounding box. + masks: masks per bounding box. + boundaries: boundaries per bounding box. + keypoints: keypoints per bounding box. + keypoint_heatmaps: keypoint heatmaps per bounding box. + is_crowd: is_crowd annotation per bounding box. + """ + boxes = 'boxes' + classes = 'classes' + scores = 'scores' + weights = 'weights' + confidences = 'confidences' + objectness = 'objectness' + masks = 'masks' + boundaries = 'boundaries' + keypoints = 'keypoints' + keypoint_heatmaps = 'keypoint_heatmaps' + is_crowd = 'is_crowd' + + +class TfExampleFields(object): + """TF-example proto feature names for object detection. + + Holds the standard feature names to load from an Example proto for object + detection. + + Attributes: + image_encoded: JPEG encoded string + image_format: image format, e.g. "JPEG" + filename: filename + channels: number of channels of image + colorspace: colorspace, e.g. "RGB" + height: height of image in pixels, e.g. 462 + width: width of image in pixels, e.g. 581 + source_id: original source of the image + image_class_text: image-level label in text format + image_class_label: image-level label in numerical format + object_class_text: labels in text format, e.g. ["person", "cat"] + object_class_label: labels in numbers, e.g. [16, 8] + object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 + object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 + object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 + object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 + object_view: viewpoint of object, e.g. ["frontal", "left"] + object_truncated: is object truncated, e.g. [true, false] + object_occluded: is object occluded, e.g. [true, false] + object_difficult: is object difficult, e.g. [true, false] + object_group_of: is object a single object or a group of objects + object_depiction: is object a depiction + object_is_crowd: [DEPRECATED, use object_group_of instead] + is the object a single object or a crowd + object_segment_area: the area of the segment. + object_weight: a weight factor for the object's bounding box. + instance_masks: instance segmentation masks. + instance_boundaries: instance boundaries. + instance_classes: Classes for each instance segmentation mask. + detection_class_label: class label in numbers. + detection_bbox_ymin: ymin coordinates of a detection box. + detection_bbox_xmin: xmin coordinates of a detection box. + detection_bbox_ymax: ymax coordinates of a detection box. + detection_bbox_xmax: xmax coordinates of a detection box. + detection_score: detection score for the class label and box. + """ + image_encoded = 'image/encoded' + image_format = 'image/format' # format is reserved keyword + filename = 'image/filename' + channels = 'image/channels' + colorspace = 'image/colorspace' + height = 'image/height' + width = 'image/width' + source_id = 'image/source_id' + image_class_text = 'image/class/text' + image_class_label = 'image/class/label' + object_class_text = 'image/object/class/text' + object_class_label = 'image/object/class/label' + object_bbox_ymin = 'image/object/bbox/ymin' + object_bbox_xmin = 'image/object/bbox/xmin' + object_bbox_ymax = 'image/object/bbox/ymax' + object_bbox_xmax = 'image/object/bbox/xmax' + object_view = 'image/object/view' + object_truncated = 'image/object/truncated' + object_occluded = 'image/object/occluded' + object_difficult = 'image/object/difficult' + object_group_of = 'image/object/group_of' + object_depiction = 'image/object/depiction' + object_is_crowd = 'image/object/is_crowd' + object_segment_area = 'image/object/segment/area' + object_weight = 'image/object/weight' + instance_masks = 'image/segmentation/object' + instance_boundaries = 'image/boundaries/object' + instance_classes = 'image/segmentation/object/class' + detection_class_label = 'image/detection/label' + detection_bbox_ymin = 'image/detection/bbox/ymin' + detection_bbox_xmin = 'image/detection/bbox/xmin' + detection_bbox_ymax = 'image/detection/bbox/ymax' + detection_bbox_xmax = 'image/detection/bbox/xmax' + detection_score = 'image/detection/score' diff --git a/object_detection/data/mscoco_label_map.pbtxt b/object_detection/data/mscoco_label_map.pbtxt new file mode 100644 index 0000000..1f4872b --- /dev/null +++ b/object_detection/data/mscoco_label_map.pbtxt @@ -0,0 +1,400 @@ +item { + name: "/m/01g317" + id: 1 + display_name: "person" +} +item { + name: "/m/0199g" + id: 2 + display_name: "bicycle" +} +item { + name: "/m/0k4j" + id: 3 + display_name: "car" +} +item { + name: "/m/04_sv" + id: 4 + display_name: "motorcycle" +} +item { + name: "/m/05czz6l" + id: 5 + display_name: "airplane" +} +item { + name: "/m/01bjv" + id: 6 + display_name: "bus" +} +item { + name: "/m/07jdr" + id: 7 + display_name: "train" +} +item { + name: "/m/07r04" + id: 8 + display_name: "truck" +} +item { + name: "/m/019jd" + id: 9 + display_name: "boat" +} +item { + name: "/m/015qff" + id: 10 + display_name: "traffic light" +} +item { + name: "/m/01pns0" + id: 11 + display_name: "fire hydrant" +} +item { + name: "/m/02pv19" + id: 13 + display_name: "stop sign" +} +item { + name: "/m/015qbp" + id: 14 + display_name: "parking meter" +} +item { + name: "/m/0cvnqh" + id: 15 + display_name: "bench" +} +item { + name: "/m/015p6" + id: 16 + display_name: "bird" +} +item { + name: "/m/01yrx" + id: 17 + display_name: "cat" +} +item { + name: "/m/0bt9lr" + id: 18 + display_name: "dog" +} +item { + name: "/m/03k3r" + id: 19 + display_name: "horse" +} +item { + name: "/m/07bgp" + id: 20 + display_name: "sheep" +} +item { + name: "/m/01xq0k1" + id: 21 + display_name: "cow" +} +item { + name: "/m/0bwd_0j" + id: 22 + display_name: "elephant" +} +item { + name: "/m/01dws" + id: 23 + display_name: "bear" +} +item { + name: "/m/0898b" + id: 24 + display_name: "zebra" +} +item { + name: "/m/03bk1" + id: 25 + display_name: "giraffe" +} +item { + name: "/m/01940j" + id: 27 + display_name: "backpack" +} +item { + name: "/m/0hnnb" + id: 28 + display_name: "umbrella" +} +item { + name: "/m/080hkjn" + id: 31 + display_name: "handbag" +} +item { + name: "/m/01rkbr" + id: 32 + display_name: "tie" +} +item { + name: "/m/01s55n" + id: 33 + display_name: "suitcase" +} +item { + name: "/m/02wmf" + id: 34 + display_name: "frisbee" +} +item { + name: "/m/071p9" + id: 35 + display_name: "skis" +} +item { + name: "/m/06__v" + id: 36 + display_name: "snowboard" +} +item { + name: "/m/018xm" + id: 37 + display_name: "sports ball" +} +item { + name: "/m/02zt3" + id: 38 + display_name: "kite" +} +item { + name: "/m/03g8mr" + id: 39 + display_name: "baseball bat" +} +item { + name: "/m/03grzl" + id: 40 + display_name: "baseball glove" +} +item { + name: "/m/06_fw" + id: 41 + display_name: "skateboard" +} +item { + name: "/m/019w40" + id: 42 + display_name: "surfboard" +} +item { + name: "/m/0dv9c" + id: 43 + display_name: "tennis racket" +} +item { + name: "/m/04dr76w" + id: 44 + display_name: "bottle" +} +item { + name: "/m/09tvcd" + id: 46 + display_name: "wine glass" +} +item { + name: "/m/08gqpm" + id: 47 + display_name: "cup" +} +item { + name: "/m/0dt3t" + id: 48 + display_name: "fork" +} +item { + name: "/m/04ctx" + id: 49 + display_name: "knife" +} +item { + name: "/m/0cmx8" + id: 50 + display_name: "spoon" +} +item { + name: "/m/04kkgm" + id: 51 + display_name: "bowl" +} +item { + name: "/m/09qck" + id: 52 + display_name: "banana" +} +item { + name: "/m/014j1m" + id: 53 + display_name: "apple" +} +item { + name: "/m/0l515" + id: 54 + display_name: "sandwich" +} +item { + name: "/m/0cyhj_" + id: 55 + display_name: "orange" +} +item { + name: "/m/0hkxq" + id: 56 + display_name: "broccoli" +} +item { + name: "/m/0fj52s" + id: 57 + display_name: "carrot" +} +item { + name: "/m/01b9xk" + id: 58 + display_name: "hot dog" +} +item { + name: "/m/0663v" + id: 59 + display_name: "pizza" +} +item { + name: "/m/0jy4k" + id: 60 + display_name: "donut" +} +item { + name: "/m/0fszt" + id: 61 + display_name: "cake" +} +item { + name: "/m/01mzpv" + id: 62 + display_name: "chair" +} +item { + name: "/m/02crq1" + id: 63 + display_name: "couch" +} +item { + name: "/m/03fp41" + id: 64 + display_name: "potted plant" +} +item { + name: "/m/03ssj5" + id: 65 + display_name: "bed" +} +item { + name: "/m/04bcr3" + id: 67 + display_name: "dining table" +} +item { + name: "/m/09g1w" + id: 70 + display_name: "toilet" +} +item { + name: "/m/07c52" + id: 72 + display_name: "tv" +} +item { + name: "/m/01c648" + id: 73 + display_name: "laptop" +} +item { + name: "/m/020lf" + id: 74 + display_name: "mouse" +} +item { + name: "/m/0qjjc" + id: 75 + display_name: "remote" +} +item { + name: "/m/01m2v" + id: 76 + display_name: "keyboard" +} +item { + name: "/m/050k8" + id: 77 + display_name: "cell phone" +} +item { + name: "/m/0fx9l" + id: 78 + display_name: "microwave" +} +item { + name: "/m/029bxz" + id: 79 + display_name: "oven" +} +item { + name: "/m/01k6s3" + id: 80 + display_name: "toaster" +} +item { + name: "/m/0130jx" + id: 81 + display_name: "sink" +} +item { + name: "/m/040b_t" + id: 82 + display_name: "refrigerator" +} +item { + name: "/m/0bt_c3" + id: 84 + display_name: "book" +} +item { + name: "/m/01x3z" + id: 85 + display_name: "clock" +} +item { + name: "/m/02s195" + id: 86 + display_name: "vase" +} +item { + name: "/m/01lsmm" + id: 87 + display_name: "scissors" +} +item { + name: "/m/0kmg4" + id: 88 + display_name: "teddy bear" +} +item { + name: "/m/03wvsk" + id: 89 + display_name: "hair drier" +} +item { + name: "/m/012xff" + id: 90 + display_name: "toothbrush" +} diff --git a/object_detection/protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc b/object_detection/protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc new file mode 100644 index 0000000..cf780e2 Binary files /dev/null and b/object_detection/protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc differ diff --git a/object_detection/protos/string_int_label_map_pb2.py b/object_detection/protos/string_int_label_map_pb2.py new file mode 100644 index 0000000..3060bc8 --- /dev/null +++ b/object_detection/protos/string_int_label_map_pb2.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: object_detection/protos/string_int_label_map.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='object_detection/protos/string_int_label_map.proto', + package='object_detection.protos', + syntax='proto2', + serialized_options=None, + serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') +) + + + + +_STRINGINTLABELMAPITEM = _descriptor.Descriptor( + name='StringIntLabelMapItem', + full_name='object_detection.protos.StringIntLabelMapItem', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=79, + serialized_end=150, +) + + +_STRINGINTLABELMAP = _descriptor.Descriptor( + name='StringIntLabelMap', + full_name='object_detection.protos.StringIntLabelMap', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=152, + serialized_end=233, +) + +_STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM +DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM +DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), { + 'DESCRIPTOR' : _STRINGINTLABELMAPITEM, + '__module__' : 'object_detection.protos.string_int_label_map_pb2' + # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) + }) +_sym_db.RegisterMessage(StringIntLabelMapItem) + +StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), { + 'DESCRIPTOR' : _STRINGINTLABELMAP, + '__module__' : 'object_detection.protos.string_int_label_map_pb2' + # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) + }) +_sym_db.RegisterMessage(StringIntLabelMap) + + +# @@protoc_insertion_point(module_scope) diff --git a/object_detection/utils/__pycache__/label_map_util.cpython-36.pyc b/object_detection/utils/__pycache__/label_map_util.cpython-36.pyc new file mode 100644 index 0000000..a0c684d Binary files /dev/null and b/object_detection/utils/__pycache__/label_map_util.cpython-36.pyc differ diff --git a/object_detection/utils/__pycache__/shape_utils.cpython-36.pyc b/object_detection/utils/__pycache__/shape_utils.cpython-36.pyc new file mode 100644 index 0000000..fdd2c13 Binary files /dev/null and b/object_detection/utils/__pycache__/shape_utils.cpython-36.pyc differ diff --git a/object_detection/utils/__pycache__/static_shape.cpython-36.pyc b/object_detection/utils/__pycache__/static_shape.cpython-36.pyc new file mode 100644 index 0000000..f046749 Binary files /dev/null and b/object_detection/utils/__pycache__/static_shape.cpython-36.pyc differ diff --git a/object_detection/utils/__pycache__/visualization_utils.cpython-36.pyc b/object_detection/utils/__pycache__/visualization_utils.cpython-36.pyc new file mode 100644 index 0000000..aa51ab0 Binary files /dev/null and b/object_detection/utils/__pycache__/visualization_utils.cpython-36.pyc differ diff --git a/object_detection/utils/label_map_util.py b/object_detection/utils/label_map_util.py new file mode 100644 index 0000000..f6ecbf1 --- /dev/null +++ b/object_detection/utils/label_map_util.py @@ -0,0 +1,237 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Label map utility functions.""" + +import logging + +import tensorflow as tf +from google.protobuf import text_format +from object_detection.protos import string_int_label_map_pb2 + + +def _validate_label_map(label_map): + """Checks if a label map is valid. + + Args: + label_map: StringIntLabelMap to validate. + + Raises: + ValueError: if label map is invalid. + """ + for item in label_map.item: + if item.id < 0: + raise ValueError('Label map ids should be >= 0.') + if (item.id == 0 and item.name != 'background' and + item.display_name != 'background'): + raise ValueError('Label map id 0 is reserved for the background label') + + +def create_category_index(categories): + """Creates dictionary of COCO compatible categories keyed by category id. + + Args: + categories: a list of dicts, each of which has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + + Returns: + category_index: a dict containing the same entries as categories, but keyed + by the 'id' field of each category. + """ + category_index = {} + for cat in categories: + category_index[cat['id']] = cat + return category_index + + +def get_max_label_map_index(label_map): + """Get maximum index in label map. + + Args: + label_map: a StringIntLabelMapProto + + Returns: + an integer + """ + return max([item.id for item in label_map.item]) + + +def convert_label_map_to_categories(label_map, + max_num_classes, + use_display_name=True): + """Given label map proto returns categories list compatible with eval. + + This function converts label map proto and returns a list of dicts, each of + which has the following keys: + 'id': (required) an integer id uniquely identifying this category. + 'name': (required) string representing category name + e.g., 'cat', 'dog', 'pizza'. + We only allow class into the list if its id-label_id_offset is + between 0 (inclusive) and max_num_classes (exclusive). + If there are several items mapping to the same id in the label map, + we will only keep the first one in the categories list. + + Args: + label_map: a StringIntLabelMapProto or None. If None, a default categories + list is created with max_num_classes categories. + max_num_classes: maximum number of (consecutive) label indices to include. + use_display_name: (boolean) choose whether to load 'display_name' field as + category name. If False or if the display_name field does not exist, uses + 'name' field as category names instead. + + Returns: + categories: a list of dictionaries representing all possible categories. + """ + categories = [] + list_of_ids_already_added = [] + if not label_map: + label_id_offset = 1 + for class_id in range(max_num_classes): + categories.append({ + 'id': class_id + label_id_offset, + 'name': 'category_{}'.format(class_id + label_id_offset) + }) + return categories + for item in label_map.item: + if not 0 < item.id <= max_num_classes: + logging.info( + 'Ignore item %d since it falls outside of requested ' + 'label range.', item.id) + continue + if use_display_name and item.HasField('display_name'): + name = item.display_name + else: + name = item.name + if item.id not in list_of_ids_already_added: + list_of_ids_already_added.append(item.id) + categories.append({'id': item.id, 'name': name}) + return categories + + +def load_labelmap(path): + """Loads label map proto. + + Args: + path: path to StringIntLabelMap proto text file. + Returns: + a StringIntLabelMapProto + """ + with tf.gfile.GFile(path, 'r') as fid: + label_map_string = fid.read() + label_map = string_int_label_map_pb2.StringIntLabelMap() + try: + text_format.Merge(label_map_string, label_map) + except text_format.ParseError: + label_map.ParseFromString(label_map_string) + _validate_label_map(label_map) + return label_map + + +def get_label_map_dict(label_map_path, + use_display_name=False, + fill_in_gaps_and_background=False): + """Reads a label map and returns a dictionary of label names to id. + + Args: + label_map_path: path to StringIntLabelMap proto text file. + use_display_name: whether to use the label map items' display names as keys. + fill_in_gaps_and_background: whether to fill in gaps and background with + respect to the id field in the proto. The id: 0 is reserved for the + 'background' class and will be added if it is missing. All other missing + ids in range(1, max(id)) will be added with a dummy class name + ("class_") if they are missing. + + Returns: + A dictionary mapping label names to id. + + Raises: + ValueError: if fill_in_gaps_and_background and label_map has non-integer or + negative values. + """ + label_map = load_labelmap(label_map_path) + label_map_dict = {} + for item in label_map.item: + if use_display_name: + label_map_dict[item.display_name] = item.id + else: + label_map_dict[item.name] = item.id + + if fill_in_gaps_and_background: + values = set(label_map_dict.values()) + + if 0 not in values: + label_map_dict['background'] = 0 + if not all(isinstance(value, int) for value in values): + raise ValueError('The values in label map must be integers in order to' + 'fill_in_gaps_and_background.') + if not all(value >= 0 for value in values): + raise ValueError('The values in the label map must be positive.') + + if len(values) != max(values) + 1: + # there are gaps in the labels, fill in gaps. + for value in range(1, max(values)): + if value not in values: + # TODO(rathodv): Add a prefix 'class_' here once the tool to generate + # teacher annotation adds this prefix in the data. + label_map_dict[str(value)] = value + + return label_map_dict + + +def create_categories_from_labelmap(label_map_path, use_display_name=True): + """Reads a label map and returns categories list compatible with eval. + + This function converts label map proto and returns a list of dicts, each of + which has the following keys: + 'id': an integer id uniquely identifying this category. + 'name': string representing category name e.g., 'cat', 'dog'. + + Args: + label_map_path: Path to `StringIntLabelMap` proto text file. + use_display_name: (boolean) choose whether to load 'display_name' field + as category name. If False or if the display_name field does not exist, + uses 'name' field as category names instead. + + Returns: + categories: a list of dictionaries representing all possible categories. + """ + label_map = load_labelmap(label_map_path) + max_num_classes = max(item.id for item in label_map.item) + return convert_label_map_to_categories(label_map, max_num_classes, + use_display_name) + + +def create_category_index_from_labelmap(label_map_path, use_display_name=True): + """Reads a label map and returns a category index. + + Args: + label_map_path: Path to `StringIntLabelMap` proto text file. + use_display_name: (boolean) choose whether to load 'display_name' field + as category name. If False or if the display_name field does not exist, + uses 'name' field as category names instead. + + Returns: + A category index, which is a dictionary that maps integer ids to dicts + containing categories, e.g. + {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} + """ + categories = create_categories_from_labelmap(label_map_path, use_display_name) + return create_category_index(categories) + + +def create_class_agnostic_category_index(): + """Creates a category index with a single `object` class.""" + return {1: {'id': 1, 'name': 'object'}} diff --git a/object_detection/utils/shape_utils.py b/object_detection/utils/shape_utils.py new file mode 100644 index 0000000..71b3640 --- /dev/null +++ b/object_detection/utils/shape_utils.py @@ -0,0 +1,462 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utils used to manipulate tensor shapes.""" + +import tensorflow as tf + +from object_detection.utils import static_shape + + +get_dim_as_int = static_shape.get_dim_as_int + + +def _is_tensor(t): + """Returns a boolean indicating whether the input is a tensor. + + Args: + t: the input to be tested. + + Returns: + a boolean that indicates whether t is a tensor. + """ + return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable)) + + +def _set_dim_0(t, d0): + """Sets the 0-th dimension of the input tensor. + + Args: + t: the input tensor, assuming the rank is at least 1. + d0: an integer indicating the 0-th dimension of the input tensor. + + Returns: + the tensor t with the 0-th dimension set. + """ + t_shape = t.get_shape().as_list() + t_shape[0] = d0 + t.set_shape(t_shape) + return t + + +def pad_tensor(t, length): + """Pads the input tensor with 0s along the first dimension up to the length. + + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after padding, assuming length <= t.shape[0]. + + Returns: + padded_t: the padded tensor, whose first dimension is length. If the length + is an integer, the first dimension of padded_t is set to length + statically. + """ + t_rank = tf.rank(t) + t_shape = tf.shape(t) + t_d0 = t_shape[0] + pad_d0 = tf.expand_dims(length - t_d0, 0) + pad_shape = tf.cond( + tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), + lambda: tf.expand_dims(length - t_d0, 0)) + padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) + if not _is_tensor(length): + padded_t = _set_dim_0(padded_t, length) + return padded_t + + +def clip_tensor(t, length): + """Clips the input tensor along the first dimension up to the length. + + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after clipping, assuming length <= t.shape[0]. + + Returns: + clipped_t: the clipped tensor, whose first dimension is length. If the + length is an integer, the first dimension of clipped_t is set to length + statically. + """ + clipped_t = tf.gather(t, tf.range(length)) + if not _is_tensor(length): + clipped_t = _set_dim_0(clipped_t, length) + return clipped_t + + +def pad_or_clip_tensor(t, length): + """Pad or clip the input tensor along the first dimension. + + Args: + t: the input tensor, assuming the rank is at least 1. + length: a tensor of shape [1] or an integer, indicating the first dimension + of the input tensor t after processing. + + Returns: + processed_t: the processed tensor, whose first dimension is length. If the + length is an integer, the first dimension of the processed tensor is set + to length statically. + """ + return pad_or_clip_nd(t, [length] + t.shape.as_list()[1:]) + + +def pad_or_clip_nd(tensor, output_shape): + """Pad or Clip given tensor to the output shape. + + Args: + tensor: Input tensor to pad or clip. + output_shape: A list of integers / scalar tensors (or None for dynamic dim) + representing the size to pad or clip each dimension of the input tensor. + + Returns: + Input tensor padded and clipped to the output shape. + """ + tensor_shape = tf.shape(tensor) + clip_size = [ + tf.where(tensor_shape[i] - shape > 0, shape, -1) + if shape is not None else -1 for i, shape in enumerate(output_shape) + ] + clipped_tensor = tf.slice( + tensor, + begin=tf.zeros(len(clip_size), dtype=tf.int32), + size=clip_size) + + # Pad tensor if the shape of clipped tensor is smaller than the expected + # shape. + clipped_tensor_shape = tf.shape(clipped_tensor) + trailing_paddings = [ + shape - clipped_tensor_shape[i] if shape is not None else 0 + for i, shape in enumerate(output_shape) + ] + paddings = tf.stack( + [ + tf.zeros(len(trailing_paddings), dtype=tf.int32), + trailing_paddings + ], + axis=1) + padded_tensor = tf.pad(clipped_tensor, paddings=paddings) + output_static_shape = [ + dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape + ] + padded_tensor.set_shape(output_static_shape) + return padded_tensor + + +def combined_static_and_dynamic_shape(tensor): + """Returns a list containing static and dynamic values for the dimensions. + + Returns a list of static and dynamic values for shape dimensions. This is + useful to preserve static shapes when available in reshape operation. + + Args: + tensor: A tensor of any type. + + Returns: + A list of size tensor.shape.ndims containing integers or a scalar tensor. + """ + static_tensor_shape = tensor.shape.as_list() + dynamic_tensor_shape = tf.shape(tensor) + combined_shape = [] + for index, dim in enumerate(static_tensor_shape): + if dim is not None: + combined_shape.append(dim) + else: + combined_shape.append(dynamic_tensor_shape[index]) + return combined_shape + + +def static_or_dynamic_map_fn(fn, elems, dtype=None, + parallel_iterations=32, back_prop=True): + """Runs map_fn as a (static) for loop when possible. + + This function rewrites the map_fn as an explicit unstack input -> for loop + over function calls -> stack result combination. This allows our graphs to + be acyclic when the batch size is static. + For comparison, see https://www.tensorflow.org/api_docs/python/tf/map_fn. + + Note that `static_or_dynamic_map_fn` currently is not *fully* interchangeable + with the default tf.map_fn function as it does not accept nested inputs (only + Tensors or lists of Tensors). Likewise, the output of `fn` can only be a + Tensor or list of Tensors. + + TODO(jonathanhuang): make this function fully interchangeable with tf.map_fn. + + Args: + fn: The callable to be performed. It accepts one argument, which will have + the same structure as elems. Its output must have the + same structure as elems. + elems: A tensor or list of tensors, each of which will + be unpacked along their first dimension. The sequence of the + resulting slices will be applied to fn. + dtype: (optional) The output type(s) of fn. If fn returns a structure of + Tensors differing from the structure of elems, then dtype is not optional + and must have the same structure as the output of fn. + parallel_iterations: (optional) number of batch items to process in + parallel. This flag is only used if the native tf.map_fn is used + and defaults to 32 instead of 10 (unlike the standard tf.map_fn default). + back_prop: (optional) True enables support for back propagation. + This flag is only used if the native tf.map_fn is used. + + Returns: + A tensor or sequence of tensors. Each tensor packs the + results of applying fn to tensors unpacked from elems along the first + dimension, from first to last. + Raises: + ValueError: if `elems` a Tensor or a list of Tensors. + ValueError: if `fn` does not return a Tensor or list of Tensors + """ + if isinstance(elems, list): + for elem in elems: + if not isinstance(elem, tf.Tensor): + raise ValueError('`elems` must be a Tensor or list of Tensors.') + + elem_shapes = [elem.shape.as_list() for elem in elems] + # Fall back on tf.map_fn if shapes of each entry of `elems` are None or fail + # to all be the same size along the batch dimension. + for elem_shape in elem_shapes: + if (not elem_shape or not elem_shape[0] + or elem_shape[0] != elem_shapes[0][0]): + return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) + arg_tuples = zip(*[tf.unstack(elem) for elem in elems]) + outputs = [fn(arg_tuple) for arg_tuple in arg_tuples] + else: + if not isinstance(elems, tf.Tensor): + raise ValueError('`elems` must be a Tensor or list of Tensors.') + elems_shape = elems.shape.as_list() + if not elems_shape or not elems_shape[0]: + return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) + outputs = [fn(arg) for arg in tf.unstack(elems)] + # Stack `outputs`, which is a list of Tensors or list of lists of Tensors + if all([isinstance(output, tf.Tensor) for output in outputs]): + return tf.stack(outputs) + else: + if all([isinstance(output, list) for output in outputs]): + if all([all( + [isinstance(entry, tf.Tensor) for entry in output_list]) + for output_list in outputs]): + return [tf.stack(output_tuple) for output_tuple in zip(*outputs)] + raise ValueError('`fn` should return a Tensor or a list of Tensors.') + + +def check_min_image_dim(min_dim, image_tensor): + """Checks that the image width/height are greater than some number. + + This function is used to check that the width and height of an image are above + a certain value. If the image shape is static, this function will perform the + check at graph construction time. Otherwise, if the image shape varies, an + Assertion control dependency will be added to the graph. + + Args: + min_dim: The minimum number of pixels along the width and height of the + image. + image_tensor: The image tensor to check size for. + + Returns: + If `image_tensor` has dynamic size, return `image_tensor` with a Assert + control dependency. Otherwise returns image_tensor. + + Raises: + ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`. + """ + image_shape = image_tensor.get_shape() + image_height = static_shape.get_height(image_shape) + image_width = static_shape.get_width(image_shape) + if image_height is None or image_width is None: + shape_assert = tf.Assert( + tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim), + tf.greater_equal(tf.shape(image_tensor)[2], min_dim)), + ['image size must be >= {} in both height and width.'.format(min_dim)]) + with tf.control_dependencies([shape_assert]): + return tf.identity(image_tensor) + + if image_height < min_dim or image_width < min_dim: + raise ValueError( + 'image size must be >= %d in both height and width; image dim = %d,%d' % + (min_dim, image_height, image_width)) + + return image_tensor + + +def assert_shape_equal(shape_a, shape_b): + """Asserts that shape_a and shape_b are equal. + + If the shapes are static, raises a ValueError when the shapes + mismatch. + + If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes + mismatch. + + Args: + shape_a: a list containing shape of the first tensor. + shape_b: a list containing shape of the second tensor. + + Returns: + Either a tf.no_op() when shapes are all static and a tf.assert_equal() op + when the shapes are dynamic. + + Raises: + ValueError: When shapes are both static and unequal. + """ + if (all(isinstance(dim, int) for dim in shape_a) and + all(isinstance(dim, int) for dim in shape_b)): + if shape_a != shape_b: + raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) + else: return tf.no_op() + else: + return tf.assert_equal(shape_a, shape_b) + + +def assert_shape_equal_along_first_dimension(shape_a, shape_b): + """Asserts that shape_a and shape_b are the same along the 0th-dimension. + + If the shapes are static, raises a ValueError when the shapes + mismatch. + + If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes + mismatch. + + Args: + shape_a: a list containing shape of the first tensor. + shape_b: a list containing shape of the second tensor. + + Returns: + Either a tf.no_op() when shapes are all static and a tf.assert_equal() op + when the shapes are dynamic. + + Raises: + ValueError: When shapes are both static and unequal. + """ + if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): + if shape_a[0] != shape_b[0]: + raise ValueError('Unequal first dimension {}, {}'.format( + shape_a[0], shape_b[0])) + else: return tf.no_op() + else: + return tf.assert_equal(shape_a[0], shape_b[0]) + + +def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1): + """Asserts the input box tensor is normalized. + + Args: + boxes: a tensor of shape [N, 4] where N is the number of boxes. + maximum_normalized_coordinate: Maximum coordinate value to be considered + as normalized, default to 1.1. + + Returns: + a tf.Assert op which fails when the input box tensor is not normalized. + + Raises: + ValueError: When the input box tensor is not normalized. + """ + box_minimum = tf.reduce_min(boxes) + box_maximum = tf.reduce_max(boxes) + return tf.Assert( + tf.logical_and( + tf.less_equal(box_maximum, maximum_normalized_coordinate), + tf.greater_equal(box_minimum, 0)), + [boxes]) + + +def flatten_dimensions(inputs, first, last): + """Flattens `K-d` tensor along [first, last) dimensions. + + Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape + [D0, D1, ..., D(first) * D(first+1) * ... * D(last-1), D(last), ..., D(K-1)]. + + Example: + `inputs` is a tensor with initial shape [10, 5, 20, 20, 3]. + new_tensor = flatten_dimensions(inputs, last=4, first=2) + new_tensor.shape -> [10, 100, 20, 3]. + + Args: + inputs: a tensor with shape [D0, D1, ..., D(K-1)]. + first: first value for the range of dimensions to flatten. + last: last value for the range of dimensions to flatten. Note that the last + dimension itself is excluded. + + Returns: + a tensor with shape + [D0, D1, ..., D(first) * D(first + 1) * ... * D(last - 1), D(last), ..., + D(K-1)]. + + Raises: + ValueError: if first and last arguments are incorrect. + """ + if first >= inputs.shape.ndims or last > inputs.shape.ndims: + raise ValueError('`first` and `last` must be less than inputs.shape.ndims. ' + 'found {} and {} respectively while ndims is {}'.format( + first, last, inputs.shape.ndims)) + shape = combined_static_and_dynamic_shape(inputs) + flattened_dim_prod = tf.reduce_prod(shape[first:last], + keepdims=True) + new_shape = tf.concat([shape[:first], flattened_dim_prod, + shape[last:]], axis=0) + return tf.reshape(inputs, new_shape) + + +def flatten_first_n_dimensions(inputs, n): + """Flattens `K-d` tensor along first n dimension to be a `(K-n+1)-d` tensor. + + Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape + [D0 * D1 * ... * D(n-1), D(n), ... D(K-1)]. + + Example: + `inputs` is a tensor with initial shape [10, 5, 20, 20, 3]. + new_tensor = flatten_first_n_dimensions(inputs, 2) + new_tensor.shape -> [50, 20, 20, 3]. + + Args: + inputs: a tensor with shape [D0, D1, ..., D(K-1)]. + n: The number of dimensions to flatten. + + Returns: + a tensor with shape [D0 * D1 * ... * D(n-1), D(n), ... D(K-1)]. + """ + return flatten_dimensions(inputs, first=0, last=n) + + +def expand_first_dimension(inputs, dims): + """Expands `K-d` tensor along first dimension to be a `(K+n-1)-d` tensor. + + Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape + [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)]. + + Example: + `inputs` is a tensor with shape [50, 20, 20, 3]. + new_tensor = expand_first_dimension(inputs, [10, 5]). + new_tensor.shape -> [10, 5, 20, 20, 3]. + + Args: + inputs: a tensor with shape [D0, D1, ..., D(K-1)]. + dims: List with new dimensions to expand first axis into. The length of + `dims` is typically 2 or larger. + + Returns: + a tensor with shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)]. + """ + inputs_shape = combined_static_and_dynamic_shape(inputs) + expanded_shape = tf.stack(dims + inputs_shape[1:]) + + # Verify that it is possible to expand the first axis of inputs. + assert_op = tf.assert_equal( + inputs_shape[0], tf.reduce_prod(tf.stack(dims)), + message=('First dimension of `inputs` cannot be expanded into provided ' + '`dims`')) + + with tf.control_dependencies([assert_op]): + inputs_reshaped = tf.reshape(inputs, expanded_shape) + + return inputs_reshaped diff --git a/object_detection/utils/static_shape.py b/object_detection/utils/static_shape.py new file mode 100644 index 0000000..307c4d3 --- /dev/null +++ b/object_detection/utils/static_shape.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Helper functions to access TensorShape values. + +The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. +""" + + +def get_dim_as_int(dim): + """Utility to get v1 or v2 TensorShape dim as an int. + + Args: + dim: The TensorShape dimension to get as an int + + Returns: + None or an int. + """ + try: + return dim.value + except AttributeError: + return dim + + +def get_batch_size(tensor_shape): + """Returns batch size from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the batch size of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return get_dim_as_int(tensor_shape[0]) + + +def get_height(tensor_shape): + """Returns height from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the height of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return get_dim_as_int(tensor_shape[1]) + + +def get_width(tensor_shape): + """Returns width from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the width of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return get_dim_as_int(tensor_shape[2]) + + +def get_depth(tensor_shape): + """Returns depth from the tensor shape. + + Args: + tensor_shape: A rank 4 TensorShape. + + Returns: + An integer representing the depth of the tensor. + """ + tensor_shape.assert_has_rank(rank=4) + return get_dim_as_int(tensor_shape[3]) diff --git a/TF_object_detection/visualization_utils.py b/object_detection/utils/visualization_utils.py similarity index 100% rename from TF_object_detection/visualization_utils.py rename to object_detection/utils/visualization_utils.py diff --git a/pyparrot_modified/pyparrot/DroneVisionGUI.py b/pyparrot_modified/pyparrot/DroneVisionGUI.py index a118f7c..bbbf8d4 100644 --- a/pyparrot_modified/pyparrot/DroneVisionGUI.py +++ b/pyparrot_modified/pyparrot/DroneVisionGUI.py @@ -427,6 +427,7 @@ def _buffer_vision(self): if (self.vision_running): if(self.testing == True): self.img = self.net.get_image() + img_backup = self.img width = 800 height = 600 @@ -440,10 +441,11 @@ def _buffer_vision(self): #self.vlc_gui.set_values(50 + self.move.get_yaw(), 50 + self.move.get_pitch()) # read the picture into opencv self.img = cv2.imread(self.file) - self.img = cv2.resize(self.img, (856, 480)) - #height, width, channels = self.img.shape - width = 860 - height = 480 + img_backup = self.img + self.img = cv2.resize(self.img, (426, 240)) + height, width, channels = self.img.shape + # width = 640 + # height = 480 boxes = self.net.get_boxes() @@ -486,7 +488,7 @@ def _buffer_vision(self): self.buffer_index += 1 self.buffer_index %= self.buffer_size # print video_frame - self.buffer[self.buffer_index] = self.img + self.buffer[self.buffer_index] = img_backup self.new_frame = True def get_latest_valid_picture(self): diff --git a/pyparrot_modified/pyparrot/__pycache__/DroneVisionGUI.cpython-36.pyc b/pyparrot_modified/pyparrot/__pycache__/DroneVisionGUI.cpython-36.pyc index 52a2faf..786af61 100644 Binary files a/pyparrot_modified/pyparrot/__pycache__/DroneVisionGUI.cpython-36.pyc and b/pyparrot_modified/pyparrot/__pycache__/DroneVisionGUI.cpython-36.pyc differ diff --git a/pyparrot_modified/pyparrot/images/visionStream.jpg b/pyparrot_modified/pyparrot/images/visionStream.jpg index da3996b..49af1d6 100644 Binary files a/pyparrot_modified/pyparrot/images/visionStream.jpg and b/pyparrot_modified/pyparrot/images/visionStream.jpg differ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..16eeb27 --- /dev/null +++ b/setup.py @@ -0,0 +1,16 @@ +"""Setup script for object_detection.""" + +from setuptools import find_packages +from setuptools import setup + + +REQUIRED_PACKAGES = ['Pillow>=1.0', 'Matplotlib>=2.1', 'Cython>=0.28.1'] + +setup( + name='object_detection', + version='0.1', + install_requires=REQUIRED_PACKAGES, + include_package_data=True, + packages=[p for p in find_packages() if p.startswith('object_detection')], + description='Tensorflow Object Detection Library', +)