Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
georgematheos committed Sep 10, 2024
1 parent a12e191 commit 60f838f
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/b3d/chisight/gen3d/inference_moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,11 @@ def propose_vertex_color_given_visibility(
sampled_rgb = proposed_rgbs[sampled_index]
log_K_score = log_qs.sum() + normalized_scores[sampled_index]

## "L proposal": given the sampled rgb, estimate the probability that
# it came from the one of the 3 proposals that actually was used.
## "L proposal": given the sampled rgb, the L proposal proposes
# an index for which of the 3 proposals may have produced this sample RGB,
# and also proposes the other two RGB values.
# Here, we need to compute the logpdf of this L proposal having produced
# the values we sampled out of the K proposal.
log_qs_for_this_rgb = jnp.array(
[
uniform.logpdf(sampled_rgb, min_rgbs1, max_rgbs1),
Expand All @@ -396,13 +399,16 @@ def propose_vertex_color_given_visibility(
normalized_L_logprobs = normalize_log_scores(log_qs_for_this_rgb)

# L score for proposing the index
log_L_score = normalized_L_logprobs[sampled_index]
log_L_score_for_index = normalized_L_logprobs[sampled_index]

# Also add in the L score for proposing the other two RGB values.
# The L proposal over these values will just generate them from their prior.
log_L_score += jnp.sum(log_qs) - log_qs[sampled_index]
log_L_score_for_unused_values = jnp.sum(log_qs) - log_qs[sampled_index]

## Compute the overall score.
# full L score
log_L_score = log_L_score_for_index + log_L_score_for_unused_values

## Compute the overall estimate of the marginal density of proposing `sampled_rgb`.
overall_score = log_K_score - log_L_score

## Return
Expand Down

0 comments on commit 60f838f

Please sign in to comment.