From 466779b3dcf0e7c64c9bb0dd3508a9945cc1773b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20=C3=81ngel=20Gonz=C3=A1lez=20Santamarta?= Date: Wed, 23 Oct 2024 13:12:42 +0200 Subject: [PATCH] formatting fixes --- yolo_ros/yolo_ros/debug_node.py | 57 ++++++++++++++++++++++++------ yolo_ros/yolo_ros/tracking_node.py | 27 ++++++++------ yolo_ros/yolo_ros/yolo_node.py | 9 +++-- 3 files changed, 69 insertions(+), 24 deletions(-) diff --git a/yolo_ros/yolo_ros/debug_node.py b/yolo_ros/yolo_ros/debug_node.py index 8a836aa..a6deb8d 100644 --- a/yolo_ros/yolo_ros/debug_node.py +++ b/yolo_ros/yolo_ros/debug_node.py @@ -127,7 +127,12 @@ def on_shutdown(self, state: LifecycleState) -> TransitionCallbackReturn: self.get_logger().info(f"[{self.get_name()}] Shutted down") return TransitionCallbackReturn.SUCCESS - def draw_box(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int]) -> np.ndarray: + def draw_box( + self, + cv_image: np.ndarray, + detection: Detection, + color: Tuple[int] + ) -> np.ndarray: # get detection info class_name = detection.class_name @@ -176,7 +181,12 @@ def draw_box(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int] return cv_image - def draw_mask(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int]) -> np.ndarray: + def draw_mask( + self, + cv_image: np.ndarray, + detection: Detection, + color: Tuple[int] + ) -> np.ndarray: mask_msg = detection.mask mask_array = np.array([[int(ele.x), int(ele.y)] @@ -186,11 +196,21 @@ def draw_mask(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int layer = cv_image.copy() layer = cv2.fillPoly(layer, pts=[mask_array], color=color) cv2.addWeighted(cv_image, 0.4, layer, 0.6, 0, cv_image) - cv_image = cv2.polylines(cv_image, [mask_array], isClosed=True, - color=color, thickness=2, lineType=cv2.LINE_AA) + cv_image = cv2.polylines( + cv_image, + [mask_array], + isClosed=True, + color=color, + thickness=2, + lineType=cv2.LINE_AA + ) return cv_image - def draw_keypoints(self, cv_image: np.ndarray, detection: Detection) -> np.ndarray: + def draw_keypoints( + self, + cv_image: np.ndarray, + detection: Detection + ) -> np.ndarray: keypoints_msg = detection.keypoints @@ -218,12 +238,22 @@ def get_pk_pose(kp_id: int) -> Tuple[int]: kp2_pos = get_pk_pose(sk[1]) if kp1_pos is not None and kp2_pos is not None: - cv2.line(cv_image, kp1_pos, kp2_pos, [ - int(x) for x in ann.limb_color[i]], thickness=2, lineType=cv2.LINE_AA) + cv2.line( + cv_image, + kp1_pos, + kp2_pos, + [int(x) for x in ann.limb_color[i]], + thickness=2, + lineType=cv2.LINE_AA + ) return cv_image - def create_bb_marker(self, detection: Detection, color: Tuple[int]) -> Marker: + def create_bb_marker( + self, + detection: Detection, + color: Tuple[int] + ) -> Marker: bbox3d = detection.bbox3d @@ -288,7 +318,11 @@ def create_kp_marker(self, keypoint: KeyPoint3D) -> Marker: return marker - def detections_cb(self, img_msg: Image, detection_msg: DetectionArray) -> None: + def detections_cb( + self, + img_msg: Image, + detection_msg: DetectionArray + ) -> None: cv_image = self.cv_bridge.imgmsg_to_cv2(img_msg) bb_marker_array = MarkerArray() @@ -327,8 +361,9 @@ def detections_cb(self, img_msg: Image, detection_msg: DetectionArray) -> None: kp_marker_array.markers.append(marker) # publish dbg image - self._dbg_pub.publish(self.cv_bridge.cv2_to_imgmsg(cv_image, - encoding=img_msg.encoding)) + self._dbg_pub.publish( + self.cv_bridge.cv2_to_imgmsg( + cv_image, encoding=img_msg.encoding)) self._bb_markers_pub.publish(bb_marker_array) self._kp_markers_pub.publish(kp_marker_array) diff --git a/yolo_ros/yolo_ros/tracking_node.py b/yolo_ros/yolo_ros/tracking_node.py index 33d6f5e..0190a27 100644 --- a/yolo_ros/yolo_ros/tracking_node.py +++ b/yolo_ros/yolo_ros/tracking_node.py @@ -135,7 +135,11 @@ def create_tracker(self, tracker_yaml: str) -> BaseTrack: tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=1) return tracker - def detections_cb(self, img_msg: Image, detections_msg: DetectionArray) -> None: + def detections_cb( + self, + img_msg: Image, + detections_msg: DetectionArray + ) -> None: tracked_detections_msg = DetectionArray() tracked_detections_msg.header = img_msg.header @@ -148,16 +152,17 @@ def detections_cb(self, img_msg: Image, detections_msg: DetectionArray) -> None: detection: Detection for detection in detections_msg.detections: - detection_list.append( - [ - detection.bbox.center.position.x - detection.bbox.size.x / 2, - detection.bbox.center.position.y - detection.bbox.size.y / 2, - detection.bbox.center.position.x + detection.bbox.size.x / 2, - detection.bbox.center.position.y + detection.bbox.size.y / 2, - detection.score, - detection.class_id - ] - ) + detection_list.append([ + detection.bbox.center.position.x - + detection.bbox.size.x / + 2, detection.bbox.center.position.y - + detection.bbox.size.y / + 2, detection.bbox.center.position.x + + detection.bbox.size.x / + 2, detection.bbox.center.position.y + + detection.bbox.size.y / + 2, detection.score, detection.class_id + ]) # tracking if len(detection_list) > 0: diff --git a/yolo_ros/yolo_ros/yolo_node.py b/yolo_ros/yolo_ros/yolo_node.py index 79c911f..a544d52 100644 --- a/yolo_ros/yolo_ros/yolo_node.py +++ b/yolo_ros/yolo_ros/yolo_node.py @@ -196,7 +196,11 @@ def on_shutdown(self, state: LifecycleState) -> TransitionCallbackReturn: self.get_logger().info(f"[{self.get_name()}] Shutted down") return TransitionCallbackReturn.SUCCESS - def enable_cb(self, request: SetBool.Request, response: SetBool.Response) -> SetBool.Response: + def enable_cb( + self, + request: SetBool.Request, + response: SetBool.Response + ) -> SetBool.Response: self.enable = request.data response.success = True return response @@ -299,7 +303,8 @@ def parse_keypoints(self, results: Results) -> List[KeyPoint2DArray]: if points.conf is None: continue - for kp_id, (p, conf) in enumerate(zip(points.xy[0], points.conf[0])): + for kp_id, (p, conf) in enumerate( + zip(points.xy[0], points.conf[0])): if conf >= self.threshold: msg = KeyPoint2D()