diff --git a/b3d/utils.py b/b3d/utils.py index 4bc2c950..9200fa6b 100644 --- a/b3d/utils.py +++ b/b3d/utils.py @@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values): enumerate_choices_get_scores, static_argnums=(2,) ) +def unproject_depth(depth, renderer): + """Unprojects a depth image into a point cloud. + + Args: + depth (jnp.ndarray): The depth image. Shape (H, W) + intrinsics (b.camera.Intrinsics): The camera intrinsics. + Returns: + jnp.ndarray: The point cloud. Shape (H, W, 3) + """ + mask = (depth < renderer.far) * (depth > renderer.near) + depth = depth * mask + renderer.far * (1.0 - mask) + y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]] + x = (x - renderer.cx) / renderer.fx + y = (y - renderer.cy) / renderer.fy + point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None] + return point_cloud_image def nn_background_segmentation(images): import torch diff --git a/test/test_render_ycb_model.py b/test/test_render_ycb_model.py index 9859b35d..4c4a50a1 100644 --- a/test/test_render_ycb_model.py +++ b/test/test_render_ycb_model.py @@ -2,7 +2,11 @@ import jax.numpy as jnp import trimesh import b3d +import rerun as rr +PORT = 8812 +rr.init("real") +rr.connect(addr=f"127.0.0.1:{PORT}") def test_renderer_full(renderer): mesh_path = os.path.join( @@ -15,7 +19,7 @@ def test_renderer_full(renderer): object_library.add_trimesh(mesh) pose = b3d.Pose.from_position_and_target( - jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0]) + jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0]) ).inv() rgb, depth = renderer.render_attribute( @@ -39,10 +43,10 @@ def test_renderer_normal_full(renderer): object_library.add_trimesh(mesh) pose = b3d.Pose.from_position_and_target( - jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0]) + jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0]) ).inv() - _, _, normal = renderer.render_attribute_normal( + rgb, depth, normal = renderer.render_attribute_normal( pose[None, ...], object_library.vertices, object_library.faces, @@ -50,6 +54,10 @@ def test_renderer_normal_full(renderer): object_library.attributes, ) - normal = jnp.abs(normal) - b3d.get_rgb_pil_image(normal).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png") - assert normal.sum() > 0 + b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png") + + point_im = b3d.utils.unproject_depth(depth, renderer) + rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3))) + rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100)) + + assert jnp.abs(normal).sum() > 0