Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix torch compiled transform utils #1033

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions omnigibson/utils/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def quat_slerp(quat0, quat1, frac, shortestpath=True, eps=1.0e-15):
# type: (Tensor, Tensor, Tensor, bool, float) -> Tensor
# reshape quaternion
quat_shape = quat0.shape
quat0 = unit_vector(quat0.reshape(-1, 4), dim=-1)
quat1 = unit_vector(quat1.reshape(-1, 4), dim=-1)
quat0 = unit_vector(quat0.reshape(-1, 4), dim=-1, out=None)
quat1 = unit_vector(quat1.reshape(-1, 4), dim=-1, out=None)

# Check for endpoint cases
where_start = frac <= 0.0
Expand Down Expand Up @@ -481,8 +481,8 @@ def vec2quat(vec: torch.Tensor, up: torch.Tensor = torch.tensor([0.0, 0.0, 1.0])
if up.dim() == 1:
up = up.unsqueeze(0)

vec_n = torch.nn.functional.normalize(vec, dim=-1)
up_n = torch.nn.functional.normalize(up, dim=-1)
vec_n = torch.nn.functional.normalize(vec, dim=-1, eps=1e-10)
ChengshuLi marked this conversation as resolved.
Show resolved Hide resolved
up_n = torch.nn.functional.normalize(up, dim=-1, eps=1e-10)

s_n = torch.cross(up_n, vec_n, dim=-1)
u_n = torch.cross(vec_n, s_n, dim=-1)
Expand Down Expand Up @@ -1141,8 +1141,8 @@ def vecs2axisangle(vec0, vec1):
vec1 (torch.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized
"""
# Normalize vectors
vec0 = normalize(vec0, dim=-1)
vec1 = normalize(vec1, dim=-1)
vec0 = normalize(vec0, dim=-1, eps=1e-10)
vec1 = normalize(vec1, dim=-1, eps=1e-10)

# Get cross product for direction of angle, and multiply by arcos of the dot product which is the angle
return torch.linalg.cross(vec0, vec1) * torch.arccos((vec0 * vec1).sum(-1, keepdim=True))
Expand All @@ -1162,8 +1162,8 @@ def vecs2quat(vec0: torch.Tensor, vec1: torch.Tensor, normalized: bool = False)
"""
# Normalize vectors if requested
if not normalized:
vec0 = normalize(vec0, dim=-1)
vec1 = normalize(vec1, dim=-1)
vec0 = normalize(vec0, dim=-1, eps=1e-10)
vec1 = normalize(vec1, dim=-1, eps=1e-10)

# Half-way Quaternion Solution -- see https://stackoverflow.com/a/11741520
cos_theta = torch.sum(vec0 * vec1, dim=-1, keepdim=True)
Expand Down
Loading