Skip to content

Commit

Permalink
Fix for input images with un-multiplied alpha.
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee committed Jan 9, 2025
1 parent 5e4d35f commit 1c9f93f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
4 changes: 2 additions & 2 deletions crates/brush-dataset/src/scene_loader.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use brush_render::Backend;
use brush_train::image::image_to_tensor;
use brush_train::image::image_to_sample;
use brush_train::scene::Scene;
use brush_train::train::SceneBatch;
use burn::tensor::Tensor;
Expand Down Expand Up @@ -45,7 +45,7 @@ impl<B: Backend> SceneLoader<B> {
.expect("Need at least one view in dataset")
});
let view = scene.views[index].clone();
(image_to_tensor(&view.image, &device), view)
(image_to_sample(&view.image, &device), view)
})
.unzip();

Expand Down
4 changes: 2 additions & 2 deletions crates/brush-train/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn::tensor::{ElementConversion, Tensor};
use image::DynamicImage;
use rand::seq::IteratorRandom;

use crate::image::image_to_tensor;
use crate::image::image_to_sample;
use crate::scene::{Scene, SceneView};
use crate::ssim::Ssim;

Expand Down Expand Up @@ -59,7 +59,7 @@ pub async fn eval_stats<B: Backend>(
let ground_truth: DynamicImage = view.image.clone().to_rgb8().into();
let res = glam::uvec2(ground_truth.width(), ground_truth.height());

let gt_tensor = image_to_tensor::<B>(&ground_truth, device);
let gt_tensor = image_to_sample::<B>(&ground_truth, device);
let (rendered, aux) = splats.render(&view.camera, res, false);

let render_rgb = rendered.slice([0..res.y as usize, 0..res.x as usize, 0..3]);
Expand Down
16 changes: 13 additions & 3 deletions crates/brush-train/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@ use burn::{
};
use image::{DynamicImage, Rgb32FImage, Rgba32FImage};

// Converts an image to a tensor. The tensor will be a floating point image with a [0, 1] image.
pub fn image_to_tensor<B: Backend>(image: &DynamicImage, device: &B::Device) -> Tensor<B, 3> {
// Converts an image to a train sample. The tensor will be a floating point image with a [0, 1] image.
//
// This assume the input image has un-premultiplied alpha, whereas the output has pre-multiplied alpha.
pub fn image_to_sample<B: Backend>(image: &DynamicImage, device: &B::Device) -> Tensor<B, 3> {
let (w, h) = (image.width(), image.height());

let tensor_data = if image.color().has_alpha() {
TensorData::new(image.to_rgba32f().into_vec(), [h as usize, w as usize, 4])
// Assume image has un-multiplied alpha and conver it to pre-mutliplied.
let mut rgba = image.to_rgba32f();
for pixel in rgba.pixels_mut() {
let a = pixel[3];
pixel[0] *= a;
pixel[1] *= a;
pixel[2] *= a;
}
TensorData::new(rgba.into_vec(), [h as usize, w as usize, 4])
} else {
TensorData::new(image.to_rgb32f().into_vec(), [h as usize, w as usize, 3])
};
Expand Down
4 changes: 2 additions & 2 deletions crates/train-2d/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use brush_render::{
gaussian_splats::{RandomSplatsConfig, Splats},
};
use brush_train::{
image::image_to_tensor,
image::image_to_sample,
scene::SceneView,
train::{SceneBatch, SplatTrainer, TrainConfig},
};
Expand Down Expand Up @@ -56,7 +56,7 @@ fn spawn_train_loop(

// One batch of training data, it's the same every step so can just cosntruct it once.
let batch = SceneBatch {
gt_images: image_to_tensor(&view.image, &device).unsqueeze(),
gt_images: image_to_sample(&view.image, &device).unsqueeze(),
gt_views: vec![view],
scene_extent: 1.0,
};
Expand Down

0 comments on commit 1c9f93f

Please sign in to comment.