Skip to content

Commit

Permalink
translation/rotation gridding (#192)
Browse files Browse the repository at this point in the history
Implements 
```
    grid = pose_grid(
        pose0,   # center of grid
        min_x=min_x,    # translation min/max ranges (x,y,z)
        min_y=min_y,
        min_z=min_z,
        max_x=max_x,
        max_y=max_y,
        max_z=max_z,
        nx=nx,     # translation number of grid points (x,y,z)
        ny=ny,
        nz=ny,
        min_euler_angle=min_euler,     # rotation min/max ranges (x,y,z)
        max_euler_angle=max_euler,
        n_xrot=n_xrot,    # rotations around x axis 
        n_yrot=n_yrot,
        n_zrot=n_zrot,
    )
```
Which produces a uniform translation/rotation grid from a center pose.
(Make small note that this is different from the old Bayes3D rotation
gridding scheme based on Ben Zinberg's fibonacci lattice division.)

Script also contains a test for visually sanity checking the grid poses
on rerun; this code can be moved out/deleted.

Perf benchmark with `jax.jit`: `Time taken: 0.6721019744873047
milliseconds for 15625 poses`

---------

Co-authored-by: George Matheos <[email protected]>
  • Loading branch information
karen-sy and georgematheos authored Sep 25, 2024
1 parent 3577f5c commit 4273a6b
Show file tree
Hide file tree
Showing 2 changed files with 311 additions and 0 deletions.
182 changes: 182 additions & 0 deletions src/b3d/pose/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os

import jax
import jax.numpy as jnp

# for debug+test only
import rerun as rr
from jax.scipy.spatial.transform import Rotation

import b3d
from b3d import Mesh

from .core import Pose


def viz_rotation(pose, vertices, t=0, channel="mesh/xfm_cloud"):
"""visualize a rotation of a given mesh vertex set in rerun"""
b3d.rr_set_time(t)

# render
b3d.rr_log_cloud(
pose.apply(vertices),
channel,
)
rr.log(
"info",
rr.TextDocument(
f"""
translation: {pose.pos}
rotation (xyzw): {pose.quaternion}
""".strip(),
media_type=rr.MediaType.MARKDOWN,
),
)
b3d.rr_log_pose(pose, "axes/xfm_pose_axes")


def viz_from_grid(pose_grid, rerun_session_name="grid_test", ycb_obj_id=13):
# load vertices
ycb_dir = os.path.join(b3d.get_assets_path(), "bop/ycbv")
mesh = Mesh.from_obj_file(
os.path.join(ycb_dir, f'models/obj_{f"{ycb_obj_id + 1}".rjust(6, "0")}.ply')
).scale(0.001)
cam_pose = Pose(
position=mesh.vertices.mean(axis=0), quaternion=jnp.array([0, 0, 0, 1])
) # center the mesh for cleaner viz
mesh_vertices = cam_pose.inv().apply(mesh.vertices)

# setup viz
b3d.rr_init(rerun_session_name)

# log the default pose
b3d.rr_set_time(0)
b3d.rr_log_pose(Pose.identity(), "axes/default_pose_axes")
viz_rotation(Pose.identity(), mesh_vertices, 0, "mesh/default_pose_cloud")

# visualize
for t, pose_viz in enumerate(pose_grid):
viz_rotation(pose_viz, mesh_vertices, t + 1, "mesh/xfm_cloud")


def sorted_linspace(delta, half_num):
if half_num == 0: # only the zero-transform sample is returned
return jnp.array([0.0])

num_samples = half_num * 2 + 1
linspace = jnp.linspace(-delta, delta, num_samples)
ordered_linspace = linspace[jnp.argsort(jnp.abs(linspace))]
return ordered_linspace


def rr_log_pose_arrows_grid(pose_grid, channel="pose_grid", scale=0.02):
origins = jnp.tile(pose_grid.pos, (3, 1))
colors = jnp.tile(jnp.eye(3), (len(origins), 1))
vectors = jax.vmap(lambda pose: pose.as_matrix()[:3, :3].T)(pose_grid) * scale
vectors = vectors.reshape(-1, 3)
rr.log(channel, rr.Arrows3D(origins=origins, vectors=vectors, colors=colors))


def make_rotation_grid_enumeration(
half_d_angle,
half_n_alpha,
half_n_beta,
half_n_gamma,
) -> Pose:
"""
Enumerate rotations via euler angles uniformly gridded in the range [min_angle, max_angle]
for each axis (angle of axes: X = alpha, Y = beta, Z = gamma)
"""
alphas = sorted_linspace(half_d_angle, half_n_alpha)
betas = sorted_linspace(half_d_angle, half_n_beta)
gammas = sorted_linspace(half_d_angle, half_n_gamma)

# nest vmap over all axes
def _inner_proposal(alpha, beta, gamma):
return Rotation.from_euler("ZYX", jnp.array([gamma, beta, alpha]))

_proposal_z = lambda gamma, beta: jax.vmap( # noqa:E731
_inner_proposal, in_axes=(None, None, 0)
)(gamma, beta, alphas)
_proposal_zy = lambda gamma: jax.vmap(_proposal_z, in_axes=(None, 0))(gamma, betas) # noqa:E731
proposal_zyx = jax.vmap(_proposal_zy, in_axes=(0,))(gammas)

n_alpha, n_beta, n_gamma = (
2 * half_n_alpha + 1,
2 * half_n_beta + 1,
2 * half_n_gamma + 1,
)
return Pose(
jnp.zeros((n_alpha * n_beta * n_gamma, 3)),
proposal_zyx.as_quat().reshape(-1, 4),
)


def make_translation_grid_enumeration(
half_dx, half_dy, dz, half_num_x, half_num_y, half_num_z
) -> Pose:
"""
Generate uniformly spaced translation proposals in a 3D box
Args:
half_dx, half_dy, dz: half-dimension of each of the x, y, z directions
half_num_x, half_num_y, half_num_z: samples in each of the dimensions, EXCLUDING the zero sample.
"""
x_space = sorted_linspace(half_dx, half_num_x)
y_space = sorted_linspace(half_dy, half_num_y)
z_space = sorted_linspace(dz, half_num_z)
deltas = jnp.stack(
jnp.meshgrid(
x_space,
y_space,
z_space,
),
axis=-1,
)
deltas = deltas.reshape((-1, 3), order="F")

num_x, num_y, num_z = 2 * half_num_x + 1, 2 * half_num_y + 1, 2 * half_num_z + 1
return Pose(deltas, jnp.tile(Pose.identity_quaternion, (num_x * num_y * num_z, 1)))


def pose_grid(
pose_center: Pose,
half_dx: float,
half_dy: float,
half_dz: float,
half_nx: int,
half_ny: int,
half_nz: int,
half_dangle: float,
half_n_xrot: int,
half_n_yrot: int,
half_n_zrot: int,
) -> Pose:
"""
Returns:
A batched Pose object.
Args:
pose_center: the center pose
half_dx, dy, dz: half the step size for the x, y, z dimension
half_nx, ny, nz: half the number of samples in the x, y, z dimension
(2n + 1 samples will be used per dimension)
half_dangle: half the step size for the euler angle
half_n_xrot, half_n_yrot, half_n_zrot: half the number of samples in the x, y, z rotation
(2n + 1 samples will be used per dimension)
"""
tr_grid = make_translation_grid_enumeration(
half_dx, half_dy, half_dz, half_nx, half_ny, half_nz
)
rot_grid = make_rotation_grid_enumeration(
half_dangle, half_n_xrot, half_n_yrot, half_n_zrot
)

compose_pose = lambda tr, rot: pose_center.compose(tr).compose(rot) # noqa:E731
inner_vmap = lambda tr: jax.vmap(compose_pose, in_axes=(None, 0))( # noqa:E731
tr, rot_grid
)
_total_grid = jax.vmap(inner_vmap, in_axes=(0,))(tr_grid) # vmap over tr

total_grid = _total_grid.reshape(-1)

return total_grid
129 changes: 129 additions & 0 deletions tests/gen3d/test_posegrid_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import jax
import jax.numpy as jnp
from b3d import Pose
from b3d.pose.grid import pose_grid, viz_from_grid


def test_pose_gridding():
#################################
# Setup test
#################################
VIZ_TEST = False # toggle to visualize all grid on rerun

# center pose
xyz0 = jnp.array([0.0, 0.0, 0.0])
rot0 = jnp.array([0.0, 0.0, 0.0, 1.0])
pose0 = Pose(xyz0, rot0)

# translation grid
half_nx = half_ny = half_nz = 2
ntr = (2 * half_nx + 1) * (2 * half_ny + 1) * (2 * half_nz + 1)
print(f"Generating {ntr} translations")
half_dx, half_dy, half_dz = 1, 1, 0.001

# rotation grid
half_n_alpha = half_n_beta = half_n_gamma = 2
half_d_euler = jnp.pi / 3
nrot = (2 * half_n_alpha + 1) * (2 * half_n_beta + 1) * (2 * half_n_gamma + 1)
print(f"Generating {nrot} rotations")

##################################
# Generate pose grid
##################################

pose_grid_jit = jax.jit(
pose_grid,
static_argnames=(
"half_nx",
"half_ny",
"half_nz",
"half_n_xrot",
"half_n_yrot",
"half_n_zrot",
),
)
grid = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=half_nx,
half_ny=half_ny,
half_nz=half_nz,
half_dangle=half_d_euler,
half_n_xrot=half_n_alpha,
half_n_yrot=half_n_beta,
half_n_zrot=half_n_gamma,
)

##################################
# Test correctness
##################################
# 1a. sanity check sizes
assert isinstance(
grid, Pose
), f"Wrong return type; expected b3d.Pose, got {type(grid)}"
assert grid.pos.shape == (
ntr * nrot,
3,
), f"Wrong shape for pos; expected {(ntr * nrot, 3)}, got {grid.pos.shape}"
assert grid.quaternion.shape == (
ntr * nrot,
4,
), f"Wrong shape for quat; expected {(ntr * nrot, 4)}, got {grid.quaternion.shape}"

# 1b. check that original pose is in grid
assert (
grid.pos.tolist().index(pose0.pos.tolist())
== grid.quat.tolist().index(pose0.quat.tolist())
!= -1
), "Center pose not in grid"
print("Size checks and center-pose checks passed")

# 2. visualize grid
if VIZ_TEST:
print(f"Visualizing {ntr*nrot} poses in rerun...")
viz_from_grid(grid, rerun_session_name=f"GRID_{ntr}_{nrot}", ycb_obj_id=13)
else:
print("Skipping visualization...")

# 3. Test that a 1-pose grid is just the starting point
pose_from_grid = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=0,
half_ny=0,
half_nz=0,
half_dangle=half_d_euler,
half_n_xrot=0,
half_n_yrot=0,
half_n_zrot=0,
)[0]
assert jnp.allclose(pose_from_grid.position, pose0.position) and jnp.allclose(
pose_from_grid.quaternion, pose0.quaternion
), "Single-pose grid not equal to original pose"

# ##################################
# # Test time
# ##################################
import time

print("Testing jitted runtime...")
start = time.time()
_ = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=half_nx,
half_ny=half_ny,
half_nz=half_nz,
half_dangle=half_d_euler,
half_n_xrot=half_n_alpha,
half_n_yrot=half_n_beta,
half_n_zrot=half_n_gamma,
)
end = time.time()
print(f"Time taken: {(end-start)*1000} milliseconds for {ntr*nrot} poses")

0 comments on commit 4273a6b

Please sign in to comment.