Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix torch compiled transform utils #1033

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion omnigibson/configs/tiago_primitives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ robots:
name: JointController
arm_left:
name: JointController
subsume_controllers: [trunk]
motor_type: position
command_input_limits: null
command_output_limits: null
use_delta_commands: false
arm_right:
name: JointController
subsume_controllers: [trunk]
motor_type: position
command_input_limits: null
command_output_limits: null
Expand Down
11 changes: 7 additions & 4 deletions omnigibson/objects/controllable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def _load_controllers(self):
# Generate the controller config
self._controller_config = self._generate_controller_config(custom_config=self._controller_config)

# We copy the controller config here because we add/remove some keys in-place that shouldn't persist
_controller_config = deepcopy(self._controller_config)

# Store dof idx mapping to dof name
self.dof_names_ordered = list(self._joints.keys())

Expand All @@ -237,8 +240,8 @@ def _load_controllers(self):
subsume_names = set()
for name in self._raw_controller_order:
# Make sure we have the valid controller name specified
assert_valid_key(key=name, valid_keys=self._controller_config, name="controller name")
cfg = self._controller_config[name]
assert_valid_key(key=name, valid_keys=_controller_config, name="controller name")
cfg = _controller_config[name]
subsume_controllers = cfg.pop("subsume_controllers", [])
# If this controller subsumes other controllers, it cannot be subsumed by another controller
# (i.e.: we don't allow nested / cyclical subsuming)
Expand All @@ -262,11 +265,11 @@ def _load_controllers(self):
# If this controller is subsumed by another controller, simply skip it
if name in subsume_names:
continue
cfg = self._controller_config[name]
cfg = _controller_config[name]
# If we subsume other controllers, prepend the subsumed' dof idxs to this controller's idxs
if name in controller_subsumes:
for subsumed_name in controller_subsumes[name]:
subsumed_cfg = self._controller_config[subsumed_name]
subsumed_cfg = _controller_config[subsumed_name]
cfg["dof_idx"] = th.concatenate([subsumed_cfg["dof_idx"], cfg["dof_idx"]])

# If we're using normalized action space, override the inputs for all controllers
Expand Down
4 changes: 2 additions & 2 deletions omnigibson/utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def create_object_from_init_info(init_info):

def safe_equal(a, b):
if isinstance(a, th.Tensor) and isinstance(b, th.Tensor):
return (a == b).all().item()
return a.shape == b.shape and (a == b).all().item()
elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
return all(safe_equal(a_item, b_item) for a_item, b_item in zip(a, b))
return len(a) == len(b) and all(safe_equal(a_item, b_item) for a_item, b_item in zip(a, b))
else:
return a == b

Expand Down
16 changes: 8 additions & 8 deletions omnigibson/utils/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def quat_slerp(quat0, quat1, frac, shortestpath=True, eps=1.0e-15):
# type: (Tensor, Tensor, Tensor, bool, float) -> Tensor
# reshape quaternion
quat_shape = quat0.shape
quat0 = unit_vector(quat0.reshape(-1, 4), dim=-1)
quat1 = unit_vector(quat1.reshape(-1, 4), dim=-1)
quat0 = unit_vector(quat0.reshape(-1, 4), dim=-1, out=None)
quat1 = unit_vector(quat1.reshape(-1, 4), dim=-1, out=None)

# Check for endpoint cases
where_start = frac <= 0.0
Expand Down Expand Up @@ -481,8 +481,8 @@ def vec2quat(vec: torch.Tensor, up: torch.Tensor = torch.tensor([0.0, 0.0, 1.0])
if up.dim() == 1:
up = up.unsqueeze(0)

vec_n = torch.nn.functional.normalize(vec, dim=-1)
up_n = torch.nn.functional.normalize(up, dim=-1)
vec_n = normalize(vec, dim=-1, eps=1e-10)
up_n = normalize(up, dim=-1, eps=1e-10)

s_n = torch.cross(up_n, vec_n, dim=-1)
u_n = torch.cross(vec_n, s_n, dim=-1)
Expand Down Expand Up @@ -1141,8 +1141,8 @@ def vecs2axisangle(vec0, vec1):
vec1 (torch.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized
"""
# Normalize vectors
vec0 = normalize(vec0, dim=-1)
vec1 = normalize(vec1, dim=-1)
vec0 = normalize(vec0, dim=-1, eps=1e-10)
vec1 = normalize(vec1, dim=-1, eps=1e-10)

# Get cross product for direction of angle, and multiply by arcos of the dot product which is the angle
return torch.linalg.cross(vec0, vec1) * torch.arccos((vec0 * vec1).sum(-1, keepdim=True))
Expand All @@ -1162,8 +1162,8 @@ def vecs2quat(vec0: torch.Tensor, vec1: torch.Tensor, normalized: bool = False)
"""
# Normalize vectors if requested
if not normalized:
vec0 = normalize(vec0, dim=-1)
vec1 = normalize(vec1, dim=-1)
vec0 = normalize(vec0, dim=-1, eps=1e-10)
vec1 = normalize(vec1, dim=-1, eps=1e-10)

# Half-way Quaternion Solution -- see https://stackoverflow.com/a/11741520
cos_theta = torch.sum(vec0 * vec1, dim=-1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_robot_states_flatcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def camera_pose_test(flatcache):
relative_pose_transform(sensor_world_pos, sensor_world_ori, robot_world_pos, robot_world_ori)
)

sensor_world_pos_gt = th.tensor([150.1620, 149.9999, 101.2193])
sensor_world_pos_gt = th.tensor([150.1628, 149.9993, 101.3773])
sensor_world_ori_gt = th.tensor([-0.2952, 0.2959, 0.6427, -0.6421])

assert th.allclose(sensor_world_pos, sensor_world_pos_gt, atol=1e-3)
Expand Down
Loading