Skip to content

Commit

Permalink
feat: add tool for open-set detection using grounding dino (#215)
Browse files Browse the repository at this point in the history
Co-authored-by: Maciej Majek <[email protected]>
  • Loading branch information
rachwalk and maciejmajek authored Sep 24, 2024
1 parent 73b1f19 commit ea213aa
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/rai/rai/tools/ros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#

from .cli import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
from .native import Ros2BaseInput, Ros2BaseTool
from .tools import (
AddDescribedWaypointToDatabaseTool,
GetCurrentPositionTool,
Expand All @@ -24,6 +25,8 @@
"Ros2TopicTool",
"Ros2InterfaceTool",
"Ros2ServiceTool",
"Ros2BaseTool",
"Ros2BaseInput",
"AddDescribedWaypointToDatabaseTool",
"GetOccupancyGridTool",
"GetCurrentPositionTool",
Expand Down
38 changes: 37 additions & 1 deletion src/rai/rai/tools/ros/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Type, Union, cast

import cv2
import numpy as np
import sensor_msgs.msg
from cv_bridge import CvBridge
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
Expand All @@ -34,6 +35,41 @@ def import_message_from_str(msg_type: str) -> Type[object]:
return import_message_from_namespaced_type(msg_namespaced_type)


def convert_ros_img_to_ndarray(
msg: sensor_msgs.msg.Image, encoding: str = ""
) -> np.ndarray:
if encoding == "":
encoding = msg.encoding.lower()

if encoding == "rgb8":
image_data = np.frombuffer(msg.data, np.uint8)
image = image_data.reshape((msg.height, msg.width, 3))
elif encoding == "bgr8":
image_data = np.frombuffer(msg.data, np.uint8)
image = image_data.reshape((msg.height, msg.width, 3))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif encoding == "mono8":
image_data = np.frombuffer(msg.data, np.uint8)
image = image_data.reshape((msg.height, msg.width))
elif encoding == "16uc1":
image_data = np.frombuffer(msg.data, np.uint16)
image = image_data.reshape((msg.height, msg.width))
else:
raise ValueError(f"Unsupported encoding: {encoding}")

return image


def convert_ros_img_to_cv2mat(msg: sensor_msgs.msg.Image) -> cv2.typing.MatLike:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
else:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
return cv_image


def convert_ros_img_to_base64(msg: sensor_msgs.msg.Image) -> str:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
Expand All @@ -55,7 +91,7 @@ def wait_for_message(
topic: str,
*,
qos_profile: Union[QoSProfile, int] = 1,
time_to_wait=-1
time_to_wait=-1,
):
"""
Wait for the next incoming message.
Expand Down
43 changes: 43 additions & 0 deletions src/rai_extensions/rai_grounding_dino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,49 @@ ros2 launch rai_grounding_dino gdino_launch.xml [weights_path:=PATH/TO/WEIGHTS]
> By default the weights will be downloaded to `$(ros2 pkg prefix rai_grounding_dino)/share/weights/`.
> You can change this path if you downloaded the weights manually or moved them.
### RAI Tools

This package provides the following tools:

- `GetDetectionTool`
This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt.

**Example call**

```
x = GetDetectionTool(node=RaiBaseNode(node_name="test_node"))._run(
camera_topic="/camera/camera/color/image_raw",
object_names=["chair", "human", "plushie", "box", "ball"],
)
```

**Example output**

```
I have detected the following items in the picture - chair, human
```

- `GetDistanceToObjectsTool`
This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt. Then it utilises messages from depth camera to create an estimation of distance to a detected object.

**Example call**

```
x = GetDistanceToObjectsTool(node=RaiBaseNode(node_name="test_node"))._run(
camera_topic="/camera/camera/color/image_raw",
depth_topic="/camera/camera/depth/image_rect_raw",
object_names=["chair", "human", "plushie", "box", "ball"],
)
```

**Example output**

```
I have detected the following items in the picture human: 1.68 m away, chair: 2.20 m away
```

### Example

An example client is provided with the package as `rai_grounding_dino/talker.py`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .grounding_dino import GDINO_NODE_NAME, GDINO_SERVICE_NAME
from .tools import GetDetectionTool, GetDistanceToObjectsTool

__all__ = [
"GetDistanceToObjectsTool",
"GetDetectionTool",
"GDINO_NODE_NAME",
"GDINO_SERVICE_NAME",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from os import PathLike
from typing import Dict

import cv2
from cv_bridge import CvBridge
from groundingdino.util.inference import Model
from rclpy.time import Time
Expand Down Expand Up @@ -83,6 +84,7 @@ def get_boxes(
image = self.bridge.imgmsg_to_cv2(
image_msg, desired_encoding=image_msg.encoding
)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

predictions = self.model.predict_with_classes(
image=image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ class GDRequest(TypedDict):
source_img: Image


GDINO_NODE_NAME = "grounding_dino"
GDINO_SERVICE_NAME = "grounding_dino_classify"


class GDinoService(Node):
def __init__(self):
super().__init__(node_name="grounding_dino", parameter_overrides=[])
super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[])
self.srv = self.create_service(
RAIGroundingDino, "grounding_dino_classify", self.classify_callback
RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback
)
self.declare_parameter("weights_path", "")
try:
Expand Down
Loading

0 comments on commit ea213aa

Please sign in to comment.