Skip to content

Commit

Permalink
Implement Robosuite Benchmarking Standardization (#2)
Browse files Browse the repository at this point in the history
* Cleaned up some of the code

* Added benchmarking code to be more clear

* Added printing of parameters for a user

* Updated with inpaint code

* Added support for files that contain the configurations

* Updated benchmarking code to work

* Removed critical hardcoded paths
  • Loading branch information
KDharmarajanDev authored Feb 27, 2024
1 parent 69109ab commit 79265fc
Show file tree
Hide file tree
Showing 20 changed files with 603 additions and 278 deletions.
4 changes: 3 additions & 1 deletion xembody/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
description='Repo for Mirage',
author='Lawrence Chen, Karthik Dharmarajan, Kush Hari',
package_dir = {'': '.'},
packages=find_packages(include=['xembody', 'xembody.*']),
packages=find_packages(include=['xembody', 'xembody.*', 'xembody_robosuite', 'xembody_robosuite.*']),
install_requires=[
"numpy",
"pyyaml",
],
extras_require={
"docs": [
Expand Down
Empty file.
11 changes: 11 additions & 0 deletions xembody/xembody/src/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod

class ExperimentConfig(ABC):

@abstractmethod
def validate_config(self):
"""
Validate the configuration to see if the values are feasible.
:throws ValueError: If the configuration is not valid.
"""
pass
17 changes: 11 additions & 6 deletions xembody/xembody/src/general/ros_inpaint_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sensor_msgs.msg import Image, PointCloud2
from cv_bridge import CvBridge
from xembody.src.general.xembody_publisher import XEmbodyPublisher
from sensor_msgs.msg import PointCloud2, PointField
from sensor_msgs.msg import PointCloud2, PointField, Image
from input_filenames_msg.msg import MultipleInpaintImages
import cv2
import numpy as np
Expand All @@ -15,7 +15,7 @@ class ROSInpaintPublisher(XEmbodyPublisher):
Publishing data is left to sim or real subclasses.
"""

def __init__(self, use_diffusion: bool = False):
def __init__(self, use_diffusion: bool = False, uses_single_img: bool = False):
"""
Initializes the ROS2 node.
"""
Expand All @@ -39,9 +39,10 @@ def __init__(self, use_diffusion: bool = False):
# 0.1
# )
# self._time_sync.registerCallback(self._inpaint_image_callback)

self._uses_single_img = uses_single_img
self._inpaint_sub_type = MultipleInpaintImages if not uses_single_img else Image
self._inpaint_sub = self.node.create_subscription(
MultipleInpaintImages, 'inpainted_image', self._inpaint_single, 1)
self._inpaint_sub_type, 'inpainted_image', self._inpaint_single, 1)

self._cv_bridge = CvBridge()
self._cv_images = None
Expand Down Expand Up @@ -111,8 +112,12 @@ def _inpaint_single(self, inpaint_msg):
print("Received inpainted images")
with self._internal_lock:
images = []
for image in inpaint_msg.images:
images.append(self._cv_bridge.imgmsg_to_cv2(image))
if not self._uses_single_img:
for image in inpaint_msg.images:
images.append(self._cv_bridge.imgmsg_to_cv2(image))
else:
images = self._cv_bridge.imgmsg_to_cv2(inpaint_msg)

self._cv_images = images

if self._use_diffusion:
Expand Down
18 changes: 12 additions & 6 deletions xembody/xembody/src/general/ros_inpaint_publisher_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,22 @@ class ROSInpaintPublisherSim(ROSInpaintPublisher):
to a node that performs inpainting on a target robot.
"""

def __init__(self):
def __init__(self, source_robot_info: str, target_robot_info: str):
"""
Initializes the ROS2 node.
:param source_robot_info: the information about the source robot to determine which interpolation scheme to use
:param target_robot_info: the information about the target robot to determine which interpolation scheme to use
"""
super().__init__()
super().__init__(uses_single_img=True)

self._publisher = self.node.create_publisher(
InputFilesSimData, 'input_files_data_sim', 1)
InputFilesSimData, '/input_files_data_sim', 1)

# TODO: generalize this
self.gripper_interpolator = GripperInterpolator('panda', 'ur5', '/home/kdharmarajan/x-embody/xembody/xembody_robosuite/paired_trajectories_collection/gripper_interpolation_results_no_task_diff.pkl')
self.gripper_interpolator = GripperInterpolator(source_robot_info, target_robot_info, ['/home/mirage/x-embody/xembody/xembody_robosuite/paired_trajectories_collection/gripper_interpolation_results_no_task_diff.pkl',
'/home/mirage/x-embody/xembody/xembody_robosuite/paired_trajectories_collection/gripper_interpolation_results_20_rollouts.pkl'])
# self.gripper_interpolator = GripperInterpolator('panda', 'panda', ['/home/mirage/x-embody/xembody/xembody_robosuite/paired_trajectories_collection/gripper_interpolation_results_no_task_diff.pkl'])
# self.gripper_interpolator = GripperInterpolator('panda', 'ur5', '/home/mirage/x-embody/xembody/xembody_robosuite/paired_trajectories_collection/gripper_interpolation_results_no_task_diff.pkl')

def publish_to_ros_node(self, data: ROSInpaintSimData):
"""
Expand All @@ -56,5 +61,6 @@ def publish_to_ros_node(self, data: ROSInpaintSimData):
msg.segmentation = segmentation_mask.flatten().tolist()
msg.ee_pose = data.ee_pose.flatten().tolist()
msg.interpolated_gripper = self.gripper_interpolator.interpolate_gripper(data.gripper_angles).flatten().tolist()
msg.camera_name = data.camera_name
self._publisher.publish(msg)
# msg.camera_name = data.camera_name
self._publisher.publish(msg)
print("Published message")
12 changes: 12 additions & 0 deletions xembody/xembody/src/ros_ws/src/gazebo_env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ ament_target_dependencies(panda_control_plugin
)
target_link_libraries(panda_control_plugin ${GAZEBO_LIBRARIES} ${OpenCV_LIBS})

add_library(panda_no_gripper_control_plugin SHARED panda_no_gripper_control_plugin/panda_no_gripper_control_plugin.cc)
ament_target_dependencies(panda_no_gripper_control_plugin
"gazebo_dev"
"gazebo_ros"
"rclcpp"
"std_msgs"
"sensor_msgs"
"message_filters"
"cv_bridge"
)
target_link_libraries(panda_no_gripper_control_plugin ${GAZEBO_LIBRARIES} ${OpenCV_LIBS})

add_library(ur5_and_panda_no_gripper_control_plugin SHARED ur5_and_panda_no_gripper_control_plugin/ur5_and_panda_no_gripper_control_plugin.cc)
ament_target_dependencies(ur5_and_panda_no_gripper_control_plugin
"gazebo_dev"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,27 @@
)
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare

import os

def generate_launch_description():

# Set the path to the Gazebo ROS package
pkg_gazebo_ros = FindPackageShare(package='gazebo_ros').find('gazebo_ros')

# Set the path to this package.
pkg_share = FindPackageShare(package='gazebo_env').find('gazebo_env')

# Set the path to the world file
world_file_name = 'no_shadow_sim.world'
world_path = os.path.join(pkg_share, 'worlds', world_file_name)

world = LaunchConfiguration('world')

declare_world_cmd = DeclareLaunchArgument(
name='world',
default_value=world_path,
description='Full path to the world model file to load')

# Declare arguments
declared_arguments = []
declared_arguments.append(
Expand All @@ -37,18 +55,14 @@ def generate_launch_description():
)
)

declared_arguments.append(declare_world_cmd)

# Initialize Arguments
gui = LaunchConfiguration("gui")

gazebo_server = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
PathJoinSubstitution(
[FindPackageShare("gazebo_ros"), "launch", "gzserver.launch.py"]
)
]
)
)
PythonLaunchDescriptionSource(os.path.join(pkg_gazebo_ros, 'launch', 'gzserver.launch.py')),
launch_arguments={'world': world}.items())
gazebo_client = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,26 @@
)
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare

import os

def generate_launch_description():
# Set the path to the Gazebo ROS package
pkg_gazebo_ros = FindPackageShare(package='gazebo_ros').find('gazebo_ros')

# Set the path to this package.
pkg_share = FindPackageShare(package='gazebo_env').find('gazebo_env')

# Set the path to the world file
world_file_name = 'no_shadow_sim.world'
world_path = os.path.join(pkg_share, 'worlds', world_file_name)

world = LaunchConfiguration('world')

declare_world_cmd = DeclareLaunchArgument(
name='world',
default_value=world_path,
description='Full path to the world model file to load')

# Declare arguments
declared_arguments = []
declared_arguments.append(
Expand All @@ -37,18 +54,14 @@ def generate_launch_description():
)
)

declared_arguments.append(declare_world_cmd)

# Initialize Arguments
gui = LaunchConfiguration("gui")

gazebo_server = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
PathJoinSubstitution(
[FindPackageShare("gazebo_ros"), "launch", "gzserver.launch.py"]
)
]
)
)
PythonLaunchDescriptionSource(os.path.join(pkg_gazebo_ros, 'launch', 'gzserver.launch.py')),
launch_arguments={'world': world}.items())
gazebo_client = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
Expand Down Expand Up @@ -128,7 +141,7 @@ def generate_launch_description():

nodes = [
gazebo_server,
gazebo_client,
#gazebo_client,
node_robot_state_publisher,
spawn_entity,
joint_state_publisher_node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,26 @@
)
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare

import os

def generate_launch_description():
# Set the path to the Gazebo ROS package
pkg_gazebo_ros = FindPackageShare(package='gazebo_ros').find('gazebo_ros')

# Set the path to this package.
pkg_share = FindPackageShare(package='gazebo_env').find('gazebo_env')

# Set the path to the world file
world_file_name = 'no_shadow_sim.world'
world_path = os.path.join(pkg_share, 'worlds', world_file_name)

world = LaunchConfiguration('world')

declare_world_cmd = DeclareLaunchArgument(
name='world',
default_value=world_path,
description='Full path to the world model file to load')

# Declare arguments
declared_arguments = []
declared_arguments.append(
Expand All @@ -37,18 +54,14 @@ def generate_launch_description():
)
)

declared_arguments.append(declare_world_cmd)

# Initialize Arguments
gui = LaunchConfiguration("gui")

gazebo_server = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
PathJoinSubstitution(
[FindPackageShare("gazebo_ros"), "launch", "gzserver.launch.py"]
)
]
)
)
PythonLaunchDescriptionSource(os.path.join(pkg_gazebo_ros, 'launch', 'gzserver.launch.py')),
launch_arguments={'world': world}.items())
gazebo_client = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@
<horizontal_fov>0.7853981634</horizontal_fov>
<image>
<format>R8G8B8</format>
<width>256</width>
<height>256</height>
<width>512</width>
<height>512</height>
<!-- <width>256</width>
<height>256</height> -->
</image>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<!-- for some reason we have to modify the xyz because it is a little shifted-->

<!-- Lift task-->
<origin xyz="0.516 0.018 11.749" rpy="0 0.7840421 3.1415927" />
<origin xyz="1.016 0.018 11.749" rpy="0 0.7840421 3.1415927" />

<!-- Can task-->
<!-- <origin xyz="1.016 0.018 11.749" rpy="0 0.7840421 3.1415927" /> -->
Expand Down
Loading

0 comments on commit 79265fc

Please sign in to comment.