From c236a6f142b13ef1de090dc67deba0b606830bd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 5 Dec 2024 14:09:53 +0100 Subject: [PATCH 01/31] fix(`GetTransform`): add missing `args_schema` --- src/rai/rai/tools/ros/native_actions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index 1327a576..369fe1e4 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -1,5 +1,4 @@ # Copyright (C) 2024 Robotec.AI -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -187,10 +186,17 @@ def _run(self) -> str: return str(self.node.action_feedback) +class GetTransformInput(BaseModel): + target_frame: str = Field(default="map", description="Target frame") + source_frame: str = Field(default="body_link", description="Source frame") + + class GetTransformTool(Ros2BaseActionTool): name: str = "GetTransform" description: str = "Get transform between two frames" + args_schema: Type[GetTransformInput] = GetTransformInput + def _run(self, target_frame="map", source_frame="body_link") -> dict: return message_to_ordereddict( get_transform(self.node, target_frame, source_frame) From 96e0a6cc12fa51bcab318915504cbd43590cf429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 5 Dec 2024 14:23:17 +0100 Subject: [PATCH 02/31] refactor(`RaiNode`): executors and subscribers --- src/rai/rai/node.py | 293 +++++++++++++++++---------------- src/rai/rai/tools/ros/utils.py | 27 ++- src/rai/rai/utils/ros.py | 39 ++++- 3 files changed, 212 insertions(+), 147 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index d60fba93..5df004b8 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -45,12 +45,11 @@ from rclpy.topic_endpoint_info import TopicEndpointInfo from std_srvs.srv import Trigger +import rai.utils.ros from rai.agents.state_based import Report, State, create_state_based_agent from rai.messages import HumanMultimodalMessage from rai.tools.ros.native import Ros2BaseTool -from rai.tools.ros.native_actions import Ros2BaseActionTool from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str -from rai.tools.utils import wait_for_message from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks from rai.utils.ros import NodeDiscovery from rai.utils.ros_logs import create_logs_parser @@ -252,8 +251,6 @@ def __init__( ): super().__init__(*args, **kwargs) - self.robot_state: Dict[str, Any] = dict() # where Any is ROS 2 message type - self.DISCOVERY_FREQ = 2.0 self.DISCOVERY_DEPTH = 5 self.timer = self.create_timer( @@ -264,15 +261,12 @@ def __init__( self.discovery() self.qos_profile_cache: Dict[str, QoSProfile] = dict() - self.state_subscribers = dict() - - # ------- ROS2 actions handling ------- - self._async_tool_node = RaiAsyncToolsNode() + executor = rai.utils.ros.MultiThreadedExecutorFixed() + executor.add_node(self) + self.ros_executor = executor def spin(self): - executor = rclpy.executors.MultiThreadedExecutor() - executor.add_node(self) - executor.spin() + self.ros_executor.spin() rclpy.shutdown() def discovery(self): @@ -282,94 +276,6 @@ def discovery(self): get_action_names_and_types(self), ) - def adapt_requests_to_offers( - self, publisher_info: List[TopicEndpointInfo] - ) -> QoSProfile: - if not publisher_info: - return QoSProfile(depth=1) - - num_endpoints = len(publisher_info) - reliability_reliable_count = 0 - durability_transient_local_count = 0 - - for endpoint in publisher_info: - profile = endpoint.qos_profile - if profile.reliability == ReliabilityPolicy.RELIABLE: - reliability_reliable_count += 1 - if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: - durability_transient_local_count += 1 - - request_qos = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - liveliness=LivelinessPolicy.AUTOMATIC, - ) - - # Set reliability based on publisher offers - if reliability_reliable_count == num_endpoints: - request_qos.reliability = ReliabilityPolicy.RELIABLE - else: - if reliability_reliable_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering RELIABLE reliability. " - "Falling back to BEST_EFFORT as it will connect to all publishers. " - "Some messages from Reliable publishers could be dropped." - ) - request_qos.reliability = ReliabilityPolicy.BEST_EFFORT - - # Set durability based on publisher offers - if durability_transient_local_count == num_endpoints: - request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL - else: - if durability_transient_local_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " - "Falling back to VOLATILE as it will connect to all publishers. " - "Previously-published latched messages will not be retrieved." - ) - request_qos.durability = DurabilityPolicy.VOLATILE - - return request_qos - - def get_raw_message_from_topic( - self, topic: str, timeout_sec: int = 1 - ) -> Union[Any, str]: # ROS 2 topic or error string - self.get_logger().debug(f"Getting msg from topic: {topic}") - if topic in self.state_subscribers and topic in self.robot_state: - self.get_logger().debug("Returning cached message") - return self.robot_state[topic] - else: - msg_type = self.get_msg_type(topic) - if topic not in self.qos_profile_cache: - self.get_logger().debug(f"Getting qos profile for topic: {topic}") - qos_profile = self.adapt_requests_to_offers( - self.get_publishers_info_by_topic(topic) - ) - self.qos_profile_cache[topic] = qos_profile - else: - self.get_logger().debug(f"Using cached qos profile for topic: {topic}") - qos_profile = self.qos_profile_cache[topic] - - success, msg = wait_for_message( - msg_type, - self, - topic, - qos_profile=qos_profile, - time_to_wait=timeout_sec, - ) - - if success: - self.get_logger().debug( - f"Received message of type {msg_type.__class__.__name__} from topic {topic}" - ) - return msg - else: - error = ( - f"No message received in {timeout_sec} seconds from topic {topic}" - ) - self.get_logger().error(error) - return error - def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: """Sometimes node fails to do full discovery, therefore we need to retry""" for _ in range(n_tries): @@ -416,7 +322,7 @@ def __init__( self.callback_group = rclpy.callback_groups.ReentrantCallbackGroup() # ---------- Robot State ---------- - self.robot_state = dict() + self.last_subscription_msgs_buffer = dict() self.state_topics = observe_topics if observe_topics is not None else [] self.state_postprocessors = ( observe_postprocessors if observe_postprocessors is not None else dict() @@ -465,15 +371,7 @@ def _initialize_tools(self, tools: List[Type[BaseTool]]): initialized_tools: List[BaseTool] = list() for tool_cls in tools: if issubclass(tool_cls, Ros2BaseTool): - if ( - issubclass(tool_cls, Ros2BaseActionTool) - or "DetectionTool" in tool_cls.__name__ - or "GetDistance" in tool_cls.__name__ - or "GetTransformTool" in tool_cls.__name__ - ): # TODO(boczekbartek): develop a way to handle all mutially - tool = tool_cls(node=self._async_tool_node) - else: - tool = tool_cls(node=self) + tool = tool_cls(node=self) else: tool = tool_cls() @@ -489,28 +387,143 @@ def _initialize_system_prompt(self, prompt: str): def _initialize_robot_state_interfaces(self, topics: List[str]): for topic in topics: - msg_type = self.get_msg_type(topic) - topic_callback = functools.partial( - self.generic_state_subscriber_callback, topic - ) - qos_profile = self.adapt_requests_to_offers( - self.get_publishers_info_by_topic(topic) - ) - subscriber = self.create_subscription( - msg_type, - topic, - callback=topic_callback, - callback_group=self.callback_group, - qos_profile=qos_profile, + self.create_subscription_by_topic_name(topic) + + def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 5) -> Any: + self.get_logger().debug(f"Getting msg from topic: {topic}") + + ts = time.perf_counter() + + if topic not in self.ros_discovery_info.topics_and_types: + raise KeyError( + f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" ) - self.state_subscribers[topic] = subscriber + if topic in self.last_subscription_msgs_buffer: + self.get_logger().info("Returning cached message") + return self.last_subscription_msgs_buffer[topic] + else: + self.create_subscription_by_topic_name(topic) + try: + msg = self.last_subscription_msgs_buffer.get(topic, None) + while msg is None and time.perf_counter() - ts < timeout_sec: + msg = self.last_subscription_msgs_buffer.get(topic, None) + self.get_logger().info("Waiting for message...") + time.sleep(0.1) + + success = msg is not None + + if success: + self.get_logger().debug( + f"Received message of type {type(msg)} from topic {topic}" + ) + return msg + else: + error = f"No message received in {timeout_sec} seconds from topic {topic}" + self.get_logger().error(error) + return error + finally: + self.destroy_subscription_by_topic_name(topic) def generic_state_subscriber_callback(self, topic_name: str, msg: Any): self.get_logger().debug( f"Received message of type {type(msg)} from topic {topic_name}" ) - self.robot_state[topic_name] = msg + + self.last_subscription_msgs_buffer[topic_name] = msg + + def adapt_requests_to_offers( + self, publisher_info: List[TopicEndpointInfo] + ) -> QoSProfile: + if not publisher_info: + return QoSProfile(depth=1) + + num_endpoints = len(publisher_info) + reliability_reliable_count = 0 + durability_transient_local_count = 0 + + for endpoint in publisher_info: + profile = endpoint.qos_profile + if profile.reliability == ReliabilityPolicy.RELIABLE: + reliability_reliable_count += 1 + if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: + durability_transient_local_count += 1 + + request_qos = QoSProfile( + history=HistoryPolicy.KEEP_LAST, + depth=1, + liveliness=LivelinessPolicy.AUTOMATIC, + ) + + # Set reliability based on publisher offers + if reliability_reliable_count == num_endpoints: + request_qos.reliability = ReliabilityPolicy.RELIABLE + else: + if reliability_reliable_count > 0: + self.get_logger().warning( + "Some, but not all, publishers are offering RELIABLE reliability. " + "Falling back to BEST_EFFORT as it will connect to all publishers. " + "Some messages from Reliable publishers could be dropped." + ) + request_qos.reliability = ReliabilityPolicy.BEST_EFFORT + + # Set durability based on publisher offers + if durability_transient_local_count == num_endpoints: + request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL + else: + if durability_transient_local_count > 0: + self.get_logger().warning( + "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " + "Falling back to VOLATILE as it will connect to all publishers. " + "Previously-published latched messages will not be retrieved." + ) + request_qos.durability = DurabilityPolicy.VOLATILE + + return request_qos + + def create_subscription_by_topic_name(self, topic): + if self.has_subscription(topic): + self.get_logger().warning( + f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" + ) + return + + msg_type = self.get_msg_type(topic) + + if topic not in self.qos_profile_cache: + self.get_logger().debug(f"Getting qos profile for topic: {topic}") + qos_profile = self.adapt_requests_to_offers( + self.get_publishers_info_by_topic(topic) + ) + self.qos_profile_cache[topic] = qos_profile + else: + self.get_logger().debug(f"Using cached qos profile for topic: {topic}") + qos_profile = self.qos_profile_cache[topic] + + topic_callback = functools.partial( + self.generic_state_subscriber_callback, topic + ) + + + self.create_subscription( + msg_type, + topic, + callback=topic_callback, + callback_group=self.callback_group, + qos_profile=qos_profile, + ) + + def has_subscription(self, topic: str) -> bool: + for sub in self._subscriptions: + if sub.topic == topic: + return True + return False + + def destroy_subscription_by_topic_name(self, topic: str): + self.last_subscription_msgs_buffer.clear() + for sub in self._subscriptions: + if sub.topic == topic: + self.destroy_subscription(sub) def goal_callback(self, _) -> GoalResponse: """Accept or reject a client request to begin an action.""" @@ -619,31 +632,35 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): def state_update_callback(self): state_dict = dict() - if self.robot_state is None: - return state_dict + ts = time.perf_counter() + try: + state_dict["logs_summary"] = self.summarize_logs() + except Exception as e: + self.get_logger().error(f"Error summarizing logs: {e}") + state_dict["logs_summary"] = "" + te = time.perf_counter() - ts + self.get_logger().info(f"Logs summary retrieved in: {te:.2f}") + self.get_logger().debug(f"{state_dict=}") + + if self.last_subscription_msgs_buffer is None: + self.state_dict = state_dict + return - for t in self.state_subscribers: - if t not in self.robot_state: + for t in self.state_topics: + if t not in self.last_subscription_msgs_buffer: msg = "No message yet" state_dict[t] = msg continue + ts = time.perf_counter() - msg = self.robot_state[t] + msg = self.last_subscription_msgs_buffer[t] if t in self.state_postprocessors: msg = self.state_postprocessors[t](msg) te = time.perf_counter() - ts self.get_logger().info(f"Topic '{t}' postprocessed in: {te:.2f}") + state_dict[t] = msg - ts = time.perf_counter() - try: - state_dict["logs_summary"] = self.summarize_logs() - except Exception as e: - self.get_logger().error(f"Error summarizing logs: {e}") - state_dict["logs_summary"] = "" - te = time.perf_counter() - ts - self.get_logger().info(f"Logs summary retrieved in: {te:.2f}") - self.get_logger().debug(f"{state_dict=}") self.state_dict = state_dict def get_robot_state(self) -> Dict[str, str]: diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index 2e52ab7c..2499da2c 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -19,6 +19,7 @@ import cv2 import numpy as np import rclpy +import rclpy.executors import rclpy.node import rclpy.time import sensor_msgs.msg @@ -31,7 +32,7 @@ from rosidl_parser.definition import NamespacedType from rosidl_runtime_py.import_message import import_message_from_namespaced_type from rosidl_runtime_py.utilities import get_namespaced_type -from tf2_ros import Buffer, TransformListener, TransformStamped +from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped def import_message_from_str(msg_type: str) -> Type[object]: @@ -155,17 +156,27 @@ def wait_for_message( def get_transform( - node: rclpy.node.Node, target_frame: str, source_frame: str + node: rclpy.node.Node, + target_frame: str, + source_frame: str, + timeout_sec: float = 5.0, ) -> TransformStamped: tf_buffer = Buffer(node=node) tf_listener = TransformListener(tf_buffer, node) transform = None - while transform is None: - rclpy.spin_once(node, timeout=0.5) - if tf_buffer.can_transform(target_frame, source_frame, rclpy.time.Time()): - transform = tf_buffer.lookup_transform( - target_frame, source_frame, rclpy.time.Time() - ) + future = tf_buffer.wait_for_transform_async( + target_frame, source_frame, rclpy.time.Time() + ) + + node.ros_executor.spin_until_future_complete(future, timeout_sec=timeout_sec) + + transform = future.result() + tf_listener.unregister() + if not future.done() or transform is None: + raise LookupException( + f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds" + ) + return transform diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 89a76cc8..9c83c2b3 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -16,7 +16,6 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple - @dataclass class NodeDiscovery: topics_and_types: Dict[str, str] = field(default_factory=dict) @@ -50,3 +49,41 @@ def dict(self): "services_and_types": self.services_and_types, "actions_and_types": self.actions_and_types, } + + +class MultiThreadedExecutorFixed(MultiThreadedExecutor): + """ + Adresses a comment: + ```python + # make a copy of the list that we iterate over while modifying it + # (https://stackoverflow.com/q/1207406/3753684) + ``` + from the rclpy implementation + """ + + def _spin_once_impl( + self, + timeout_sec: Optional[Union[float, TimeoutObject]] = None, + wait_condition: Callable[[], bool] = lambda: False, + ) -> None: + try: + handler, entity, node = self.wait_for_ready_callbacks( + timeout_sec, None, wait_condition + ) + except ExternalShutdownException: + pass + except ShutdownException: + pass + except TimeoutException: + pass + except ConditionReachedException: + pass + else: + self._executor.submit(handler) + self._futures.append(handler) + futures = self._futures.copy() + for future in futures[:]: + if future.done(): + futures.remove(future) + future.result() # raise any exceptions + self._futures = futures From 6b7a2a94739f240f47949d1498c4ec52f122347d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 5 Dec 2024 14:27:33 +0100 Subject: [PATCH 03/31] improve logging of tool calls --- src/rai/rai/agents/tool_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai/rai/agents/tool_runner.py b/src/rai/rai/agents/tool_runner.py index e2014d60..12e0889d 100644 --- a/src/rai/rai/agents/tool_runner.py +++ b/src/rai/rai/agents/tool_runner.py @@ -62,7 +62,7 @@ def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any: raise ValueError("Last message is not an AIMessage") def run_one(call: ToolCall): - self.logger.info(f"Running tool: {call['name']}") + self.logger.info(f"Running tool: {call['name']}, args: {call['args']}") artifact = None try: From 6ced00cbb7258c3e7a1c16117ee71166aa74a4b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 5 Dec 2024 15:25:25 +0100 Subject: [PATCH 04/31] remove RaiAsctionToolsNode --- src/rai/rai/node.py | 49 ++++++++++++++++++++------------------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 5df004b8..86a99d53 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -143,9 +143,28 @@ def ros2_build_msg(msg_type: str, msg_args: Dict[str, Any]) -> Tuple[object, Typ return msg, msg_cls -class RaiAsyncToolsNode(Node): - def __init__(self): - super().__init__("rai_internal_action_node") +class RaiBaseNode(Node): + def __init__( + self, + allowlist: Optional[List[str]] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.DISCOVERY_FREQ = 2.0 + self.DISCOVERY_DEPTH = 5 + self.timer = self.create_timer( + self.DISCOVERY_FREQ, + self.discovery, + ) + self.ros_discovery_info = NodeDiscovery(allowlist=allowlist) + self.discovery() + self.qos_profile_cache: Dict[str, QoSProfile] = dict() + + executor = rai.utils.ros.MultiThreadedExecutorFixed() + executor.add_node(self) + self.ros_executor = executor self.goal_handle = None self.result_future = None @@ -241,30 +260,6 @@ def _cancel_task(self): rclpy.spin_until_future_complete(self, future) return True - -class RaiBaseNode(Node): - def __init__( - self, - allowlist: Optional[List[str]] = None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.DISCOVERY_FREQ = 2.0 - self.DISCOVERY_DEPTH = 5 - self.timer = self.create_timer( - self.DISCOVERY_FREQ, - self.discovery, - ) - self.ros_discovery_info = NodeDiscovery(allowlist=allowlist) - self.discovery() - self.qos_profile_cache: Dict[str, QoSProfile] = dict() - - executor = rai.utils.ros.MultiThreadedExecutorFixed() - executor.add_node(self) - self.ros_executor = executor - def spin(self): self.ros_executor.spin() rclpy.shutdown() From 7d96e5e1154ab44dbce26887f4fd562e373438cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 5 Dec 2024 19:18:10 +0100 Subject: [PATCH 05/31] refactor: executors and actions calling --- src/rai/rai/node.py | 55 +++++++++++++++++++++++++++------- src/rai/rai/tools/ros/utils.py | 2 +- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 86a99d53..7ca88109 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -186,10 +186,11 @@ def _run_action(self, action_name, action_type, action_goal_args): except Exception as e: return f"Failed to build message: {e}" - client = ActionClient(self, msg_cls, action_name) + self.client = ActionClient(self, msg_cls, action_name) + self.msg_cls = msg_cls retries = 0 - while not client.wait_for_server(timeout_sec=1.0): + while not self.client.wait_for_server(timeout_sec=1.0): retries += 1 if retries > 5: raise Exception( @@ -201,7 +202,9 @@ def _run_action(self, action_name, action_type, action_goal_args): self.get_logger().info(f"Sending action message: {goal_msg}") - send_goal_future = client.send_goal_async(goal_msg, self._feedback_callback) + send_goal_future = self.client.send_goal_async( + goal_msg, self._feedback_callback + ) self.get_logger().info("Action goal sent!") rclpy.spin_until_future_complete(self, send_goal_future) self.goal_handle = send_goal_future.result() @@ -216,18 +219,42 @@ def _run_action(self, action_name, action_type, action_goal_args): self.get_logger().info("Action sent!") return f"{action_name} started successfully with args: {action_goal_args}" - def _get_task_result(self): + def _get_task_result(self) -> str: if not self._is_task_complete(): return "Task is not complete yet" - if self.status == GoalStatus.STATUS_SUCCEEDED: - return "Succeeded" - elif self.status == GoalStatus.STATUS_ABORTED: - return "Failed" - elif self.status == GoalStatus.STATUS_CANCELED: - return "Cancelled" + def parse_status(status: int) -> str: + return { + GoalStatus.STATUS_SUCCEEDED: "succeeded", + GoalStatus.STATUS_ABORTED: "aborted", + GoalStatus.STATUS_CANCELED: "canceled", + GoalStatus.STATUS_ACCEPTED: "accepted", + GoalStatus.STATUS_CANCELING: "canceling", + GoalStatus.STATUS_EXECUTING: "executing", + GoalStatus.STATUS_UNKNOWN: "unknown", + }.get(status, "unknown") + + result = self.result_future.result() + + self.destroy_client(self.client) + + if result.status == GoalStatus.STATUS_SUCCEEDED: + msg = f"Result succeeded: {result.result}" + self.get_logger().info(msg) + return msg else: - return "Failed" + str_status = parse_status(result.status) + error_code_str = self.parse_error_code(result.result.error_code) + msg = f"Result {str_status}, because of: error_code={result.result.error_code}({error_code_str}), error_msg={result.result.error_msg}" + self.get_logger().info(msg) + return msg + + def parse_error_code(self, code: int) -> str: + name_to_code = self.msg_cls.Result.__prepare__( + name="", bases="" + ) # arguments are not used + code_to_name = {v: k for k, v in name_to_code.items()} + return code_to_name.get(code, "UNKNOWN") def _feedback_callback(self, msg): self.get_logger().info(f"Received ros2 action feedback: {msg}") @@ -258,6 +285,7 @@ def _cancel_task(self): if self.result_future and self.goal_handle: future = self.goal_handle.cancel_goal_async() rclpy.spin_until_future_complete(self, future) + self.destroy_client(self.client) return True def spin(self): @@ -340,6 +368,7 @@ def __init__( ) # Node is busy when task is executed. Only 1 task is allowed self.busy = False + self.current_task = None # ---------- LLM Agents ---------- self.tools = self._initialize_tools(tools) if tools is not None else [] @@ -403,6 +432,7 @@ def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 5) -> Any: msg = self.last_subscription_msgs_buffer.get(topic, None) while msg is None and time.perf_counter() - ts < timeout_sec: msg = self.last_subscription_msgs_buffer.get(topic, None) + rclpy.spin_once(self, timeout_sec=0.1) self.get_logger().info("Waiting for message...") time.sleep(0.1) @@ -546,6 +576,7 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): ), HumanMessage(content=f"Task: {task}"), ] + self.current_task = task payload = State(messages=messages) @@ -623,9 +654,11 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): return result finally: self.busy = False + self.current_task = None def state_update_callback(self): state_dict = dict() + state_dict["current_task"] = self.current_task ts = time.perf_counter() try: diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index 2499da2c..d5452094 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -169,7 +169,7 @@ def get_transform( target_frame, source_frame, rclpy.time.Time() ) - node.ros_executor.spin_until_future_complete(future, timeout_sec=timeout_sec) + rclpy.spin_until_future_complete(node, future, timeout_sec=timeout_sec) transform = future.result() From 2c2fc346d9866f2cf10a9dc5eca3cbe1664bd787 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 6 Dec 2024 03:53:31 +0100 Subject: [PATCH 06/31] further refactoring of `RaiNode` --- config.toml | 6 +- examples/rosbot-xl-demo.py | 15 +- src/rai/rai/node.py | 481 ++++++++++++++---------- src/rai/rai/tools/ros/native.py | 6 +- src/rai/rai/tools/ros/native_actions.py | 10 +- src/rai/rai/tools/ros/utils.py | 15 +- src/rai/rai/utils/ros.py | 66 +++- src/rai/rai/utils/ros_logs.py | 19 +- 8 files changed, 372 insertions(+), 246 deletions(-) diff --git a/config.toml b/config.toml index e06fc752..832a59b2 100644 --- a/config.toml +++ b/config.toml @@ -16,10 +16,10 @@ embeddings_model = "text-embedding-ada-002" base_url = "https://api.openai.com/v1/" # for openai compatible apis [ollama] -simple_model = "llama3.2" -complex_model = "llama3.1:70b" +simple_model = "qwq" +complex_model = "qwen2.5:7b" embeddings_model = "llama3.2" -base_url = "http://localhost:11434" +base_url = "http://via-ip-robo-srv-004.robotec.tm.pl:11434" [tracing] project = "rai" diff --git a/examples/rosbot-xl-demo.py b/examples/rosbot-xl-demo.py index 028f9f4b..988a8a4f 100644 --- a/examples/rosbot-xl-demo.py +++ b/examples/rosbot-xl-demo.py @@ -18,23 +18,22 @@ import rclpy import rclpy.executors import rclpy.logging -from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool from rai.node import RaiStateBasedLlmNode from rai.tools.ros.native import ( GetCameraImage, GetMsgFromTopic, + Ros2GetRobotInterfaces, Ros2PubMessageTool, Ros2ShowMsgInterfaceTool, ) from rai.tools.ros.native_actions import ( + GetTransformTool, Ros2CancelAction, Ros2GetActionResult, Ros2GetLastActionFeedback, - Ros2IsActionComplete, Ros2RunActionAsync, ) -from rai.tools.ros.tools import GetCurrentPositionTool from rai.tools.time import WaitForSecondsTool p = argparse.ArgumentParser() @@ -62,11 +61,14 @@ def main(allowlist: Optional[Path] = None): rclpy.logging.get_logger("rosbot_xl_demo").error( f"Failed to read allowlist: {e}" ) + else: + ros2_allowlist = None SYSTEM_PROMPT = """You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. Do not make assumptions about the environment you are currently in. You can use ros2 topics, services and actions to operate. + As a first step check transforms by getting 1 message from /tf topic use /cmd_vel topic very carefully. Obstacle detection works only with nav2 stack, so be careful when it is not used. > be patient with running ros2 actions. usually the take some time to run. Always check your transform before and after you perform ros2 actions, so that you can verify if it worked. @@ -88,6 +90,7 @@ def main(allowlist: Optional[Path] = None): - to find some object navigate around and check the surrounding area - when the goal is accomplished please make sure to cancel running actions - when you reach the navigation goal - double check if you reached it by checking the current position + - if you detect collision, please stop operation - you will be given your camera image description. Based on this information you can reason about positions of objects. - be careful and aboid obstacles @@ -118,19 +121,17 @@ def main(allowlist: Optional[Path] = None): allowlist=ros2_allowlist, system_prompt=SYSTEM_PROMPT, tools=[ + Ros2GetRobotInterfaces, Ros2PubMessageTool, Ros2RunActionAsync, - Ros2IsActionComplete, Ros2CancelAction, Ros2GetActionResult, Ros2GetLastActionFeedback, Ros2ShowMsgInterfaceTool, - GetCurrentPositionTool, + GetTransformTool, WaitForSecondsTool, GetMsgFromTopic, GetCameraImage, - GetDetectionTool, - GetDistanceToObjectsTool, ], ) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 7ca88109..ad04948d 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -20,6 +20,7 @@ import rclpy import rclpy.callback_groups import rclpy.executors +import rclpy.node import rclpy.qos import rclpy.subscription import rclpy.task @@ -32,7 +33,6 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langgraph.graph.graph import CompiledGraph from rclpy.action.client import ActionClient -from rclpy.action.graph import get_action_names_and_types from rclpy.action.server import ActionServer, GoalResponse, ServerGoalHandle from rclpy.node import Node from rclpy.qos import ( @@ -52,6 +52,7 @@ from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks from rai.utils.ros import NodeDiscovery +from rai.utils.ros_async import get_future_result from rai.utils.ros_logs import create_logs_parser from rai_interfaces.action import Task as TaskAction @@ -92,15 +93,15 @@ def append_whoami_info_to_prompt( node.get_logger().info("Identity service not available, waiting again...") constitution_future = constitution_service.call_async(Trigger.Request()) - rclpy.spin_until_future_complete(node, constitution_future) - constitution_response = constitution_future.result() + constitution_response: Optional[Trigger.Response] = get_future_result( + constitution_future + ) constitution_message = ( "" if constitution_response is None else constitution_response.message ) identity_future = identity_service.call_async(Trigger.Request()) - rclpy.spin_until_future_complete(node, identity_future) - identity_response = identity_future.result() + identity_response: Optional[Trigger.Response] = get_future_result(identity_future) identity_message = "" if identity_response is None else identity_response.message system_prompt = WHOAMI_SYSTEM_PROMPT_TEMPLATE.format( @@ -143,36 +144,24 @@ def ros2_build_msg(msg_type: str, msg_args: Dict[str, Any]) -> Tuple[object, Typ return msg, msg_cls -class RaiBaseNode(Node): - def __init__( - self, - allowlist: Optional[List[str]] = None, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.DISCOVERY_FREQ = 2.0 - self.DISCOVERY_DEPTH = 5 - self.timer = self.create_timer( - self.DISCOVERY_FREQ, - self.discovery, - ) - self.ros_discovery_info = NodeDiscovery(allowlist=allowlist) - self.discovery() - self.qos_profile_cache: Dict[str, QoSProfile] = dict() - - executor = rai.utils.ros.MultiThreadedExecutorFixed() - executor.add_node(self) - self.ros_executor = executor +class AsyncRos2ActionClient: + def __init__(self, node: rclpy.node.Node): + self.node = node self.goal_handle = None self.result_future = None self.feedback = None - self.status = None + self.status: Optional[int] = None + self.client: Optional[ActionClient] = None + self.action_feedback: Optional[Any] = None + + def get_logger(self): + return self.node.get_logger() - def _run_action(self, action_name, action_type, action_goal_args): - if not self._is_task_complete(): + def run_action( + self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] + ): + if not self.is_task_complete(): raise AssertionError( "Another ros2 action is currently running and parallel actions are not supported. Please wait until the previous action is complete before starting a new one. You can also cancel the current one." ) @@ -186,7 +175,7 @@ def _run_action(self, action_name, action_type, action_goal_args): except Exception as e: return f"Failed to build message: {e}" - self.client = ActionClient(self, msg_cls, action_name) + self.client = ActionClient(self.node, msg_cls, action_name) self.msg_cls = msg_cls retries = 0 @@ -206,8 +195,8 @@ def _run_action(self, action_name, action_type, action_goal_args): goal_msg, self._feedback_callback ) self.get_logger().info("Action goal sent!") - rclpy.spin_until_future_complete(self, send_goal_future) - self.goal_handle = send_goal_future.result() + + self.goal_handle = get_future_result(send_goal_future) if not self.goal_handle: raise Exception(f"Action '{action_name}' not sent to server") @@ -219,8 +208,8 @@ def _run_action(self, action_name, action_type, action_goal_args): self.get_logger().info("Action sent!") return f"{action_name} started successfully with args: {action_goal_args}" - def _get_task_result(self) -> str: - if not self._is_task_complete(): + def get_task_result(self) -> str: + if not self.is_task_complete(): return "Task is not complete yet" def parse_status(status: int) -> str: @@ -236,8 +225,7 @@ def parse_status(status: int) -> str: result = self.result_future.result() - self.destroy_client(self.client) - + self.destroy_client() if result.status == GoalStatus.STATUS_SUCCEEDED: msg = f"Result succeeded: {result.result}" self.get_logger().info(msg) @@ -260,13 +248,14 @@ def _feedback_callback(self, msg): self.get_logger().info(f"Received ros2 action feedback: {msg}") self.action_feedback = msg - def _is_task_complete(self): + def is_task_complete(self): if not self.result_future: # task was cancelled or completed return True - rclpy.spin_until_future_complete(self, self.result_future, timeout_sec=0.10) - if self.result_future.result(): - self.status = self.result_future.result().status + + result = get_future_result(self.result_future, timeout_sec=0.10) + if result is not None: + self.status = result.status if self.status != GoalStatus.STATUS_SUCCEEDED: self.get_logger().debug( f"Task with failed with status code: {self.status}" @@ -280,25 +269,126 @@ def _is_task_complete(self): self.get_logger().info("Task succeeded!") return True - def _cancel_task(self): + def cancel_task(self) -> Union[str, bool]: self.get_logger().info("Canceling current task.") - if self.result_future and self.goal_handle: - future = self.goal_handle.cancel_goal_async() - rclpy.spin_until_future_complete(self, future) - self.destroy_client(self.client) - return True + try: + if self.result_future and self.goal_handle: + future = self.goal_handle.cancel_goal_async() + result = get_future_result(future, timeout_sec=1.0) + return "Failed to cancel result" if result is None else True + return True + finally: + self.destroy_client() - def spin(self): - self.ros_executor.spin() - rclpy.shutdown() + def destroy_client(self): + if self.client: + self.client.destroy() + + +class Ros2TopicsHandler: + def __init__( + self, + node: rclpy.node.Node, + callback_group: rclpy.callback_groups.CallbackGroup, + ros_discovery_info: NodeDiscovery, + ) -> None: + self.node = node + self.qos_profile = QoSProfile( + history=HistoryPolicy.KEEP_LAST, + depth=1, + reliability=ReliabilityPolicy.BEST_EFFORT, + durability=DurabilityPolicy.VOLATILE, + liveliness=LivelinessPolicy.AUTOMATIC, + ) + self.callback_group = callback_group + self.last_subscription_msgs_buffer = dict() + self.qos_profile_cache: Dict[str, QoSProfile] = dict() + + self.ros_discovery_info = ros_discovery_info + + def get_logger(self): + return self.node.get_logger() + + def adapt_requests_to_offers( + self, publisher_info: List[TopicEndpointInfo] + ) -> QoSProfile: + if not publisher_info: + return QoSProfile(depth=1) + + num_endpoints = len(publisher_info) + reliability_reliable_count = 0 + durability_transient_local_count = 0 - def discovery(self): - self.ros_discovery_info.set( - self.get_topic_names_and_types(), - self.get_service_names_and_types(), - get_action_names_and_types(self), + for endpoint in publisher_info: + profile = endpoint.qos_profile + if profile.reliability == ReliabilityPolicy.RELIABLE: + reliability_reliable_count += 1 + if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: + durability_transient_local_count += 1 + + request_qos = QoSProfile( + history=HistoryPolicy.KEEP_LAST, + depth=1, + liveliness=LivelinessPolicy.AUTOMATIC, ) + # Set reliability based on publisher offers + if reliability_reliable_count == num_endpoints: + request_qos.reliability = ReliabilityPolicy.RELIABLE + else: + if reliability_reliable_count > 0: + self.get_logger().warning( + "Some, but not all, publishers are offering RELIABLE reliability. " + "Falling back to BEST_EFFORT as it will connect to all publishers. " + "Some messages from Reliable publishers could be dropped." + ) + request_qos.reliability = ReliabilityPolicy.BEST_EFFORT + + # Set durability based on publisher offers + if durability_transient_local_count == num_endpoints: + request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL + else: + if durability_transient_local_count > 0: + self.get_logger().warning( + "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " + "Falling back to VOLATILE as it will connect to all publishers. " + "Previously-published latched messages will not be retrieved." + ) + request_qos.durability = DurabilityPolicy.VOLATILE + + return request_qos + + def create_subscription_by_topic_name(self, topic): + if self.has_subscription(topic): + self.get_logger().warning( + f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" + ) + return + + msg_type = self.get_msg_type(topic) + + if topic not in self.qos_profile_cache: + self.get_logger().debug(f"Getting qos profile for topic: {topic}") + qos_profile = self.adapt_requests_to_offers( + self.node.get_publishers_info_by_topic(topic) + ) + self.qos_profile_cache[topic] = qos_profile + else: + self.get_logger().debug(f"Using cached qos profile for topic: {topic}") + qos_profile = self.qos_profile_cache[topic] + + topic_callback = functools.partial( + self.generic_state_subscriber_callback, topic + ) + + self.node.create_subscription( + msg_type, + topic, + callback=topic_callback, + callback_group=self.callback_group, + qos_profile=qos_profile, + ) + def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: """Sometimes node fails to do full discovery, therefore we need to retry""" for _ in range(n_tries): @@ -306,11 +396,123 @@ def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: msg_type = self.ros_discovery_info.topics_and_types[topic] return import_message_from_str(msg_type) else: + # Wait for next discovery cycle self.get_logger().info(f"Waiting for topic: {topic}") - self.discovery() - time.sleep(self.DISCOVERY_FREQ) + if self.ros_discovery_info: + time.sleep(self.ros_discovery_info.period_sec) + else: + time.sleep(1.0) raise KeyError(f"Topic {topic} not found") + def set_ros_discovery_info(self, new_ros_discovery_info: NodeDiscovery): + self.ros_discovery_info = new_ros_discovery_info + + def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 5) -> Any: + self.get_logger().debug(f"Getting msg from topic: {topic}") + + ts = time.perf_counter() + + if topic not in self.ros_discovery_info.topics_and_types: + raise KeyError( + f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" + ) + + if topic in self.last_subscription_msgs_buffer: + self.get_logger().info("Returning cached message") + return self.last_subscription_msgs_buffer[topic] + else: + self.create_subscription_by_topic_name(topic) + try: + msg = self.last_subscription_msgs_buffer.get(topic, None) + while msg is None and time.perf_counter() - ts < timeout_sec: + msg = self.last_subscription_msgs_buffer.get(topic, None) + self.get_logger().info("Waiting for message...") + time.sleep(0.1) + + success = msg is not None + + if success: + self.get_logger().debug( + f"Received message of type {type(msg)} from topic {topic}" + ) + return msg + else: + error = f"No message received in {timeout_sec} seconds from topic {topic}" + self.get_logger().error(error) + return error + finally: + self.destroy_subscription_by_topic_name(topic) + + def generic_state_subscriber_callback(self, topic_name: str, msg: Any): + self.get_logger().debug( + f"Received message of type {type(msg)} from topic {topic_name}" + ) + self.last_subscription_msgs_buffer[topic_name] = msg + + def has_subscription(self, topic: str) -> bool: + for sub in self.node._subscriptions: + if sub.topic == topic: + return True + return False + + def destroy_subscription_by_topic_name(self, topic: str): + self.last_subscription_msgs_buffer.clear() + for sub in self.node._subscriptions: + if sub.topic == topic: + self.node.destroy_subscription(sub) + + +class RaiBaseNode(Node): + def __init__( + self, + allowlist: Optional[List[str]] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + # ---------- ROS configuration ---------- + self.callback_group = rclpy.callback_groups.ReentrantCallbackGroup() + + # ---------- ROS helpers ---------- + self.ros_discovery_info = NodeDiscovery(self, allowlist=allowlist) + self.async_action_client = AsyncRos2ActionClient(self) + self.topics_handler = Ros2TopicsHandler( + self, self.callback_group, self.ros_discovery_info + ) + self.ros_discovery_info.add_setter(self.topics_handler.set_ros_discovery_info) + + # ------------- ros2 topics interface ------------- + def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 5) -> Any: + return self.topics_handler.get_raw_message_from_topic(topic, timeout_sec) + + # ------------- ros2 actions interface ------------- + def run_action( + self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] + ): + return self.async_action_client.run_action( + action_name, action_type, action_goal_args + ) + + def get_task_result(self) -> str: + return self.async_action_client.get_task_result() + + def is_task_complete(self) -> bool: + return self.async_action_client.is_task_complete() + + @property + def action_feedback(self) -> Any: + return self.async_action_client.action_feedback + + def cancel_task(self) -> Union[str, bool]: + return self.async_action_client.cancel_task() + + # ------------- other methods ------------- + def spin(self): + executor = rai.utils.ros.MultiThreadedExecutorFixed() + executor.add_node(self) + executor.spin() + rclpy.shutdown() + def parse_task_goal(ros_action_goal: TaskAction.Goal) -> Dict[str, Any]: return dict( @@ -322,6 +524,7 @@ def parse_task_goal(ros_action_goal: TaskAction.Goal) -> Dict[str, Any]: class RaiStateBasedLlmNode(RaiBaseNode): AGENT_RECURSION_LIMIT = 500 + STATE_UPDATE_PERIOD = 5.0 def __init__( self, @@ -341,18 +544,14 @@ def __init__( **kwargs, ) - # ---------- ROS configuration ---------- - self.callback_group = rclpy.callback_groups.ReentrantCallbackGroup() - # ---------- Robot State ---------- - self.last_subscription_msgs_buffer = dict() self.state_topics = observe_topics if observe_topics is not None else [] self.state_postprocessors = ( observe_postprocessors if observe_postprocessors is not None else dict() ) self._initialize_robot_state_interfaces(self.state_topics) self.state_update_timer = self.create_timer( - 7.0, + self.STATE_UPDATE_PERIOD, self.state_update_callback, callback_group=rclpy.callback_groups.MutuallyExclusiveCallbackGroup(), ) @@ -364,6 +563,7 @@ def __init__( TaskAction, "perform_task", execute_callback=self.agent_loop, + callback_group=self.callback_group, goal_callback=self.goal_callback, ) # Node is busy when task is executed. Only 1 task is allowed @@ -380,13 +580,13 @@ def __init__( logger=self.get_logger(), ) - # We have to use a separate node that we can manually spin for ros-service based - # parser and this node ros-subscriber based parser - logs_parser_node = self if logs_parser_type == "llm" else self._async_tool_node + self.simple_llm = get_llm_model(model_type="simple_model") self.logs_parser = create_logs_parser( - logs_parser_type, logs_parser_node, callback_group=self.callback_group + logs_parser_type, + self, + callback_group=self.callback_group, + llm=self.simple_llm, ) - self.simple_llm = get_llm_model(model_type="simple_model") def summarize_logs(self) -> str: return self.logs_parser.summarize() @@ -411,144 +611,7 @@ def _initialize_system_prompt(self, prompt: str): def _initialize_robot_state_interfaces(self, topics: List[str]): for topic in topics: - self.create_subscription_by_topic_name(topic) - - def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 5) -> Any: - self.get_logger().debug(f"Getting msg from topic: {topic}") - - ts = time.perf_counter() - - if topic not in self.ros_discovery_info.topics_and_types: - raise KeyError( - f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" - ) - - if topic in self.last_subscription_msgs_buffer: - self.get_logger().info("Returning cached message") - return self.last_subscription_msgs_buffer[topic] - else: - self.create_subscription_by_topic_name(topic) - try: - msg = self.last_subscription_msgs_buffer.get(topic, None) - while msg is None and time.perf_counter() - ts < timeout_sec: - msg = self.last_subscription_msgs_buffer.get(topic, None) - rclpy.spin_once(self, timeout_sec=0.1) - self.get_logger().info("Waiting for message...") - time.sleep(0.1) - - success = msg is not None - - if success: - self.get_logger().debug( - f"Received message of type {type(msg)} from topic {topic}" - ) - return msg - else: - error = f"No message received in {timeout_sec} seconds from topic {topic}" - self.get_logger().error(error) - return error - finally: - self.destroy_subscription_by_topic_name(topic) - - def generic_state_subscriber_callback(self, topic_name: str, msg: Any): - self.get_logger().debug( - f"Received message of type {type(msg)} from topic {topic_name}" - ) - - self.last_subscription_msgs_buffer[topic_name] = msg - - def adapt_requests_to_offers( - self, publisher_info: List[TopicEndpointInfo] - ) -> QoSProfile: - if not publisher_info: - return QoSProfile(depth=1) - - num_endpoints = len(publisher_info) - reliability_reliable_count = 0 - durability_transient_local_count = 0 - - for endpoint in publisher_info: - profile = endpoint.qos_profile - if profile.reliability == ReliabilityPolicy.RELIABLE: - reliability_reliable_count += 1 - if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: - durability_transient_local_count += 1 - - request_qos = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - liveliness=LivelinessPolicy.AUTOMATIC, - ) - - # Set reliability based on publisher offers - if reliability_reliable_count == num_endpoints: - request_qos.reliability = ReliabilityPolicy.RELIABLE - else: - if reliability_reliable_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering RELIABLE reliability. " - "Falling back to BEST_EFFORT as it will connect to all publishers. " - "Some messages from Reliable publishers could be dropped." - ) - request_qos.reliability = ReliabilityPolicy.BEST_EFFORT - - # Set durability based on publisher offers - if durability_transient_local_count == num_endpoints: - request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL - else: - if durability_transient_local_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " - "Falling back to VOLATILE as it will connect to all publishers. " - "Previously-published latched messages will not be retrieved." - ) - request_qos.durability = DurabilityPolicy.VOLATILE - - return request_qos - - def create_subscription_by_topic_name(self, topic): - if self.has_subscription(topic): - self.get_logger().warning( - f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" - ) - return - - msg_type = self.get_msg_type(topic) - - if topic not in self.qos_profile_cache: - self.get_logger().debug(f"Getting qos profile for topic: {topic}") - qos_profile = self.adapt_requests_to_offers( - self.get_publishers_info_by_topic(topic) - ) - self.qos_profile_cache[topic] = qos_profile - else: - self.get_logger().debug(f"Using cached qos profile for topic: {topic}") - qos_profile = self.qos_profile_cache[topic] - - topic_callback = functools.partial( - self.generic_state_subscriber_callback, topic - ) - - - self.create_subscription( - msg_type, - topic, - callback=topic_callback, - callback_group=self.callback_group, - qos_profile=qos_profile, - ) - - def has_subscription(self, topic: str) -> bool: - for sub in self._subscriptions: - if sub.topic == topic: - return True - return False - - def destroy_subscription_by_topic_name(self, topic: str): - self.last_subscription_msgs_buffer.clear() - for sub in self._subscriptions: - if sub.topic == topic: - self.destroy_subscription(sub) + self.topics_handler.create_subscription_by_topic_name(topic) def goal_callback(self, _) -> GoalResponse: """Accept or reject a client request to begin an action.""" @@ -657,6 +720,7 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): self.current_task = None def state_update_callback(self): + self.get_logger().info("Updating state.") state_dict = dict() state_dict["current_task"] = self.current_task @@ -670,18 +734,18 @@ def state_update_callback(self): self.get_logger().info(f"Logs summary retrieved in: {te:.2f}") self.get_logger().debug(f"{state_dict=}") - if self.last_subscription_msgs_buffer is None: + if self.topics_handler.last_subscription_msgs_buffer is None: self.state_dict = state_dict return for t in self.state_topics: - if t not in self.last_subscription_msgs_buffer: + if t not in self.topics_handler.last_subscription_msgs_buffer: msg = "No message yet" state_dict[t] = msg continue ts = time.perf_counter() - msg = self.last_subscription_msgs_buffer[t] + msg = self.topics_handler.last_subscription_msgs_buffer[t] if t in self.state_postprocessors: msg = self.state_postprocessors[t](msg) te = time.perf_counter() - ts @@ -690,6 +754,7 @@ def state_update_callback(self): state_dict[t] = msg self.state_dict = state_dict + self.get_logger().info("State updated.") def get_robot_state(self) -> Dict[str, str]: return self.state_dict diff --git a/src/rai/rai/tools/ros/native.py b/src/rai/rai/tools/ros/native.py index 2b09840d..3c7fb9c2 100644 --- a/src/rai/rai/tools/ros/native.py +++ b/src/rai/rai/tools/ros/native.py @@ -184,7 +184,7 @@ def _run( msg, msg_cls = self._build_msg(msg_type, msg_args) publisher = self.node.create_publisher( - msg_cls, topic_name, 10 + msg_cls, topic_name, 10, callback_group=self.node.callback_group ) # TODO(boczekbartek): infer qos profile from topic info def callback(): @@ -192,7 +192,9 @@ def callback(): self.logger.info(f"Published message '{msg}' to topic '{topic_name}'") ts = time.perf_counter() - timer = self.node.create_timer(1.0 / rate, callback) + timer = self.node.create_timer( + 1.0 / rate, callback, callback_group=self.node.callback_group + ) while time.perf_counter() - ts < timeout_seconds: time.sleep(0.1) diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index 369fe1e4..bdd6fa89 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -141,7 +141,7 @@ class Ros2RunActionAsync(Ros2BaseActionTool): def _run( self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] ): - return self.node._run_action(action_name, action_type, action_goal_args) + return self.node.run_action(action_name, action_type, action_goal_args) class Ros2IsActionComplete(Ros2BaseActionTool): @@ -151,7 +151,7 @@ class Ros2IsActionComplete(Ros2BaseActionTool): args_schema: Type[Ros2BaseInput] = Ros2BaseInput def _run(self) -> bool: - return self.node._is_task_complete() + return self.node.is_task_complete() class Ros2GetActionResult(Ros2BaseActionTool): @@ -161,7 +161,7 @@ class Ros2GetActionResult(Ros2BaseActionTool): args_schema: Type[Ros2BaseInput] = Ros2BaseInput def _run(self) -> bool: - return self.node._get_task_result() + return self.node.get_task_result() class Ros2CancelAction(Ros2BaseActionTool): @@ -171,7 +171,7 @@ class Ros2CancelAction(Ros2BaseActionTool): args_schema: Type[Ros2BaseInput] = Ros2BaseInput def _run(self) -> bool: - return self.node._cancel_task() + return self.node.cancel_task() class Ros2GetLastActionFeedback(Ros2BaseActionTool): @@ -197,7 +197,7 @@ class GetTransformTool(Ros2BaseActionTool): args_schema: Type[GetTransformInput] = GetTransformInput - def _run(self, target_frame="map", source_frame="body_link") -> dict: + def _run(self, target_frame="odom", source_frame="body_link") -> dict: return message_to_ordereddict( get_transform(self.node, target_frame, source_frame) ) diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index d5452094..b9674537 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -14,7 +14,7 @@ import base64 -from typing import Type, Union, cast +from typing import Optional, Type, Union, cast import cv2 import numpy as np @@ -34,6 +34,8 @@ from rosidl_runtime_py.utilities import get_namespaced_type from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped +from rai.utils.ros_async import get_future_result + def import_message_from_str(msg_type: str) -> Type[object]: msg_namespaced_type: NamespacedType = get_namespaced_type(msg_type) @@ -164,17 +166,14 @@ def get_transform( tf_buffer = Buffer(node=node) tf_listener = TransformListener(tf_buffer, node) - transform = None future = tf_buffer.wait_for_transform_async( target_frame, source_frame, rclpy.time.Time() ) - - rclpy.spin_until_future_complete(node, future, timeout_sec=timeout_sec) - - transform = future.result() - + transform: Optional[TransformStamped] = get_future_result( + future, timeout_sec=timeout_sec + ) tf_listener.unregister() - if not future.done() or transform is None: + if transform is None: raise LookupException( f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds" ) diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 9c83c2b3..22092284 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -12,18 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Dict, List, Optional, Tuple, Union + +import rclpy.callback_groups +import rclpy.node +from rclpy.action.graph import get_action_names_and_types +from rclpy.executors import ( + ConditionReachedException, + ExternalShutdownException, + MultiThreadedExecutor, + ShutdownException, + TimeoutException, + TimeoutObject, +) -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple -@dataclass class NodeDiscovery: - topics_and_types: Dict[str, str] = field(default_factory=dict) - services_and_types: Dict[str, str] = field(default_factory=dict) - actions_and_types: Dict[str, str] = field(default_factory=dict) - allowlist: Optional[List[str]] = field(default_factory=list) + def __init__( + self, + node: rclpy.node.Node, + allowlist: Optional[List[str]] = None, + period_sec: float = 2.0, + setters: Optional[List[Callable]] = None, + ) -> None: + self.period_sec = period_sec + self.node = node + self.allowlist = allowlist + + self.topcies_and_types: Dict[str, str] = dict() + self.services_and_types: Dict[str, str] = dict() + self.actions_and_types: Dict[str, str] = dict() + self.allowlist: Optional[List[str]] = allowlist + + self.timer = self.node.create_timer( + self.period_sec, + self.discovery_callback, + callback_group=rclpy.callback_groups.MutuallyExclusiveCallbackGroup(), + ) + + # callables (e.g. fun(x: NodeDiscovery)) that will receive the discovery info on every timer callback + # allows to register other entities that needs up-to-date discovery info + if setters is None: + self.setters = list() + else: + self.setters = setters + + # make first callback as fast as possible + self.discovery_callback() + + def add_setter(self, setter: Callable): + self.setters.append(setter) + + def discovery_callback(self): + self.node.get_logger().info("Discovery callback") + self.__set( + self.node.get_topic_names_and_types(), + self.node.get_service_names_and_types(), + get_action_names_and_types(self.node), + ) + for callable in self.setters: + callable(self) - def set(self, topics, services, actions): + def __set(self, topics, services, actions): def to_dict(info: List[Tuple[str, List[str]]]) -> Dict[str, str]: return {k: v[0] for k, v in info} diff --git a/src/rai/rai/utils/ros_logs.py b/src/rai/rai/utils/ros_logs.py index e33d75eb..fc3b246b 100644 --- a/src/rai/rai/utils/ros_logs.py +++ b/src/rai/rai/utils/ros_logs.py @@ -18,6 +18,7 @@ import rcl_interfaces.msg import rclpy.callback_groups +import rclpy.executors import rclpy.node import rclpy.qos import rclpy.subscription @@ -25,6 +26,7 @@ from langchain_core.prompts import ChatPromptTemplate import rai_interfaces.srv +from rai.utils.ros_async import get_future_result class BaseLogsParser: @@ -40,7 +42,7 @@ def create_logs_parser( bufsize: Optional[int] = 100, ) -> BaseLogsParser: if parser_type == "rai_state_logs": - return RaiStateLogsParser(node) + return RaiStateLogsParser(node, callback_group) elif parser_type == "llm": if any([v is None for v in [llm, callback_group, bufsize]]): raise ValueError("Must provide llm, callback_group, and bufsize") @@ -54,11 +56,15 @@ class RaiStateLogsParser(BaseLogsParser): SERVICE_NAME = "/get_log_digest" - def __init__(self, node: rclpy.node.Node) -> None: + def __init__( + self, node: rclpy.node.Node, callback_group: rclpy.callback_groups.CallbackGroup + ) -> None: self.node = node self.rai_state_logs_client = node.create_client( - rai_interfaces.srv.StringList, self.SERVICE_NAME + rai_interfaces.srv.StringList, + self.SERVICE_NAME, + callback_group=callback_group, ) while not self.rai_state_logs_client.wait_for_service(timeout_sec=1.0): node.get_logger().info( @@ -68,8 +74,11 @@ def __init__(self, node: rclpy.node.Node) -> None: def summarize(self) -> str: request = rai_interfaces.srv.StringList.Request() future = self.rai_state_logs_client.call_async(request) - rclpy.spin_until_future_complete(self.node, future) - response: Optional[rai_interfaces.srv.StringList.Response] = future.result() + + response: Optional[rai_interfaces.srv.StringList.Response] = get_future_result( + future + ) + if response is None or not response.success: self.node.get_logger().error(f"'{self.SERVICE_NAME}' service call failed") return "" From b933eb9e75068caf52f11ca5b62bd359346a3b1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 19 Dec 2024 14:10:25 +0100 Subject: [PATCH 07/31] revert(`config.toml`): --- config.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config.toml b/config.toml index 832a59b2..e06fc752 100644 --- a/config.toml +++ b/config.toml @@ -16,10 +16,10 @@ embeddings_model = "text-embedding-ada-002" base_url = "https://api.openai.com/v1/" # for openai compatible apis [ollama] -simple_model = "qwq" -complex_model = "qwen2.5:7b" +simple_model = "llama3.2" +complex_model = "llama3.1:70b" embeddings_model = "llama3.2" -base_url = "http://via-ip-robo-srv-004.robotec.tm.pl:11434" +base_url = "http://localhost:11434" [tracing] project = "rai" From 003a55f7114b49380b227e398310f7c190e62dd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 19 Dec 2024 14:12:40 +0100 Subject: [PATCH 08/31] revert: executor to standard `rclpy.MultiThreadedExecutor` - with current method based on callbacks the default executor is good --- src/rai/rai/node.py | 7 +++--- src/rai/rai/utils/ros.py | 53 +++------------------------------------- 2 files changed, 7 insertions(+), 53 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index ad04948d..1d41ab80 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -45,7 +45,6 @@ from rclpy.topic_endpoint_info import TopicEndpointInfo from std_srvs.srv import Trigger -import rai.utils.ros from rai.agents.state_based import Report, State, create_state_based_agent from rai.messages import HumanMultimodalMessage from rai.tools.ros.native import Ros2BaseTool @@ -144,7 +143,7 @@ def ros2_build_msg(msg_type: str, msg_args: Dict[str, Any]) -> Tuple[object, Typ return msg, msg_cls -class AsyncRos2ActionClient: +class Ros2ActionsHelper: def __init__(self, node: rclpy.node.Node): self.node = node @@ -475,7 +474,7 @@ def __init__( # ---------- ROS helpers ---------- self.ros_discovery_info = NodeDiscovery(self, allowlist=allowlist) - self.async_action_client = AsyncRos2ActionClient(self) + self.async_action_client = Ros2ActionsHelper(self) self.topics_handler = Ros2TopicsHandler( self, self.callback_group, self.ros_discovery_info ) @@ -508,7 +507,7 @@ def cancel_task(self) -> Union[str, bool]: # ------------- other methods ------------- def spin(self): - executor = rai.utils.ros.MultiThreadedExecutorFixed() + executor = rclpy.executors.MultiThreadedExecutor() executor.add_node(self) executor.spin() rclpy.shutdown() diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 22092284..b692e31a 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -12,19 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple import rclpy.callback_groups import rclpy.node from rclpy.action.graph import get_action_names_and_types -from rclpy.executors import ( - ConditionReachedException, - ExternalShutdownException, - MultiThreadedExecutor, - ShutdownException, - TimeoutException, - TimeoutObject, -) class NodeDiscovery: @@ -50,8 +42,9 @@ def __init__( callback_group=rclpy.callback_groups.MutuallyExclusiveCallbackGroup(), ) - # callables (e.g. fun(x: NodeDiscovery)) that will receive the discovery info on every timer callback - # allows to register other entities that needs up-to-date discovery info + # callables (e.g. fun(x: NodeDiscovery)) that will receive the discovery + # info on every timer callback. This allows to register other entities that + # needs up-to-date discovery info if setters is None: self.setters = list() else: @@ -99,41 +92,3 @@ def dict(self): "services_and_types": self.services_and_types, "actions_and_types": self.actions_and_types, } - - -class MultiThreadedExecutorFixed(MultiThreadedExecutor): - """ - Adresses a comment: - ```python - # make a copy of the list that we iterate over while modifying it - # (https://stackoverflow.com/q/1207406/3753684) - ``` - from the rclpy implementation - """ - - def _spin_once_impl( - self, - timeout_sec: Optional[Union[float, TimeoutObject]] = None, - wait_condition: Callable[[], bool] = lambda: False, - ) -> None: - try: - handler, entity, node = self.wait_for_ready_callbacks( - timeout_sec, None, wait_condition - ) - except ExternalShutdownException: - pass - except ShutdownException: - pass - except TimeoutException: - pass - except ConditionReachedException: - pass - else: - self._executor.submit(handler) - self._futures.append(handler) - futures = self._futures.copy() - for future in futures[:]: - if future.done(): - futures.remove(future) - future.result() # raise any exceptions - self._futures = futures From 40ebdcf937cbdf2a2b654ead8298c99f3f358cc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Mon, 23 Dec 2024 13:24:11 +0100 Subject: [PATCH 09/31] chore: remove unnecessary log --- src/rai/rai/utils/ros.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index b692e31a..2729e606 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -57,7 +57,6 @@ def add_setter(self, setter: Callable): self.setters.append(setter) def discovery_callback(self): - self.node.get_logger().info("Discovery callback") self.__set( self.node.get_topic_names_and_types(), self.node.get_service_names_and_types(), From e6e347c8ce6bea74b8f530fa701dc2a8ff06fa1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 24 Dec 2024 00:43:37 +0100 Subject: [PATCH 10/31] fix: import --- .../rai_open_set_vision/tools/gdino_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 69f229b4..4144bd59 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -26,7 +26,7 @@ ) from rclpy.task import Future -from rai.node import RaiAsyncToolsNode +from rai.node import RaiBaseNode from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray from rai.tools.utils import wait_for_message @@ -81,7 +81,7 @@ class DistanceMeasurement(NamedTuple): # --------------------- Tools --------------------- class GroundingDinoBaseTool(Ros2BaseTool): - node: RaiAsyncToolsNode = Field(..., exclude=True, required=True) + node: RaiBaseNode = Field(..., exclude=True, required=True) box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") From bb35e0d4a8ebf510a0e509a610904c990b54e59f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 24 Dec 2024 09:32:06 +0100 Subject: [PATCH 11/31] feat: add function to get rclpy.Future result using callback --- src/rai/rai/utils/ros_async.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/rai/rai/utils/ros_async.py diff --git a/src/rai/rai/utils/ros_async.py b/src/rai/rai/utils/ros_async.py new file mode 100644 index 00000000..e1cc4c50 --- /dev/null +++ b/src/rai/rai/utils/ros_async.py @@ -0,0 +1,23 @@ +import time +from typing import Any, Optional + +import rclpy.task + + +def get_future_result( + future: rclpy.task.Future, timeout_sec: float = 5.0 +) -> Optional[Any]: + """Replaces rclpy.spin_until_future_complete""" + result = None + + def callback(future: rclpy.task.Future) -> None: + nonlocal result + result = future.result() + + future.add_done_callback(callback) + + ts = time.perf_counter() + while result is None and time.perf_counter() - ts < timeout_sec: + time.sleep(0.1) + + return result From bc5ed801403d210b3c5f71cb0134311df06b54f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 24 Dec 2024 09:47:01 +0100 Subject: [PATCH 12/31] chore: add licence note --- src/rai/rai/utils/ros_async.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/rai/rai/utils/ros_async.py b/src/rai/rai/utils/ros_async.py index e1cc4c50..769f979d 100644 --- a/src/rai/rai/utils/ros_async.py +++ b/src/rai/rai/utils/ros_async.py @@ -1,3 +1,17 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from typing import Any, Optional From 02248214f7280273e4f1ee2aff0b593c9f879024 Mon Sep 17 00:00:00 2001 From: Bartek Boczek <22739059+boczekbartek@users.noreply.github.com> Date: Thu, 2 Jan 2025 10:14:26 +0100 Subject: [PATCH 13/31] Update src/rai/rai/utils/ros.py Co-authored-by: Maciej Majek <46171033+maciejmajek@users.noreply.github.com> --- src/rai/rai/utils/ros.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 2729e606..94c5ec8c 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -31,7 +31,7 @@ def __init__( self.node = node self.allowlist = allowlist - self.topcies_and_types: Dict[str, str] = dict() + self.topics_and_types: Dict[str, str] = dict() self.services_and_types: Dict[str, str] = dict() self.actions_and_types: Dict[str, str] = dict() self.allowlist: Optional[List[str]] = allowlist From fd823c3690febefbc5a6c3f0f2578a878ed7fc2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 7 Jan 2025 13:36:30 +0100 Subject: [PATCH 14/31] fix(`rosbot_xl_demo`): revert detection removal and fix it - rollback open-set detection removal - fix it by setting conversion_ratio=1.0 --- examples/rosbot-xl-demo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/rosbot-xl-demo.py b/examples/rosbot-xl-demo.py index 988a8a4f..90600b6a 100644 --- a/examples/rosbot-xl-demo.py +++ b/examples/rosbot-xl-demo.py @@ -18,10 +18,10 @@ import rclpy import rclpy.executors import rclpy.logging +from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool from rai.node import RaiStateBasedLlmNode from rai.tools.ros.native import ( - GetCameraImage, GetMsgFromTopic, Ros2GetRobotInterfaces, Ros2PubMessageTool, @@ -131,9 +131,11 @@ def main(allowlist: Optional[Path] = None): GetTransformTool, WaitForSecondsTool, GetMsgFromTopic, - GetCameraImage, + GetDetectionTool, + GetDistanceToObjectsTool, ], ) + node.declare_parameter("conversion_ratio", 1.0) node.spin() rclpy.shutdown() From 399003aec816b46cd70dbc30607debde24fac13a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 7 Jan 2025 14:24:47 +0100 Subject: [PATCH 15/31] improve ros discovery --- src/rai/rai/node.py | 11 ++++++++++- src/rai/rai/utils/ros.py | 5 ++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 1d41ab80..290cea8d 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -406,11 +406,20 @@ def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: def set_ros_discovery_info(self, new_ros_discovery_info: NodeDiscovery): self.ros_discovery_info = new_ros_discovery_info - def get_raw_message_from_topic(self, topic: str, timeout_sec: int = 5) -> Any: + def get_raw_message_from_topic( + self, topic: str, timeout_sec: int = 5, topic_wait_sec: int = 2 + ) -> Any: self.get_logger().debug(f"Getting msg from topic: {topic}") ts = time.perf_counter() + for _ in range(topic_wait_sec * 10): + if topic not in self.ros_discovery_info.topics_and_types: + time.sleep(0.1) + continue + else: + break + if topic not in self.ros_discovery_info.topics_and_types: raise KeyError( f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 94c5ec8c..d2405df2 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from typing import Callable, Dict, List, Optional, Tuple import rclpy.callback_groups @@ -24,7 +25,7 @@ def __init__( self, node: rclpy.node.Node, allowlist: Optional[List[str]] = None, - period_sec: float = 2.0, + period_sec: float = 0.5, setters: Optional[List[Callable]] = None, ) -> None: self.period_sec = period_sec @@ -51,6 +52,8 @@ def __init__( self.setters = setters # make first callback as fast as possible + # sleep before first callback due ros discovery issue: https://github.com/ros2/ros2/issues/1057 + time.sleep(0.5) self.discovery_callback() def add_setter(self, setter: Callable): From c6e721c504adf04787ed6d205195bd3653de884a Mon Sep 17 00:00:00 2001 From: Bartek Boczek <22739059+boczekbartek@users.noreply.github.com> Date: Wed, 8 Jan 2025 14:31:37 +0100 Subject: [PATCH 16/31] Apply suggestions from code review Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- src/rai/rai/utils/ros.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index d2405df2..6752f369 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -65,8 +65,8 @@ def discovery_callback(self): self.node.get_service_names_and_types(), get_action_names_and_types(self.node), ) - for callable in self.setters: - callable(self) + for setter in self.setters: + setter(self) def __set(self, topics, services, actions): def to_dict(info: List[Tuple[str, List[str]]]) -> Dict[str, str]: From baa01f88b9bc30b314054fde51711f56d5f57249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 14:31:54 +0100 Subject: [PATCH 17/31] fix(`GetTransformTool`): make default values consistent --- src/rai/rai/tools/ros/native_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index bdd6fa89..d1d741c0 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -197,7 +197,7 @@ class GetTransformTool(Ros2BaseActionTool): args_schema: Type[GetTransformInput] = GetTransformInput - def _run(self, target_frame="odom", source_frame="body_link") -> dict: + def _run(self, target_frame="map", source_frame="body_link") -> dict: return message_to_ordereddict( get_transform(self.node, target_frame, source_frame) ) From 92fca442092cbb0b7a012c7054cb57947e9fa18d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 14:39:50 +0100 Subject: [PATCH 18/31] fix(`NodeDiscovery`): repeated argument setting --- src/rai/rai/utils/ros.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 6752f369..2377e140 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -30,7 +30,6 @@ def __init__( ) -> None: self.period_sec = period_sec self.node = node - self.allowlist = allowlist self.topics_and_types: Dict[str, str] = dict() self.services_and_types: Dict[str, str] = dict() From 81a98a80389b14fa379ebd417c8b57bd8a0424b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 15:00:39 +0100 Subject: [PATCH 19/31] refactor: ros2 action error code parsing --- src/rai/rai/node.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 290cea8d..f1a15fed 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -237,10 +237,9 @@ def parse_status(status: int) -> str: return msg def parse_error_code(self, code: int) -> str: - name_to_code = self.msg_cls.Result.__prepare__( - name="", bases="" - ) # arguments are not used - code_to_name = {v: k for k, v in name_to_code.items()} + code_to_name = { + v: k for k, v in vars(self.msg_cls.Result).items() if isinstance(v, int) + } return code_to_name.get(code, "UNKNOWN") def _feedback_callback(self, msg): From c14114c01c244df7e9c1731c066f8c45639c9874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 15:29:58 +0100 Subject: [PATCH 20/31] fix deprecated way of spinning the node in open-set vision --- .../rai_open_set_vision/tools/gdino_tools.py | 60 ++++++++----------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py index 4144bd59..2fa76fba 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/gdino_tools.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, NamedTuple, Optional, Type +from typing import List, NamedTuple, Type import numpy as np import rclpy @@ -30,6 +30,7 @@ from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_ndarray from rai.tools.utils import wait_for_message +from rai.utils.ros_async import get_future_result from rai_interfaces.srv import RAIGroundingDino @@ -86,20 +87,6 @@ class GroundingDinoBaseTool(Ros2BaseTool): box_threshold: float = Field(default=0.35, description="Box threshold for GDINO") text_threshold: float = Field(default=0.45, description="Text threshold for GDINO") - def _spin(self, future: Future) -> Optional[RAIGroundingDino.Response]: - rclpy.spin_once(self.node) - if future.done(): - try: - response = future.result() - except Exception as e: - self.node.get_logger().info("Service call failed %r" % (e,)) - raise Exception("Service call failed %r" % (e,)) - else: - assert response is not None - self.node.get_logger().info(f"{response.detections}") - return response - return None - def _call_gdino_node( self, camera_img_message: sensor_msgs.msg.Image, object_names: list[str] ) -> Future: @@ -174,12 +161,16 @@ def _run( camera_img_msg = self._get_image_message(camera_topic) future = self._call_gdino_node(camera_img_msg, object_names) - while rclpy.ok(): - resolved = self._spin(future) - if resolved is not None: - detected = self._parse_detection_array(resolved) - names = ", ".join([det.class_name for det in detected]) - return f"I have detected the following items in the picture {names or 'None'}" + resolved = get_future_result(future) + + if resolved is not None: + detected = self._parse_detection_array(resolved) + names = ", ".join([det.class_name for det in detected]) + return ( + f"I have detected the following items in the picture {names or 'None'}" + ) + else: + return "Service call failed. Can't get detections." return "Failed to get detection" @@ -256,18 +247,17 @@ def _run( "Parameter conversion_ratio not found in node, using default value: 0.001" ) conversion_ratio = 0.001 - while rclpy.ok(): - resolved = self._spin(future) - if resolved is not None: - detected = self._parse_detection_array(resolved) - measurements = self._get_distance_from_detections( - depth_img_msg, detected, threshold, conversion_ratio - ) - measurement_string = ", ".join( - [ - f"{measurement[0]}: {measurement[1]:.2f}m away" - for measurement in measurements - ] - ) - return f"I have detected the following items in the picture {measurement_string or 'no objects'}" + resolved = get_future_result(future) + if resolved is not None: + detected = self._parse_detection_array(resolved) + measurements = self._get_distance_from_detections( + depth_img_msg, detected, threshold, conversion_ratio + ) + measurement_string = ", ".join( + [ + f"{measurement[0]}: {measurement[1]:.2f}m away" + for measurement in measurements + ] + ) + return f"I have detected the following items in the picture {measurement_string or 'no objects'}" return "Failed" From ab3a7b33467b3a9331a52889c79d2022a141c16c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 15:30:22 +0100 Subject: [PATCH 21/31] get transform synchronously --- src/rai/rai/tools/ros/utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/rai/rai/tools/ros/utils.py b/src/rai/rai/tools/ros/utils.py index b9674537..7756a6c0 100644 --- a/src/rai/rai/tools/ros/utils.py +++ b/src/rai/rai/tools/ros/utils.py @@ -24,6 +24,7 @@ import rclpy.time import sensor_msgs.msg from cv_bridge import CvBridge +from rclpy.duration import Duration from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy from rclpy.node import Node from rclpy.qos import QoSProfile @@ -34,8 +35,6 @@ from rosidl_runtime_py.utilities import get_namespaced_type from tf2_ros import Buffer, LookupException, TransformListener, TransformStamped -from rai.utils.ros_async import get_future_result - def import_message_from_str(msg_type: str) -> Type[object]: msg_namespaced_type: NamespacedType = get_namespaced_type(msg_type) @@ -166,13 +165,12 @@ def get_transform( tf_buffer = Buffer(node=node) tf_listener = TransformListener(tf_buffer, node) - future = tf_buffer.wait_for_transform_async( - target_frame, source_frame, rclpy.time.Time() - ) - transform: Optional[TransformStamped] = get_future_result( - future, timeout_sec=timeout_sec + transform: Optional[TransformStamped] = tf_buffer.lookup_transform( + target_frame, source_frame, rclpy.time.Time(), timeout=Duration(seconds=3) ) + tf_listener.unregister() + if transform is None: raise LookupException( f"Could not find transform from {source_frame} to {target_frame} in {timeout_sec} seconds" From c36f1acd92c9ba28b41ad248f3f951af2003c593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 16:22:29 +0100 Subject: [PATCH 22/31] rollback to fixed executor --- src/rai/rai/node.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index f1a15fed..4a02f591 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -52,6 +52,7 @@ from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks from rai.utils.ros import NodeDiscovery from rai.utils.ros_async import get_future_result +from rai.utils.ros_executors import MultiThreadedExecutorFixed from rai.utils.ros_logs import create_logs_parser from rai_interfaces.action import Task as TaskAction @@ -515,7 +516,7 @@ def cancel_task(self) -> Union[str, bool]: # ------------- other methods ------------- def spin(self): - executor = rclpy.executors.MultiThreadedExecutor() + executor = MultiThreadedExecutorFixed() executor.add_node(self) executor.spin() rclpy.shutdown() From 89663ef8737b0335a87ee9d8b833dd0781b00675 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 16:24:01 +0100 Subject: [PATCH 23/31] add missing executor file --- src/rai/rai/utils/ros_executors.py | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 src/rai/rai/utils/ros_executors.py diff --git a/src/rai/rai/utils/ros_executors.py b/src/rai/rai/utils/ros_executors.py new file mode 100644 index 00000000..d461770c --- /dev/null +++ b/src/rai/rai/utils/ros_executors.py @@ -0,0 +1,48 @@ +from typing import Callable, Optional, Union + +from rclpy.executors import ( + ConditionReachedException, + ExternalShutdownException, + MultiThreadedExecutor, + ShutdownException, + TimeoutException, + TimeoutObject, +) + + +class MultiThreadedExecutorFixed(MultiThreadedExecutor): + """ + Adresses a comment: + ```python + # make a copy of the list that we iterate over while modifying it + # (https://stackoverflow.com/q/1207406/3753684) + ``` + from the rclpy implementation + """ + + def _spin_once_impl( + self, + timeout_sec: Optional[Union[float, TimeoutObject]] = None, + wait_condition: Callable[[], bool] = lambda: False, + ) -> None: + try: + handler, entity, node = self.wait_for_ready_callbacks( + timeout_sec, None, wait_condition + ) + except ExternalShutdownException: + pass + except ShutdownException: + pass + except TimeoutException: + pass + except ConditionReachedException: + pass + else: + self._executor.submit(handler) + self._futures.append(handler) + futures = self._futures.copy() + for future in futures[:]: + if future.done(): + futures.remove(future) + future.result() # raise any exceptions + self._futures = futures From 0c06db3b7c32e563a334d5bf26338b8e5bb04f85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 16:29:00 +0100 Subject: [PATCH 24/31] fix(`text_hmi`): `TaskAction` attribute is named `report` --- src/rai_hmi/rai_hmi/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rai_hmi/rai_hmi/base.py b/src/rai_hmi/rai_hmi/base.py index 07cd1544..9502a225 100644 --- a/src/rai_hmi/rai_hmi/base.py +++ b/src/rai_hmi/rai_hmi/base.py @@ -250,12 +250,12 @@ def task_feedback_callback(self, feedback_msg, uid: UUID4): def task_result_callback(self, future, uid: UUID4): """Callback for handling the result from the action server.""" - result = future.result().result + result: TaskAction.Result = future.result().result self.task_running[uid] = False self.task_feedbacks.put(MissionDoneMessage(uid=uid, result=result)) if result.success: self.get_logger().info(f"Task completed successfully: {result.report}") self.task_results[uid] = result else: - self.get_logger().error(f"Task failed: {result.result_message}") + self.get_logger().error(f"Task failed: {result.report}") self.task_results[uid] = "ERROR" From d881e174bdb607f21f272f8a29fd821b50499783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Wed, 8 Jan 2025 16:32:39 +0100 Subject: [PATCH 25/31] chore: add missing licence header --- src/rai/rai/utils/ros_executors.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/rai/rai/utils/ros_executors.py b/src/rai/rai/utils/ros_executors.py index d461770c..03fa759d 100644 --- a/src/rai/rai/utils/ros_executors.py +++ b/src/rai/rai/utils/ros_executors.py @@ -1,3 +1,17 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Callable, Optional, Union from rclpy.executors import ( From bf40cecaeb58f94b1a2f555cb983083657e1b4f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Thu, 9 Jan 2025 15:27:00 +0100 Subject: [PATCH 26/31] avoid spinning the node in open-set vision tools --- .../tools/segmentation_tools.py | 25 +++---------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index 17389fb4..c264bd88 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -29,6 +29,7 @@ from rai.node import RaiBaseNode from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray +from rai.utils.ros_async import get_future_result from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino # --------------------- Inputs --------------------- @@ -77,30 +78,10 @@ class GetSegmentationTool(Ros2BaseTool): def _get_gdino_response( self, future: Future ) -> Optional[RAIGroundingDino.Response]: - rclpy.spin_once(self.node) - if future.done(): - try: - response = future.result() - except Exception as e: - self.node.get_logger().info("Service call failed %r" % (e,)) - raise Exception("Service call failed %r" % (e,)) - else: - assert response is not None - return response - return None + return get_future_result(future) def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response]: - rclpy.spin_once(self.node) - if future.done(): - try: - response = future.result() - except Exception as e: - self.node.get_logger().info("Service call failed %r" % (e,)) - raise Exception("Service call failed %r" % (e,)) - else: - assert response is not None - return response - return None + return get_future_result(future) def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: msg = self.node.get_raw_message_from_topic(topic) From a877f7f42319964d4f11606c12ceb8e697dcb326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 10 Jan 2025 08:28:38 +0100 Subject: [PATCH 27/31] chore: pre-commit --- src/rai/rai/node.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 4a02f591..62386c0c 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -307,7 +307,7 @@ def __init__( def get_logger(self): return self.node.get_logger() - + def adapt_requests_to_offers( self, publisher_info: List[TopicEndpointInfo] ) -> QoSProfile: @@ -363,9 +363,9 @@ def create_subscription_by_topic_name(self, topic): f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" ) return - + msg_type = self.get_msg_type(topic) - + if topic not in self.qos_profile_cache: self.get_logger().debug(f"Getting qos profile for topic: {topic}") qos_profile = self.adapt_requests_to_offers( @@ -387,7 +387,7 @@ def create_subscription_by_topic_name(self, topic): callback_group=self.callback_group, qos_profile=qos_profile, ) - + def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: """Sometimes node fails to do full discovery, therefore we need to retry""" for _ in range(n_tries): From e65dde9229b87bf65086dfb057fd92ef650b2784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 10 Jan 2025 08:29:26 +0100 Subject: [PATCH 28/31] test(`transport`): spin RaiBaseNode in test --- tests/messages/test_transport.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/messages/test_transport.py b/tests/messages/test_transport.py index eed59da8..93a31e5e 100644 --- a/tests/messages/test_transport.py +++ b/tests/messages/test_transport.py @@ -100,6 +100,9 @@ def test_transport(qos_profile: str): rai_base_node = RaiBaseNode( node_name="test_transport_" + str(uuid.uuid4()).replace("-", "") ) + + thread2 = threading.Thread(target=rai_base_node.spin) + thread2.start() topics = ["/image", "/text"] try: for topic in topics: @@ -107,6 +110,9 @@ def test_transport(qos_profile: str): assert not isinstance(output, str), "No message received" finally: executor.shutdown() + rai_base_node.executor.shutdown() + rai_base_node.destroy_node() publisher.destroy_node() rclpy.shutdown() thread.join(timeout=1.0) + thread2.join(timeout=1.0) From 13935b9cb606e9d13f24e24da7dd1ee7918b92bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 10 Jan 2025 08:39:09 +0100 Subject: [PATCH 29/31] don't shutdown rclpy in rai node --- src/rai/rai/node.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 62386c0c..598c7aab 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -519,7 +519,6 @@ def spin(self): executor = MultiThreadedExecutorFixed() executor.add_node(self) executor.spin() - rclpy.shutdown() def parse_task_goal(ros_action_goal: TaskAction.Goal) -> Dict[str, Any]: From 91b0401e0da1e7b15218e106a2fcd5f0c79dbd81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 14 Jan 2025 12:26:04 +0100 Subject: [PATCH 30/31] rename Ros2*Helper to Ros2*Api and move to a separate file --- src/rai/rai/node.py | 365 +-------------------------------------- src/rai/rai/ros2_apis.py | 363 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 368 insertions(+), 360 deletions(-) create mode 100644 src/rai/rai/ros2_apis.py diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 598c7aab..fcc2d3ae 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -13,9 +13,8 @@ # limitations under the License. -import functools import time -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union import rclpy import rclpy.callback_groups @@ -24,31 +23,20 @@ import rclpy.qos import rclpy.subscription import rclpy.task -import rosidl_runtime_py.set_message -import rosidl_runtime_py.utilities import sensor_msgs.msg -from action_msgs.msg import GoalStatus from langchain.tools import BaseTool from langchain.tools.render import render_text_description_and_args from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langgraph.graph.graph import CompiledGraph -from rclpy.action.client import ActionClient from rclpy.action.server import ActionServer, GoalResponse, ServerGoalHandle from rclpy.node import Node -from rclpy.qos import ( - DurabilityPolicy, - HistoryPolicy, - LivelinessPolicy, - QoSProfile, - ReliabilityPolicy, -) -from rclpy.topic_endpoint_info import TopicEndpointInfo from std_srvs.srv import Trigger from rai.agents.state_based import Report, State, create_state_based_agent from rai.messages import HumanMultimodalMessage +from rai.ros2_apis import Ros2ActionsAPI, Ros2TopicsAPI from rai.tools.ros.native import Ros2BaseTool -from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str +from rai.tools.ros.utils import convert_ros_img_to_base64 from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks from rai.utils.ros import NodeDiscovery from rai.utils.ros_async import get_future_result @@ -127,349 +115,6 @@ def append_tools_text_description_to_prompt(prompt: str, tools: List[BaseTool]) """ -def ros2_build_msg(msg_type: str, msg_args: Dict[str, Any]) -> Tuple[object, Type]: - """ - Import message and create it. Return both ready message and message class. - - msgs args can have two formats: - { "goal" : {arg 1 : xyz, ... } or {arg 1 : xyz, ... } - """ - - msg_cls: Type = rosidl_runtime_py.utilities.get_interface(msg_type) - msg = msg_cls.Goal() - - if "goal" in msg_args: - msg_args = msg_args["goal"] - rosidl_runtime_py.set_message.set_message_fields(msg, msg_args) - return msg, msg_cls - - -class Ros2ActionsHelper: - def __init__(self, node: rclpy.node.Node): - self.node = node - - self.goal_handle = None - self.result_future = None - self.feedback = None - self.status: Optional[int] = None - self.client: Optional[ActionClient] = None - self.action_feedback: Optional[Any] = None - - def get_logger(self): - return self.node.get_logger() - - def run_action( - self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] - ): - if not self.is_task_complete(): - raise AssertionError( - "Another ros2 action is currently running and parallel actions are not supported. Please wait until the previous action is complete before starting a new one. You can also cancel the current one." - ) - - if action_name[0] != "/": - action_name = "/" + action_name - self.get_logger().info(f"Action name corrected to: {action_name}") - - try: - goal_msg, msg_cls = ros2_build_msg(action_type, action_goal_args) - except Exception as e: - return f"Failed to build message: {e}" - - self.client = ActionClient(self.node, msg_cls, action_name) - self.msg_cls = msg_cls - - retries = 0 - while not self.client.wait_for_server(timeout_sec=1.0): - retries += 1 - if retries > 5: - raise Exception( - f"Action server '{action_name}' is not available. Make sure `action_name` is correct..." - ) - self.get_logger().info( - f"'{action_name}' action server not available, waiting..." - ) - - self.get_logger().info(f"Sending action message: {goal_msg}") - - send_goal_future = self.client.send_goal_async( - goal_msg, self._feedback_callback - ) - self.get_logger().info("Action goal sent!") - - self.goal_handle = get_future_result(send_goal_future) - - if not self.goal_handle: - raise Exception(f"Action '{action_name}' not sent to server") - - if not self.goal_handle.accepted: - raise Exception(f"Action '{action_name}' not accepted by server") - - self.result_future = self.goal_handle.get_result_async() - self.get_logger().info("Action sent!") - return f"{action_name} started successfully with args: {action_goal_args}" - - def get_task_result(self) -> str: - if not self.is_task_complete(): - return "Task is not complete yet" - - def parse_status(status: int) -> str: - return { - GoalStatus.STATUS_SUCCEEDED: "succeeded", - GoalStatus.STATUS_ABORTED: "aborted", - GoalStatus.STATUS_CANCELED: "canceled", - GoalStatus.STATUS_ACCEPTED: "accepted", - GoalStatus.STATUS_CANCELING: "canceling", - GoalStatus.STATUS_EXECUTING: "executing", - GoalStatus.STATUS_UNKNOWN: "unknown", - }.get(status, "unknown") - - result = self.result_future.result() - - self.destroy_client() - if result.status == GoalStatus.STATUS_SUCCEEDED: - msg = f"Result succeeded: {result.result}" - self.get_logger().info(msg) - return msg - else: - str_status = parse_status(result.status) - error_code_str = self.parse_error_code(result.result.error_code) - msg = f"Result {str_status}, because of: error_code={result.result.error_code}({error_code_str}), error_msg={result.result.error_msg}" - self.get_logger().info(msg) - return msg - - def parse_error_code(self, code: int) -> str: - code_to_name = { - v: k for k, v in vars(self.msg_cls.Result).items() if isinstance(v, int) - } - return code_to_name.get(code, "UNKNOWN") - - def _feedback_callback(self, msg): - self.get_logger().info(f"Received ros2 action feedback: {msg}") - self.action_feedback = msg - - def is_task_complete(self): - if not self.result_future: - # task was cancelled or completed - return True - - result = get_future_result(self.result_future, timeout_sec=0.10) - if result is not None: - self.status = result.status - if self.status != GoalStatus.STATUS_SUCCEEDED: - self.get_logger().debug( - f"Task with failed with status code: {self.status}" - ) - return True - else: - self.get_logger().info("There is no result") - # Timed out, still processing, not complete yet - return False - - self.get_logger().info("Task succeeded!") - return True - - def cancel_task(self) -> Union[str, bool]: - self.get_logger().info("Canceling current task.") - try: - if self.result_future and self.goal_handle: - future = self.goal_handle.cancel_goal_async() - result = get_future_result(future, timeout_sec=1.0) - return "Failed to cancel result" if result is None else True - return True - finally: - self.destroy_client() - - def destroy_client(self): - if self.client: - self.client.destroy() - - -class Ros2TopicsHandler: - def __init__( - self, - node: rclpy.node.Node, - callback_group: rclpy.callback_groups.CallbackGroup, - ros_discovery_info: NodeDiscovery, - ) -> None: - self.node = node - self.qos_profile = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - reliability=ReliabilityPolicy.BEST_EFFORT, - durability=DurabilityPolicy.VOLATILE, - liveliness=LivelinessPolicy.AUTOMATIC, - ) - self.callback_group = callback_group - self.last_subscription_msgs_buffer = dict() - self.qos_profile_cache: Dict[str, QoSProfile] = dict() - - self.ros_discovery_info = ros_discovery_info - - def get_logger(self): - return self.node.get_logger() - - def adapt_requests_to_offers( - self, publisher_info: List[TopicEndpointInfo] - ) -> QoSProfile: - if not publisher_info: - return QoSProfile(depth=1) - - num_endpoints = len(publisher_info) - reliability_reliable_count = 0 - durability_transient_local_count = 0 - - for endpoint in publisher_info: - profile = endpoint.qos_profile - if profile.reliability == ReliabilityPolicy.RELIABLE: - reliability_reliable_count += 1 - if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: - durability_transient_local_count += 1 - - request_qos = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - liveliness=LivelinessPolicy.AUTOMATIC, - ) - - # Set reliability based on publisher offers - if reliability_reliable_count == num_endpoints: - request_qos.reliability = ReliabilityPolicy.RELIABLE - else: - if reliability_reliable_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering RELIABLE reliability. " - "Falling back to BEST_EFFORT as it will connect to all publishers. " - "Some messages from Reliable publishers could be dropped." - ) - request_qos.reliability = ReliabilityPolicy.BEST_EFFORT - - # Set durability based on publisher offers - if durability_transient_local_count == num_endpoints: - request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL - else: - if durability_transient_local_count > 0: - self.get_logger().warning( - "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " - "Falling back to VOLATILE as it will connect to all publishers. " - "Previously-published latched messages will not be retrieved." - ) - request_qos.durability = DurabilityPolicy.VOLATILE - - return request_qos - - def create_subscription_by_topic_name(self, topic): - if self.has_subscription(topic): - self.get_logger().warning( - f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" - ) - return - - msg_type = self.get_msg_type(topic) - - if topic not in self.qos_profile_cache: - self.get_logger().debug(f"Getting qos profile for topic: {topic}") - qos_profile = self.adapt_requests_to_offers( - self.node.get_publishers_info_by_topic(topic) - ) - self.qos_profile_cache[topic] = qos_profile - else: - self.get_logger().debug(f"Using cached qos profile for topic: {topic}") - qos_profile = self.qos_profile_cache[topic] - - topic_callback = functools.partial( - self.generic_state_subscriber_callback, topic - ) - - self.node.create_subscription( - msg_type, - topic, - callback=topic_callback, - callback_group=self.callback_group, - qos_profile=qos_profile, - ) - - def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: - """Sometimes node fails to do full discovery, therefore we need to retry""" - for _ in range(n_tries): - if topic in self.ros_discovery_info.topics_and_types: - msg_type = self.ros_discovery_info.topics_and_types[topic] - return import_message_from_str(msg_type) - else: - # Wait for next discovery cycle - self.get_logger().info(f"Waiting for topic: {topic}") - if self.ros_discovery_info: - time.sleep(self.ros_discovery_info.period_sec) - else: - time.sleep(1.0) - raise KeyError(f"Topic {topic} not found") - - def set_ros_discovery_info(self, new_ros_discovery_info: NodeDiscovery): - self.ros_discovery_info = new_ros_discovery_info - - def get_raw_message_from_topic( - self, topic: str, timeout_sec: int = 5, topic_wait_sec: int = 2 - ) -> Any: - self.get_logger().debug(f"Getting msg from topic: {topic}") - - ts = time.perf_counter() - - for _ in range(topic_wait_sec * 10): - if topic not in self.ros_discovery_info.topics_and_types: - time.sleep(0.1) - continue - else: - break - - if topic not in self.ros_discovery_info.topics_and_types: - raise KeyError( - f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" - ) - - if topic in self.last_subscription_msgs_buffer: - self.get_logger().info("Returning cached message") - return self.last_subscription_msgs_buffer[topic] - else: - self.create_subscription_by_topic_name(topic) - try: - msg = self.last_subscription_msgs_buffer.get(topic, None) - while msg is None and time.perf_counter() - ts < timeout_sec: - msg = self.last_subscription_msgs_buffer.get(topic, None) - self.get_logger().info("Waiting for message...") - time.sleep(0.1) - - success = msg is not None - - if success: - self.get_logger().debug( - f"Received message of type {type(msg)} from topic {topic}" - ) - return msg - else: - error = f"No message received in {timeout_sec} seconds from topic {topic}" - self.get_logger().error(error) - return error - finally: - self.destroy_subscription_by_topic_name(topic) - - def generic_state_subscriber_callback(self, topic_name: str, msg: Any): - self.get_logger().debug( - f"Received message of type {type(msg)} from topic {topic_name}" - ) - self.last_subscription_msgs_buffer[topic_name] = msg - - def has_subscription(self, topic: str) -> bool: - for sub in self.node._subscriptions: - if sub.topic == topic: - return True - return False - - def destroy_subscription_by_topic_name(self, topic: str): - self.last_subscription_msgs_buffer.clear() - for sub in self.node._subscriptions: - if sub.topic == topic: - self.node.destroy_subscription(sub) - - class RaiBaseNode(Node): def __init__( self, @@ -483,8 +128,8 @@ def __init__( # ---------- ROS helpers ---------- self.ros_discovery_info = NodeDiscovery(self, allowlist=allowlist) - self.async_action_client = Ros2ActionsHelper(self) - self.topics_handler = Ros2TopicsHandler( + self.async_action_client = Ros2ActionsAPI(self) + self.topics_handler = Ros2TopicsAPI( self, self.callback_group, self.ros_discovery_info ) self.ros_discovery_info.add_setter(self.topics_handler.set_ros_discovery_info) diff --git a/src/rai/rai/ros2_apis.py b/src/rai/rai/ros2_apis.py new file mode 100644 index 00000000..150161b2 --- /dev/null +++ b/src/rai/rai/ros2_apis.py @@ -0,0 +1,363 @@ +import functools +import time +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import rclpy +import rclpy.callback_groups +import rclpy.executors +import rclpy.node +import rclpy.qos +import rclpy.subscription +import rclpy.task +import rosidl_runtime_py.set_message +import rosidl_runtime_py.utilities +from action_msgs.msg import GoalStatus +from rclpy.action.client import ActionClient +from rclpy.qos import ( + DurabilityPolicy, + HistoryPolicy, + LivelinessPolicy, + QoSProfile, + ReliabilityPolicy, +) +from rclpy.topic_endpoint_info import TopicEndpointInfo + +from rai.tools.ros.utils import import_message_from_str +from rai.utils.ros import NodeDiscovery +from rai.utils.ros_async import get_future_result + + +def ros2_build_msg(msg_type: str, msg_args: Dict[str, Any]) -> Tuple[object, Type]: + """ + Import message and create it. Return both ready message and message class. + + msgs args can have two formats: + { "goal" : {arg 1 : xyz, ... } or {arg 1 : xyz, ... } + """ + + msg_cls: Type = rosidl_runtime_py.utilities.get_interface(msg_type) + msg = msg_cls.Goal() + + if "goal" in msg_args: + msg_args = msg_args["goal"] + rosidl_runtime_py.set_message.set_message_fields(msg, msg_args) + return msg, msg_cls + + +class Ros2TopicsAPI: + def __init__( + self, + node: rclpy.node.Node, + callback_group: rclpy.callback_groups.CallbackGroup, + ros_discovery_info: NodeDiscovery, + ) -> None: + self.node = node + self.callback_group = callback_group + self.last_subscription_msgs_buffer = dict() + self.qos_profile_cache: Dict[str, QoSProfile] = dict() + + self.ros_discovery_info = ros_discovery_info + + def get_logger(self): + return self.node.get_logger() + + def adapt_requests_to_offers( + self, publisher_info: List[TopicEndpointInfo] + ) -> QoSProfile: + if not publisher_info: + return QoSProfile(depth=1) + + num_endpoints = len(publisher_info) + reliability_reliable_count = 0 + durability_transient_local_count = 0 + + for endpoint in publisher_info: + profile = endpoint.qos_profile + if profile.reliability == ReliabilityPolicy.RELIABLE: + reliability_reliable_count += 1 + if profile.durability == DurabilityPolicy.TRANSIENT_LOCAL: + durability_transient_local_count += 1 + + request_qos = QoSProfile( + history=HistoryPolicy.KEEP_LAST, + depth=1, + liveliness=LivelinessPolicy.AUTOMATIC, + ) + + # Set reliability based on publisher offers + if reliability_reliable_count == num_endpoints: + request_qos.reliability = ReliabilityPolicy.RELIABLE + else: + if reliability_reliable_count > 0: + self.get_logger().warning( + "Some, but not all, publishers are offering RELIABLE reliability. " + "Falling back to BEST_EFFORT as it will connect to all publishers. " + "Some messages from Reliable publishers could be dropped." + ) + request_qos.reliability = ReliabilityPolicy.BEST_EFFORT + + # Set durability based on publisher offers + if durability_transient_local_count == num_endpoints: + request_qos.durability = DurabilityPolicy.TRANSIENT_LOCAL + else: + if durability_transient_local_count > 0: + self.get_logger().warning( + "Some, but not all, publishers are offering TRANSIENT_LOCAL durability. " + "Falling back to VOLATILE as it will connect to all publishers. " + "Previously-published latched messages will not be retrieved." + ) + request_qos.durability = DurabilityPolicy.VOLATILE + + return request_qos + + def create_subscription_by_topic_name(self, topic): + if self.has_subscription(topic): + self.get_logger().warning( + f"Subscription to {topic} already exists. To override use destroy_subscription_by_topic_name first" + ) + return + + msg_type = self.get_msg_type(topic) + + if topic not in self.qos_profile_cache: + self.get_logger().debug(f"Getting qos profile for topic: {topic}") + qos_profile = self.adapt_requests_to_offers( + self.node.get_publishers_info_by_topic(topic) + ) + self.qos_profile_cache[topic] = qos_profile + else: + self.get_logger().debug(f"Using cached qos profile for topic: {topic}") + qos_profile = self.qos_profile_cache[topic] + + topic_callback = functools.partial( + self.generic_state_subscriber_callback, topic + ) + + self.node.create_subscription( + msg_type, + topic, + callback=topic_callback, + callback_group=self.callback_group, + qos_profile=qos_profile, + ) + + def get_msg_type(self, topic: str, n_tries: int = 5) -> Any: + """Sometimes node fails to do full discovery, therefore we need to retry""" + for _ in range(n_tries): + if topic in self.ros_discovery_info.topics_and_types: + msg_type = self.ros_discovery_info.topics_and_types[topic] + return import_message_from_str(msg_type) + else: + # Wait for next discovery cycle + self.get_logger().info(f"Waiting for topic: {topic}") + if self.ros_discovery_info: + time.sleep(self.ros_discovery_info.period_sec) + else: + time.sleep(1.0) + raise KeyError(f"Topic {topic} not found") + + def set_ros_discovery_info(self, new_ros_discovery_info: NodeDiscovery): + self.ros_discovery_info = new_ros_discovery_info + + def get_raw_message_from_topic( + self, topic: str, timeout_sec: int = 5, topic_wait_sec: int = 2 + ) -> Any: + self.get_logger().debug(f"Getting msg from topic: {topic}") + + ts = time.perf_counter() + + for _ in range(topic_wait_sec * 10): + if topic not in self.ros_discovery_info.topics_and_types: + time.sleep(0.1) + continue + else: + break + + if topic not in self.ros_discovery_info.topics_and_types: + raise KeyError( + f"Topic {topic} not found. Available topics: {self.ros_discovery_info.topics_and_types.keys()}" + ) + + if topic in self.last_subscription_msgs_buffer: + self.get_logger().info("Returning cached message") + return self.last_subscription_msgs_buffer[topic] + else: + self.create_subscription_by_topic_name(topic) + try: + msg = self.last_subscription_msgs_buffer.get(topic, None) + while msg is None and time.perf_counter() - ts < timeout_sec: + msg = self.last_subscription_msgs_buffer.get(topic, None) + self.get_logger().info("Waiting for message...") + time.sleep(0.1) + + success = msg is not None + + if success: + self.get_logger().debug( + f"Received message of type {type(msg)} from topic {topic}" + ) + return msg + else: + error = f"No message received in {timeout_sec} seconds from topic {topic}" + self.get_logger().error(error) + return error + finally: + self.destroy_subscription_by_topic_name(topic) + + def generic_state_subscriber_callback(self, topic_name: str, msg: Any): + self.get_logger().debug( + f"Received message of type {type(msg)} from topic {topic_name}" + ) + self.last_subscription_msgs_buffer[topic_name] = msg + + def has_subscription(self, topic: str) -> bool: + for sub in self.node._subscriptions: + if sub.topic == topic: + return True + return False + + def destroy_subscription_by_topic_name(self, topic: str): + self.last_subscription_msgs_buffer.clear() + for sub in self.node._subscriptions: + if sub.topic == topic: + self.node.destroy_subscription(sub) + + +class Ros2ActionsAPI: + def __init__(self, node: rclpy.node.Node): + self.node = node + + self.goal_handle = None + self.result_future = None + self.feedback = None + self.status: Optional[int] = None + self.client: Optional[ActionClient] = None + self.action_feedback: Optional[Any] = None + + def get_logger(self): + return self.node.get_logger() + + def run_action( + self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] + ): + if not self.is_task_complete(): + raise AssertionError( + "Another ros2 action is currently running and parallel actions are not supported. Please wait until the previous action is complete before starting a new one. You can also cancel the current one." + ) + + if action_name[0] != "/": + action_name = "/" + action_name + self.get_logger().info(f"Action name corrected to: {action_name}") + + try: + goal_msg, msg_cls = ros2_build_msg(action_type, action_goal_args) + except Exception as e: + return f"Failed to build message: {e}" + + self.client = ActionClient(self.node, msg_cls, action_name) + self.msg_cls = msg_cls + + retries = 0 + while not self.client.wait_for_server(timeout_sec=1.0): + retries += 1 + if retries > 5: + raise Exception( + f"Action server '{action_name}' is not available. Make sure `action_name` is correct..." + ) + self.get_logger().info( + f"'{action_name}' action server not available, waiting..." + ) + + self.get_logger().info(f"Sending action message: {goal_msg}") + + send_goal_future = self.client.send_goal_async( + goal_msg, self._feedback_callback + ) + self.get_logger().info("Action goal sent!") + + self.goal_handle = get_future_result(send_goal_future) + + if not self.goal_handle: + raise Exception(f"Action '{action_name}' not sent to server") + + if not self.goal_handle.accepted: + raise Exception(f"Action '{action_name}' not accepted by server") + + self.result_future = self.goal_handle.get_result_async() + self.get_logger().info("Action sent!") + return f"{action_name} started successfully with args: {action_goal_args}" + + def get_task_result(self) -> str: + if not self.is_task_complete(): + return "Task is not complete yet" + + def parse_status(status: int) -> str: + return { + GoalStatus.STATUS_SUCCEEDED: "succeeded", + GoalStatus.STATUS_ABORTED: "aborted", + GoalStatus.STATUS_CANCELED: "canceled", + GoalStatus.STATUS_ACCEPTED: "accepted", + GoalStatus.STATUS_CANCELING: "canceling", + GoalStatus.STATUS_EXECUTING: "executing", + GoalStatus.STATUS_UNKNOWN: "unknown", + }.get(status, "unknown") + + result = self.result_future.result() + + self.destroy_client() + if result.status == GoalStatus.STATUS_SUCCEEDED: + msg = f"Result succeeded: {result.result}" + self.get_logger().info(msg) + return msg + else: + str_status = parse_status(result.status) + error_code_str = self.parse_error_code(result.result.error_code) + msg = f"Result {str_status}, because of: error_code={result.result.error_code}({error_code_str}), error_msg={result.result.error_msg}" + self.get_logger().info(msg) + return msg + + def parse_error_code(self, code: int) -> str: + code_to_name = { + v: k for k, v in vars(self.msg_cls.Result).items() if isinstance(v, int) + } + return code_to_name.get(code, "UNKNOWN") + + def _feedback_callback(self, msg): + self.get_logger().info(f"Received ros2 action feedback: {msg}") + self.action_feedback = msg + + def is_task_complete(self): + if not self.result_future: + # task was cancelled or completed + return True + + result = get_future_result(self.result_future, timeout_sec=0.10) + if result is not None: + self.status = result.status + if self.status != GoalStatus.STATUS_SUCCEEDED: + self.get_logger().debug( + f"Task with failed with status code: {self.status}" + ) + return True + else: + self.get_logger().info("There is no result") + # Timed out, still processing, not complete yet + return False + + self.get_logger().info("Task succeeded!") + return True + + def cancel_task(self) -> Union[str, bool]: + self.get_logger().info("Canceling current task.") + try: + if self.result_future and self.goal_handle: + future = self.goal_handle.cancel_goal_async() + result = get_future_result(future, timeout_sec=1.0) + return "Failed to cancel result" if result is None else True + return True + finally: + self.destroy_client() + + def destroy_client(self): + if self.client: + self.client.destroy() From 0943f2658124f4bef3b8c5c5a18a2f1f348ac490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 14 Jan 2025 14:51:09 +0100 Subject: [PATCH 31/31] add copyright notice --- src/rai/rai/ros2_apis.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/rai/rai/ros2_apis.py b/src/rai/rai/ros2_apis.py index 150161b2..359885f0 100644 --- a/src/rai/rai/ros2_apis.py +++ b/src/rai/rai/ros2_apis.py @@ -1,3 +1,17 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import functools import time from typing import Any, Dict, List, Optional, Tuple, Type, Union