Skip to content

Commit

Permalink
add faster indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
cospectrum committed Apr 27, 2024
1 parent 67c866b commit c9f1549
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions src/filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,6 @@ pub fn bilateral_filter(
(-0.5 * x.powi(2) / sigma_squared).exp()
}

/// Effectively a meshgrid command with flattened outputs.
fn window_coords(window_size: u32) -> (Vec<i32>, Vec<i32>) {
let window_start = (-(window_size as f32) / 2.0).floor() as i32;
let window_end = (window_size as f32 / 2.0).floor() as i32 + 1;
let window_range = window_start..window_end;

let cc: Vec<i32> = window_range
.clone()
.cycle()
.take(window_range.len().pow(2))
.collect();
let n = window_size as usize + 1;
let mut rr = Vec::with_capacity(n * window_range.len());
for i in window_range {
rr.extend(vec![i; n]);
}
(rr, cc)
}

/// Create look-up table of Gaussian weights for color dimension.
fn compute_color_lut(bins: u32, sigma: f32, max_value: f32) -> Vec<f32> {
let step_size = max_value / bins as f32;
Expand All @@ -90,14 +71,21 @@ pub fn bilateral_filter(

/// Create look-up table of weights corresponding to flattened 2-D Gaussian kernel.
fn compute_spatial_lut(window_size: u32, sigma: f32) -> Vec<f32> {
let (rr, cc) = window_coords(window_size);
let window_start = (-(window_size as f32) / 2.0).floor() as i32;
let window_end = (window_size as f32 / 2.0).floor() as i32 + 1;
let window_range = window_start..window_end;

let cc = window_range.clone().cycle().take(window_range.len().pow(2));
let n = window_size as usize + 1;
let rr = window_range.flat_map(|i| std::iter::repeat(i).take(n));

let sigma_squared = sigma.powi(2);
let it = rr.into_iter().zip(cc);
it.map(|(r, c)| {
let dist = ((r as f32).powi(2) + (c as f32).powi(2)).sqrt();
gaussian_weight(dist, sigma_squared)
})
.collect()
rr.zip(cc)
.map(|(r, c)| {
let dist = ((r as f32).powi(2) + (c as f32).powi(2)).sqrt();
gaussian_weight(dist, sigma_squared)
})
.collect()
}

let max_value = *image.iter().max().unwrap() as f32;
Expand All @@ -110,19 +98,27 @@ pub fn bilateral_filter(
let window_extent = (window_size - 1) / 2;

let (width, height) = image.dimensions();
ImageBuffer::from_fn(width, height, |col, row| {
Image::from_fn(width, height, |col, row| {
let mut total_val = 0f32;
let mut total_weight = 0f32;
let window_center_val = image.get_pixel(col, row)[0] as i32;
debug_assert!(col < width);
debug_assert!(row < height);
// Safety: Image::from_fn yeilds col in [0, width) and row in [0, height).
let window_center_val = unsafe { image.unsafe_get_pixel(col, row)[0] } as i32;

for window_row in -window_extent..window_extent + 1 {
let window_row_abs = (row as i32 + window_row).clamp(0, height as i32 - 1); // Wrap to edge.
let window_row_abs =
(row as i32 + window_row).clamp(0, height.saturating_sub(1) as i32) as u32;
let kr = window_row + window_extent;
for window_col in -window_extent..window_extent + 1 {
let window_col_abs = (col as i32 + window_col).clamp(0, width as i32 - 1); // Wrap to edge.
let window_col_abs =
(col as i32 + window_col).clamp(0, width.saturating_sub(1) as i32) as u32;
debug_assert!(window_col_abs < width);
debug_assert!(window_row_abs < height);
// Safety: we clamped window_row_abs and window_col_abs to be in bounds.
let val = unsafe { image.unsafe_get_pixel(window_col_abs, window_row_abs)[0] };
let kc = window_col + window_extent;
let range_bin = (kr * window_size + kc) as usize;
let val = image.get_pixel(window_col_abs as u32, window_row_abs as u32)[0];
let color_dist = (window_center_val - val as i32).abs() as f32;
let color_bin = ((color_dist * color_dist_scale) as usize).min(max_color_bin);
let weight = range_lut[range_bin] * color_lut[color_bin];
Expand Down

0 comments on commit c9f1549

Please sign in to comment.