Skip to content

Commit

Permalink
docstring; move grid test to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
georgematheos committed Sep 19, 2024
1 parent 75a260f commit e76938c
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 135 deletions.
158 changes: 23 additions & 135 deletions src/b3d/pose/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,30 @@ def make_translation_grid_enumeration(

def pose_grid(
pose_center: Pose,
half_dx=float,
half_dy=float,
half_dz=float,
half_nx=int,
half_ny=int,
half_nz=int,
half_dangle=float,
half_n_xrot=int,
half_n_yrot=int,
half_n_zrot=int,
half_dx: float,
half_dy: float,
half_dz: float,
half_nx: int,
half_ny: int,
half_nz: int,
half_dangle: float,
half_n_xrot: int,
half_n_yrot: int,
half_n_zrot: int,
) -> Pose:
"""
Returns:
A batched Pose object.
Args:
pose_center: the center pose
half_dx, dy, dz: half the step size for the x, y, z dimension
half_nx, ny, nz: half the number of samples in the x, y, z dimension
(2n + 1 samples will be used per dimension)
half_dangle: half the step size for the euler angle
half_n_xrot, half_n_yrot, half_n_zrot: half the number of samples in the x, y, z rotation
(2n + 1 samples will be used per dimension)
"""
tr_grid = make_translation_grid_enumeration(
half_dx, half_dy, half_dz, half_nx, half_ny, half_nz
)
Expand All @@ -167,128 +180,3 @@ def pose_grid(
total_grid = _total_grid.reshape(-1)

return total_grid


if __name__ == "__main__":
#################################
# Setup test
#################################
VIZ_TEST = False # toggle to visualize all grid on rerun

# center pose
xyz0 = jnp.array([0.0, 0.0, 0.0])
rot0 = jnp.array([0.0, 0.0, 0.0, 1.0])
pose0 = Pose(xyz0, rot0)

# translation grid
half_nx = half_ny = half_nz = 2
ntr = (2 * half_nx + 1) * (2 * half_ny + 1) * (2 * half_nz + 1)
print(f"Generating {ntr} translations")
half_dx, half_dy, half_dz = 1, 1, 0.001

# rotation grid
half_n_alpha = half_n_beta = half_n_gamma = 2
half_d_euler = jnp.pi / 3
nrot = (2 * half_n_alpha + 1) * (2 * half_n_beta + 1) * (2 * half_n_gamma + 1)
print(f"Generating {nrot} rotations")

##################################
# Generate pose grid
##################################

pose_grid_jit = jax.jit(
pose_grid,
static_argnames=(
"half_nx",
"half_ny",
"half_nz",
"half_n_xrot",
"half_n_yrot",
"half_n_zrot",
),
)
grid = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=half_nx,
half_ny=half_ny,
half_nz=half_nz,
half_dangle=half_d_euler,
half_n_xrot=half_n_alpha,
half_n_yrot=half_n_beta,
half_n_zrot=half_n_gamma,
)

##################################
# Test correctness
##################################
# 1a. sanity check sizes
assert isinstance(
grid, Pose
), f"Wrong return type; expected b3d.Pose, got {type(grid)}"
assert grid.pos.shape == (
ntr * nrot,
3,
), f"Wrong shape for pos; expected {(ntr * nrot, 3)}, got {grid.pos.shape}"
assert grid.quaternion.shape == (
ntr * nrot,
4,
), f"Wrong shape for quat; expected {(ntr * nrot, 4)}, got {grid.quaternion.shape}"

# 1b. check that original pose is in grid
assert (
grid.pos.tolist().index(pose0.pos.tolist())
== grid.quat.tolist().index(pose0.quat.tolist())
!= -1
), "Center pose not in grid"
print("Size checks and center-pose checks passed")

# 2. visualize grid
if VIZ_TEST:
print(f"Visualizing {ntr*nrot} poses in rerun...")
viz_from_grid(grid, rerun_session_name=f"GRID_{ntr}_{nrot}", ycb_obj_id=13)
else:
print("Skipping visualization...")

# ##################################
# # Test time
# ##################################
import time

print("Testing jitted runtime...")
start = time.time()
timed_grid = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=half_nx,
half_ny=half_ny,
half_nz=half_nz,
half_dangle=half_d_euler,
half_n_xrot=half_n_alpha,
half_n_yrot=half_n_beta,
half_n_zrot=half_n_gamma,
)
end = time.time()
print(f"Time taken: {(end-start)*1000} milliseconds for {ntr*nrot} poses")

