From 250a4cdcfd9f98d030f3b7d1009c20fdd5735ecf Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Tue, 26 Nov 2024 22:55:03 +0000 Subject: [PATCH] Fix a potential memory corruption for tiny sized splats Due to floating point accuracy, some splats would be small enough that project_visible decided it had N tiles but map gaussians to intersects did not agree. This then leads to sebsequent terrible behaviour as not all intersections have data written. --- crates/brush-dataset/src/splat_import.rs | 10 +++-- crates/brush-render/src/burn_glue.rs | 38 ++++++++--------- crates/brush-render/src/render.rs | 41 +++++++++++++------ .../shaders/map_gaussian_to_intersects.wgsl | 12 +++--- .../src/shaders/project_forward.wgsl | 2 +- .../src/shaders/project_visible.wgsl | 4 +- 6 files changed, 59 insertions(+), 48 deletions(-) diff --git a/crates/brush-dataset/src/splat_import.rs b/crates/brush-dataset/src/splat_import.rs index 9f50f85c..478669c0 100644 --- a/crates/brush-dataset/src/splat_import.rs +++ b/crates/brush-dataset/src/splat_import.rs @@ -41,7 +41,7 @@ impl PropertyAccess for GaussianData { fn set_property(&mut self, key: &str, property: Property) { let ascii = key.as_bytes(); - let value = if let Property::Float(value) = property { + let mut value = if let Property::Float(value) = property { value } else if let Property::UChar(value) = property { (value as f32) / (u8::MAX as f32) @@ -51,6 +51,11 @@ impl PropertyAccess for GaussianData { return; }; + if value.is_nan() || value.is_infinite() || value.is_subnormal() { + log::warn!("Invalid numbers in your friggin splat!!"); + value = 0.0; + } + match ascii { b"x" => self.means[0] = value, b"y" => self.means[1] = value, @@ -322,9 +327,6 @@ pub fn load_splat_from_ply( .await; } else if element.name.starts_with("meta_delta_min_") { let splat = decode_splat(&mut reader, &gaussian_parser, &header, element).await?; - - log::info!("Splat means:::: {:?}", splat.means); - meta_min.mean = splat.means; meta_min.rotation = splat.rotation.into(); meta_min.scale = splat.log_scale; diff --git a/crates/brush-render/src/burn_glue.rs b/crates/brush-render/src/burn_glue.rs index 42a208b1..ee53cde8 100644 --- a/crates/brush-render/src/burn_glue.rs +++ b/crates/brush-render/src/burn_glue.rs @@ -15,11 +15,13 @@ use burn::{ use burn_fusion::{client::FusionClient, stream::Operation, Fusion}; use burn_jit::fusion::{FusionJitRuntime, JitFusionHandle}; use burn_wgpu::WgpuRuntime; -use glam::uvec2; use crate::{ camera::Camera, - render::{render_backward, render_forward, sh_coeffs_for_degree, sh_degree_from_coeffs}, + render::{ + calc_tile_bounds, max_intersections, render_backward, render_forward, sh_coeffs_for_degree, + sh_degree_from_coeffs, + }, shaders, AutodiffBackend, Backend, GaussianBackwardState, InnerWgpu, RenderAux, SplatGrads, }; @@ -280,16 +282,8 @@ impl Backend for Fusion { let proj_size = size_of::() / 4; let uniforms_size = size_of::() / 4; - - // Divide screen into tiles. - let tile_bounds = uvec2( - img_size.x.div_ceil(shaders::helpers::TILE_WIDTH), - img_size.y.div_ceil(shaders::helpers::TILE_WIDTH), - ); - - let max_intersects = num_points - .saturating_mul(tile_bounds[0] as usize * tile_bounds[1] as usize) - .min(128 * 65535); + let tile_bounds = calc_tile_bounds(img_size); + let max_intersects = max_intersections(img_size, num_points as u32); // If render_u32_buffer is true, we render a packed buffer of u32 values, otherwise // render RGBA f32 values. @@ -302,19 +296,19 @@ impl Backend for Fusion { let aux = RenderAux:: { projected_splats: client.tensor_uninitialized(vec![num_points, proj_size], DType::F32), - uniforms_buffer: client - .tensor_uninitialized(vec![num_points, uniforms_size], DType::I32), - num_intersections: client.tensor_uninitialized(vec![1], DType::I32), - num_visible: client.tensor_uninitialized(vec![1], DType::I32), + uniforms_buffer: client.tensor_uninitialized(vec![uniforms_size], DType::U32), + num_intersections: client.tensor_uninitialized(vec![1], DType::U32), + num_visible: client.tensor_uninitialized(vec![1], DType::U32), final_index: client - .tensor_uninitialized(vec![img_size.y as usize, img_size.x as usize], DType::I32), - cum_tiles_hit: client.tensor_uninitialized(vec![num_points], DType::I32), + .tensor_uninitialized(vec![img_size.y as usize, img_size.x as usize], DType::U32), + cum_tiles_hit: client.tensor_uninitialized(vec![num_points], DType::U32), tile_bins: client.tensor_uninitialized( - vec![img_size.y as usize, img_size.x as usize, 2], - DType::I32, + vec![tile_bounds.y as usize, tile_bounds.x as usize, 2], + DType::U32, ), - compact_gid_from_isect: client.tensor_uninitialized(vec![max_intersects], DType::I32), - global_from_compact_gid: client.tensor_uninitialized(vec![num_points], DType::I32), + compact_gid_from_isect: client + .tensor_uninitialized(vec![max_intersects as usize], DType::U32), + global_from_compact_gid: client.tensor_uninitialized(vec![num_points], DType::U32), }; let desc = CustomOpDescription::new( diff --git a/crates/brush-render/src/render.rs b/crates/brush-render/src/render.rs index 8a3d0eb9..0f3a2d84 100644 --- a/crates/brush-render/src/render.rs +++ b/crates/brush-render/src/render.rs @@ -50,6 +50,30 @@ pub fn rgb_to_sh(rgb: f32) -> f32 { (rgb - 0.5) / shaders::gather_grads::SH_C0 } +pub(crate) fn calc_tile_bounds(img_size: glam::UVec2) -> glam::UVec2 { + uvec2( + img_size.x.div_ceil(shaders::helpers::TILE_WIDTH), + img_size.y.div_ceil(shaders::helpers::TILE_WIDTH), + ) +} + +pub(crate) fn max_intersections(img_size: glam::UVec2, num_splats: u32) -> u32 { + // Divide screen into tiles. + let tile_bounds = calc_tile_bounds(img_size); + let num_tiles = tile_bounds[0] as u32 * tile_bounds[1] as u32; + + // On wasm, we cannot do a sync readback at all. + // Instead, can just estimate a max number of intersects. All the kernels only handle the actual + // cound of intersects, and spin up empty threads for the rest atm. In the future, could use indirect + // dispatch to avoid this. + // Estimating the max number of intersects can be a bad hack though... The worst case sceneario is so massive + // that it's easy to run out of memory... How do we actually properly deal with this :/ + let max = num_splats.saturating_mul(num_tiles); + + // clamp to max nr. of dispatches. + max.min(256 * 65535) +} + pub(crate) fn render_forward( camera: &Camera, img_size: glam::UVec2, @@ -198,23 +222,14 @@ pub(crate) fn render_forward( InnerWgpu::int_slice(cum_tiles_hit.clone(), &[num_points - 1..num_points]); let num_tiles = tile_bounds[0] * tile_bounds[1]; + let max_intersects = max_intersections(img_size, num_points as u32); // Each intersection maps to a gaussian. let (tile_bins, compact_gid_from_isect) = { - // On wasm, we cannot do a sync readback at all. - // Instead, can just estimate a max number of intersects. All the kernels only handle the actual - // cound of intersects, and spin up empty threads for the rest atm. In the future, could use indirect - // dispatch to avoid this. - // Estimating the max number of intersects can be a bad hack though... The worst case sceneario is so massive - // that it's easy to run out of memory... How do we actually properly deal with this :/ - let max_intersects = num_points - .saturating_mul(num_tiles as usize) - .min(256 * 65535); - let tile_id_from_isect = - create_tensor::<1, _>([max_intersects], device, client, DType::U32); + create_tensor::<1, _>([max_intersects as usize], device, client, DType::U32); let compact_gid_from_isect = - create_tensor::<1, _>([max_intersects], device, client, DType::U32); + create_tensor::<1, _>([max_intersects as usize], device, client, DType::U32); tracing::trace_span!("MapGaussiansToIntersect", sync_burn = true).in_scope(|| unsafe { client.execute_unchecked( @@ -306,7 +321,7 @@ pub(crate) fn render_forward( [img_size.y as usize, img_size.x as usize], device, client, - DType::U32, + DType::I32, ); if !raster_u32 { diff --git a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl index 39f42be5..868eddd5 100644 --- a/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl +++ b/crates/brush-render/src/shaders/map_gaussian_to_intersects.wgsl @@ -18,17 +18,16 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) { let projected = projected_splats[compact_gid]; // get the tile bbox for gaussian - let xy = vec2f(projected.xy_x, projected.xy_y); + let mean2d = vec2f(projected.xy_x, projected.xy_y); + let conic = vec3f(projected.conic_x, projected.conic_y, projected.conic_z); - let invDet = 1.0 / (conic.x * conic.z - conic.y * conic.y); - let cov2d = vec3f(conic.z * invDet, -conic.y * invDet, conic.x * invDet); + let radius = helpers::radius_from_cov(helpers::inverse_symmetric(conic), 1.0); let opac = projected.color_a; - let radius = helpers::radius_from_cov(cov2d, opac); let tile_bounds = uniforms.tile_bounds; - let tile_minmax = helpers::get_tile_bbox(xy, radius, tile_bounds); + let tile_minmax = helpers::get_tile_bbox(mean2d, radius, tile_bounds); let tile_min = tile_minmax.xy; let tile_max = tile_minmax.zw; @@ -40,8 +39,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) { for (var ty = tile_min.y; ty < tile_max.y; ty++) { for (var tx = tile_min.x; tx < tile_max.x; tx++) { - if helpers::can_be_visible(vec2u(tx, ty), xy, conic, opac) && isect_id < arrayLength(&tile_id_from_isect) { + if helpers::can_be_visible(vec2u(tx, ty), mean2d, conic, opac) { let tile_id = tx + ty * tile_bounds.x; // tile within image + tile_id_from_isect[isect_id] = tile_id; compact_gid_from_isect[isect_id] = compact_gid; isect_id++; // handles gaussians that hit more than one tile diff --git a/crates/brush-render/src/shaders/project_forward.wgsl b/crates/brush-render/src/shaders/project_forward.wgsl index 3d49029f..18e07048 100644 --- a/crates/brush-render/src/shaders/project_forward.wgsl +++ b/crates/brush-render/src/shaders/project_forward.wgsl @@ -52,7 +52,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) { let mean2d = uniforms.focal * mean_c.xy * (1.0 / mean_c.z) + uniforms.pixel_center; // TODO: Include opacity here or is this ok? - let radius = helpers::radius_from_cov(cov2d, 1.0); + let radius = helpers::radius_from_cov(helpers::inverse_symmetric(conic), 1.0); if (radius <= 0) { return; diff --git a/crates/brush-render/src/shaders/project_visible.wgsl b/crates/brush-render/src/shaders/project_visible.wgsl index 83681dee..2dcbdfa8 100644 --- a/crates/brush-render/src/shaders/project_visible.wgsl +++ b/crates/brush-render/src/shaders/project_visible.wgsl @@ -237,7 +237,7 @@ fn main(@builtin(global_invocation_id) gid: vec3u) { var color = sh_coeffs_to_color(sh_degree, viewdir, sh) + vec3f(0.5); // color = max(color, vec3f(0.0)); - let radius = helpers::radius_from_cov(cov2d, opac); + let radius = helpers::radius_from_cov(helpers::inverse_symmetric(conic), 1.0); let tile_minmax = helpers::get_tile_bbox(mean2d, radius, uniforms.tile_bounds); let tile_min = tile_minmax.xy; @@ -257,5 +257,5 @@ fn main(@builtin(global_invocation_id) gid: vec3u) { conic, vec4f(color, opac) ); - num_tiles_hit[compact_gid] = u32(tile_area); + num_tiles_hit[compact_gid] = tile_area; }