Skip to content

Commit

Permalink
Merge pull request #878 from StanfordVL/feat/curobo
Browse files Browse the repository at this point in the history
Feat/curobo
  • Loading branch information
cremebrule authored Sep 18, 2024
2 parents 1a671f8 + ceb0547 commit 15f190b
Show file tree
Hide file tree
Showing 16 changed files with 770 additions and 11 deletions.
549 changes: 549 additions & 0 deletions omnigibson/action_primitives/curobo.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions omnigibson/object_states/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from omnigibson.object_states.heat_source_or_sink import HeatSourceOrSink
from omnigibson.object_states.heated import Heated
from omnigibson.object_states.inside import Inside
from omnigibson.object_states.joint_state import Joint
from omnigibson.object_states.max_temperature import MaxTemperature
from omnigibson.object_states.next_to import NextTo
from omnigibson.object_states.object_state_base import REGISTERED_OBJECT_STATES
Expand Down
11 changes: 11 additions & 0 deletions omnigibson/object_states/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from omnigibson.object_states import *
from omnigibson.object_states.kinematics_mixin import KinematicsMixin
from omnigibson.object_states.link_based_state_mixin import LinkBasedStateMixin

# states: list of ObjectBaseState
# requirements: list of ObjectBaseRequirement
Expand Down Expand Up @@ -163,3 +164,13 @@ def get_states_by_dependency_order(states=None):
list: all states in topological order of dependency
"""
return list(reversed(list(nx.algorithms.topological_sort(get_state_dependency_graph(states)))))


# Define all metalinks
METALINK_PREFIXES = set()
for state in get_states_by_dependency_order():
if issubclass(state, LinkBasedStateMixin):
try:
METALINK_PREFIXES.add(state.metalink_prefix)
except NotImplementedError:
pass
23 changes: 23 additions & 0 deletions omnigibson/object_states/joint_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch as th

from omnigibson.macros import create_module_macros
from omnigibson.object_states.object_state_base import AbsoluteObjectState

# Create settings for this module
m = create_module_macros(module_path=__file__)

m.POSITIONAL_VALIDATION_EPSILON = 1e-10


class Joint(AbsoluteObjectState):

def _get_value(self):
return self.obj.get_joint_positions() if self.obj.n_joints > 0 else th.tensor([])

def _has_changed(self, get_value_args, value, info):
# Only changed if the squared distance between old and current q has changed above some threshold
old_q = value
# Get current joint values
cur_q = self.get_value()
dist_squared = th.sum(th.square(cur_q - old_q))
return dist_squared > m.POSITIONAL_VALIDATION_EPSILON
15 changes: 10 additions & 5 deletions omnigibson/object_states/kinematics_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from omnigibson.object_states.aabb import AABB
from omnigibson.object_states.contact_bodies import ContactBodies
from omnigibson.object_states.joint_state import Joint
from omnigibson.object_states.object_state_base import BaseObjectState
from omnigibson.object_states.pose import Pose
from omnigibson.utils.python_utils import classproperty
Expand All @@ -14,7 +15,7 @@ class KinematicsMixin(BaseObjectState):
@classmethod
def get_dependencies(cls):
deps = super().get_dependencies()
deps.update({Pose, AABB, ContactBodies})
deps.update({Pose, Joint, AABB, ContactBodies})
return deps

def cache_info(self, get_value_args):
Expand All @@ -25,10 +26,10 @@ def cache_info(self, get_value_args):
info = super().cache_info(get_value_args=get_value_args)

# Store this object as well as any other objects from @get_value_args
info[self.obj] = self.obj.states[Pose].get_value()
info[self.obj] = {"q": self.obj.states[Joint].get_value(), "p": self.obj.states[Pose].get_value()}
for arg in get_value_args:
if isinstance(arg, StatefulObject):
info[arg] = arg.states[Pose].get_value()
info[arg] = {"q": self.obj.states[Joint].get_value(), "p": arg.states[Pose].get_value()}

return info

Expand All @@ -38,9 +39,13 @@ def _cache_is_valid(self, get_value_args):

# Cache is valid if and only if all of our cached objects have not changed
t = self._cache[get_value_args]["t"]
for obj, pose in self._cache[get_value_args]["info"].items():
for obj, info in self._cache[get_value_args]["info"].items():
if isinstance(obj, StatefulObject):
if obj.states[Pose].has_changed(get_value_args=(), value=pose, info={}, t=t):
# If pose has changed, return False
if obj.states[Pose].has_changed(get_value_args=(), value=info["p"], info={}, t=t):
return False
# If obj's joints have changed, return False
if obj.states[Joint].has_changed(get_value_args=(), value=info["q"], info={}, t=t):
return False
return True

Expand Down
2 changes: 1 addition & 1 deletion omnigibson/objects/controllable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
# Store inputs
self._control_freq = control_freq
self._controller_config = controller_config
self._reset_joint_pos = None if reset_joint_pos is None else th.tensor(reset_joint_pos)
self._reset_joint_pos = None if reset_joint_pos is None else th.tensor(reset_joint_pos, dtype=th.float)

# Make sure action type is valid, and also save
assert_valid_key(key=action_type, valid_keys={"discrete", "continuous"}, name="action type")
Expand Down
4 changes: 0 additions & 4 deletions omnigibson/objects/object_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,6 @@ def scale(self, scale):
# Update init info for scale
self._init_info["args"]["scale"] = scale

@cached_property
def link_prim_paths(self):
return [link.prim_path for link in self._links.values()]

@property
def highlighted(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions omnigibson/prims/entity_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ def links(self):
"""
return self._links

