From 508e0a9d76d43bce549424afa1138300b39f101a Mon Sep 17 00:00:00 2001 From: MrinallU Date: Sat, 9 Sep 2023 15:06:42 -0400 Subject: [PATCH] Add python library for yolov7 node --- src/detect_ros.py | 4 ++ src/detection_server.py | 82 +++++++++++++++++++++++++++ src/utils/BoundingBox2D.py | 7 +++ src/utils/Detection2D.py | 8 +++ src/utils/Detection2DArray.py | 7 +++ src/utils/ObjectHypothesisWithPose.py | 5 ++ 6 files changed, 113 insertions(+) create mode 100644 src/detection_server.py create mode 100644 src/utils/BoundingBox2D.py create mode 100644 src/utils/Detection2D.py create mode 100644 src/utils/Detection2DArray.py create mode 100644 src/utils/ObjectHypothesisWithPose.py diff --git a/src/detect_ros.py b/src/detect_ros.py index ba9cdec..5f72975 100755 --- a/src/detect_ros.py +++ b/src/detect_ros.py @@ -17,6 +17,7 @@ from vision_msgs.msg import Detection2DArray, Detection2D, BoundingBox2D from sensor_msgs.msg import Image from cv_bridge import CvBridge +from detection_server import DetectionServer import yaml @@ -152,6 +153,9 @@ def process_img_msg(self, img_msg: Image): # publishing detection_msg = create_detection_msg(img_msg, detections) self.detection_publisher.publish(detection_msg) + # post to detection library + DetectionServer.update_detections(np_img_resized, detections) + # visualizing if required if self.visualization_publisher: diff --git a/src/detection_server.py b/src/detection_server.py new file mode 100644 index 0000000..a64c734 --- /dev/null +++ b/src/detection_server.py @@ -0,0 +1,82 @@ +import rospy +import torch +from numpy import ndarray +from utils.Detection2DArray import Detection2DArray +from utils.Detection2D import Detection2D +from utils.BoundingBox2D import BoundingBox2D +from utils.ObjectHypothesisWithPose import ObjectHypothesisWithPose + +# todo: update to classes +class DetectionServer: + def __init__(self): + self.detection_array = Detection2DArray() + self.detection_history = [] + self.number_of_detections = 0 + self.objects_found = False + +# todo: find our if I still need this + def create_time_stamp(): + h = Header() + h.stamp = rospy.Time.now() + return h + + + def update_detections(image: ndarray, detections: torch.Tensor) -> Detection2DArray: + """ + :param detections: torch tensor of shape [num_boxes, 6] where each element is + [x1, y1, x2, y2, confidence, class_id] + :returns: detections as a ros message of type Detection2DArray + """ + self.detection_array = Detection2DArray() + + # todo: add source image https://docs.ros.org/en/lunar/api/vision_msgs/html/msg/Detection2D.html + + # header + time_stamp = create_time_stamp() + self.detection_array.header = time_stamp + self.detection_array.img = image + for detection in detections: + x1, y1, x2, y2, conf, cls = detection.tolist() + single_detection_msg = Detection2D() + single_detection_msg.header = time_stamp + + + # bbox + bbox = BoundingBox2D() + w = int(round(x2 - x1)) + h = int(round(y2 - y1)) + cx = int(round(x1 + w / 2)) + cy = int(round(y1 + h / 2)) + bbox.size_x = w + bbox.size_y = h + bbox.center_x = cx + bbox.center_y = cy + + single_detection_msg.bbox = bbox + + # class id & confidence + obj_hyp = ObjectHypothesisWithPose() + obj_hyp.id = int(cls) + obj_hyp.score = conf + single_detection_msg.results = [obj_hyp] + + self.detection_array.detections.append(single_detection_msg) + self.detection_history.append(self.detection_array) + return self.detection_array + + def get_detections(): + if(len(self.detection_array)==0: + raise Exception("Detections were requested but no objects were found") + else: + return self.detection_array + + def get_number_of_objects_detected(): + return len(self.detection_array) + + def get_detection_history(): + return self.detection_history + + def erase_detection_history(): + self.detection_history = [] + + diff --git a/src/utils/BoundingBox2D.py b/src/utils/BoundingBox2D.py new file mode 100644 index 0000000..335cfe7 --- /dev/null +++ b/src/utils/BoundingBox2D.py @@ -0,0 +1,7 @@ +class BoundingBox2D: + def __init__(self, cx: int, cy: int, sx: int, sy: int): + self.center_x = cx + self.center_y = cy + self.size_x = sx + self.size_y = sy + diff --git a/src/utils/Detection2D.py b/src/utils/Detection2D.py new file mode 100644 index 0000000..19124f4 --- /dev/null +++ b/src/utils/Detection2D.py @@ -0,0 +1,8 @@ +from ObjectHypothesisWithPose import ObjectHypothesisWithPose +from BoundingBox1D import BoundingBox2D + +class Detection2D: + def __init__(self, results: ObjectHypothesisWithPose, bbox: BoundingBox2D, header): + self.results = results + self.bbox = bbox + self.header = header diff --git a/src/utils/Detection2DArray.py b/src/utils/Detection2DArray.py new file mode 100644 index 0000000..caddca8 --- /dev/null +++ b/src/utils/Detection2DArray.py @@ -0,0 +1,7 @@ +from Detection2D import Detection2D +from numpy import ndarray +class Detection2DArray: + def __init__(self, img: ndarray, detections: list[Detection2D], header): + self.img = img + self.detections = detections + self.header = header diff --git a/src/utils/ObjectHypothesisWithPose.py b/src/utils/ObjectHypothesisWithPose.py new file mode 100644 index 0000000..97192e1 --- /dev/null +++ b/src/utils/ObjectHypothesisWithPose.py @@ -0,0 +1,5 @@ +class ObjectHypothesisWithPose: + def __init__(self, id: int, score: float): + self.id = id + self.score = score +