Skip to content

Commit

Permalink
Fix a potential memory corruption for tiny sized splats
Browse files Browse the repository at this point in the history
Due to floating point accuracy, some splats would be small enough that project_visible decided it had N tiles but map gaussians to intersects did not agree. This then leads to sebsequent terrible behaviour as not all intersections have data written.
  • Loading branch information
ArthurBrussee committed Nov 26, 2024
1 parent be21a9e commit 250a4cd
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 48 deletions.
10 changes: 6 additions & 4 deletions crates/brush-dataset/src/splat_import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl PropertyAccess for GaussianData {
fn set_property(&mut self, key: &str, property: Property) {
let ascii = key.as_bytes();

let value = if let Property::Float(value) = property {
let mut value = if let Property::Float(value) = property {
value
} else if let Property::UChar(value) = property {
(value as f32) / (u8::MAX as f32)
Expand All @@ -51,6 +51,11 @@ impl PropertyAccess for GaussianData {
return;
};

if value.is_nan() || value.is_infinite() || value.is_subnormal() {
log::warn!("Invalid numbers in your friggin splat!!");
value = 0.0;
}

match ascii {
b"x" => self.means[0] = value,
b"y" => self.means[1] = value,
Expand Down Expand Up @@ -322,9 +327,6 @@ pub fn load_splat_from_ply<T: AsyncRead + Unpin + 'static, B: Backend>(
.await;
} else if element.name.starts_with("meta_delta_min_") {
let splat = decode_splat(&mut reader, &gaussian_parser, &header, element).await?;

log::info!("Splat means:::: {:?}", splat.means);

meta_min.mean = splat.means;
meta_min.rotation = splat.rotation.into();
meta_min.scale = splat.log_scale;
Expand Down
38 changes: 16 additions & 22 deletions crates/brush-render/src/burn_glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ use burn::{
use burn_fusion::{client::FusionClient, stream::Operation, Fusion};
use burn_jit::fusion::{FusionJitRuntime, JitFusionHandle};
use burn_wgpu::WgpuRuntime;
use glam::uvec2;

use crate::{
camera::Camera,
render::{render_backward, render_forward, sh_coeffs_for_degree, sh_degree_from_coeffs},
render::{
calc_tile_bounds, max_intersections, render_backward, render_forward, sh_coeffs_for_degree,
sh_degree_from_coeffs,
},
shaders, AutodiffBackend, Backend, GaussianBackwardState, InnerWgpu, RenderAux, SplatGrads,
};

Expand Down Expand Up @@ -280,16 +282,8 @@ impl Backend for Fusion<InnerWgpu> {

let proj_size = size_of::<shaders::helpers::ProjectedSplat>() / 4;
let uniforms_size = size_of::<shaders::helpers::RenderUniforms>() / 4;

// Divide screen into tiles.
let tile_bounds = uvec2(
img_size.x.div_ceil(shaders::helpers::TILE_WIDTH),
img_size.y.div_ceil(shaders::helpers::TILE_WIDTH),
);

let max_intersects = num_points
.saturating_mul(tile_bounds[0] as usize * tile_bounds[1] as usize)
.min(128 * 65535);
let tile_bounds = calc_tile_bounds(img_size);
let max_intersects = max_intersections(img_size, num_points as u32);

// If render_u32_buffer is true, we render a packed buffer of u32 values, otherwise
// render RGBA f32 values.
Expand All @@ -302,19 +296,19 @@ impl Backend for Fusion<InnerWgpu> {

let aux = RenderAux::<Self> {
projected_splats: client.tensor_uninitialized(vec![num_points, proj_size], DType::F32),
uniforms_buffer: client
.tensor_uninitialized(vec![num_points, uniforms_size], DType::I32),
num_intersections: client.tensor_uninitialized(vec![1], DType::I32),
num_visible: client.tensor_uninitialized(vec![1], DType::I32),
uniforms_buffer: client.tensor_uninitialized(vec![uniforms_size], DType::U32),
num_intersections: client.tensor_uninitialized(vec![1], DType::U32),
num_visible: client.tensor_uninitialized(vec![1], DType::U32),
final_index: client
.tensor_uninitialized(vec![img_size.y as usize, img_size.x as usize], DType::I32),
cum_tiles_hit: client.tensor_uninitialized(vec![num_points], DType::I32),
.tensor_uninitialized(vec![img_size.y as usize, img_size.x as usize], DType::U32),
cum_tiles_hit: client.tensor_uninitialized(vec![num_points], DType::U32),
tile_bins: client.tensor_uninitialized(
vec![img_size.y as usize, img_size.x as usize, 2],
DType::I32,
vec![tile_bounds.y as usize, tile_bounds.x as usize, 2],
DType::U32,
),
compact_gid_from_isect: client.tensor_uninitialized(vec![max_intersects], DType::I32),
global_from_compact_gid: client.tensor_uninitialized(vec![num_points], DType::I32),
compact_gid_from_isect: client
.tensor_uninitialized(vec![max_intersects as usize], DType::U32),
global_from_compact_gid: client.tensor_uninitialized(vec![num_points], DType::U32),
};

let desc = CustomOpDescription::new(
Expand Down
41 changes: 28 additions & 13 deletions crates/brush-render/src/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ pub fn rgb_to_sh(rgb: f32) -> f32 {
(rgb - 0.5) / shaders::gather_grads::SH_C0
}

pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 {
uvec2(
img_size.x.div_ceil(shaders::helpers::TILE_WIDTH),
img_size.y.div_ceil(shaders::helpers::TILE_WIDTH),
)
}

pub(crate) fn max_intersections(img_size: glam::UVec2, num_splats: u32) -> u32 {
// Divide screen into tiles.
let tile_bounds = calc_tile_bounds(img_size);
let num_tiles = tile_bounds[0] as u32 * tile_bounds[1] as u32;

// On wasm, we cannot do a sync readback at all.
// Instead, can just estimate a max number of intersects. All the kernels only handle the actual
// cound of intersects, and spin up empty threads for the rest atm. In the future, could use indirect
// dispatch to avoid this.
// Estimating the max number of intersects can be a bad hack though... The worst case sceneario is so massive
// that it's easy to run out of memory... How do we actually properly deal with this :/
let max = num_splats.saturating_mul(num_tiles);

// clamp to max nr. of dispatches.
max.min(256 * 65535)
}

pub(crate) fn render_forward(
camera: &Camera,
img_size: glam::UVec2,
Expand Down Expand Up @@ -198,23 +222,14 @@ pub(crate) fn render_forward(
InnerWgpu::int_slice(cum_tiles_hit.clone(), &[num_points - 1..num_points]);

let num_tiles = tile_bounds[0] * tile_bounds[1];
let max_intersects = max_intersections(img_size, num_points as u32);

// Each intersection maps to a gaussian.
let (tile_bins, compact_gid_from_isect) = {
// On wasm, we cannot do a sync readback at all.
// Instead, can just estimate a max number of intersects. All the kernels only handle the actual
// cound of intersects, and spin up empty threads for the rest atm. In the future, could use indirect
// dispatch to avoid this.
// Estimating the max number of intersects can be a bad hack though... The worst case sceneario is so massive
// that it's easy to run out of memory... How do we actually properly deal with this :/
let max_intersects = num_points
.saturating_mul(num_tiles as usize)
.min(256 * 65535);

let tile_id_from_isect =
create_tensor::<1, _>([max_intersects], device, client, DType::U32);
create_tensor::<1, _>([max_intersects as usize], device, client, DType::U32);
let compact_gid_from_isect =
create_tensor::<1, _>([max_intersects], device, client, DType::U32);
create_tensor::<1, _>([max_intersects as usize], device, client, DType::U32);

tracing::trace_span!("MapGaussiansToIntersect", sync_burn = true).in_scope(|| unsafe {
client.execute_unchecked(
Expand Down Expand Up @@ -306,7 +321,7 @@ pub(crate) fn render_forward(
[img_size.y as usize, img_size.x as usize],
device,
client,
DType::U32,
DType::I32,
);

if !raster_u32 {
Expand Down
12 changes: 6 additions & 6 deletions crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {

let projected = projected_splats[compact_gid];
// get the tile bbox for gaussian
let xy = vec2f(projected.xy_x, projected.xy_y);
let mean2d = vec2f(projected.xy_x, projected.xy_y);

let conic = vec3f(projected.conic_x, projected.conic_y, projected.conic_z);
let invDet = 1.0 / (conic.x * conic.z - conic.y * conic.y);
let cov2d = vec3f(conic.z * invDet, -conic.y * invDet, conic.x * invDet);
let radius = helpers::radius_from_cov(helpers::inverse_symmetric(conic), 1.0);

let opac = projected.color_a;

let radius = helpers::radius_from_cov(cov2d, opac);
let tile_bounds = uniforms.tile_bounds;

let tile_minmax = helpers::get_tile_bbox(xy, radius, tile_bounds);
let tile_minmax = helpers::get_tile_bbox(mean2d, radius, tile_bounds);
let tile_min = tile_minmax.xy;
let tile_max = tile_minmax.zw;

Expand All @@ -40,8 +39,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {

for (var ty = tile_min.y; ty < tile_max.y; ty++) {
for (var tx = tile_min.x; tx < tile_max.x; tx++) {
if helpers::can_be_visible(vec2u(tx, ty), xy, conic, opac) && isect_id < arrayLength(&tile_id_from_isect) {
if helpers::can_be_visible(vec2u(tx, ty), mean2d, conic, opac) {
let tile_id = tx + ty * tile_bounds.x; // tile within image

tile_id_from_isect[isect_id] = tile_id;
compact_gid_from_isect[isect_id] = compact_gid;
isect_id++; // handles gaussians that hit more than one tile
Expand Down
2 changes: 1 addition & 1 deletion crates/brush-render/src/shaders/project_forward.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {
let mean2d = uniforms.focal * mean_c.xy * (1.0 / mean_c.z) + uniforms.pixel_center;

// TODO: Include opacity here or is this ok?
let radius = helpers::radius_from_cov(cov2d, 1.0);
let radius = helpers::radius_from_cov(helpers::inverse_symmetric(conic), 1.0);

if (radius <= 0) {
return;
Expand Down
4 changes: 2 additions & 2 deletions crates/brush-render/src/shaders/project_visible.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
var color = sh_coeffs_to_color(sh_degree, viewdir, sh) + vec3f(0.5);
// color = max(color, vec3f(0.0));

let radius = helpers::radius_from_cov(cov2d, opac);
let radius = helpers::radius_from_cov(helpers::inverse_symmetric(conic), 1.0);

let tile_minmax = helpers::get_tile_bbox(mean2d, radius, uniforms.tile_bounds);
let tile_min = tile_minmax.xy;
Expand All @@ -257,5 +257,5 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
conic,
vec4f(color, opac)
);
num_tiles_hit[compact_gid] = u32(tile_area);
num_tiles_hit[compact_gid] = tile_area;
}

0 comments on commit 250a4cd

Please sign in to comment.