From 8b4587608aff9e1bf4f2a16da7613319b2f08372 Mon Sep 17 00:00:00 2001 From: Stephen Crane Date: Thu, 27 Jun 2024 14:26:24 -0700 Subject: [PATCH] Do not zero `txa` before initializing it --- src/ctx.rs | 41 +++++++++++++++++++++++++++ src/lf_mask.rs | 75 ++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 96 insertions(+), 20 deletions(-) diff --git a/src/ctx.rs b/src/ctx.rs index be22f7e27..49808f38b 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -40,6 +40,7 @@ use crate::src::disjoint_mut::AsMutPtr; use crate::src::disjoint_mut::DisjointMut; use std::iter::zip; +use std::ptr; /// Perform a `memset` optimized for lengths that are small powers of 2. /// @@ -75,6 +76,37 @@ pub fn small_memset( + buf: *mut [T], + val: T, + offset: usize, + len: usize, +) { + assert!(buf.len() >= offset && buf.len() - offset >= len); + // SAFETY: `buf` is correctly aligned for type T and offset is within bounds. + let buf = unsafe { (buf as *mut T).add(offset) }; + match len { + 01 if UP_TO >= 01 => unsafe { ptr::write(buf as *mut [T; 01], [val; 01]) }, + 02 if UP_TO >= 02 => unsafe { ptr::write(buf as *mut [T; 02], [val; 02]) }, + 04 if UP_TO >= 04 => unsafe { ptr::write(buf as *mut [T; 04], [val; 04]) }, + 08 if UP_TO >= 08 => unsafe { ptr::write(buf as *mut [T; 08], [val; 08]) }, + 16 if UP_TO >= 16 => unsafe { ptr::write(buf as *mut [T; 16], [val; 16]) }, + 32 if UP_TO >= 32 => unsafe { ptr::write(buf as *mut [T; 32], [val; 32]) }, + 64 if UP_TO >= 64 => unsafe { ptr::write(buf as *mut [T; 64], [val; 64]) }, + _ => { + if WITH_DEFAULT { + for i in 0..len { + unsafe { buf.add(i).write(val) }; + } + } + } + } +} + pub struct CaseSetter { offset: usize, len: usize, @@ -99,6 +131,15 @@ impl CaseSetter(&mut *buf, val); } + + /// # Safety + /// + /// `buf` must be correctly aligned and dereferenceable (but need not be + /// initialized). + #[inline] + pub unsafe fn set_raw(&self, buf: *mut [T], val: T) { + small_memset_raw::(buf, val, self.offset, self.len); + } } /// The entrypoint to the [`CaseSet`] API. diff --git a/src/lf_mask.rs b/src/lf_mask.rs index 7d8622943..486006bb1 100644 --- a/src/lf_mask.rs +++ b/src/lf_mask.rs @@ -18,6 +18,7 @@ use libc::ptrdiff_t; use parking_lot::RwLock; use std::cmp; use std::ffi::c_int; +use std::mem::MaybeUninit; #[repr(C)] pub struct Av1FilterLUT { @@ -92,8 +93,14 @@ pub struct Av1Restoration { /// but in Rust, dereferencing such a pointer would be an out-of-bounds access, and thus UB. /// Instead of offsetting `txa`, the offsets are calculated from /// the existing `y_off` and `x_off` args and applied at each use site of `txa. +/// +/// Initializes: +/// * `txa[0][0][y][x]` for all `y` and `x` in the range of the current block +/// * `txa[1][0][y][x]` for all `y` and `x` in the range of the current block +/// * `txa[0][1][y][x_off * t_dim.w]` for all `y` in the range of the current block +/// * `txa[1][1][y_off * t_dim.h][x]` for all `x` in the range of the current block fn decomp_tx( - txa: &mut [[[[u8; 32]; 32]; 2]; 2], + txa: &mut [[[[MaybeUninit; 32]; 32]; 2]; 2], from: TxfmSize, depth: usize, y_off: u8, @@ -128,15 +135,16 @@ fn decomp_tx( let lw = cmp::min(2, t_dim.lw); let lh = cmp::min(2, t_dim.lh); + debug_assert!(t_dim.w == 1 << t_dim.lw && t_dim.w <= 16); CaseSet::<16, false>::one((), t_dim.w as usize, x0, |case, ()| { for y in 0..t_dim.h as usize { - case.set(&mut txa[0][0][y0 + y], lw); - case.set(&mut txa[1][0][y0 + y], lh); - txa[0][1][y0 + y][x0] = t_dim.w; + case.set(&mut txa[0][0][y0 + y], MaybeUninit::new(lw)); + case.set(&mut txa[1][0][y0 + y], MaybeUninit::new(lh)); + txa[0][1][y0 + y][x0].write(t_dim.w); } }); CaseSet::<16, false>::one((), t_dim.w as usize, x0, |case, ()| { - case.set(&mut txa[1][1][y0], t_dim.h); + case.set(&mut txa[1][1][y0], MaybeUninit::new(t_dim.h)); }); }; } @@ -157,7 +165,8 @@ fn mask_edges_inter( let t_dim = &dav1d_txfm_dimensions[max_tx as usize]; // See [`decomp_tx`]'s docs for the `txa` arg. - let mut txa = Align16([[[[0; 32]; 32]; 2]; 2]); + + let mut txa = Align16([[[[MaybeUninit::uninit(); 32]; 32]; 2]; 2]); for (y_off, _) in (0..h4).step_by(t_dim.h as usize).enumerate() { for (x_off, _) in (0..w4).step_by(t_dim.w as usize).enumerate() { @@ -165,13 +174,20 @@ fn mask_edges_inter( } } + // After these calls to `decomp_tx`, the following elements of `txa` are initialized: + // * `txa[0][0][0..h4][0..w4]` + // * `txa[1][0][0..h4][0..w4]` + // * `txa[0][1][0..h4][x]` where `x` is the start of a block edge + // * `txa[1][1][y][0..w4]` where `y` is the start of a block edge + // left block edge for y in 0..h4 { let mask = 1u32 << (by4 + y); let sidx = (mask >= 0x10000) as usize; let smask = mask >> (sidx << 4); - masks[0][bx4][cmp::min(txa[0][0][y][0], l[y]) as usize][sidx] - .update(|it| it | smask as u16); + // SAFETY: y < h4 so txa[0][0][y][0] is initialized. + let txa_y = unsafe { txa[0][0][y][0].assume_init() }; + masks[0][bx4][cmp::min(txa_y, l[y]) as usize][sidx].update(|it| it | smask as u16); } // top block edge @@ -179,8 +195,9 @@ fn mask_edges_inter( let mask = 1u32 << (bx4 + x); let sidx = (mask >= 0x10000) as usize; let smask = mask >> (sidx << 4); - masks[1][by4][cmp::min(txa[1][0][0][x], a[x]) as usize][sidx] - .update(|it| it | smask as u16); + // SAFETY: x < h4 so txa[1][0][0][x] is initialized. + let txa_x = unsafe { txa[1][0][0][x].assume_init() }; + masks[1][by4][cmp::min(txa_x, a[x]) as usize][sidx].update(|it| it | smask as u16); } if !skip { // inner (tx) left|right edges @@ -188,14 +205,20 @@ fn mask_edges_inter( let mask = 1u32 << (by4 + y); let sidx = (mask >= 0x10000) as usize; let smask = mask >> (sidx << 4); - let mut ltx = txa[0][0][y][0]; - let step = txa[0][1][y][0] as usize; + // SAFETY: y < h4 so txa[0][0][y][0] is initialized. + let mut ltx = unsafe { txa[0][0][y][0].assume_init() }; + // SAFETY: y < h4 and x == 0 so txa[0][1][y][0] is initialized. + let step = unsafe { txa[0][1][y][0].assume_init() } as usize; let mut x = step; while x < w4 { - let rtx = txa[0][0][y][x]; + // SAFETY: x < w4 and y < h4 so txa[0][0][y][x] is initialized. + let rtx = unsafe { txa[0][0][y][x].assume_init() }; masks[0][bx4 + x][cmp::min(rtx, ltx) as usize][sidx].update(|it| it | smask as u16); ltx = rtx; - let step = txa[0][1][y][x] as usize; + // SAFETY: x is incremented by tdim.w from previously + // initialized element, so we know that this element is a block + // edge and also initialized. + let step = unsafe { txa[0][1][y][x].assume_init() } as usize; x += step; } } @@ -207,23 +230,35 @@ fn mask_edges_inter( let mask = 1u32 << (bx4 + x); let sidx = (mask >= 0x10000) as usize; let smask = mask >> (sidx << 4); - let mut ttx = txa[1][0][0][x]; - let step = txa[1][1][0][x] as usize; + // SAFETY: x < w4 so txa[1][0][0][x] is initialized. + let mut ttx = unsafe { txa[1][0][0][x].assume_init() }; + // SAFETY: x < h4 and y == 0 so txa[1][1][0][x] is initialized. + let step = unsafe { txa[1][1][0][x].assume_init() } as usize; let mut y = step; while y < h4 { - let btx = txa[1][0][y][x]; + // SAFETY: x < w4 and y < h4 so txa[1][0][y][x] is initialized. + let btx = unsafe { txa[1][0][y][x].assume_init() }; masks[1][by4 + y][cmp::min(ttx, btx) as usize][sidx].update(|it| it | smask as u16); ttx = btx; - let step = txa[1][1][y][x] as usize; + // SAFETY: y is incremented by tdim.h from previously + // initialized element, so we know that this element is a block + // edge and also initialized. + let step = unsafe { txa[1][1][y][x].assume_init() } as usize; y += step; } } } for (l, txa) in l[..h4].iter_mut().zip(&txa[0][0][..h4]) { - *l = txa[w4 - 1]; + // SAFETY: y < h4 and x < w4 so txa[0][0][y][x] is initialized. + *l = unsafe { txa[w4 - 1].assume_init() }; } - a[..w4].copy_from_slice(&txa[1][0][h4 - 1][..w4]); + // SAFETY: y < h4 and x < w4 so txa[1][0][y][x] is initialized. Note that + // this can be replaced by `MaybeUninit::slice_assume_init_ref` if it is + // stabilized. + let txa_slice = + unsafe { &*(&txa[1][0][h4 - 1][..w4] as *const [MaybeUninit] as *const [u8]) }; + a[..w4].copy_from_slice(txa_slice); } #[inline]