Skip to content

Commit

Permalink
Training improvements (#56)
Browse files Browse the repository at this point in the history
* Implement adam surgery

* More tweaks to match gsplat

* Fix names of settings

* cargo update

* Typo

* More tweaks to resemble gsplat

* Add radii screen splitting, fix quaternion * vec multiply

* Version bump
  • Loading branch information
ArthurBrussee authored Dec 5, 2024
1 parent e2147e7 commit 1b978b9
Show file tree
Hide file tree
Showing 18 changed files with 851 additions and 1,135 deletions.
1,245 changes: 456 additions & 789 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ serde_json = { version = "1.0.128", default-features = false }

rand = "0.8.5"
anyhow = "1.0.81"
tracing = "0.1.40"
tracing = "0.1.30"
tracing-tracy = "0.11.0"
tracing-subscriber = "0.3.18"

Expand Down Expand Up @@ -78,6 +78,7 @@ burn-wgpu = { git = "https://github.com/tracel-ai/burn", features = [
burn-fusion = { git = "https://github.com/tracel-ai/burn" }

egui = { git = "https://github.com/emilk/egui/", rev = "5bfff316c9818b3c140d02bb6cdc488556d46ab7" }

eframe = { git = "https://github.com/emilk/egui/", rev = "5bfff316c9818b3c140d02bb6cdc488556d46ab7", default-features = false, features = [
"wgpu",
"android-game-activity",
Expand All @@ -89,7 +90,7 @@ egui_extras = { git = "https://github.com/emilk/egui/", rev = "5bfff316c9818b3c1
"all_loaders",
] }

rerun = { version = "0.19.1", default-features = false, features = [
rerun = { version = "0.20.0", default-features = false, features = [
'sdk',
'glam',
'image',
Expand All @@ -107,6 +108,7 @@ web-sys = { version = "0.3.72", features = [
wasm-logger = "0.2.0"
zip = { version = "2.1.3", default-features = false, features = ["deflate"] }
urlencoding = "2.1"
hashbrown = "0.15"

[patch."https://github.com/tracel-ai/burn"]
# Uncomment this to use local burn.
Expand All @@ -124,6 +126,9 @@ urlencoding = "2.1"
wgpu = { git = "https://github.com/ArthurBrussee/wgpu", branch = "flt-atom" }
naga = { git = "https://github.com/ArthurBrussee/wgpu", branch = "flt-atom" }

emath = { git = "https://github.com/emilk/egui/", rev = "5bfff316c9818b3c140d02bb6cdc488556d46ab7" }
ecolor = { git = "https://github.com/emilk/egui/", rev = "5bfff316c9818b3c140d02bb6cdc488556d46ab7" }

[profile.dev]
opt-level = 0
debug = true
Expand Down
4 changes: 2 additions & 2 deletions crates/brush-dataset/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct LoadInitArgs {

impl Default for LoadInitArgs {
fn default() -> Self {
Self { sh_degree: 2 }
Self { sh_degree: 3 }
}
}

Expand Down Expand Up @@ -90,7 +90,7 @@ pub(crate) fn stream_fut_parallel<T: Send + 'static>(
.get()
};

log::info!("Loading steam with {parallel} threads");
log::info!("Loading stream with {parallel} threads");

let mut futures = futures;
fn_stream(|emitter| async move {
Expand Down
11 changes: 10 additions & 1 deletion crates/brush-dataset/src/scene_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@ impl<B: Backend> SceneLoader<B> {
// The bounded size == number of batches to prefetch.
let (tx, rx) = mpsc::channel(5);
let device = device.clone();
let scene_extent = scene.bounds(0.0, 0.0).extent.max_element() as f64;

let center = scene.bounds().center;
let dists = scene
.views
.iter()
.map(|v| (v.camera.position - center).length())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less))
.unwrap();

let scene_extent = dists * 1.1; // Idk why exactly, but gsplat multiplies this by 1.1

let mut rng = rand::rngs::StdRng::seed_from_u64(seed);

Expand Down
21 changes: 10 additions & 11 deletions crates/brush-render/src/burn_glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use burn::{
Autodiff,
},
tensor::{
backend::AutodiffBackend,
repr::{CustomOpDescription, HandleContainer, OperationDescription},
BasicAutodiffOps, DType, Float, Tensor, TensorPrimitive,
DType, Tensor, TensorPrimitive,
},
};
use burn_fusion::{client::FusionClient, stream::Operation, Fusion};
Expand All @@ -22,7 +23,7 @@ use crate::{
calc_tile_bounds, max_intersections, render_backward, render_forward, sh_coeffs_for_degree,
sh_degree_from_coeffs,
},
shaders, AutodiffBackend, Backend, GaussianBackwardState, InnerWgpu, RenderAux, SplatGrads,
shaders, Backend, GaussianBackwardState, InnerWgpu, RenderAux, SplatGrads,
};

// Implement forward functions for the inner wgpu backend.
Expand Down Expand Up @@ -170,15 +171,10 @@ impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
render_u32_buffer,
);

// Not sure why going into the autodiff float tensor type is so verbose.
let diff_proj = <Float as BasicAutodiffOps<Self>>::from_inner(TensorPrimitive::Float(
aux.projected_splats.clone(),
))
.tensor();

let auxc = aux.clone();
let wrapped_aux = RenderAux::<Self> {
projected_splats: diff_proj,
projected_splats: <Self as AutodiffBackend>::from_inner(aux.projected_splats),
radii: <Self as AutodiffBackend>::from_inner(aux.radii),
num_intersections: aux.num_intersections,
num_visible: aux.num_visible,
final_index: aux.final_index,
Expand Down Expand Up @@ -241,7 +237,7 @@ impl Backend for Fusion<InnerWgpu> {
fn execute(self: Box<Self>, h: &mut HandleContainer<JitFusionHandle<WgpuRuntime>>) {
let (
[means, log_scales, quats, sh_coeffs, raw_opacity],
[projected_splats, uniforms_buffer, num_intersections, num_visible, final_index, cum_tiles_hit, tile_bins, compact_gid_from_isect, global_from_compact_gid, out_img],
[projected_splats, uniforms_buffer, num_intersections, num_visible, final_index, cum_tiles_hit, tile_bins, compact_gid_from_isect, global_from_compact_gid, radii, out_img],
) = self.desc.consume();

let (img, aux) = render_forward(
Expand Down Expand Up @@ -272,6 +268,7 @@ impl Backend for Fusion<InnerWgpu> {
&global_from_compact_gid.id,
aux.global_from_compact_gid,
);
h.register_float_tensor::<InnerWgpu>(&radii.id, aux.radii);
}
}

Expand Down Expand Up @@ -309,6 +306,7 @@ impl Backend for Fusion<InnerWgpu> {
compact_gid_from_isect: client
.tensor_uninitialized(vec![max_intersects as usize], DType::I32),
global_from_compact_gid: client.tensor_uninitialized(vec![num_points], DType::I32),
radii: client.tensor_uninitialized(vec![num_points], DType::F32),
};

let desc = CustomOpDescription::new(
Expand All @@ -330,6 +328,7 @@ impl Backend for Fusion<InnerWgpu> {
aux.tile_bins.to_description_out(),
aux.compact_gid_from_isect.to_description_out(),
aux.global_from_compact_gid.to_description_out(),
aux.radii.to_description_out(),
out_img.to_description_out(),
],
);
Expand Down Expand Up @@ -441,4 +440,4 @@ impl Backend for Fusion<InnerWgpu> {
}
}

impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {}
impl<B: Backend, C: CheckpointStrategy> crate::AutodiffBackend for Autodiff<B, C> {}
11 changes: 6 additions & 5 deletions crates/brush-render/src/gaussian_splats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,21 @@ impl<B: Backend> Splats<B> {
.iter()
.map(|p| {
// Get average of 3 nearest squared distances.
tree.nearest_n::<SquaredEuclidean>(p, 3)
(tree
.nearest_n::<SquaredEuclidean>(p, 4)
.iter()
.map(|x| x.distance)
.sum::<f32>()
/ 4.0)
.sqrt()
/ 3.0
.max(1e-12)
.ln()
})
.collect();

Tensor::<B, 1>::from_floats(extents.as_slice(), device)
.reshape([n_splats, 1])
.repeat_dim(1, 3)
.clamp_min(0.00001)
.log()
};

let sh_coeffs = if let Some(sh_coeffs) = sh_coeffs {
Expand Down Expand Up @@ -150,7 +151,7 @@ impl<B: Backend> Splats<B> {
)
}

pub fn with_min_sh_degree(mut self, sh_degree: u32) -> Self {
pub fn with_sh_degree(mut self, sh_degree: u32) -> Self {
let n_coeffs = sh_coeffs_for_degree(sh_degree) as usize;

let [n, c, _] = self.sh_coeffs.dims();
Expand Down
6 changes: 5 additions & 1 deletion crates/brush-render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub struct RenderAux<B: Backend> {
pub tile_bins: B::IntTensorPrimitive,
pub compact_gid_from_isect: B::IntTensorPrimitive,
pub global_from_compact_gid: B::IntTensorPrimitive,
pub radii: B::FloatTensorPrimitive,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -115,5 +116,8 @@ pub trait Backend: burn::tensor::backend::Backend {
}
}

pub trait AutodiffBackend: Backend + burn::tensor::backend::AutodiffBackend {}
pub trait AutodiffBackend:
Backend + burn::tensor::backend::AutodiffBackend<InnerBackend: Backend>
{
}
type InnerWgpu = JitBackend<WgpuRuntime, f32, i32, u32>;
26 changes: 17 additions & 9 deletions crates/brush-render/src/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ pub(crate) fn render_forward(
create_tensor::<1, WgpuRuntime>([num_points], device, client, DType::U32);
let num_tiles = InnerWgpu::int_zeros([num_points].into(), device);

let radii = InnerWgpu::float_zeros([num_points].into(), device);

let (global_from_compact_gid, num_visible) = {
let global_from_presort_gid = InnerWgpu::int_zeros([num_points].into(), device);
let depths = create_tensor([num_points], device, client, DType::F32);
Expand All @@ -172,7 +174,8 @@ pub(crate) fn render_forward(
raw_opacities.clone().handle.binding(),
global_from_presort_gid.clone().handle.binding(),
depths.clone().handle.binding(),
num_tiles_scatter.clone().handle.binding()
num_tiles_scatter.clone().handle.binding(),
radii.clone().handle.binding(),
],
);
});
Expand Down Expand Up @@ -354,6 +357,7 @@ pub(crate) fn render_forward(
final_index,
compact_gid_from_isect,
global_from_compact_gid,
radii,
},
)
}
Expand Down Expand Up @@ -384,7 +388,7 @@ pub(crate) fn render_backward(

let client = &means.client;

let (v_xys_local, v_xys_global, v_conics, v_coeffs, v_raw_opac) = {
let (v_xys_local, v_conics, v_coeffs, v_raw_opac) = {
let tile_bounds = uvec2(
img_size.x.div_ceil(shaders::helpers::TILE_WIDTH),
img_size.y.div_ceil(shaders::helpers::TILE_WIDTH),
Expand Down Expand Up @@ -430,7 +434,6 @@ pub(crate) fn render_backward(

let num_vis_wg = create_dispatch_buffer(num_visible.clone(), GatherGrads::WORKGROUP_SIZE);

let v_xys_global = InnerWgpu::float_zeros([num_points, 2].into(), device);
unsafe {
client.execute_unchecked(
GatherGrads::task(),
Expand All @@ -441,15 +444,13 @@ pub(crate) fn render_backward(
raw_opac.clone().handle.binding(),
means.clone().handle.binding(),
v_colors.clone().handle.binding(),
v_xys_local.clone().handle.binding(),
v_coeffs.handle.clone().binding(),
v_opacities.handle.clone().binding(),
v_xys_global.handle.clone().binding(),
],
);
}

(v_xys_local, v_xys_global, v_conics, v_coeffs, v_opacities)
(v_xys_local, v_conics, v_coeffs, v_opacities)
};

// Create tensors to hold gradients.
Expand Down Expand Up @@ -485,7 +486,7 @@ pub(crate) fn render_backward(
v_scales,
v_coeffs,
v_raw_opac,
v_xy: v_xys_global,
v_xy: v_xys_local,
}
}

Expand Down Expand Up @@ -592,7 +593,7 @@ mod tests {

let rec = if USE_RERUN {
rerun::RecordingStreamBuilder::new("render test")
.connect()
.connect_tcp()
.ok()
} else {
None
Expand Down Expand Up @@ -688,9 +689,16 @@ mod tests {
.mean()
.backward();

// XY gradients are also in compact format.
let v_xys = splats.xys_dummy.grad(&grads).context("no xys grad")?;
let v_xys = v_xys.slice([0..num_visible]);

let v_xys_ref =
safetensor_to_burn::<DiffBack, 2>(tensors.tensor("v_xy")?, &device).inner();
let v_xys = splats.xys_dummy.grad(&grads).context("no xys grad")?;
let v_xys_ref = v_xys_ref
.select(0, gs_ids.inner().clone())
.slice([0..num_visible]);

assert!(v_xys.all_close(v_xys_ref, Some(1e-5), Some(1e-9)));

let v_opacities_ref =
Expand Down
10 changes: 2 additions & 8 deletions crates/brush-render/src/shaders/gather_grads.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
@group(0) @binding(2) var<storage, read> raw_opacities: array<f32>;
@group(0) @binding(3) var<storage, read> means: array<helpers::PackedVec3>;
@group(0) @binding(4) var<storage, read> v_colors: array<vec4f>;
@group(0) @binding(5) var<storage, read> v_xy_local: array<vec2f>;

@group(0) @binding(6) var<storage, read_write> v_coeffs: array<f32>;
@group(0) @binding(7) var<storage, read_write> v_opacs: array<f32>;
@group(0) @binding(8) var<storage, read_write> v_xy_global: array<vec2f>;
@group(0) @binding(5) var<storage, read_write> v_coeffs: array<f32>;
@group(0) @binding(6) var<storage, read_write> v_opacs: array<f32>;

const SH_C0: f32 = 0.2820947917738781f;

Expand Down Expand Up @@ -223,8 +221,4 @@ fn main(@builtin(global_invocation_id) gid: vec3u) {
let raw_opac = raw_opacities[global_gid];
let v_opac = v_color.w * v_sigmoid(raw_opac);
v_opacs[global_gid] = v_opac;

// Scatter the xy gradients, as later operations need them to be global.
let v_xy_local = v_xy_local[compact_gid];
v_xy_global[global_gid] = v_xy_local;
}
8 changes: 5 additions & 3 deletions crates/brush-render/src/shaders/project_forward.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
@group(0) @binding(6) var<storage, read_write> depths: array<f32>;
@group(0) @binding(7) var<storage, read_write> num_tiles: array<u32>;

@group(0) @binding(8) var<storage, read_write> radii: array<f32>;

@compute
@workgroup_size(256, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3u) {
Expand All @@ -31,7 +33,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {
let R = mat3x3f(viewmat[0].xyz, viewmat[1].xyz, viewmat[2].xyz);
let mean_c = R * mean + viewmat[3].xyz;

if mean_c.z < 0.01 || mean_c.z > 1e12 {
if mean_c.z < 0.01 || mean_c.z > 1e10 {
return;
}

Expand All @@ -52,8 +54,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {

// compute the projected mean
let mean2d = uniforms.focal * mean_c.xy * (1.0 / mean_c.z) + uniforms.pixel_center;


let opac = helpers::sigmoid(raw_opacities[global_gid]);

// NB: It might seem silly to use the inverse of the conic here (as that's the same as cov2d)
Expand Down Expand Up @@ -91,4 +91,6 @@ fn main(@builtin(global_invocation_id) global_id: vec3u) {
depths[write_id] = mean_c.z;
// Write metadata to global array.
num_tiles[global_gid] = tile_area;

radii[global_gid] = radius;
}
1 change: 1 addition & 0 deletions crates/brush-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ glam.workspace = true
rand.workspace = true
tracing.workspace = true
log.workspace = true
hashbrown.workspace = true

burn.workspace = true
8 changes: 7 additions & 1 deletion crates/brush-train/src/scene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ impl Scene {
}

// Returns the extent of the cameras in the scene.
pub fn bounds(&self, cam_near: f32, cam_far: f32) -> BoundingBox {
pub fn bounds(&self) -> BoundingBox {
self.adjusted_bounds(0.0, 0.0)
}

// Returns the extent of the cameras in the scene, taking into account
// the near and far plane of the cameras.
pub fn adjusted_bounds(&self, cam_near: f32, cam_far: f32) -> BoundingBox {
let (min, max) = self.views.iter().fold(
(Vec3::splat(f32::INFINITY), Vec3::splat(f32::NEG_INFINITY)),
|(min, max), view| {
Expand Down
Loading

0 comments on commit 1b978b9

Please sign in to comment.