Skip to content

Commit

Permalink
Seperate out refining for better tracing.
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee committed Dec 8, 2024
1 parent 97965c8 commit 92d6ec8
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 48 deletions.
38 changes: 21 additions & 17 deletions crates/brush-train/src/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ pub struct TrainStepStats<B: AutodiffBackend> {
pub lr_scale: f64,
pub lr_coeffs: f64,
pub lr_opac: f64,

pub refine: Option<RefineStats>,
}

pub struct SplatTrainer<B: AutodiffBackend> {
Expand Down Expand Up @@ -229,11 +227,13 @@ impl<B: AutodiffBackend> SplatTrainer<B> {
);
}

pub async fn step(
pub fn step(
&mut self,
batch: SceneBatch<B>,
splats: Splats<B>,
) -> Result<(Splats<B>, TrainStepStats<B>), anyhow::Error> {
let _span = trace_span!("Train step").entered();

assert!(
batch.gt_views.len() == 1,
"Bigger batches aren't yet supported"
Expand Down Expand Up @@ -388,19 +388,6 @@ impl<B: AutodiffBackend> SplatTrainer<B> {
splats
});

let mut refine_stats = None;

let do_refine = self.iter < self.config.refine_stop_iter
&& self.iter >= self.config.refine_start_iter
&& self.iter % self.config.refine_every == 1;

if do_refine {
// If not refining, update splat to step with gradients applied.
let (refined_splats, refine) = self.refine_splats(splats, batch.scene_extent).await;
refine_stats = Some(refine);
splats = refined_splats;
}

self.iter += 1;

let stats = TrainStepStats {
Expand All @@ -414,12 +401,29 @@ impl<B: AutodiffBackend> SplatTrainer<B> {
lr_scale,
lr_coeffs,
lr_opac,
refine: refine_stats,
};

Ok((splats, stats))
}

pub async fn refine_if_needed(
&mut self,
splats: Splats<B>,
scene_extent: f32,
) -> (Splats<B>, Option<RefineStats>) {
let do_refine = self.iter < self.config.refine_stop_iter
&& self.iter >= self.config.refine_start_iter
&& self.iter % self.config.refine_every == 1;

if do_refine {
// If not refining, update splat to step with gradients applied.
let (refined_splats, refine) = self.refine_splats(splats, scene_extent).await;
(refined_splats, Some(refine))
} else {
(splats, None)
}
}

async fn refine_splats(
&mut self,
splats: Splats<B>,
Expand Down
51 changes: 31 additions & 20 deletions crates/brush-viewer/src/panels/rerun.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use brush_rerun::BurnToRerun;
use burn_wgpu::WgpuDevice;

use brush_render::{gaussian_splats::Splats, AutodiffBackend, Backend};
use brush_train::{image::tensor_into_image, scene::Scene};
use brush_train::{image::tensor_into_image, scene::Scene, train::RefineStats};
use brush_train::{ssim::Ssim, train::TrainStepStats};
use burn::tensor::{activation::sigmoid, ElementConversion};
use rerun::{Color, FillMode, RecordingStream};
Expand Down Expand Up @@ -310,27 +310,38 @@ impl VisualizeTools {
&rerun::Scalar::new(main_aux.read_num_visible().await as f64),
)?;

if let Some(refine) = stats.refine {
rec.log(
"refine/num_split",
&rerun::Scalar::new(refine.num_split as f64),
)?;
rec.log(
"refine/num_cloned",
&rerun::Scalar::new(refine.num_cloned as f64),
)?;
rec.log(
"refine/num_transparent_pruned",
&rerun::Scalar::new(refine.num_transparent_pruned as f64),
)?;
rec.log(
"refine/num_scale_pruned",
&rerun::Scalar::new(refine.num_scale_pruned as f64),
)?;
}
Ok(())
});
}

