Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ActionServerBT Overhaul and Add Static TF Object #114

Merged
merged 9 commits into from
Sep 28, 2023
34 changes: 20 additions & 14 deletions ada_feeding/ada_feeding/action_server_bt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ class ActionServerBT(ABC):
`create_action_server.py`
"""

# pylint: disable=too-many-arguments
# One over is fine.
def __init__(self, node: Node) -> None:
"""
Store the ROS2 Node that created trees is associated with. Necessary for
behaviors within the tree to connect to ROS topics/services/actions.
Subclasses should add kwargs with tree-agnostic
input parameters.
"""
self._node = node

@abstractmethod
def create_tree(
self,
name: str,
action_type: type,
tree_root_name: str,
node: Node,
self, name: str, tree_root_name: str # DEPRECATED
) -> py_trees.trees.BehaviourTree:
"""
Create the behavior tree that will be executed by this action server.
Expand All @@ -38,12 +41,9 @@ def create_tree(
Parameters
----------
name: The name of the behavior tree.
action_type: the type for the action, as a class
tree_root_name: The name of the tree. This is necessary because sometimes
trees create subtrees, but still need to track the top-level tree
name to read/write the correct blackboard variables.
node: The ROS2 node that this tree is associated with. Necessary for
behaviors within the tree connect to ROS topics/services/actions.
"""
raise NotImplementedError("create_tree not implemented")

Expand Down Expand Up @@ -91,7 +91,9 @@ def preempt_goal(self, tree: py_trees.trees.BehaviourTree) -> bool:
return False

@abstractmethod
def get_feedback(self, tree: py_trees.trees.BehaviourTree) -> object:
def get_feedback(
self, tree: py_trees.trees.BehaviourTree, action_type: type
) -> object:
"""
Creates the ROS feedback message corresponding to this action.
Expand All @@ -100,15 +102,18 @@ def get_feedback(self, tree: py_trees.trees.BehaviourTree) -> object:
Parameters
----------
tree: The behavior tree that is being executed.
action_type: the type for the action, as a class
Returns
-------
feedback: The ROS feedback message to be sent to the action client.
The ROS feedback message to be sent to the action client, type action_type.Feedback()
"""
raise NotImplementedError("get_feedback not implemented")

@abstractmethod
def get_result(self, tree: py_trees.trees.BehaviourTree) -> object:
def get_result(
self, tree: py_trees.trees.BehaviourTree, action_type: type
) -> object:
"""
Creates the ROS result message corresponding to this action.
Expand All @@ -117,9 +122,10 @@ def get_result(self, tree: py_trees.trees.BehaviourTree) -> object:
Parameters
----------
tree: The behavior tree that is being executed.
action_type: the type for the action, as a class
Returns
-------
result: The ROS result message to be sent to the action client.
The ROS result message to be sent to the action client, type action_type.Result()
"""
raise NotImplementedError("get_result not implemented")
52 changes: 25 additions & 27 deletions ada_feeding/ada_feeding/behaviors/acquisition/compute_food_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
"""
This module defines the ComputeFoodFrame behavior, which computes the
food frame from
food frame from the Mask provided from a perception algorithm.
"""
# Standard imports
from typing import Union, Optional
Expand All @@ -11,18 +11,22 @@
import cv2 as cv
from geometry_msgs.msg import PointStamped, TransformStamped, Vector3Stamped
import numpy as np
from overrides import override
import py_trees
import pyrealsense2
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import CameraInfo
import tf2_ros
from tf2_ros.static_transform_broadcaster import StaticTransformBroadcaster

# Local imports
from ada_feeding_msgs.msg import Mask
from ada_feeding_msgs.srv import AcquisitionSelect
from ada_feeding.helpers import BlackboardKey, quat_between_vectors, get_tf_object
from ada_feeding.helpers import (
BlackboardKey,
quat_between_vectors,
get_tf_object,
set_static_tf,
)
from ada_feeding.behaviors import BlackboardBehavior
from ada_feeding_perception.helpers import ros_msg_to_cv2_image

Expand All @@ -44,23 +48,21 @@ class ComputeFoodFrame(BlackboardBehavior):

def blackboard_inputs(
self,
ros2_node: Union[BlackboardKey, Node],
camera_info: Union[BlackboardKey, CameraInfo],
mask: Union[BlackboardKey, Mask],
food_frame_id: Union[BlackboardKey, str] = "food",
world_frame: Union[BlackboardKey, str] = "world",
debug_food_frame: Union[BlackboardKey, str] = "food",
) -> None:
"""
Blackboard Inputs
Parameters
----------
ros2_node (Node): ROS2 Node for reading/writing TFs
camera_info (geometry_msgs/CameraInfo): camera intrinsics matrix
mask (ada_feeding_msgs/Mask): food context, see Mask.msg
food_frame_id (string): If len>0, TF frame to publish static transform
(relative to world_frame)
world_frame (string): ID of the TF frame to represent the food frame in
debug_food_frame (string): If len>0, TF frame to publish static transform
(relative to world_frame) for debugging purposes
"""
# pylint: disable=unused-argument
# Arguments are handled generically in base class.
Expand All @@ -72,9 +74,6 @@ def blackboard_outputs(
self,
action_select_request: Optional[BlackboardKey], # AcquisitionSelect.Request
food_frame: Optional[BlackboardKey], # TransformStamped
debug_tf_publisher: Optional[
BlackboardKey
] = None, # StaticTransformBroadcaster
) -> None:
"""
Blackboard Outputs
Expand All @@ -84,17 +83,15 @@ def blackboard_outputs(
----------
action_select_request (AcquisitionSelect.Request): request to send to AcquisitionSelect
(copies mask input)
food_frame (geometry_msgs/TransformStamped): transform from world_frame to food frame
debug_tf_publisher (StaticTransformBroadcaster): If set, store
static broadcaster here to keep it alive
for debugging purposes.
food_frame (geometry_msgs/TransformStamped): transform from world_frame to food_frame
"""
# pylint: disable=unused-argument
# Arguments are handled generically in base class.
super().blackboard_outputs(
**{key: value for key, value in locals().items() if key != "self"}
)

@override
def setup(self, **kwargs):
"""
Middleware (i.e. TF) setup
Expand All @@ -104,11 +101,13 @@ def setup(self, **kwargs):
# It is okay for attributes in behaviors to be
# defined in the setup / initialise functions.

# Get Node from Kwargs
self.node = kwargs["node"]

# Get TF Listener from blackboard
self.tf_buffer, _, self.tf_lock = get_tf_object(
self.blackboard, self.blackboard_get("ros2_node")
)
self.tf_buffer, _, self.tf_lock = get_tf_object(self.blackboard, self.node)

@override
def initialise(self):
"""
Behavior initialization
Expand All @@ -131,15 +130,17 @@ def initialise(self):
self.intrinsics.fy = camera_info.k[4]
if camera_info.distortion_model == "plumb_bob":
self.intrinsics.model = pyrealsense2.distortion.brown_conrady
self.intrinsics.coeffs = list(camera_info.d)
elif camera_info.distortion_model == "equidistant":
self.intrinsics.model = pyrealsense2.distortion.kannala_brandt4
self.intrinsics.coeffs = list(camera_info.d)
else:
self.logger.warning(
f"Unsupported camera distortion model: {camera_info.distortion_model}"
)
self.intrinsics.model = pyrealsense2.distortion.none
self.intrinsics.coeffs = list(camera_info.d)

@override
def update(self) -> py_trees.common.Status:
"""
Behavior tick (DO NOT BLOCK)
Expand All @@ -153,7 +154,6 @@ def update(self) -> py_trees.common.Status:
# to ROS2 msg types, which take 3-4 statements each.

camera_frame = self.blackboard_get("camera_info").header.frame_id
node = self.blackboard_get("ros2_node")
world_frame = self.blackboard_get("world_frame")

# Lock TF Buffer
Expand All @@ -180,9 +180,9 @@ def update(self) -> py_trees.common.Status:

# Set up return objects
world_to_food_transform = TransformStamped()
world_to_food_transform.header.stamp = node.get_clock().now().to_msg()
world_to_food_transform.header.stamp = self.node.get_clock().now().to_msg()
world_to_food_transform.header.frame_id = world_frame
world_to_food_transform.child_frame_id = self.blackboard_get("debug_food_frame")
world_to_food_transform.child_frame_id = self.blackboard_get("food_frame_id")

# De-project center of ROI
mask = self.blackboard_get("mask")
Expand Down Expand Up @@ -251,10 +251,8 @@ def update(self) -> py_trees.common.Status:
)

# Write to blackboard outputs
if len(self.blackboard_get("debug_food_frame")) > 0:
stb = StaticTransformBroadcaster(self.blackboard_get("ros2_node"))
stb.sendTransform(world_to_food_transform)
self.blackboard_set("debug_tf_publisher", stb)
if len(self.blackboard_get("food_frame_id")) > 0:
set_static_tf(world_to_food_transform, self.blackboard, self.node)
self.blackboard_set("food_frame", world_to_food_transform)
request = AcquisitionSelect.Request()
request.food_context = mask
Expand Down
87 changes: 84 additions & 3 deletions ada_feeding/ada_feeding/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

# Third-party imports
import numpy as np
from geometry_msgs.msg import Vector3, Quaternion
from geometry_msgs.msg import TransformStamped, Vector3, Quaternion
import py_trees
from py_trees.common import Access
from pymoveit2 import MoveIt2
from pymoveit2.robots import kinova
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.node import Node
from tf2_ros.buffer import Buffer
from tf2_ros.static_transform_broadcaster import StaticTransformBroadcaster
from tf2_ros.transform_listener import TransformListener


Expand Down Expand Up @@ -92,6 +93,86 @@ def quat_between_vectors(vec_from: Vector3, vec_to: Vector3) -> Quaternion:
return ret


def set_static_tf(
transform_stamped: TransformStamped,
blackboard: py_trees.blackboard.Client,
node: Optional[Node] = None,
) -> bool:
"""
Adds a transform to the list sent to /tf_static.
This uses a StaticTransformBroadcaster on the global backboard.
Note this is *not* a resource-intensive operation, as both
publisher and subscribers to /tf_static use latching.
Do NOT call this function in a fast loop.
Since these transforms are assumed static until updated, they cannot be deleted.
More Info: https://answers.ros.org/question/226824/using-tf_static-for-almost-static-transforms/
Parameters
----------
transform_stamped: Transform to publish (will overwrite a transform with an identical
child_frame_id)
blackboard: Client in which to store the static transform broadcaster (STB) and mutex
node: The ROS2 node the STB is associated with. If None, this function will not create
the STB if it does exist, and will instead raise a KeyError.
Returns
---------
False if the lock is held, else True
Raises
------
KeyError: if the TF objects do not exist and node is None.
"""

static_tf_broadcaster_blackboard_key = "/tf_static/stb"
static_tf_transforms_blackboard_key = "/tf_static/transforms"
static_tf_lock_blackboard_key = "/tf_static/lock"

# First, register the TF objects and their corresponding lock for READ access
if not blackboard.is_registered(static_tf_broadcaster_blackboard_key, Access.READ):
blackboard.register_key(static_tf_broadcaster_blackboard_key, Access.READ)
if not blackboard.is_registered(static_tf_transforms_blackboard_key, Access.WRITE):
blackboard.register_key(static_tf_transforms_blackboard_key, Access.WRITE)
if not blackboard.is_registered(static_tf_lock_blackboard_key, Access.READ):
blackboard.register_key(static_tf_lock_blackboard_key, Access.READ)

# Second, check if the MoveIt2 object and its corresponding lock exist on the
# blackboard. If they do not, register the blackboard for WRITE access to those
# keys and create them.
try:
stb = blackboard.get(static_tf_broadcaster_blackboard_key)
lock = blackboard.get(static_tf_lock_blackboard_key)
except KeyError as exc:
# If no node is passed in, raise an error.
if node is None:
raise KeyError("Static TF objects do not exist on the blackboard") from exc

# If a node is passed in, create a new MoveIt2 object and lock.
node.get_logger().info(
"Static TF objects and lock do not exist on the blackboard. Creating them now."
)
blackboard.register_key(static_tf_broadcaster_blackboard_key, Access.WRITE)
blackboard.register_key(static_tf_lock_blackboard_key, Access.WRITE)
stb = StaticTransformBroadcaster(node)
transforms = {}
lock = Lock()
blackboard.set(static_tf_broadcaster_blackboard_key, stb)
blackboard.set(static_tf_transforms_blackboard_key, transforms)
blackboard.set(static_tf_lock_blackboard_key, lock)

# Check and acquire the lock
if lock.locked():
return False

with lock:
transforms = blackboard.get(static_tf_transforms_blackboard_key)
transforms[transform_stamped.child_frame_id] = transform_stamped
blackboard.set(static_tf_transforms_blackboard_key, transforms)
stb.sendTransform(list(transforms.values()))

return True


def get_tf_object(
blackboard: py_trees.blackboard.Client,
node: Optional[Node] = None,
Expand Down Expand Up @@ -135,7 +216,7 @@ def get_tf_object(
if not blackboard.is_registered(tf_lock_blackboard_key, Access.READ):
blackboard.register_key(tf_lock_blackboard_key, Access.READ)

# Second, check if the MoveIt2 object and its corresponding lock exist on the
# Second, check if the TF objects and its corresponding lock exist on the
# blackboard. If they do not, register the blackboard for WRITE access to those
# keys and create them.
try:
Expand All @@ -147,7 +228,7 @@ def get_tf_object(
if node is None:
raise KeyError("TF objects do not exist on the blackboard") from exc

# If a node is passed in, create a new MoveIt2 object and lock.
# If a node is passed in, create new TF objects and lock.
node.get_logger().info(
"TF objects and lock do not exist on the blackboard. Creating them now."
)
Expand Down
Loading