Skip to content

Commit

Permalink
add normal rerun viz
Browse files Browse the repository at this point in the history
  • Loading branch information
esli999 committed May 31, 2024
1 parent 09ad80a commit 1dad458
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
16 changes: 16 additions & 0 deletions b3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values):
enumerate_choices_get_scores, static_argnums=(2,)
)

def unproject_depth(depth, renderer):
"""Unprojects a depth image into a point cloud.
Args:
depth (jnp.ndarray): The depth image. Shape (H, W)
intrinsics (b.camera.Intrinsics): The camera intrinsics.
Returns:
jnp.ndarray: The point cloud. Shape (H, W, 3)
"""
mask = (depth < renderer.far) * (depth > renderer.near)
depth = depth * mask + renderer.far * (1.0 - mask)
y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]]
x = (x - renderer.cx) / renderer.fx
y = (y - renderer.cy) / renderer.fy
point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None]
return point_cloud_image

def nn_background_segmentation(images):
import torch
Expand Down
20 changes: 14 additions & 6 deletions test/test_render_ycb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import jax.numpy as jnp
import trimesh
import b3d
import rerun as rr

PORT = 8812
rr.init("real")
rr.connect(addr=f"127.0.0.1:{PORT}")

def test_renderer_full(renderer):
mesh_path = os.path.join(
Expand All @@ -15,7 +19,7 @@ def test_renderer_full(renderer):
object_library.add_trimesh(mesh)

pose = b3d.Pose.from_position_and_target(
jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0])
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
).inv()

rgb, depth = renderer.render_attribute(
Expand All @@ -39,17 +43,21 @@ def test_renderer_normal_full(renderer):
object_library.add_trimesh(mesh)

pose = b3d.Pose.from_position_and_target(
jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0])
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
).inv()

_, _, normal = renderer.render_attribute_normal(
rgb, depth, normal = renderer.render_attribute_normal(
pose[None, ...],
object_library.vertices,
object_library.faces,
jnp.array([[0, len(object_library.faces)]]),
object_library.attributes,
)

normal = jnp.abs(normal)
b3d.get_rgb_pil_image(normal).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png")
assert normal.sum() > 0
b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png")

point_im = b3d.utils.unproject_depth(depth, renderer)
rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3)))
rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100))

assert jnp.abs(normal).sum() > 0

0 comments on commit 1dad458

Please sign in to comment.