From c3f1405bbf83543ccb61fa42403ac0d376e6e53c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 17 Oct 2024 18:58:27 +0200 Subject: [PATCH 1/8] feat: integrate manipulation vision tools with RaiNode --- examples/rosbot-xl-generic-node-demo.py | 24 +++++-- src/rai/rai/node.py | 16 +++-- src/rai/rai/tools/ros/manipulation.py | 12 ++-- src/rai/rai/tools/ros/native_actions.py | 4 +- src/rai/rai/tools/ros/utils.py | 14 +++- .../rai_open_set_vision/tools/gdino_tools.py | 6 +- .../tools/segmentation_tools.py | 71 ++++++++++++++----- 7 files changed, 106 insertions(+), 41 deletions(-) diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index 025e3083..e6f6147f 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -16,25 +16,24 @@ import rclpy import rclpy.executors -from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool from rai.node import RaiStateBasedLlmNode +from rai.tools.ros.manipulation import GetObjectPositionsTool from rai.tools.ros.native import ( GetCameraImage, - GetMsgFromTopic, Ros2PubMessageTool, Ros2ShowMsgInterfaceTool, ) # from rai.tools.ros.native_actions import Ros2RunActionSync from rai.tools.ros.native_actions import ( + GetTransformTool, Ros2CancelAction, Ros2GetActionResult, Ros2GetLastActionFeedback, Ros2IsActionComplete, Ros2RunActionAsync, ) -from rai.tools.ros.tools import GetCurrentPositionTool from rai.tools.time import WaitForSecondsTool @@ -50,7 +49,9 @@ def main(): topics_whitelist = [ "/rosout", "/camera/camera/color/image_raw", + "/camera/camera/color/camera_info", "/camera/camera/depth/image_rect_raw", + "/camera/camera/depth/camera_info", "/map", "/scan", "/diagnostics", @@ -136,12 +137,21 @@ def main(): Ros2GetActionResult, Ros2GetLastActionFeedback, Ros2ShowMsgInterfaceTool, - GetCurrentPositionTool, WaitForSecondsTool, - GetMsgFromTopic, + # GetMsgFromTopic, GetCameraImage, - GetDetectionTool, - GetDistanceToObjectsTool, + # GetDetectionTool, + GetTransformTool, + ( + GetObjectPositionsTool, + dict( + target_frame="odom", + source_frame="sensor_frame", + camera_topic="/camera/camera/color/image_raw", + depth_topic="/camera/camera/depth/image_rect_raw", + camera_info_topic="/camera/camera/color/camera_info", + ), + ), ], ) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 7d0208a0..f8651656 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -403,14 +403,21 @@ def __init__( def _initialize_tools(self, tools: List[Type[BaseTool]]): initialized_tools = list() for tool_cls in tools: + args = None + if type(tool_cls) is tuple: + tool_cls, args = tool_cls if issubclass(tool_cls, Ros2BaseTool): if ( issubclass(tool_cls, Ros2BaseActionTool) or "DetectionTool" in tool_cls.__name__ or "GetDistance" in tool_cls.__name__ or "GetTransformTool" in tool_cls.__name__ + or "GetObjectPositionsTool" in tool_cls.__name__ ): # TODO(boczekbartek): develop a way to handle all mutially - tool = tool_cls(node=self._async_tool_node) + if args: + tool = tool_cls(node=self._async_tool_node, **args) + else: + tool = tool_cls(node=self._async_tool_node) else: tool = tool_cls(node=self) else: @@ -496,11 +503,10 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): else: last_msg = msg.content - if len(str(last_msg)) > 0 and "Retrieved state: {}" not in last_msg: - feedback_msg = TaskAction.Feedback() - feedback_msg.current_status = f"{graph_node_name}: {last_msg}" + feedback_msg = TaskAction.Feedback() + feedback_msg.current_status = f"{graph_node_name}: {last_msg}" - goal_handle.publish_feedback(feedback_msg) + goal_handle.publish_feedback(feedback_msg) # ---- Share Action Result ---- if state is None: diff --git a/src/rai/rai/tools/ros/manipulation.py b/src/rai/rai/tools/ros/manipulation.py index a5aaec5e..22cebe70 100644 --- a/src/rai/rai/tools/ros/manipulation.py +++ b/src/rai/rai/tools/ros/manipulation.py @@ -29,7 +29,8 @@ from rclpy.node import Node from tf2_geometry_msgs import do_transform_pose -from rai.tools.utils import TF2TransformFetcher +from rai.tools.ros.native_actions import Ros2BaseActionTool +from rai.tools.ros.utils import get_transform from rai_interfaces.srv import ManipulatorMoveTo @@ -141,7 +142,7 @@ class GetObjectPositionsToolInput(BaseModel): ) -class GetObjectPositionsTool(BaseTool): +class GetObjectPositionsTool(Ros2BaseActionTool): name: str = "get_object_positions" description: str = ( "Retrieve the positions of all objects of a specified type within the manipulator's frame of reference. " @@ -154,7 +155,6 @@ class GetObjectPositionsTool(BaseTool): camera_topic: str # rgb camera topic depth_topic: str camera_info_topic: str # rgb camera info topic - node: Node get_grabbing_point_tool: GetGrabbingPointTool def __init__(self, node: Node, **kwargs): @@ -169,9 +169,8 @@ def format_pose(pose: Pose): return f"Centroid(x={pose.position.x:.2f}, y={pose.position.y:2f}, z={pose.position.z:2f})" def _run(self, object_name: str): - transform = TF2TransformFetcher( - target_frame=self.target_frame, source_frame=self.source_frame - ).get_data() + transform = get_transform(self.node, self.source_frame, self.target_frame) + self.logger.info("Got transform: {transform}") results = self.get_grabbing_point_tool._run( camera_topic=self.camera_topic, @@ -179,6 +178,7 @@ def _run(self, object_name: str): camera_info_topic=self.camera_info_topic, object_name=object_name, ) + self.logger.info("Got result: {results}") poses = [] for result in results: diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index 94947b45..8baea5e2 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -184,14 +184,14 @@ class Ros2GetLastActionFeedback(Ros2BaseActionTool): args_schema: Type[Ros2BaseInput] = Ros2BaseInput def _run(self) -> str: - return str(self.node.action_feedback) + return str(self.node.feedback) class GetTransformTool(Ros2BaseActionTool): name: str = "GetTransform" description: str = "Get transform between two frames" - def _run(self, target_frame="map", source_frame="body_link") -> dict: + def _run(self, target_frame="odom", source_frame="body_link") -> dict: return message_to_ordereddict( get_transform(self.node, target_frame, source_frame) ) diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index 72107533..09bd73fa 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -14,6 +14,7 @@ # import base64 +import time from typing import Type, Union, cast import cv2 @@ -155,17 +156,24 @@ def wait_for_message( def get_transform( - node: rclpy.node.Node, target_frame: str, source_frame: str + node: rclpy.node.Node, target_frame: str, source_frame: str, timeout=30 ) -> TransformStamped: + node.get_logger().info( + "Waiting for transform from {} to {}".format(source_frame, target_frame) + ) tf_buffer = Buffer(node=node) tf_listener = TransformListener(tf_buffer, node) transform = None - while transform is None: - rclpy.spin_once(node) + for _ in range(timeout * 10): + rclpy.spin_once(node, timeout_sec=0.1) if tf_buffer.can_transform(target_frame, source_frame, rclpy.time.Time()): transform = tf_buffer.lookup_transform( target_frame, source_frame, rclpy.time.Time() ) + break + else: + time.sleep(0.1) + tf_listener.unregister() return transform diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 0415cfb6..22af5be6 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -25,11 +25,13 @@ ParameterUninitializedException, ) from rclpy.task import Future +from rclpy.wait_for_message import wait_for_message from rai.node import RaiAsyncToolsNode from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray -from rai.tools.utils import wait_for_message + +# from rai.tools.utils import wait_for_message from rai_interfaces.srv import RAIGroundingDino @@ -115,7 +117,7 @@ def _call_gdino_node( future = cli.call_async(req) return future - def get_img_from_topic(self, topic: str, timeout_sec: int = 2): + def get_img_from_topic(self, topic: str, timeout_sec: int = 4): success, msg = wait_for_message( sensor_msgs.msg.Image, self.node, diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 5a4dbf16..60ffe779 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -17,6 +17,7 @@ import cv2 import numpy as np import rclpy +import rclpy.qos import sensor_msgs.msg from pydantic import Field from rai_open_set_vision import GDINO_SERVICE_NAME @@ -26,8 +27,11 @@ ParameterUninitializedException, ) -from rai.node import RaiBaseNode -from rai.tools.ros import Ros2BaseInput, Ros2BaseTool +# from rai.tools.utils import wait_for_message +from rclpy.wait_for_message import wait_for_message + +from rai.tools.ros import Ros2BaseInput +from rai.tools.ros.native_actions import Ros2BaseActionTool from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino @@ -63,9 +67,7 @@ class GetGrabbingPointInput(Ros2BaseInput): # --------------------- Tools --------------------- -class GetSegmentationTool(Ros2BaseTool): - node: RaiBaseNode = Field(..., exclude=True) - +class GetSegmentationTool(Ros2BaseActionTool): name: str = "" description: str = "" @@ -102,12 +104,29 @@ def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response return response return None + def get_img_from_topic(self, topic: str, timeout_sec: int = 10): + success, msg = wait_for_message( + sensor_msgs.msg.Image, + self.node, + topic, + qos_profile=rclpy.qos.qos_profile_sensor_data, + time_to_wait=timeout_sec, + ) + + if success: + self.node.get_logger().info(f"Received message of type from topic {topic}") + return msg + else: + error = f"No message received in {timeout_sec} seconds from topic {topic}" + self.node.get_logger().error(error) + return error + def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: - msg = self.node.get_raw_message_from_topic(topic) + msg = self.get_img_from_topic(topic) if type(msg) is sensor_msgs.msg.Image: return msg else: - raise Exception("Received wrong message") + raise Exception(f"Received wrong message: {type(msg)}") def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_name: str @@ -211,13 +230,25 @@ class GetGrabbingPointTool(GetSegmentationTool): args_schema: Type[GetGrabbingPointInput] = GetGrabbingPointInput def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: - for _ in range(3): - msg = self.node.get_raw_message_from_topic(topic, timeout_sec=3.0) - if isinstance(msg, sensor_msgs.msg.CameraInfo): - return msg - self.node.get_logger().warn("Received wrong message type. Retrying...") + self.node.get_logger().info(f"Waiting for CameraInfo from topic {topic}") + success, msg = wait_for_message( + sensor_msgs.msg.CameraInfo, + self.node, + topic, + qos_profile=rclpy.qos.qos_profile_sensor_data, + time_to_wait=3, + ) + print(msg) - raise Exception("Failed to receive correct CameraInfo message after 3 attempts") + if success: + self.node.get_logger().info(f"Received message of type from topic {topic}") + return msg + else: + error = f"No message received in 3 seconds from topic {topic}" + self.node.get_logger().error(error) + raise Exception( + "Failed to receive correct CameraInfo message after 3 attempts" + ) def _get_intrinsic_from_camera_info(self, camera_info: sensor_msgs.msg.CameraInfo): """Extract camera intrinsic parameters from the CameraInfo message.""" @@ -276,11 +307,14 @@ def _run( camera_info_topic: str, object_name: str, ): - camera_img_msg = self._get_image_message(camera_topic) - depth_msg = self._get_image_message(depth_topic) camera_info = self._get_camera_info_message(camera_info_topic) - + self.logger.info("Received camera info") + camera_img_msg = self.get_img_from_topic(camera_topic) + self.logger.info("Received camera image") + depth_msg = self.get_img_from_topic(depth_topic) + self.logger.info("Received depth image") intrinsic = self._get_intrinsic_from_camera_info(camera_info) + self.logger.info("Received camera intrinsic") future = self._call_gdino_node(camera_img_msg, object_name) logger = self.node.get_logger() @@ -297,13 +331,16 @@ def _run( ) conversion_ratio = 0.001 resolved = None + self.logger.info("Waiting gdino response") while rclpy.ok(): resolved = self._get_gdino_response(future) if resolved is not None: break assert resolved is not None + self.logger.info("Got gdino response") future = self._call_gsam_node(camera_img_msg, resolved) + self.logger.info("Waiting gsam response") ret = [] while rclpy.ok(): @@ -313,6 +350,8 @@ def _run( ret.append(convert_ros_img_to_base64(img_msg)) break assert resolved is not None + + self.logger.info("Got gsam response") rets = [] for mask_msg in resolved.masks: rets.append(self._process_mask(mask_msg, depth_msg, intrinsic)) From 3d8ba70c61489092e10515f530a84c5c4cbc718c Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Thu, 17 Oct 2024 20:03:01 +0200 Subject: [PATCH 2/8] chore: reduce verbosity feat: enhance typing --- src/rai/rai/node.py | 49 ++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 7d0208a0..029e586a 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -30,7 +30,7 @@ from action_msgs.msg import GoalStatus from langchain.tools import BaseTool from langchain.tools.render import render_text_description_and_args -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langgraph.graph.graph import CompiledGraph from rclpy.action.client import ActionClient from rclpy.action.graph import get_action_names_and_types @@ -226,7 +226,7 @@ def _is_task_complete(self): ) return True else: - self.get_logger().info("There is not result") + self.get_logger().info("There is no result") # Timed out, still processing, not complete yet return False @@ -250,7 +250,7 @@ def __init__( ): super().__init__(*args, **kwargs) - self.robot_state = dict() + self.robot_state: Dict[str, Any] = dict() # where Any is ROS 2 message type self.DISCOVERY_FREQ = 2.0 self.DISCOVERY_DEPTH = 5 @@ -288,9 +288,9 @@ def discovery(self): ) def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 1) -> Any: - self.get_logger().info(f"Getting msg from topic: {topic}") + self.get_logger().debug(f"Getting msg from topic: {topic}") if topic in self.state_subscribers and topic in self.robot_state: - self.get_logger().info("Returning cached message") + self.get_logger().debug("Returning cached message") return self.robot_state[topic] else: msg_type = self.get_msg_type(topic) @@ -303,7 +303,7 @@ def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 1) -> Any: ) if success: - self.get_logger().info( + self.get_logger().debug( f"Received message of type {msg_type.__class__.__name__} from topic {topic}" ) return msg @@ -342,7 +342,7 @@ def __init__( self, system_prompt: str, observe_topics: Optional[List[str]] = None, - observe_postprocessors: Optional[Dict[str, Callable]] = None, + observe_postprocessors: Optional[Dict[str, Callable[[Any], Any]]] = None, whitelist: Optional[List[str]] = None, tools: Optional[List[Type[BaseTool]]] = None, *args, @@ -399,9 +399,10 @@ def __init__( state_retriever=self.get_robot_state, logger=self.get_logger(), ) + self.simple_llm = get_llm_model(model_type="simple_model") def _initialize_tools(self, tools: List[Type[BaseTool]]): - initialized_tools = list() + initialized_tools: List[BaseTool] = list() for tool_cls in tools: if issubclass(tool_cls, Ros2BaseTool): if ( @@ -426,7 +427,7 @@ def _initialize_system_prompt(self, prompt: str): ) return system_prompt - def _initialize_robot_state_interfaces(self, topics): + def _initialize_robot_state_interfaces(self, topics: List[str]): self.rosout_buffer = RosoutBuffer(get_llm_model(model_type="simple_model")) for topic in topics: @@ -467,7 +468,7 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): self.get_logger().info(f"Received task: {task}") # ---- LLM Task Handling ---- - self.get_logger().info(f'This is system prompt: "{self.system_prompt}"') + self.get_logger().debug(f'This is system prompt: "{self.system_prompt}"') messages = [ SystemMessage(content=self.system_prompt), @@ -484,19 +485,36 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} ): - print(state.keys()) graph_node_name = list(state.keys())[0] if graph_node_name == "reporter": continue msg = state[graph_node_name]["messages"][-1] - if isinstance(msg, HumanMultimodalMessage): last_msg = msg.text + elif isinstance(msg, BaseMessage): + if isinstance(msg.content, list): + assert len(msg.content) == 1 + last_msg = msg.content[0].get("text", "") + else: + last_msg = msg.content else: - last_msg = msg.content - - if len(str(last_msg)) > 0 and "Retrieved state: {}" not in last_msg: + raise ValueError(f"Unexpected type of message: {type(msg)}") + + last_msg = self.simple_llm.invoke( + [ + SystemMessage( + content=( + "You are an experienced reporter deployed on a autonomous robot. " # type: ignore + "Your task is to summarize the message in a way that is easy for other agents to understand. " + "Do not use markdown formatting. Keep it short and concise. If the message is empty, please return empty string." + ) + ), + HumanMessage(content=last_msg), + ] + ).content + + if len(str(last_msg)) > 0 and graph_node_name != "state_retriever": feedback_msg = TaskAction.Feedback() feedback_msg.current_status = f"{graph_node_name}: {last_msg}" @@ -505,7 +523,6 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): # ---- Share Action Result ---- if state is None: raise ValueError("No output from LLM") - print(state) graph_node_name = list(state.keys())[0] if graph_node_name != "reporter": From 743e408c7ff80c38664a694db794105315755ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 18 Oct 2024 09:27:12 +0200 Subject: [PATCH 3/8] [system_prompt] Add transform checking before and after nav2 actions to --- examples/rosbot-xl-generic-node-demo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index e6f6147f..ac73e350 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -82,6 +82,7 @@ def main(): use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. > be patient with running ros2 actions. usually the take some time to run. + Always check your transform before and after you perform ros2 actions, so that you can verify if it worked. Navigation tips: - it's good to start finding objects by rotating, then navigating to some diverse location with occasional rotations. Remember to frequency detect objects. From c505cbe72d625a2189062eb5b39240a3ce88a8f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 18 Oct 2024 09:27:40 +0200 Subject: [PATCH 4/8] fix(`text_hmi`): remove repeated tool --- src/rai_hmi/rai_hmi/base.py | 3 +-- src/rai_hmi/rai_hmi/tools.py | 20 -------------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py index 6bcffb42..9e0fb81d 100644 --- a/src/rai_hmi/rai_hmi/base.py +++ b/src/rai_hmi/rai_hmi/base.py @@ -35,7 +35,7 @@ MissionFeedbackMessage, ) from rai_hmi.task import Task -from rai_hmi.tools import QueryDatabaseTool, QueueTaskTool +from rai_hmi.tools import QueryDatabaseTool from rai_interfaces.action import Task as TaskAction @@ -138,7 +138,6 @@ def _initialize_available_tools(self): tools.append( QueryDatabaseTool(get_response=self.query_faiss_index_with_scores) ) - tools.append(QueueTaskTool(add_task=self.add_task_to_queue)) return tools def status_callback(self): diff --git a/src/rai_hmi/rai_hmi/tools.py b/src/rai_hmi/rai_hmi/tools.py index 55ff7d5c..cdd4fd79 100644 --- a/src/rai_hmi/rai_hmi/tools.py +++ b/src/rai_hmi/rai_hmi/tools.py @@ -18,8 +18,6 @@ from langchain_core.tools import BaseTool from pydantic import BaseModel, Field -from .task import Task - class QueryDatabaseInput(BaseModel): query: str = Field( @@ -42,21 +40,3 @@ class QueryDatabaseTool(BaseTool): def _run(self, query: str): retrieval_response = self.get_response(query) return str(retrieval_response) - - -class QueueTaskInput(BaseModel): - task: Task = Field(..., description="The task to queue") - - -class QueueTaskTool(BaseTool): - name: str = "queue_task" - description: str = "Queue a task for the platform" - input_type: Type[QueueTaskInput] = QueueTaskInput - - args_schema: Type[QueueTaskInput] = QueueTaskInput - - add_task: Any - - def _run(self, task: Task): - self.add_task(task) - return f"Task {task} has been queued for the LLM" From 2e4b1d889bb7e766fb5c24c8329c8bad0539d4d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 18 Oct 2024 09:52:23 +0200 Subject: [PATCH 5/8] rosbot-example: add detection tool and img describer again --- examples/rosbot-xl-generic-node-demo.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index ac73e350..0330f9e2 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -16,16 +16,15 @@ import rclpy import rclpy.executors +from rai_open_set_vision import GetDetectionTool -from rai.node import RaiStateBasedLlmNode +from rai.node import RaiStateBasedLlmNode, describe_ros_image from rai.tools.ros.manipulation import GetObjectPositionsTool from rai.tools.ros.native import ( GetCameraImage, Ros2PubMessageTool, Ros2ShowMsgInterfaceTool, ) - -# from rai.tools.ros.native_actions import Ros2RunActionSync from rai.tools.ros.native_actions import ( GetTransformTool, Ros2CancelAction, @@ -40,11 +39,11 @@ def main(): rclpy.init() - # observe_topics = [ - # "/camera/camera/color/image_raw", - # ] - # - # observe_postprocessors = {"/camera/camera/color/image_raw": describe_ros_image} + observe_topics = [ + "/camera/camera/color/image_raw", + ] + + observe_postprocessors = {"/camera/camera/color/image_raw": describe_ros_image} topics_whitelist = [ "/rosout", @@ -126,8 +125,8 @@ def main(): """ node = RaiStateBasedLlmNode( - observe_topics=None, - observe_postprocessors=None, + observe_topics=observe_topics, + observe_postprocessors=observe_postprocessors, whitelist=topics_whitelist + actions_whitelist, system_prompt=SYSTEM_PROMPT, tools=[ @@ -139,9 +138,8 @@ def main(): Ros2GetLastActionFeedback, Ros2ShowMsgInterfaceTool, WaitForSecondsTool, - # GetMsgFromTopic, GetCameraImage, - # GetDetectionTool, + GetDetectionTool, GetTransformTool, ( GetObjectPositionsTool, From 28a53251dcfffc467544366c07e5785475063dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 18 Oct 2024 09:56:14 +0200 Subject: [PATCH 6/8] change "task_queue" to "submit_mission" --- src/rai_hmi/rai_hmi/agent.py | 10 +++++----- src/rai_hmi/rai_hmi/base.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rai_hmi/rai_hmi/agent.py b/src/rai_hmi/rai_hmi/agent.py index 08634508..a74e907b 100644 --- a/src/rai_hmi/rai_hmi/agent.py +++ b/src/rai_hmi/rai_hmi/agent.py @@ -33,15 +33,15 @@ def initialize_agent(hmi_node: BaseHMINode, rai_node: RaiBaseNode, memory: Memor @tool def get_mission_memory(uid: str) -> List[MissionMessage]: - """List mission memory. Mission uid is required.""" + """List mission memory. It contains the information about running tasks. Mission uid is required.""" return memory.get_mission_memory(uid) @tool - def add_task_to_queue(task: TaskInput): - """Use this tool to add a task to the queue. The task will be handled by the executor part of your system.""" + def submit_mission_to_the_robot(task: TaskInput): + """Use this tool submit the task to the robot. The task will be handled by the executor part of your system.""" uid = uuid.uuid4() - hmi_node.add_task_to_queue( + hmi_node.execute_mission( Task( name=task.name, description=task.description, @@ -55,7 +55,7 @@ def add_task_to_queue(task: TaskInput): Ros2GetRobotInterfaces(node=rai_node), GetCameraImage(node=rai_node), ] - task_tools = [add_task_to_queue, get_mission_memory] + task_tools = [submit_mission_to_the_robot, get_mission_memory] tools = hmi_node.tools + node_tools + task_tools agent = create_conversational_agent( diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py index 9e0fb81d..d64c9067 100644 --- a/src/rai_hmi/rai_hmi/base.py +++ b/src/rai_hmi/rai_hmi/base.py @@ -72,7 +72,7 @@ class BaseHMINode(Node): If you have multiple questions, please ask them one by one allowing user to respond before moving forward to the next question. Keep the conversation short and to the point. If you are requested tasks that you are capable of perfoming as a robot, not as a - conversational agent, please use tools to submit them to the task queue - only 1 + conversational agent, please use tools to submit them to the robot - only 1 task in parallel is supported. For more complicated tasks, don't split them, just add as 1 task. They will be done by another agent resposible for communication with the robotic's @@ -182,7 +182,7 @@ def initialize_task_action_client_and_server(self): # self, TaskFeedback, "provide_task_feedback", self.handle_task_feedback # ) - def add_task_to_queue(self, task: Task): + def execute_mission(self, task: Task): """Sends a task to the action server to be handled by the rai node.""" if not self.task_action_client.wait_for_server(timeout_sec=10.0): From ea1164e2744c24fea4e1999590590c346e777edd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 18 Oct 2024 10:07:56 +0200 Subject: [PATCH 7/8] remove image describer and cmd_vel --- examples/rosbot-xl-generic-node-demo.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index 0330f9e2..7f3e36b0 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -18,7 +18,7 @@ import rclpy.executors from rai_open_set_vision import GetDetectionTool -from rai.node import RaiStateBasedLlmNode, describe_ros_image +from rai.node import RaiStateBasedLlmNode from rai.tools.ros.manipulation import GetObjectPositionsTool from rai.tools.ros.native import ( GetCameraImage, @@ -39,11 +39,11 @@ def main(): rclpy.init() - observe_topics = [ - "/camera/camera/color/image_raw", - ] - - observe_postprocessors = {"/camera/camera/color/image_raw": describe_ros_image} + # observe_topics = [ + # "/camera/camera/color/image_raw", + # ] + # + # observe_postprocessors = {"/camera/camera/color/image_raw": describe_ros_image} topics_whitelist = [ "/rosout", @@ -54,7 +54,7 @@ def main(): "/map", "/scan", "/diagnostics", - "/cmd_vel", + # "/cmd_vel", "/led_strip", ] @@ -125,8 +125,8 @@ def main(): """ node = RaiStateBasedLlmNode( - observe_topics=observe_topics, - observe_postprocessors=observe_postprocessors, + observe_topics=None, + observe_postprocessors=None, whitelist=topics_whitelist + actions_whitelist, system_prompt=SYSTEM_PROMPT, tools=[ From f37a3f2e1d306a9f9451c4c8b37519adff7da951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 18 Oct 2024 11:57:19 +0200 Subject: [PATCH 8/8] wip --- src/rai/rai/node.py | 1 + src/rai/rai/tools/utils.py | 1 - src/rai/rai/utils/ros.py | 4 ++++ .../rai_open_set_vision/tools/gdino_tools.py | 4 +--- .../rai_open_set_vision/tools/segmentation_tools.py | 7 ++++--- src/rai_hmi/rai_hmi/base.py | 1 + 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 129c6a05..9552ce83 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -416,6 +416,7 @@ def _initialize_tools(self, tools: List[Type[BaseTool]]): or "GetDistance" in tool_cls.__name__ or "GetTransformTool" in tool_cls.__name__ or "GetObjectPositionsTool" in tool_cls.__name__ + or "GetDetectionTool" in tool_cls.__name__ ): # TODO(boczekbartek): develop a way to handle all mutially if args: tool = tool_cls(node=self._async_tool_node, **args) diff --git a/src/rai/rai/tools/utils.py b/src/rai/rai/tools/utils.py index a9aca8a5..fb17774e 100644 --- a/src/rai/rai/tools/utils.py +++ b/src/rai/rai/tools/utils.py @@ -43,7 +43,6 @@ from rai.messages import ToolMultimodalMessage -# Copied from https://github.com/ros2/rclpy/blob/jazzy/rclpy/rclpy/wait_for_message.py, to support humble def wait_for_message( msg_type, node: "Node", diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index c281de74..18d6c0f9 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -36,11 +36,15 @@ def __init__(self, llm: BaseChatModel, bufsize: int = 100) -> None: ) llm = llm self.llm = self.template | llm + self.filter_out = ["rviz", "rai"] def clear(self): self._buffer.clear() def append(self, line: str): + for w in self.filter_out: + if w in line: + return self._buffer.append(line) if len(self._buffer) > self.bufsize: self._buffer.popleft() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 22af5be6..591af011 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -25,13 +25,11 @@ ParameterUninitializedException, ) from rclpy.task import Future -from rclpy.wait_for_message import wait_for_message from rai.node import RaiAsyncToolsNode from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray - -# from rai.tools.utils import wait_for_message +from rai.tools.utils import wait_for_message from rai_interfaces.srv import RAIGroundingDino diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 60ffe779..eb5abd6c 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -27,14 +27,15 @@ ParameterUninitializedException, ) -# from rai.tools.utils import wait_for_message -from rclpy.wait_for_message import wait_for_message - from rai.tools.ros import Ros2BaseInput from rai.tools.ros.native_actions import Ros2BaseActionTool from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray +from rai.tools.utils import wait_for_message from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino +# from rclpy.wait_for_message import wait_for_message + + # --------------------- Inputs --------------------- diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py index d64c9067..2aed9f9a 100644 --- a/src/rai_hmi/rai_hmi/base.py +++ b/src/rai_hmi/rai_hmi/base.py @@ -77,6 +77,7 @@ class BaseHMINode(Node): add as 1 task. They will be done by another agent resposible for communication with the robotic's stack. + If you are asked about logs, or what was write or wrong about the mission, use get_mission_memory tool to get such information. """ def __init__(