Skip to content

Commit

Permalink
Different method of emitting tile intersections (#63)
Browse files Browse the repository at this point in the history
* Change bench

* Remove aux tiles hit

* More robust intersections counting

* Change to tile prefix sum

* Disable rerun

* Fixes

* Rename, fixes
  • Loading branch information
ArthurBrussee authored Dec 7, 2024
1 parent 1e27349 commit 5e2c3b4
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
@compute
@workgroup_size(helpers::THREADS_PER_GROUP, 1, 1)
fn main(
@builtin(global_invocation_id) id: vec3u,
@builtin(workgroup_id) gid: vec3u
@builtin(global_invocation_id) id: vec3u,
@builtin(workgroup_id) wid: vec3u
) {
if (id.x < arrayLength(&helpers::output)) {
helpers::output[id.x] += helpers::input[gid.x];
helpers::output[id.x] += helpers::input[wid.x];
}
}
6 changes: 3 additions & 3 deletions crates/brush-render/benches/render_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ const DENSE_MULT: f32 = 0.25;
const LOW_RES: glam::UVec2 = glam::uvec2(512, 512);
const HIGH_RES: glam::UVec2 = glam::uvec2(1024, 1024);

const TARGET_SAMPLE_COUNT: u32 = 5;
const INTERNAL_ITERS: u32 = 4;
const TARGET_SAMPLE_COUNT: u32 = 100;
const INTERNAL_ITERS: u32 = 10;

fn generate_bench_data() -> anyhow::Result<()> {
<DiffBack as burn::prelude::Backend>::seed(4);
Expand Down Expand Up @@ -195,7 +195,7 @@ fn bench_general(
}
}

#[divan::bench_group(max_time = 20, sample_count = TARGET_SAMPLE_COUNT, sample_size = 1)]
#[divan::bench_group(max_time = 1000, sample_count = TARGET_SAMPLE_COUNT, sample_size = 1)]
mod fwd {
use super::*;

Expand Down
1 change: 0 additions & 1 deletion crates/brush-render/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ fn main() -> miette::Result<()> {
"src/shaders/project_forward.wgsl",
"src/shaders/project_visible.wgsl",
"src/shaders/map_gaussian_to_intersects.wgsl",
"src/shaders/get_tile_bin_edges.wgsl",
"src/shaders/rasterize.wgsl",
"src/shaders/rasterize_backwards.wgsl",
"src/shaders/gather_grads.wgsl",
Expand Down
24 changes: 10 additions & 14 deletions crates/brush-render/src/burn_glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Backend for InnerWgpu {
state.aux.uniforms_buffer,
state.aux.compact_gid_from_isect,
state.aux.global_from_compact_gid,
state.aux.tile_bins,
state.aux.tile_offsets,
state.aux.final_index,
state.sh_degree,
)
Expand Down Expand Up @@ -178,8 +178,7 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
num_intersections: aux.num_intersections,
num_visible: aux.num_visible,
final_index: aux.final_index,
cum_tiles_hit: aux.cum_tiles_hit,
tile_bins: aux.tile_bins,
tile_offsets: aux.tile_offsets,
compact_gid_from_isect: aux.compact_gid_from_isect,
global_from_compact_gid: aux.global_from_compact_gid,
uniforms_buffer: aux.uniforms_buffer,
Expand Down Expand Up @@ -237,7 +236,7 @@ impl Backend for Fusion<InnerWgpu> {
fn execute(self: Box<Self>, h: &mut HandleContainer<JitFusionHandle<WgpuRuntime>>) {
let (
[means, log_scales, quats, sh_coeffs, raw_opacity],
[projected_splats, uniforms_buffer, num_intersections, num_visible, final_index, cum_tiles_hit, tile_bins, compact_gid_from_isect, global_from_compact_gid, radii, out_img],
[projected_splats, uniforms_buffer, num_intersections, num_visible, final_index, tile_offsets, compact_gid_from_isect, global_from_compact_gid, radii, out_img],
) = self.desc.consume();

let (img, aux) = render_forward(
Expand All @@ -258,8 +257,7 @@ impl Backend for Fusion<InnerWgpu> {
h.register_int_tensor::<InnerWgpu>(&num_intersections.id, aux.num_intersections);
h.register_int_tensor::<InnerWgpu>(&num_visible.id, aux.num_visible);
h.register_int_tensor::<InnerWgpu>(&final_index.id, aux.final_index);
h.register_int_tensor::<InnerWgpu>(&cum_tiles_hit.id, aux.cum_tiles_hit);
h.register_int_tensor::<InnerWgpu>(&tile_bins.id, aux.tile_bins);
h.register_int_tensor::<InnerWgpu>(&tile_offsets.id, aux.tile_offsets);
h.register_int_tensor::<InnerWgpu>(
&compact_gid_from_isect.id,
aux.compact_gid_from_isect,
Expand Down Expand Up @@ -298,9 +296,8 @@ impl Backend for Fusion<InnerWgpu> {
num_visible: client.tensor_uninitialized(vec![1], DType::I32),
final_index: client
.tensor_uninitialized(vec![img_size.y as usize, img_size.x as usize], DType::I32),
cum_tiles_hit: client.tensor_uninitialized(vec![num_points], DType::I32),
tile_bins: client.tensor_uninitialized(
vec![tile_bounds.y as usize, tile_bounds.x as usize, 2],
tile_offsets: client.tensor_uninitialized(
vec![(tile_bounds.y * tile_bounds.x) as usize + 1],
DType::I32,
),
compact_gid_from_isect: client
Expand All @@ -324,8 +321,7 @@ impl Backend for Fusion<InnerWgpu> {
aux.num_intersections.to_description_out(),
aux.num_visible.to_description_out(),
aux.final_index.to_description_out(),
aux.cum_tiles_hit.to_description_out(),
aux.tile_bins.to_description_out(),
aux.tile_offsets.to_description_out(),
aux.compact_gid_from_isect.to_description_out(),
aux.global_from_compact_gid.to_description_out(),
aux.radii.to_description_out(),
Expand Down Expand Up @@ -357,7 +353,7 @@ impl Backend for Fusion<InnerWgpu> {
impl Operation<FusionJitRuntime<WgpuRuntime, u32>> for CustomOp {
fn execute(self: Box<Self>, h: &mut HandleContainer<JitFusionHandle<WgpuRuntime>>) {
let (
[v_output, means, log_scales, quats, raw_opac, out_img, projected_splats, num_visible, uniforms_buffer, compact_gid_from_isect, global_from_compact_gid, tile_bins, final_index],
[v_output, means, log_scales, quats, raw_opac, out_img, projected_splats, num_visible, uniforms_buffer, compact_gid_from_isect, global_from_compact_gid, tile_offsets, final_index],
[v_means, v_quats, v_scales, v_coeffs, v_raw_opac, v_xy],
) = self.desc.consume();

Expand All @@ -373,7 +369,7 @@ impl Backend for Fusion<InnerWgpu> {
h.get_int_tensor::<InnerWgpu>(&uniforms_buffer),
h.get_int_tensor::<InnerWgpu>(&compact_gid_from_isect),
h.get_int_tensor::<InnerWgpu>(&global_from_compact_gid),
h.get_int_tensor::<InnerWgpu>(&tile_bins),
h.get_int_tensor::<InnerWgpu>(&tile_offsets),
h.get_int_tensor::<InnerWgpu>(&final_index),
self.sh_degree,
);
Expand Down Expand Up @@ -417,7 +413,7 @@ impl Backend for Fusion<InnerWgpu> {
state.aux.uniforms_buffer.into_description(),
state.aux.compact_gid_from_isect.into_description(),
state.aux.global_from_compact_gid.into_description(),
state.aux.tile_bins.into_description(),
state.aux.tile_offsets.into_description(),
state.aux.final_index.into_description(),
],
&[
Expand Down
5 changes: 2 additions & 3 deletions crates/brush-render/src/kernels.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use super::shaders::{
get_tile_bin_edges, map_gaussian_to_intersects, project_backwards, project_forward,
project_visible, rasterize, rasterize_backwards,
map_gaussian_to_intersects, project_backwards, project_forward, project_visible, rasterize,
rasterize_backwards,
};
use crate::shaders::gather_grads;
use brush_kernel::kernel_source_gen;

kernel_source_gen!(ProjectSplats {}, project_forward);
kernel_source_gen!(ProjectVisible {}, project_visible);
kernel_source_gen!(MapGaussiansToIntersect {}, map_gaussian_to_intersects);
kernel_source_gen!(GetTileBinEdges {}, get_tile_bin_edges);
kernel_source_gen!(Rasterize { raster_u32 }, rasterize);
kernel_source_gen!(RasterizeBackwards { hard_float }, rasterize_backwards);
kernel_source_gen!(GatherGrads {}, gather_grads);
Expand Down
24 changes: 16 additions & 8 deletions crates/brush-render/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#![allow(clippy::too_many_arguments)]
#![allow(clippy::single_range_in_vec_init)]
use burn::prelude::Tensor;
use burn::tensor::{ElementConversion, Int};
use burn::tensor::{ElementConversion, Int, TensorMetadata};
use burn_jit::JitBackend;
use burn_wgpu::WgpuRuntime;
use camera::Camera;
use shaders::helpers::TILE_WIDTH;

mod burn_glue;
mod dim_check;
Expand All @@ -25,8 +26,7 @@ pub struct RenderAux<B: Backend> {
pub num_intersections: B::IntTensorPrimitive,
pub num_visible: B::IntTensorPrimitive,
pub final_index: B::IntTensorPrimitive,
pub cum_tiles_hit: B::IntTensorPrimitive,
pub tile_bins: B::IntTensorPrimitive,
pub tile_offsets: B::IntTensorPrimitive,
pub compact_gid_from_isect: B::IntTensorPrimitive,
pub global_from_compact_gid: B::IntTensorPrimitive,
pub radii: B::FloatTensorPrimitive,
Expand Down Expand Up @@ -54,11 +54,19 @@ impl<B: Backend> RenderAux<B> {
}

pub fn read_tile_depth(&self) -> Tensor<B, 2, Int> {
let bins = Tensor::from_primitive(self.tile_bins.clone());
let [ty, tx, _] = bins.dims();
let max = bins.clone().slice([0..ty, 0..tx, 1..2]).squeeze(2);
let min = bins.clone().slice([0..ty, 0..tx, 0..1]).squeeze(2);
max - min
let bins = Tensor::<B, 1, Int>::from_primitive(self.tile_offsets.clone());

let n_bins = bins.dims()[0];

let max = bins.clone().slice([1..n_bins]);
let min = bins.slice([0..n_bins - 1]);

let [h, w] = self.final_index.shape().dims();
let [ty, tx] = [
h.div_ceil(TILE_WIDTH as usize),
w.div_ceil(TILE_WIDTH as usize),
];
(max - min).reshape([ty, tx])
}
}

Expand Down
Loading

0 comments on commit 5e2c3b4

Please sign in to comment.