-
Notifications
You must be signed in to change notification settings - Fork 336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute density compensation for screen space blurring of tiny gaussians #117
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ def project_gaussians( | |
img_width: int, | ||
tile_bounds: Tuple[int, int, int], | ||
clip_thresh: float = 0.01, | ||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, Tensor]: | ||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: | ||
"""This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting. | ||
|
||
Note: | ||
|
@@ -47,12 +47,13 @@ def project_gaussians( | |
clip_thresh (float): minimum z depth threshold. | ||
|
||
Returns: | ||
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}: | ||
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}: | ||
|
||
- **xys** (Tensor): x,y locations of 2D gaussian projections. | ||
- **depths** (Tensor): z depth of gaussians. | ||
- **radii** (Tensor): radii of 2D gaussian projections. | ||
- **conics** (Tensor): conic parameters for 2D gaussian. | ||
- **compensation** (Tensor): the density compensation for blurring 2D kernel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This extra return would break backward compatibility. Personally I'm fine with it as we are in active-developing version There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @vye16 , could you help take a look at this PR and see if you have any other comments other than @liruilong940607 |
||
- **num_tiles_hit** (Tensor): number of tiles hit per gaussian. | ||
- **cov3d** (Tensor): 3D covariances. | ||
""" | ||
|
@@ -105,6 +106,7 @@ def forward( | |
depths, | ||
radii, | ||
conics, | ||
compensation, | ||
num_tiles_hit, | ||
) = _C.project_gaussians_forward( | ||
num_points, | ||
|
@@ -146,10 +148,19 @@ def forward( | |
conics, | ||
) | ||
|
||
return (xys, depths, radii, conics, num_tiles_hit, cov3d) | ||
return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d) | ||
|
||
@staticmethod | ||
def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d): | ||
def backward( | ||
ctx, | ||
v_xys, | ||
v_depths, | ||
v_radii, | ||
v_conics, | ||
v_compensation, | ||
v_num_tiles_hit, | ||
v_cov3d, | ||
): | ||
( | ||
means3d, | ||
scales, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Read/Write the global memory is usually the most time consuming part in a kernel (computation is usually not the burden). I tested this a bit and it slows down the
project_gaussians
from 3000 it/s to 2800 it/s which is not that much so I think is fine. Especially thatproject_gaussians
is much cheaper comparing to therasterization
stage. I'm fine with this tiny little extra burden but just want to point it out for future reference.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code I used to test this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the evaluation effort.