Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee committed Dec 11, 2024
1 parent 78d2a8a commit 675729d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion crates/brush-render/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ naga.workspace = true
log.workspace = true
serde.workspace = true
kiddo.workspace = true
tokio = { workspace = true, features = ["macros", "rt"] }
tokio = { workspace = true, features = ["macros", "rt", "sync"] }
rand.workspace = true

[build-dependencies]
Expand Down
4 changes: 3 additions & 1 deletion crates/brush-render/src/burn_glue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ impl Backend for Fusion<BBase> {
let client = v_output.client.clone();

let num_points = state.means.shape[0];
let num_visible = state.rx.borrow().data().num_visible;

let coeffs = sh_coeffs_for_degree(state.sh_degree) as usize;

let grads = SplatGrads::<Self> {
Expand All @@ -420,7 +422,7 @@ impl Backend for Fusion<BBase> {
v_scales: client.tensor_uninitialized(vec![num_points, 3], DType::F32),
v_coeffs: client.tensor_uninitialized(vec![num_points, coeffs, 3], DType::F32),
v_raw_opac: client.tensor_uninitialized(vec![num_points], DType::F32),
v_xy: client.tensor_uninitialized(vec![num_points, 2], DType::F32),
v_xy: client.tensor_uninitialized(vec![num_visible as usize, 2], DType::F32),
};

let desc = CustomOpDescription::new(
Expand Down
24 changes: 11 additions & 13 deletions crates/brush-render/src/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,10 +665,10 @@ mod tests {
"img/dif",
&(img_ref.clone() - out.clone()).into_rerun().await,
)?;
// rec.log(
// "images/tile depth",
// &aux.read_tile_depth().into_rerun().await,
// )?;
rec.log(
"images/tile depth",
&aux.read_tile_depth().into_rerun().await,
)?;
}

// Check if images match.
Expand All @@ -679,37 +679,35 @@ mod tests {
Tensor::from_primitive(TensorPrimitive::Float(aux.projected_splats.clone()));

let gs_ids =
Tensor::<DiffBack, 1, Int>::from_primitive(aux.global_from_compact_gid.clone());
Tensor::<DiffBack, 1, Int>::from_primitive(aux.global_from_compact_gid.clone())
.slice([0..num_visible]);

let xys: Tensor<DiffBack, 2, Float> =
projected_splats.clone().slice([0..num_visible, 0..2]);
let xys_ref = safetensor_to_burn::<DiffBack, 2>(tensors.tensor("xys")?, &device);
let xys_ref = xys_ref.select(0, gs_ids.clone()).slice([0..num_visible]);
let xys_ref = xys_ref.select(0, gs_ids.clone());

assert!(xys.all_close(xys_ref, Some(1e-1), Some(1e-6)));

let conics: Tensor<DiffBack, 2, Float> =
projected_splats.clone().slice([0..num_visible, 2..5]);
let conics_ref = safetensor_to_burn::<DiffBack, 2>(tensors.tensor("conics")?, &device);
let conics_ref = conics_ref.select(0, gs_ids.clone()).slice([0..num_visible]);
let conics_ref = conics_ref.select(0, gs_ids.clone());

assert!(conics.all_close(conics_ref, Some(1e-3), Some(1e-6)));

aux.resolve_bwd_data().await;

let grads = (out.clone() - crab_tens.clone())
.powi_scalar(2.0)
.mean()
.backward();

// XY gradients are also in compact format.
let v_xys = splats.xys_dummy.grad(&grads).context("no xys grad")?;
let v_xys = v_xys.slice([0..num_visible]);

let v_xys_ref =
safetensor_to_burn::<DiffBack, 2>(tensors.tensor("v_xy")?, &device).inner();
let v_xys_ref = v_xys_ref
.select(0, gs_ids.inner().clone())
.slice([0..num_visible]);

let v_xys_ref = v_xys_ref.select(0, gs_ids.inner().clone());
assert!(v_xys.all_close(v_xys_ref, Some(1e-5), Some(1e-9)));

let v_opacities_ref =
Expand Down

0 comments on commit 675729d

Please sign in to comment.