pub fn log_refine_stats(self: Arc<Self>, iter: u32, refine: RefineStats) {
let Some(rec) = self.rec.clone() else {
return;
};

if !rec.is_enabled() {
return;
}

rec.set_time_sequence("iterations", iter);

let _ = rec.log(
"refine/num_split",
&rerun::Scalar::new(refine.num_split as f64),
);
let _ = rec.log(
"refine/num_cloned",
&rerun::Scalar::new(refine.num_cloned as f64),
);
let _ = rec.log(
"refine/num_transparent_pruned",
&rerun::Scalar::new(refine.num_transparent_pruned as f64),
);
let _ = rec.log(
"refine/num_scale_pruned",
&rerun::Scalar::new(refine.num_scale_pruned as f64),
);
}
}

pub(crate) struct RerunPanel {
Expand Down Expand Up @@ -406,7 +417,7 @@ impl ViewerPanel for RerunPanel {
// Log out train stats.
// HACK: Always log on a refine step, as they can happen off beat.
// Not sure how to best handle this properly.
if iter % self.log_train_stats_every == 0 || stats.refine.is_some() {
if iter % self.log_train_stats_every == 0 {
visualize.log_train_stats(*iter, *stats.clone());
}
}
Expand Down
23 changes: 14 additions & 9 deletions crates/brush-viewer/src/train_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use tokio::{
sync::mpsc::{error::TryRecvError, Receiver},
};
use tokio_stream::{Stream, StreamExt};
use tracing::{trace_span, Instrument};
use web_time::Instant;

use crate::viewer::ProcessMessage;
Expand Down Expand Up @@ -147,15 +146,12 @@ pub(crate) fn train_loop<T: AsyncRead + Unpin + 'static>(
}
// By default, continue training.
None => {
let batch = dataloader
.next_batch()
.instrument(trace_span!("Get batch"))
.await;
let batch = dataloader.next_batch().await;
let extent = batch.scene_extent;

let (new_splats, stats) = trainer.step(batch, splats)?;
let (new_splats, refine) = trainer.refine_if_needed(new_splats, extent).await;

let (new_splats, stats) = trainer
.step(batch, splats)
.instrument(trace_span!("Train step"))
.await?;
splats = new_splats;

if trainer.iter % UPDATE_EVERY == 0 {
Expand All @@ -168,6 +164,15 @@ pub(crate) fn train_loop<T: AsyncRead + Unpin + 'static>(
})
.await;
}

if let Some(refine) = refine {
emitter
.emit(ProcessMessage::RefineStep {
stats: Box::new(refine),
iter: trainer.iter,
})
.await;
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion crates/brush-viewer/src/viewer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use async_fn_stream::try_fn_stream;
use brush_dataset::{self, splat_import, Dataset, LoadDatasetArgs, LoadInitArgs};
use brush_render::camera::Camera;
use brush_render::gaussian_splats::Splats;
use brush_train::train::TrainStepStats;
use brush_train::train::{RefineStats, TrainStepStats};
use brush_train::{eval::EvalStats, train::TrainConfig};
use burn::backend::Autodiff;
use burn_wgpu::{Wgpu, WgpuDevice};
Expand Down Expand Up @@ -94,6 +94,11 @@ pub(crate) enum ProcessMessage {
iter: u32,
timestamp: Instant,
},
/// Some number of training steps are done.
RefineStep {
stats: Box<RefineStats>,
iter: u32,
},
/// Eval was run successfully with these results.
EvalResult {
iter: u32,
Expand Down
6 changes: 5 additions & 1 deletion crates/train-2d/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ fn spawn_train_loop(
};

loop {
let (new_splats, _) = trainer.step(batch.clone(), splats).await.unwrap();
let (new_splats, _) = trainer.step(batch.clone(), splats).unwrap();
let (new_splats, _) = trainer
.refine_if_needed(new_splats, batch.scene_extent)
.await;

splats = new_splats;

ctx.request_repaint();
Expand Down

0 comments on commit 92d6ec8

Please sign in to comment.