From f1ed0822867a702105cdd2bc91ebed569f008e75 Mon Sep 17 00:00:00 2001 From: georgematheos Date: Tue, 10 Sep 2024 17:39:59 -0400 Subject: [PATCH] Gm/gen3d/inference3 (#156) --- notebooks/bayes3d_paper/tester.ipynb | 1442 ++++++++++++++++- src/b3d/chisight/gen3d/inference.py | 18 +- src/b3d/chisight/gen3d/inference_moves.py | 222 ++- src/b3d/chisight/gen3d/projection.py | 1 + .../test_depth_nonreturn_prob_inference.py | 32 + 5 files changed, 1619 insertions(+), 96 deletions(-) create mode 100644 tests/gen3d/inference/test_depth_nonreturn_prob_inference.py diff --git a/notebooks/bayes3d_paper/tester.ipynb b/notebooks/bayes3d_paper/tester.ipynb index 6c582ebe..0c9560c0 100644 --- a/notebooks/bayes3d_paper/tester.ipynb +++ b/notebooks/bayes3d_paper/tester.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 30, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -16,7 +16,17 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import genjax\n", + "genjax.pretty()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -25,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -46,7 +56,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 49/49 [00:03<00:00, 14.41it/s]\n" + "100%|██████████| 49/49 [00:03<00:00, 12.99it/s]\n", + "/home/georgematheos/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n", + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n", + " warnings.warn(\n" ] }, { @@ -57,7 +70,7 @@ "" ] }, - "execution_count": 32, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -92,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -120,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -151,7 +164,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -162,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -196,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -210,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -229,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -250,16 +263,43 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 36, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
" + ], "text/plain": [ - "Array(82541.78, dtype=float32)" + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" ] }, - "execution_count": 41, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -272,16 +312,43 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], "text/plain": [ - "Array(82541.78, dtype=float32)" + "" ] }, - "execution_count": 42, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -292,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -304,31 +371,59 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 31, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], "text/plain": [ - "Array(35023.812, dtype=float32)" + "" ] }, - "execution_count": 87, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "inference_hyperparams = InferenceHyperparams(\n", - " n_poses=10000,\n", + "inference_hyperparams = i.InferenceHyperparams(\n", + " n_poses=6000,\n", " pose_proposal_std=0.04,\n", " pose_proposal_conc=1000.,\n", + " color_proposal_params=None\n", ")\n", "\n", "stepped_trace, step_weight, metadata = i.inference_step(\n", " jax.random.PRNGKey(21),\n", " trace,\n", - " all_data[1][\"rgbd\"],\n", + " all_data[0][\"rgbd\"],\n", " inference_hyperparams\n", ")\n", "step_weight" @@ -336,30 +431,1315 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ - "T = 1\n", + "T = 0\n", "b3d.chisight.gen3d.model.viz_trace(stepped_trace, T, ground_truth_vertices=meshes[OBJECT_INDEX].vertices, ground_truth_pose=all_data[T][\"camera_pose\"].inv() @ all_data[T][\"object_poses\"][OBJECT_INDEX])" ] }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "[2024-09-10T18:07:03Z WARN re_log_types::path::parse_path] When parsing the entity path \"proposed positions\": Unescaped whitespace. The path will be interpreted as /proposed\\ positions\n" + "100%|██████████| 49/49 [00:59<00:00, 1.21s/it]\n" ] } ], "source": [ "import rerun as rr\n", - "rr.log(\"proposed positions\", rr.Points3D(metadata[\"proposed_poses\"].position))" + "\n", + "### Run inference ###\n", + "for T in tqdm(range(len(all_data))):\n", + " key = b3d.split_key(key)\n", + " trace, wt, _ = i.inference_step(\n", + " key,\n", + " trace,\n", + " all_data[T][\"rgbd\"],\n", + " inference_hyperparams\n", + " )\n", + " b3d.chisight.gen3d.model.viz_trace(trace, T, ground_truth_vertices=meshes[OBJECT_INDEX].vertices, ground_truth_pose=all_data[T][\"camera_pose\"].inv() @ all_data[T][\"object_poses\"][OBJECT_INDEX])\n", + " rr.log(\"importance_weight\", rr.Scalar(wt))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "dict_keys(['chosen_pose_index', 'log_q_nonpose_latents', 'log_q_poses', 'other_latents_metadata', 'p_scores', 'proposed_poses'])" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.nn.softmax(jnp.array([-jnp.inf, -jnp.inf]))" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata[\"other_latents_metadata\"][\"depth_nonreturn_proposal\"][\"log_normalized_scores\"]\n", + "metadata[\"other_latents_metadata\"][\"depth_nonreturn_proposal\"][\"likelihood_score\"][4]" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "i = jnp.argmax(metadata[\"p_scores\"])\n", + "metadata[\"p_scores\"][i-3:i+3]" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# get the 10 largest values in p_scores\n", + "jnp.sort(metadata[\"p_scores\"])[-10:]" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 80, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.max(metadata[\"log_q_poses\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.min(jnp.nan_to_num(metadata[\"log_q_nonpose_latents\"], -jnp.inf))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "dict_keys(['chosen_pose_index', 'other_latents_metadata', 'proposed_poses'])" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "{'depth_nonreturn_proposal': {'index': ,\n", + " 'latent_depth': ,\n", + " 'likelihood_score': \n", + " >,\n", + " 'log_normalized_scores': \n", + " >,\n", + " 'observed_depth': ,\n", + " 'prev_dnrp': ,\n", + " 'support': \n", + " >,\n", + " 'transition_score': \n", + " >},\n", + " 'dnrps': }" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "md = jax.tree.map(lambda x: x[metadata['chosen_pose_index']], metadata['other_latents_metadata'])\n", + "jax.tree.map(lambda x: x[0], md)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def sample(key):\n", + " return jax.random.categorical(key, jnp.array([-1.01509 , -0.44999695]))\n", + "\n", + "key = jax.random.PRNGKey(0)\n", + "jax.vmap(sample)(jax.random.split(key, 100))" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.exp(jnp.array([-1.01509 , -0.44999695]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "b3d.reload(b3d.chisight.gen3d.projection)\n", + "from b3d.chisight.gen3d.projection import PixelsPointsAssociation\n", + "import b3d.chisight.gen3d.model as m\n", + "\n", + "obs_point_depths = PixelsPointsAssociation.from_hyperparams_and_pose(\n", + " m.get_hypers(trace), m.get_new_state(trace)[\"pose\"]\n", + ").get_point_depths(m.get_observed_rgbd(trace))\n", + "\n", + "true_point_depths = template_pose.apply(hyperparams[\"vertices\"])[:, 2]\n", + "\n", + "jnp.all(jnp.abs(obs_point_depths - true_point_depths) < 1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.any(jnp.abs(obs_point_depths - true_point_depths[0]) < 1e-6)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_point_colors = m.get_prev_state(trace)[\"colors\"]\n", + "obs_point_colors = PixelsPointsAssociation.from_hyperparams_and_pose(\n", + " m.get_hypers(trace), m.get_new_state(trace)[\"pose\"]\n", + ").get_point_rgbds(m.get_observed_rgbd(trace))[..., :3]\n", + "\n", + "true_point_colors - obs_point_colors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "dict_keys(['chosen_pose_index', 'other_latents_metadata', 'proposed_poses'])" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jnp.all(stepped_trace.get_retval()[\"new_state\"][\"depth_nonreturn_prob\"] == metadata[\"other_latents_metadata\"][\"dnrps\"][metadata[\"chosen_pose_index\"]])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "# jax.Array float32(100,) ≈0.68 ±0.46 [≥0.01, ≤0.99] nonzero:100\n", + " Array([0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.01, 0.01, 0.99, 0.99, 0.01,\n", + " 0.99, 0.99, 0.99, 0.99, 0.01, 0.99, 0.99, 0.99, 0.01, 0.99, 0.99,\n", + " 0.01, 0.01, 0.99, 0.01, 0.99, 0.99, 0.01, 0.01, 0.99, 0.99, 0.01,\n", + " 0.99, 0.99, 0.99, 0.99, 0.01, 0.99, 0.01, 0.01, 0.01, 0.99, 0.99,\n", + " 0.99, 0.99, 0.99, 0.01, 0.01, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99,\n", + " 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.01, 0.99, 0.01, 0.99, 0.01,\n", + " 0.99, 0.99, 0.01, 0.01, 0.99, 0.99, 0.01, 0.99, 0.99, 0.01, 0.99,\n", + " 0.99, 0.01, 0.99, 0.01, 0.99, 0.99, 0.99, 0.99, 0.99, 0.01, 0.99,\n", + " 0.99, 0.99, 0.99, 0.01, 0.99, 0.01, 0.01, 0.99, 0.01, 0.99, 0.99,\n", + " 0.01], dtype=float32)\n" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stepped_trace.get_retval()[\"new_state\"][\"depth_nonreturn_prob\"][:100]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "T = 0\n", + "b3d.chisight.gen3d.model.viz_trace(stepped_trace, T, ground_truth_vertices=meshes[OBJECT_INDEX].vertices, ground_truth_pose=all_data[T][\"camera_pose\"].inv() @ all_data[T][\"object_poses\"][OBJECT_INDEX])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "{'depth_nonreturn_proposal': {'index': \n", + " >,\n", + " 'log_normalized_scores': \n", + " >,\n", + " 'support': \n", + " >}}" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.tree.map(\n", + " lambda x: x[closest_pose_idx], metadata[\"other_latents_metadata\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [], + "source": [ + "gt_pose = all_data[T][\"camera_pose\"].inv() @ all_data[T][\"object_poses\"][OBJECT_INDEX]" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata[\"proposed_poses\"].position" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "closest_pose_idx" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "closest_pose_idx = jnp.argmin(\n", + " jnp.linalg.norm(\n", + " metadata[\"proposed_poses\"].position - gt_pose.position, axis=-1\n", + " )\n", + ")\n", + "metadata[\"proposed_poses\"].quaternion[closest_pose_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "\n", + ">" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gt_pose.quaternion" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "T = 1\n", + "b3d.chisight.gen3d.model.viz_trace(stepped_trace, T, ground_truth_vertices=meshes[OBJECT_INDEX].vertices, ground_truth_pose=all_data[T][\"camera_pose\"].inv() @ all_data[T][\"object_poses\"][OBJECT_INDEX])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[2024-09-10T18:07:03Z WARN re_log_types::path::parse_path] When parsing the entity path \"proposed positions\": Unescaped whitespace. The path will be interpreted as /proposed\\ positions\n" + ] + } + ], + "source": [ + "import rerun as rr\n", + "rr.log(\"proposed positions\", rr.Points3D(metadata[\"proposed_poses\"].position))" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.scipy.special.logsumexp(jnp.array([-jnp.inf, -.2, -1.]))" ] }, { diff --git a/src/b3d/chisight/gen3d/inference.py b/src/b3d/chisight/gen3d/inference.py index 7db9ddbe..611911fe 100644 --- a/src/b3d/chisight/gen3d/inference.py +++ b/src/b3d/chisight/gen3d/inference.py @@ -20,7 +20,8 @@ # Use namedtuple rather than dict so we can hash this, and use it as a static arg to a jitted function. InferenceHyperparams = namedtuple( - "InferenceHyperparams", ["n_poses", "pose_proposal_std", "pose_proposal_conc"] + "InferenceHyperparams", + ["n_poses", "pose_proposal_std", "pose_proposal_conc", "color_proposal_params"], ) @@ -66,7 +67,7 @@ def inference_step(key, old_trace, observed_rgbd, inference_hyperparams): ) param_generation_keys = split(k3, inference_hyperparams.n_poses) - proposed_traces, log_q_nonpose_latents = jax.vmap( + proposed_traces, log_q_nonpose_latents, other_latents_metadata = jax.vmap( propose_other_latents_given_pose, in_axes=(0, None, 0, None) )(param_generation_keys, trace, proposed_poses, inference_hyperparams) p_scores = jax.vmap(lambda tr: tr.get_score())(proposed_traces) @@ -75,7 +76,18 @@ def inference_step(key, old_trace, observed_rgbd, inference_hyperparams): chosen_index = jax.random.categorical(k4, scores) new_trace = jax.tree.map(lambda x: x[chosen_index], proposed_traces) - return new_trace, logmeanexp(scores), {"proposed_poses": proposed_poses} + return ( + new_trace, + logmeanexp(scores), + { + "proposed_poses": proposed_poses, + "chosen_pose_index": chosen_index, + "p_scores": p_scores, + "log_q_poses": log_q_poses, + "log_q_nonpose_latents": log_q_nonpose_latents, + "other_latents_metadata": other_latents_metadata, + }, + ) def inference_step_noweight(*args): diff --git a/src/b3d/chisight/gen3d/inference_moves.py b/src/b3d/chisight/gen3d/inference_moves.py index 539f3dd4..f8842ed3 100644 --- a/src/b3d/chisight/gen3d/inference_moves.py +++ b/src/b3d/chisight/gen3d/inference_moves.py @@ -19,6 +19,18 @@ from .projection import PixelsPointsAssociation +def normalize_log_scores(scores): + """ + Util for constructing log resampling distributions, avoiding NaN issues. + + (Conversely, since there will be no NaNs, this could make it harder to debug.) + """ + val = scores - jax.scipy.special.logsumexp(scores) + return jnp.where( + jnp.any(jnp.isnan(val)), -jnp.log(len(val)) * jnp.ones_like(val), val + ) + + def propose_pose(key, advanced_trace, inference_hyperparams): """ Propose a random pose near the previous timestep's pose. @@ -43,35 +55,41 @@ def propose_other_latents_given_pose(key, advanced_trace, pose, inference_hyperp proposed latents (and the same pose and observed rgbd as in the given trace). `log_q` is (a fair estimate of) the log proposal density. """ - k1, k2, k3, k4, k5, k6 = split(key, 6) + k1, k2, k3, k4, k5 = split(key, 5) - trace_with_pose = update_field(k1, advanced_trace, "pose", pose) + trace = update_field(k1, advanced_trace, "pose", pose) - depth_nonreturn_probs, log_q_dnrps = propose_depth_nonreturn_probs( - k2, trace_with_pose + k2a, k2b = split(k2) + depth_nonreturn_probs, log_q_dnrps, dnrp_metadata = propose_depth_nonreturn_probs( + k2a, trace + ) + trace = update_vmapped_field( + k2b, trace, "depth_nonreturn_prob", depth_nonreturn_probs ) + + k3a, k3b = split(k3) colors, visibility_probs, log_q_cvp = propose_colors_and_visibility_probs( - k3, trace_with_pose + k3a, trace, inference_hyperparams + ) + trace = update_vmapped_fields( + k3b, trace, ["colors", "visibility_prob"], [colors, visibility_probs] ) log_q_cvp = 0.0 - depth_scale, log_q_ds = propose_depth_scale(k4, trace_with_pose) - color_scale, log_q_cs = propose_color_scale(k5, trace_with_pose) - proposed_trace = update_fields( - k6, - trace_with_pose, - [ - "depth_nonreturn_prob", - "colors", - "visibility_prob", - "depth_scale", - "color_scale", - ], - [depth_nonreturn_probs, colors, visibility_probs, depth_scale, color_scale], - ) - log_q = log_q_dnrps + log_q_cvp + log_q_ds + log_q_cs + k4a, k4b = split(k4) + depth_scale, log_q_ds = propose_depth_scale(k4a, trace) + trace = update_field(k4b, trace, "depth_scale", depth_scale) + + k5a, k5b = split(k5) + color_scale, log_q_cs = propose_color_scale(k5a, trace) + trace = update_field(k5b, trace, "color_scale", color_scale) - return proposed_trace, log_q + log_q = log_q_dnrps + log_q_cvp + log_q_ds + log_q_cs + return ( + trace, + log_q, + {"depth_nonreturn_proposal": dnrp_metadata, "dnrps": depth_nonreturn_probs}, + ) def propose_depth_nonreturn_probs(key, trace): @@ -86,7 +104,7 @@ def propose_depth_nonreturn_probs(key, trace): get_hypers(trace), get_new_state(trace)["pose"] ).get_point_depths(get_observed_rgbd(trace)) - depth_nonreturn_probs, per_vertex_log_qs = jax.vmap( + depth_nonreturn_probs, per_vertex_log_qs, metadata = jax.vmap( propose_vertex_depth_nonreturn_prob, in_axes=(0, 0, 0, None, None, None) )( split(key, get_n_vertices(trace)), @@ -97,10 +115,10 @@ def propose_depth_nonreturn_probs(key, trace): get_hypers(trace), ) - return depth_nonreturn_probs, per_vertex_log_qs.sum() + return depth_nonreturn_probs, per_vertex_log_qs.sum(), metadata -def propose_colors_and_visibility_probs(key, trace): +def propose_colors_and_visibility_probs(key, trace, inference_hyperparams): """ Propose a new color and visibility probability for every vertex, conditioned upon the other values in `trace`. @@ -114,7 +132,8 @@ def propose_colors_and_visibility_probs(key, trace): ).get_point_rgbds(get_observed_rgbd(trace)) colors, visibility_probs, per_vertex_log_qs = jax.vmap( - propose_vertex_color_and_visibility_prob, in_axes=(0, 0, 0, None, None, None) + propose_vertex_color_and_visibility_prob, + in_axes=(0, 0, 0, None, None, None, None), )( split(key, get_n_vertices(trace)), jnp.arange(get_n_vertices(trace)), @@ -122,6 +141,7 @@ def propose_colors_and_visibility_probs(key, trace): get_prev_state(trace), get_new_state(trace), get_hypers(trace), + inference_hyperparams, ) return colors, visibility_probs, per_vertex_log_qs.sum() @@ -136,35 +156,78 @@ def propose_vertex_depth_nonreturn_prob( Returns (depth_nonreturn_prob, log_q) where `depth_nonreturn_prob` is the proposed value and `log_q` is (a fair estimate of) the log proposal density. """ - - # TODO: could factor into a sub-function that just receives the values - # we pull out of the previous and new state here, if that facilitates - # unit testing. - previous_dnrp = previous_state["depth_nonreturn_prob"][vertex_index] visibility_prob = new_state["visibility_prob"][vertex_index] latent_depth = new_state["pose"].apply(hyperparams["vertices"][vertex_index])[2] - depth_scale = new_state["depth_scale"] - obs_depth_kernel = hyperparams["image_kernel"].get_depth_vertex_kernel() + return _propose_vertex_depth_nonreturn_prob( + key, + observed_depth, + latent_depth, + visibility_prob, + new_state["depth_scale"], + previous_dnrp, + hyperparams["depth_nonreturn_prob_kernel"], + hyperparams["image_kernel"].get_depth_vertex_kernel(), + ) + +def _propose_vertex_depth_nonreturn_prob( + key, + observed_depth, + latent_depth, + visibility_prob, + depth_scale, + previous_dnrp, + dnrp_transition_kernel, + obs_depth_kernel, + return_metadata=True, +): def score_dnrp_value(dnrp_value): - transition_score = hyperparams["depth_nonreturn_prob_kernel"].logpdf( - dnrp_value, previous_dnrp - ) + transition_score = dnrp_transition_kernel.logpdf(dnrp_value, previous_dnrp) likelihood_score = obs_depth_kernel.logpdf( - observed_depth, latent_depth, visibility_prob, dnrp_value, depth_scale + observed_depth=observed_depth, + latent_depth=latent_depth, + depth_scale=depth_scale, + visibility_prob=visibility_prob, + depth_nonreturn_prob=dnrp_value, ) return transition_score + likelihood_score - support = hyperparams["depth_nonreturn_prob_kernel"].support + support = dnrp_transition_kernel.support log_pscores = jax.vmap(score_dnrp_value)(support) - log_normalized_scores = log_pscores - jax.scipy.special.logsumexp(log_pscores) + log_normalized_scores = normalize_log_scores(log_pscores) index = jax.random.categorical(key, log_normalized_scores) # ^ since we are enumerating over every value in the domain, it is unnecessary # to add a 1/q score when resampling. (Equivalently, we could include # q = 1/len(support), which does not change the resampling distribuiton at all.) - return support[index], log_normalized_scores[index] + if return_metadata: + metadata = { + "support": support, + "log_normalized_scores": log_normalized_scores, + "index": index, + "observed_depth": observed_depth, + "latent_depth": latent_depth, + "prev_dnrp": previous_dnrp, + "transition_score": jax.vmap( + lambda dnrp_value: dnrp_transition_kernel.logpdf( + dnrp_value, previous_dnrp + ) + )(support), + "likelihood_score": jax.vmap( + lambda dnrp_value: obs_depth_kernel.logpdf( + observed_depth, + latent_depth, + visibility_prob, + dnrp_value, + depth_scale, + ) + )(support), + } + else: + metadata = {} + + return support[index], log_normalized_scores[index], metadata def propose_vertex_color_and_visibility_prob( @@ -174,6 +237,7 @@ def propose_vertex_color_and_visibility_prob( previous_state, new_state, hyperparams, + inference_hyperparams, ): """ Propose a new color and visibility probability for the single vertex @@ -199,12 +263,12 @@ def score_visprob_rgb(visprob, rgb): hyperparams["image_kernel"] .get_rgbd_vertex_kernel() .logpdf( - observed_rgbd_for_this_vertex, - jnp.append(rgb, latent_depth), - new_state["color_scale"], - new_state["depth_scale"], - visprob, - new_state["depth_nonreturn_prob"][vertex_index], + observed_rgbd=observed_rgbd_for_this_vertex, + latent_rgbd=jnp.append(rgb, latent_depth), + rgb_scale=new_state["color_scale"], + depth_scale=new_state["depth_scale"], + visibility_prob=visprob, + depth_nonreturn_prob=new_state["depth_nonreturn_prob"][vertex_index], ) ) return rgb_transition_score + visprob_transition_score + likelihood_score @@ -231,7 +295,7 @@ def score_visprob_rgb(visprob, rgb): # we are enumerating over every value in the domain. (Equivalently, # we could subtract a log q score of log(1/len(support)) for each value.) log_weights = log_pscores - log_qs_rgb - log_normalized_scores = log_weights - jax.scipy.special.logsumexp(log_weights) + log_normalized_scores = normalize_log_scores(log_weights) index = jax.random.categorical(k2, log_normalized_scores) rgb = rgbs[index] @@ -288,21 +352,21 @@ def propose_vertex_color_given_visibility( (k1, k2, k3) = split(key, 3) ## Proposal 1: near the previous value. - min_rgbs1 = previous_rgb - diffs / 10 - 2 * d - max_rgbs1 = previous_rgb + diffs / 10 + 2 * d + min_rgbs1 = jnp.maximum(0.0, previous_rgb - diffs / 10 - 2 * d) + max_rgbs1 = jnp.minimum(1.0, previous_rgb + diffs / 10 + 2 * d) proposed_rgb_1 = uniform.sample(k1, min_rgbs1, max_rgbs1) log_q_rgb_1 = uniform.logpdf(proposed_rgb_1, min_rgbs1, max_rgbs1) ## Proposal 2: near the observed value. - min_rgbs2 = observed_rgb - diffs / 10 - 2 * d - max_rgbs2 = observed_rgb + diffs / 10 + 2 * d + min_rgbs2 = jnp.maximum(0.0, observed_rgb - diffs / 10 - 2 * d) + max_rgbs2 = jnp.minimum(1.0, observed_rgb + diffs / 10 + 2 * d) proposed_rgb_2 = uniform.sample(k2, min_rgbs2, max_rgbs2) log_q_rgb_2 = uniform.logpdf(proposed_rgb_2, min_rgbs2, max_rgbs2) ## Proposal 3: somewhere in the middle mean_rgb = (previous_rgb + observed_rgb) / 2 - min_rgbs3 = mean_rgb - 8 / 10 * diffs - 2 * d - max_rgbs3 = mean_rgb + 8 / 10 * diffs + 2 * d + min_rgbs3 = jnp.maximum(0.0, mean_rgb - 8 / 10 * diffs - 2 * d) + max_rgbs3 = jnp.minimum(1.0, mean_rgb + 8 / 10 * diffs + 2 * d) proposed_rgb_3 = uniform.sample(k3, min_rgbs3, max_rgbs3) log_q_rgb_3 = uniform.logpdf(proposed_rgb_3, min_rgbs3, max_rgbs3) @@ -315,13 +379,16 @@ def propose_vertex_color_given_visibility( jax.vmap(lambda rgb: score_visprob_and_rgb(visprob, rgb))(proposed_rgbs) - log_qs ) - normalized_scores = scores - jax.scipy.special.logsumexp(scores) + normalized_scores = normalize_log_scores(scores) sampled_index = jax.random.categorical(key, normalized_scores) sampled_rgb = proposed_rgbs[sampled_index] log_K_score = log_qs.sum() + normalized_scores[sampled_index] - ## "L proposal": given the sampled rgb, estimate the probability that - # it came from the one of the 3 proposals that actually was used. + ## "L proposal": given the sampled rgb, the L proposal proposes + # an index for which of the 3 proposals may have produced this sample RGB, + # and also proposes the other two RGB values. + # Here, we need to compute the logpdf of this L proposal having produced + # the values we sampled out of the K proposal. log_qs_for_this_rgb = jnp.array( [ uniform.logpdf(sampled_rgb, min_rgbs1, max_rgbs1), @@ -329,18 +396,19 @@ def propose_vertex_color_given_visibility( uniform.logpdf(sampled_rgb, min_rgbs3, max_rgbs3), ] ) - normalized_L_logprobs = log_qs_for_this_rgb - jax.scipy.special.logsumexp( - log_qs_for_this_rgb - ) + normalized_L_logprobs = normalize_log_scores(log_qs_for_this_rgb) # L score for proposing the index - log_L_score = normalized_L_logprobs[sampled_index] + log_L_score_for_index = normalized_L_logprobs[sampled_index] # Also add in the L score for proposing the other two RGB values. # The L proposal over these values will just generate them from their prior. - log_L_score += jnp.sum(log_qs) - log_qs[sampled_index] + log_L_score_for_unused_values = jnp.sum(log_qs) - log_qs[sampled_index] + + # full L score + log_L_score = log_L_score_for_index + log_L_score_for_unused_values - ## Compute the overall score. + ## Compute the overall estimate of the marginal density of proposing `sampled_rgb`. overall_score = log_K_score - log_L_score ## Return @@ -385,8 +453,38 @@ def update_fields(key, trace, fieldnames, values): trace, _, _, _ = trace.update( key, U.g( - (Diff.no_change(hyperparams), Diff.unknown_change(previous_state)), + (Diff.no_change(hyperparams), Diff.no_change(previous_state)), C.kw(**dict(zip(fieldnames, values))), ), ) return trace + + +def update_vmapped_fields(key, trace, fieldnames, values): + """ + For each `fieldname` in fieldnames, and each array `arr` in the + corresponding slot in `values`, updates `trace` at addresses + (0, fieldname) through (len(arr) - 1, fieldname) to the corresponding + values in `arr`. + (That is, this assumes for each fieldname, there is a vmap combinator + sampled at that address in the trace.) + """ + c = C.n() + for addr, val in zip(fieldnames, values): + c = c ^ jax.vmap(lambda idx: C[addr, idx].set(val[idx]))( + jnp.arange(val.shape[0]) + ) + + hyperparams, previous_state = trace.get_args() + trace, _, _, _ = trace.update( + key, + U.g((Diff.no_change(hyperparams), Diff.no_change(previous_state)), c), + ) + return trace + + +def update_vmapped_field(key, trace, fieldname, value): + """ + For information, see `update_vmapped_fields`. + """ + return update_vmapped_fields(key, trace, [fieldname], [value]) diff --git a/src/b3d/chisight/gen3d/projection.py b/src/b3d/chisight/gen3d/projection.py index 09e00947..e2b011c9 100644 --- a/src/b3d/chisight/gen3d/projection.py +++ b/src/b3d/chisight/gen3d/projection.py @@ -56,6 +56,7 @@ def from_points_and_intrinsics( intrinsics["cx"], intrinsics["cy"], ) + - 0.5 ) # handle NaN before converting to int (otherwise NaN will be converted # to 0) diff --git a/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py b/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py new file mode 100644 index 00000000..4af3e1db --- /dev/null +++ b/tests/gen3d/inference/test_depth_nonreturn_prob_inference.py @@ -0,0 +1,32 @@ +import b3d.chisight.gen3d.inference_moves as im +import b3d.chisight.gen3d.transition_kernels as transition_kernels +import jax +import jax.numpy as jnp +import jax.random as r +from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( + FullPixelDepthDistribution, +) + +near, far = 0.001, 1.0 + +dnrp_transition_kernel = transition_kernels.DiscreteFlipKernel( + resample_probability=0.05, support=jnp.array([0.01, 0.99]) +) + + +def propose_val(k): + return im._propose_vertex_depth_nonreturn_prob( + k, + observed_depth=0.8, + latent_depth=1.0, + visibility_prob=1.0, + depth_scale=0.00001, + previous_dnrp=0.01, + dnrp_transition_kernel=dnrp_transition_kernel, + obs_depth_kernel=FullPixelDepthDistribution(near, far), + ) + + +values, log_qs, _ = jax.vmap(propose_val)(r.split(r.PRNGKey(0), 1000)) +n_01 = jnp.sum((values == 0.01).astype(jnp.int32)) +assert n_01 >= 950