# ### Visualize ###
# viz_grid = pose_grid_jit(
# pose0,
# half_dx=half_dx,
# half_dy=half_dy,
# half_dz=half_dz,
# half_nx=half_nx,
# half_ny=half_ny,
# half_nz=half_nz,
# half_dangle=half_d_euler,
# half_n_xrot=half_n_alpha,
# half_n_yrot=half_n_beta,
# half_n_zrot=half_n_gamma,
# )
# b3d.rr_init("pose_grid_test")
# rr_log_pose_arrows_grid(viz_grid)
# b3d.rr_log_pose(pose0, channel="original_pose")
147 changes: 147 additions & 0 deletions tests/gen3d/test_posegrid_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import jax
import jax.numpy as jnp
from b3d import Pose
from b3d.pose.grid import pose_grid, viz_from_grid


def test_pose_gridding():
#################################
# Setup test
#################################
VIZ_TEST = False # toggle to visualize all grid on rerun

# center pose
xyz0 = jnp.array([0.0, 0.0, 0.0])
rot0 = jnp.array([0.0, 0.0, 0.0, 1.0])
pose0 = Pose(xyz0, rot0)

# translation grid
half_nx = half_ny = half_nz = 2
ntr = (2 * half_nx + 1) * (2 * half_ny + 1) * (2 * half_nz + 1)
print(f"Generating {ntr} translations")
half_dx, half_dy, half_dz = 1, 1, 0.001

# rotation grid
half_n_alpha = half_n_beta = half_n_gamma = 2
half_d_euler = jnp.pi / 3
nrot = (2 * half_n_alpha + 1) * (2 * half_n_beta + 1) * (2 * half_n_gamma + 1)
print(f"Generating {nrot} rotations")

##################################
# Generate pose grid
##################################

pose_grid_jit = jax.jit(
pose_grid,
static_argnames=(
"half_nx",
"half_ny",
"half_nz",
"half_n_xrot",
"half_n_yrot",
"half_n_zrot",
),
)
grid = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=half_nx,
half_ny=half_ny,
half_nz=half_nz,
half_dangle=half_d_euler,
half_n_xrot=half_n_alpha,
half_n_yrot=half_n_beta,
half_n_zrot=half_n_gamma,
)

##################################
# Test correctness
##################################
# 1a. sanity check sizes
assert isinstance(
grid, Pose
), f"Wrong return type; expected b3d.Pose, got {type(grid)}"
assert grid.pos.shape == (
ntr * nrot,
3,
), f"Wrong shape for pos; expected {(ntr * nrot, 3)}, got {grid.pos.shape}"
assert grid.quaternion.shape == (
ntr * nrot,
4,
), f"Wrong shape for quat; expected {(ntr * nrot, 4)}, got {grid.quaternion.shape}"

# 1b. check that original pose is in grid
assert (
grid.pos.tolist().index(pose0.pos.tolist())
== grid.quat.tolist().index(pose0.quat.tolist())
!= -1
), "Center pose not in grid"
print("Size checks and center-pose checks passed")

# 2. visualize grid
if VIZ_TEST:
print(f"Visualizing {ntr*nrot} poses in rerun...")
viz_from_grid(grid, rerun_session_name=f"GRID_{ntr}_{nrot}", ycb_obj_id=13)
else:
print("Skipping visualization...")

# 3. Test that a 1-pose grid is just the starting point
pose_from_grid = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=0,
half_ny=0,
half_nz=0,
half_dangle=half_d_euler,
half_n_xrot=0,
half_n_yrot=0,
half_n_zrot=0,
)[0]
assert jnp.allclose(pose_from_grid.position, pose0.position) and jnp.allclose(
pose_from_grid.quaternion, pose0.quaternion
), "Single-pose grid not equal to original pose"

# ##################################
# # Test time
# ##################################
import time

print("Testing jitted runtime...")
start = time.time()
_ = pose_grid_jit(
pose0,
half_dx=half_dx,
half_dy=half_dy,
half_dz=half_dz,
half_nx=half_nx,
half_ny=half_ny,
half_nz=half_nz,
half_dangle=half_d_euler,
half_n_xrot=half_n_alpha,
half_n_yrot=half_n_beta,
half_n_zrot=half_n_gamma,
)
end = time.time()
print(f"Time taken: {(end-start)*1000} milliseconds for {ntr*nrot} poses")

# ### Visualize ###
# viz_grid = pose_grid_jit(
# pose0,
# half_dx=half_dx,
# half_dy=half_dy,
# half_dz=half_dz,
# half_nx=half_nx,
# half_ny=half_ny,
# half_nz=half_nz,
# half_dangle=half_d_euler,
# half_n_xrot=half_n_alpha,
# half_n_yrot=half_n_beta,
# half_n_zrot=half_n_gamma,
# )
# b3d.rr_init("pose_grid_test")
# rr_log_pose_arrows_grid(viz_grid)
# b3d.rr_log_pose(pose0, channel="original_pose")

0 comments on commit e76938c

Please sign in to comment.