Skip to content

Commit

Permalink
formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Oct 23, 2024
1 parent 408bf3d commit 466779b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 24 deletions.
57 changes: 46 additions & 11 deletions yolo_ros/yolo_ros/debug_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ def on_shutdown(self, state: LifecycleState) -> TransitionCallbackReturn:
self.get_logger().info(f"[{self.get_name()}] Shutted down")
return TransitionCallbackReturn.SUCCESS

def draw_box(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int]) -> np.ndarray:
def draw_box(
self,
cv_image: np.ndarray,
detection: Detection,
color: Tuple[int]
) -> np.ndarray:

# get detection info
class_name = detection.class_name
Expand Down Expand Up @@ -176,7 +181,12 @@ def draw_box(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int]

return cv_image

def draw_mask(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int]) -> np.ndarray:
def draw_mask(
self,
cv_image: np.ndarray,
detection: Detection,
color: Tuple[int]
) -> np.ndarray:

mask_msg = detection.mask
mask_array = np.array([[int(ele.x), int(ele.y)]
Expand All @@ -186,11 +196,21 @@ def draw_mask(self, cv_image: np.ndarray, detection: Detection, color: Tuple[int
layer = cv_image.copy()
layer = cv2.fillPoly(layer, pts=[mask_array], color=color)
cv2.addWeighted(cv_image, 0.4, layer, 0.6, 0, cv_image)
cv_image = cv2.polylines(cv_image, [mask_array], isClosed=True,
color=color, thickness=2, lineType=cv2.LINE_AA)
cv_image = cv2.polylines(
cv_image,
[mask_array],
isClosed=True,
color=color,
thickness=2,
lineType=cv2.LINE_AA
)
return cv_image

def draw_keypoints(self, cv_image: np.ndarray, detection: Detection) -> np.ndarray:
def draw_keypoints(
self,
cv_image: np.ndarray,
detection: Detection
) -> np.ndarray:

keypoints_msg = detection.keypoints

Expand Down Expand Up @@ -218,12 +238,22 @@ def get_pk_pose(kp_id: int) -> Tuple[int]:
kp2_pos = get_pk_pose(sk[1])

if kp1_pos is not None and kp2_pos is not None:
cv2.line(cv_image, kp1_pos, kp2_pos, [
int(x) for x in ann.limb_color[i]], thickness=2, lineType=cv2.LINE_AA)
cv2.line(
cv_image,
kp1_pos,
kp2_pos,
[int(x) for x in ann.limb_color[i]],
thickness=2,
lineType=cv2.LINE_AA
)

return cv_image

def create_bb_marker(self, detection: Detection, color: Tuple[int]) -> Marker:
def create_bb_marker(
self,
detection: Detection,
color: Tuple[int]
) -> Marker:

bbox3d = detection.bbox3d

Expand Down Expand Up @@ -288,7 +318,11 @@ def create_kp_marker(self, keypoint: KeyPoint3D) -> Marker:

return marker

def detections_cb(self, img_msg: Image, detection_msg: DetectionArray) -> None:
def detections_cb(
self,
img_msg: Image,
detection_msg: DetectionArray
) -> None:

cv_image = self.cv_bridge.imgmsg_to_cv2(img_msg)
bb_marker_array = MarkerArray()
Expand Down Expand Up @@ -327,8 +361,9 @@ def detections_cb(self, img_msg: Image, detection_msg: DetectionArray) -> None:
kp_marker_array.markers.append(marker)

# publish dbg image
self._dbg_pub.publish(self.cv_bridge.cv2_to_imgmsg(cv_image,
encoding=img_msg.encoding))
self._dbg_pub.publish(
self.cv_bridge.cv2_to_imgmsg(
cv_image, encoding=img_msg.encoding))
self._bb_markers_pub.publish(bb_marker_array)
self._kp_markers_pub.publish(kp_marker_array)

Expand Down
27 changes: 16 additions & 11 deletions yolo_ros/yolo_ros/tracking_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def create_tracker(self, tracker_yaml: str) -> BaseTrack:
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=1)
return tracker

def detections_cb(self, img_msg: Image, detections_msg: DetectionArray) -> None:
def detections_cb(
self,
img_msg: Image,
detections_msg: DetectionArray
) -> None:

tracked_detections_msg = DetectionArray()
tracked_detections_msg.header = img_msg.header
Expand All @@ -148,16 +152,17 @@ def detections_cb(self, img_msg: Image, detections_msg: DetectionArray) -> None:
detection: Detection
for detection in detections_msg.detections:

detection_list.append(
[
detection.bbox.center.position.x - detection.bbox.size.x / 2,
detection.bbox.center.position.y - detection.bbox.size.y / 2,
detection.bbox.center.position.x + detection.bbox.size.x / 2,
detection.bbox.center.position.y + detection.bbox.size.y / 2,
detection.score,
detection.class_id
]
)
detection_list.append([
detection.bbox.center.position.x -
detection.bbox.size.x /
2, detection.bbox.center.position.y -
detection.bbox.size.y /
2, detection.bbox.center.position.x +
detection.bbox.size.x /
2, detection.bbox.center.position.y +
detection.bbox.size.y /
2, detection.score, detection.class_id
])

# tracking
if len(detection_list) > 0:
Expand Down
9 changes: 7 additions & 2 deletions yolo_ros/yolo_ros/yolo_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ def on_shutdown(self, state: LifecycleState) -> TransitionCallbackReturn:
self.get_logger().info(f"[{self.get_name()}] Shutted down")
return TransitionCallbackReturn.SUCCESS

def enable_cb(self, request: SetBool.Request, response: SetBool.Response) -> SetBool.Response:
def enable_cb(
self,
request: SetBool.Request,
response: SetBool.Response
) -> SetBool.Response:
self.enable = request.data
response.success = True
return response
Expand Down Expand Up @@ -299,7 +303,8 @@ def parse_keypoints(self, results: Results) -> List[KeyPoint2DArray]:
if points.conf is None:
continue

for kp_id, (p, conf) in enumerate(zip(points.xy[0], points.conf[0])):
for kp_id, (p, conf) in enumerate(
zip(points.xy[0], points.conf[0])):

if conf >= self.threshold:
msg = KeyPoint2D()
Expand Down

0 comments on commit 466779b

Please sign in to comment.