@cached_property
def link_prim_paths(self):
return [link.prim_path for link in self._links.values()]

@cached_property
def has_attachment_points(self):
"""
Expand Down
8 changes: 8 additions & 0 deletions omnigibson/robots/franka.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ def robot_arm_descriptor_yamls(self):
def urdf_path(self):
return os.path.join(gm.ASSET_PATH, f"models/franka/{self.model_name}.urdf")

@property
def curobo_path(self):
# Only supported for normal franka now
assert (
self._model_name == "franka_panda"
), f"Only franka_panda is currently supported for curobo. Got: {self._model_name}"
return os.path.join(gm.ASSET_PATH, f"models/franka/{self.model_name}_description_curobo.yaml")

@property
def eef_usd_path(self):
return {self.default_arm: os.path.join(gm.ASSET_PATH, f"models/franka/{self.model_name}_eef.usd")}
Expand Down
4 changes: 4 additions & 0 deletions omnigibson/robots/franka_mounted.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def robot_arm_descriptor_yamls(self):
def urdf_path(self):
return os.path.join(gm.ASSET_PATH, "models/franka/franka_mounted.urdf")

@property
def curobo_path(self):
return os.path.join(gm.ASSET_PATH, "models/franka/franka_mounted_description_curobo.yaml")

@property
def eef_usd_path(self):
# TODO: Update!
Expand Down
8 changes: 8 additions & 0 deletions omnigibson/robots/robot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,14 @@ def urdf_path(self):
"""
raise NotImplementedError

@property
def curobo_path(self):
"""
Returns:
str: file path to the robot curobo configuration yaml file.
"""
raise NotImplementedError

@classproperty
def _do_not_register_classes(cls):
# Don't register this class since it's an abstract template
Expand Down
4 changes: 4 additions & 0 deletions omnigibson/robots/vx300s.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def robot_arm_descriptor_yamls(self):
def urdf_path(self):
return os.path.join(gm.ASSET_PATH, "models/vx300s/vx300s.urdf")

@property
def curobo_path(self):
return os.path.join(gm.ASSET_PATH, "models/vx300s/vx300s_description_curobo.yaml")

@property
def eef_usd_path(self):
# return {self.default_arm: os.path.join(gm.ASSET_PATH, "models/vx300s/vx300s_eef.usd")}
Expand Down
2 changes: 2 additions & 0 deletions omnigibson/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class ParticleModifyCondition(str, Enum):
# Structure categories that need to always be loaded for stability purposes
STRUCTURE_CATEGORIES = frozenset({"floors", "walls", "ceilings", "lawn", "driveway", "fence", "roof", "background"})

# Ground categories / prim names used for filtering collisions, e.g.: during motion planning
GROUND_CATEGORIES = frozenset({"floors", "lawn", "driveway", "carpet"})

# Joint friction magic values to assign to objects based on their category
DEFAULT_JOINT_FRICTION = 10.0
Expand Down
11 changes: 10 additions & 1 deletion omnigibson/utils/usd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from collections.abc import Iterable

import numpy as np
import torch as th
import trimesh

Expand Down Expand Up @@ -473,7 +474,7 @@ def in_contact(self, prim_paths_a, prim_paths_b):
if key not in self._CONTACT_CACHE:
# In contact if any of the matrix values representing the interaction between the two groups is non-zero
self._CONTACT_CACHE[key] = th.any(self.get_impulses(prim_paths_a=prim_paths_a, prim_paths_b=prim_paths_b))
return self._CONTACT_CACHE[key]
return self._CONTACT_CACHE[key].item()

