Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tool for open-set detection using grounding dino #215

Merged
merged 13 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/rai/rai/tools/ros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#

from .cli import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
from .native import Ros2BaseInput, Ros2BaseTool
from .tools import (
AddDescribedWaypointToDatabaseTool,
GetCurrentPositionTool,
Expand All @@ -24,6 +25,8 @@
"Ros2TopicTool",
"Ros2InterfaceTool",
"Ros2ServiceTool",
"Ros2BaseTool",
"Ros2BaseInput",
"AddDescribedWaypointToDatabaseTool",
"GetOccupancyGridTool",
"GetCurrentPositionTool",
Expand Down
38 changes: 37 additions & 1 deletion src/rai/rai/tools/ros/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Type, Union, cast

import cv2
import numpy as np
import sensor_msgs.msg
from cv_bridge import CvBridge
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
Expand All @@ -34,6 +35,41 @@ def import_message_from_str(msg_type: str) -> Type[object]:
return import_message_from_namespaced_type(msg_namespaced_type)


def convert_ros_img_to_ndarray(
msg: sensor_msgs.msg.Image, encoding: str = ""
) -> np.ndarray:
if encoding == "":
encoding = msg.encoding.lower()

if encoding == "rgb8":
image_data = np.frombuffer(msg.data, np.uint8)
image = image_data.reshape((msg.height, msg.width, 3))
elif encoding == "bgr8":
image_data = np.frombuffer(msg.data, np.uint8)
image = image_data.reshape((msg.height, msg.width, 3))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif encoding == "mono8":
image_data = np.frombuffer(msg.data, np.uint8)
image = image_data.reshape((msg.height, msg.width))
elif encoding == "16uc1":
image_data = np.frombuffer(msg.data, np.uint16)
image = image_data.reshape((msg.height, msg.width))
else:
raise ValueError(f"Unsupported encoding: {encoding}")

return image


def convert_ros_img_to_cv2mat(msg: sensor_msgs.msg.Image) -> cv2.typing.MatLike:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
if cv_image.shape[-1] == 4:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB)
else:
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
return cv_image


def convert_ros_img_to_base64(msg: sensor_msgs.msg.Image) -> str:
bridge = CvBridge()
cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore
Expand All @@ -55,7 +91,7 @@ def wait_for_message(
topic: str,
*,
qos_profile: Union[QoSProfile, int] = 1,
time_to_wait=-1
time_to_wait=-1,
):
"""
Wait for the next incoming message.
Expand Down
43 changes: 43 additions & 0 deletions src/rai_extensions/rai_grounding_dino/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,49 @@ ros2 launch rai_grounding_dino gdino_launch.xml [weights_path:=PATH/TO/WEIGHTS]
> By default the weights will be downloaded to `$(ros2 pkg prefix rai_grounding_dino)/share/weights/`.
> You can change this path if you downloaded the weights manually or moved them.
### RAI Tools

This package provides the following tools:

- `GetDetectionTool`
This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt.

**Example call**

```
x = GetDetectionTool(node=RaiBaseNode(node_name="test_node"))._run(
camera_topic="/camera/camera/color/image_raw",
object_names=["chair", "human", "plushie", "box", "ball"],
)
```

**Example output**

```
I have detected the following items in the picture - chair, human
```

- `GetDistanceToObjectsTool`
This tool calls the grounding dino service to use the model to see if the message from the provided camera topic contains objects from a comma separated prompt. Then it utilises messages from depth camera to create an estimation of distance to a detected object.

**Example call**

```
x = GetDistanceToObjectsTool(node=RaiBaseNode(node_name="test_node"))._run(
camera_topic="/camera/camera/color/image_raw",
depth_topic="/camera/camera/depth/image_rect_raw",
object_names=["chair", "human", "plushie", "box", "ball"],
)
```

**Example output**

```
I have detected the following items in the picture human: 1.68 m away, chair: 2.20 m away
```

### Example

An example client is provided with the package as `rai_grounding_dino/talker.py`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .grounding_dino import GDINO_NODE_NAME, GDINO_SERVICE_NAME
from .tools import GetDetectionTool, GetDistanceToObjectsTool

__all__ = [
"GetDistanceToObjectsTool",
"GetDetectionTool",
"GDINO_NODE_NAME",
"GDINO_SERVICE_NAME",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from os import PathLike
from typing import Dict

import cv2
from cv_bridge import CvBridge
from groundingdino.util.inference import Model
from rclpy.time import Time
Expand Down Expand Up @@ -83,6 +84,7 @@ def get_boxes(
image = self.bridge.imgmsg_to_cv2(
image_msg, desired_encoding=image_msg.encoding
)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

predictions = self.model.predict_with_classes(
image=image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ class GDRequest(TypedDict):
source_img: Image


GDINO_NODE_NAME = "grounding_dino"
GDINO_SERVICE_NAME = "grounding_dino_classify"


class GDinoService(Node):
def __init__(self):
super().__init__(node_name="grounding_dino", parameter_overrides=[])
super().__init__(node_name=GDINO_NODE_NAME, parameter_overrides=[])
self.srv = self.create_service(
RAIGroundingDino, "grounding_dino_classify", self.classify_callback
RAIGroundingDino, GDINO_SERVICE_NAME, self.classify_callback
)
self.declare_parameter("weights_path", "")
try:
Expand Down
Loading