def clear(self):
"""
Expand Down Expand Up @@ -1605,6 +1606,14 @@ def create_primitive_mesh(prim_path, primitive_type, extents=1.0, u_patches=None
)
)

# Modify values so that all faces are triangular
tm = mesh_prim_to_trimesh_mesh(mesh.GetPrim())
face_vertex_counts = np.array([len(face) for face in tm.faces], dtype=int)
mesh.GetFaceVertexCountsAttr().Set(face_vertex_counts)
mesh.GetFaceVertexIndicesAttr().Set(tm.faces.flatten())
mesh.GetNormalsAttr().Set(lazy.pxr.Vt.Vec3fArray.FromNumpy(tm.vertex_normals[tm.faces.flatten()]))
mesh.GetPrim().GetAttribute("primvars:st").Set(lazy.pxr.Vt.Vec2fArray.FromNumpy(tm.visual.uv[tm.faces.flatten()]))

return mesh


Expand Down
118 changes: 118 additions & 0 deletions tests/test_curobo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import torch as th

import omnigibson as og
from omnigibson.action_primitives.curobo import CuRoboMotionGenerator
from omnigibson.macros import gm
from omnigibson.object_states import Touching


def test_curobo():
# Make sure object states are enabled
assert gm.ENABLE_OBJECT_STATES

# Create env
cfg = {
"scene": {
"type": "Scene",
},
"objects": [
{
"type": "PrimitiveObject",
"name": "obj0",
"primitive_type": "Cube",
"scale": [0.4, 0.4, 0.4],
"fixed_base": True,
"position": [0.5, -0.1, 0.2],
"orientation": [0, 0, 0, 1],
},
],
"robots": [
{
"type": "FrankaPanda",
"obs_modalities": "rgb",
"position": [0.7, -0.55, 0.0],
"orientation": [0, 0, 0.707, 0.707],
"self_collisions": True,
},
],
}
env = og.Environment(configs=cfg)
robot = env.robots[0]
obj = env.scene.object_registry("name", "obj0")

robot.reset()
robot.keep_still()

for _ in range(5):
og.sim.step()

# Create CuRobo instance
batch_size = 25
n_samples = 55
cmg = CuRoboMotionGenerator(
robot=robot,
batch_size=batch_size,
)

# Sample values for robot
th.manual_seed(1)
lo, hi = robot.joint_lower_limits.view(1, -1), robot.joint_upper_limits.view(1, -1)
random_qs = lo + th.rand((n_samples, robot.n_dof)) * (hi - lo)

# Test collision
collision_results = cmg.check_collisions(q=random_qs, activation_distance=0.0)
eef_positions, eef_quats = [], []

# View results
n_mismatch = 0
for i, (q, result) in enumerate(zip(random_qs, collision_results)):
# Set robot to desired qpos
robot.set_joint_positions(q)
robot.keep_still()
og.sim.step()

# Validate that expected collision result is correct
true_result = robot.states[Touching].get_value(obj)

if result.item() != true_result:
n_mismatch += 1

# If we're collision-free, record this pose so that we can test trajectory planning afterwards
if not result and len(robot.contact_list()) == 0:
eef_pos, eef_quat = robot.get_relative_eef_pose()
eef_positions.append(eef_pos)
eef_quats.append(eef_quat)

# Make sure mismatched results are small
# Slight mismatch may occur because sphere approximation is not quite equal to the collision sim representation
assert n_mismatch / n_samples < 0.1, f"Proportion mismatched results: {n_mismatch / n_samples}"

# Test trajectories
robot.reset()
robot.keep_still()
og.sim.step()

successes, traj_paths = cmg.compute_trajectories(
target_pos=th.stack(eef_positions, dim=0),
target_quat=th.stack(eef_quats, dim=0),
is_local=True,
max_attempts=1,
enable_finetune_trajopt=True,
return_full_result=False,
success_ratio=1.0,
attached_obj=None,
)

# Execute the trajectory and make sure there's rarely any collisions
assert th.sum(successes) > 0.95, f"Failed to find > 95% collision-free trajectories: {successes}"
print(f"Total successes: {th.sum(successes)} / {len(successes)}")
for success, traj_path in zip(successes, traj_paths):
if not success:
continue
q_traj = cmg.path_to_joint_trajectory(traj_path)
for q in q_traj:
robot.set_joint_positions(q)
robot.keep_still()
og.sim.step()
assert len(robot.contact_list()) == 0
17 changes: 17 additions & 0 deletions tests/test_object_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ def test_pose(env):
breakfast_table.states[Pose].set_value(None)


@og_test
def test_joint(env):
breakfast_table = env.scene.object_registry("name", "breakfast_table")
bottom_cabinet = env.scene.object_registry("name", "bottom_cabinet")

lo = bottom_cabinet.joint_lower_limits
hi = bottom_cabinet.joint_upper_limits
q_rand = lo + (hi - lo) * th.rand(bottom_cabinet.n_joints)
bottom_cabinet.set_joint_positions(q_rand)

assert th.allclose(bottom_cabinet.states[Joint].get_value(), q_rand)
assert len(breakfast_table.states[Joint].get_value()) == 0

with pytest.raises(NotImplementedError):
bottom_cabinet.states[Joint].set_value(None)


@og_test
def test_aabb(env):
breakfast_table = env.scene.object_registry("name", "breakfast_table")
Expand Down

0 comments on commit 15f190b

Please sign in to comment.