diff --git a/jxl/src/error.rs b/jxl/src/error.rs index ad1be49..5a3ea50 100644 --- a/jxl/src/error.rs +++ b/jxl/src/error.rs @@ -67,6 +67,10 @@ pub enum Error { InvalidContextMap(u32), #[error("Invalid context map: number of histogram {0}, number of distinct histograms {1}")] InvalidContextMapHole(u32, u32), + #[error("Invalid permutation: skipped elements {skip} and encoded elements {end} don't fit in permutation of size {size}")] + InvalidPermutationSize { size: u32, skip: u32, end: u32 }, + #[error("Invalid permutation: Lehmer code {lehmer} out of bounds in permutation of size {size} at index {idx}")] + InvalidPermutationLehmerCode { size: u32, idx: u32, lehmer: u32 }, // FrameHeader format errors #[error("Invalid extra channel upsampling: upsampling: {0} dim_shift: {1} ec_upsampling: {2}")] InvalidEcUpsampling(u32, u32, u32), diff --git a/jxl/src/headers/mod.rs b/jxl/src/headers/mod.rs index 2f3c2d0..aa414ce 100644 --- a/jxl/src/headers/mod.rs +++ b/jxl/src/headers/mod.rs @@ -9,6 +9,7 @@ pub mod encodings; pub mod extra_channels; pub mod frame_header; pub mod image_metadata; +pub mod permutation; pub mod size; pub mod transform_data; diff --git a/jxl/src/headers/permutation.rs b/jxl/src/headers/permutation.rs new file mode 100644 index 0000000..00c68c6 --- /dev/null +++ b/jxl/src/headers/permutation.rs @@ -0,0 +1,330 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +use crate::bit_reader::BitReader; +use crate::entropy_coding::decode::Reader; +use crate::error::{Error, Result}; +use crate::util::{tracing::instrument, value_of_lowest_1_bit, CeilLog2}; + +pub struct Permutation(Vec); + +impl std::ops::Deref for Permutation { + type Target = [u32]; + + fn deref(&self) -> &[u32] { + &self.0 + } +} + +impl Permutation { + /// Decode a permutation from entropy-coded stream. + pub fn decode( + size: u32, + skip: u32, + br: &mut BitReader, + entropy_reader: &mut Reader, + ) -> Result { + let end = entropy_reader.read(br, get_context(size))?; + Self::decode_inner(size, skip, end, |ctx| entropy_reader.read(br, ctx)) + } + + fn decode_inner( + size: u32, + skip: u32, + end: u32, + mut read: impl FnMut(usize) -> Result, + ) -> Result { + if end > size - skip { + return Err(Error::InvalidPermutationSize { size, skip, end }); + } + + let mut lehmer = Vec::new(); + lehmer.try_reserve(end as usize)?; + + let mut prev_val = 0u32; + for idx in skip..(skip + end) { + let val = read(get_context(prev_val))?; + if val >= size - idx { + return Err(Error::InvalidPermutationLehmerCode { + size, + idx, + lehmer: val, + }); + } + lehmer.push(val); + prev_val = val; + } + + // Initialize the full permutation vector with skipped elements intact + let mut permutation: Vec = Vec::new(); + permutation.try_reserve((size - skip) as usize)?; + permutation.extend(0..size); + + // Decode the Lehmer code into the slice starting at `skip` + let permuted_slice = decode_lehmer_code(&lehmer, &permutation[skip as usize..])?; + + // Replace the target slice in `permutation` + permutation[skip as usize..].copy_from_slice(&permuted_slice); + + // Ensure the permutation has the correct size + assert_eq!(permutation.len(), size as usize); + + Ok(Self(permutation)) + } +} + +// Decodes the Lehmer code in `code` and returns the permuted slice. +#[instrument(ret, err)] +fn decode_lehmer_code(code: &[u32], permutation_slice: &[u32]) -> Result> { + let n = permutation_slice.len(); + if n == 0 { + return Err(Error::InvalidPermutationLehmerCode { + size: 0, + idx: 0, + lehmer: 0, + }); + } + + let mut permuted = vec![]; + permuted.try_reserve(n)?; + permuted.extend_from_slice(permutation_slice); + + let padded_n = (n as u32).next_power_of_two() as usize; + + // Allocate temp array inside the function + let mut temp = vec![]; + temp.try_reserve(padded_n)?; + temp.extend((0..padded_n as u32).map(|x| value_of_lowest_1_bit(x + 1))); + + for (i, permuted_item) in permuted.iter_mut().enumerate() { + let code_i = *code.get(i).unwrap_or(&0); + + // Adjust the maximum allowed value for code_i + if code_i as usize > n - i - 1 { + return Err(Error::InvalidPermutationLehmerCode { + size: n as u32, + idx: i as u32, + lehmer: code_i, + }); + } + + let mut rank = code_i + 1; + + // Extract i-th unused element via implicit order-statistics tree. + let mut bit = padded_n; + let mut next = 0usize; + while bit != 0 { + let cand = next + bit; + if cand == 0 || cand > padded_n { + return Err(Error::InvalidPermutationLehmerCode { + size: n as u32, + idx: i as u32, + lehmer: code_i, + }); + } + bit >>= 1; + if temp[cand - 1] < rank { + next = cand; + rank -= temp[cand - 1]; + } + } + + *permuted_item = permutation_slice[next]; + + next += 1; + while next <= padded_n { + temp[next - 1] -= 1; + next += value_of_lowest_1_bit(next as u32) as usize; + } + } + + Ok(permuted) +} + +// Decodes the Lehmer code in `code` and returns the permuted vector. +#[cfg(test)] +fn decode_lehmer_code_naive(code: &[u32], permutation_slice: &[u32]) -> Result> { + let n = code.len(); + if n == 0 { + return Err(Error::InvalidPermutationLehmerCode { + size: 0, + idx: 0, + lehmer: 0, + }); + } + + // Ensure permutation_slice has sufficient length + if permutation_slice.len() < n { + return Err(Error::InvalidPermutationLehmerCode { + size: n as u32, + idx: 0, + lehmer: 0, + }); + } + + // Create temp array with values from permutation_slice + let mut temp = permutation_slice.to_vec(); + let mut permuted = Vec::with_capacity(n); + + // Iterate over the Lehmer code + for (i, &idx) in code.iter().enumerate() { + if idx as usize >= temp.len() { + return Err(Error::InvalidPermutationLehmerCode { + size: n as u32, + idx: i as u32, + lehmer: idx, + }); + } + + // Assign temp[idx] to permuted vector + permuted.push(temp.remove(idx as usize)); + } + + // Append any remaining elements from temp to permuted + permuted.extend(temp); + + Ok(permuted) +} + +fn get_context(x: u32) -> usize { + (x + 1).ceil_log2().min(7) as usize +} + +#[cfg(test)] +mod test { + use super::*; + use arbtest::arbitrary::{self, Arbitrary, Unstructured}; + use core::assert_eq; + use test_log::test; + + #[test] + fn generate_permutation_arbtest() { + arbtest::arbtest(|u| { + let input = PermutationInput::arbitrary(u)?; + + let permutation_slice = input.permutation.as_slice(); + + let perm1 = decode_lehmer_code(&input.code, permutation_slice); + let perm2 = decode_lehmer_code_naive(&input.code, permutation_slice); + + assert_eq!( + perm1.map_err(|x| x.to_string()), + perm2.map_err(|x| x.to_string()) + ); + Ok(()) + }); + } + + #[derive(Debug)] + struct PermutationInput { + code: Vec, + permutation: Vec, + } + + impl<'a> Arbitrary<'a> for PermutationInput { + fn arbitrary(u: &mut Unstructured<'a>) -> Result { + // Generate a reasonable size to prevent tests from taking too long + let size_lehmer = u.int_in_range(1..=1000)?; + + let mut lehmer: Vec = Vec::with_capacity(size_lehmer as usize); + for i in 0..size_lehmer { + let max_val = size_lehmer - i - 1; + let val = if max_val > 0 { + u.int_in_range(0..=max_val)? + } else { + 0 + }; + lehmer.push(val); + } + + let mut permutation = Vec::new(); + let size_permutation = u.int_in_range(size_lehmer..=1000)?; + permutation.extend(0..size_permutation); + + let num_of_swaps = u.int_in_range(0..=100)?; + for _ in 0..num_of_swaps { + // Randomly swap two positions + let pos1 = u.int_in_range(0..=size_permutation - 1)?; + let pos2 = u.int_in_range(0..=size_permutation - 1)?; + permutation.swap(pos1 as usize, pos2 as usize); + } + + Ok(PermutationInput { + code: lehmer, + permutation, + }) + } + } + + #[test] + fn simple() { + // Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1] + let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1]; + let skip = 4; + let size = 16; + + let permutation_slice: Vec = (skip..size).collect(); + + let permuted = decode_lehmer_code(&code, &permutation_slice).unwrap(); + let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice).unwrap(); + + let mut permutation = Vec::with_capacity(size as usize); + permutation.extend(0..skip); // Add skipped elements + permutation.extend(permuted.iter()); + let expected_permutation = vec![0, 1, 2, 3, 5, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14]; + + assert_eq!(permutation, expected_permutation); + assert_eq!(permuted, permuted_naive); + } + + #[test] + fn decode_lehmer_compare_different_length() -> Result<(), Box> { + // Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1] + let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1]; + let skip = 4; + let size = 16; + + let permutation_slice: Vec = (skip..size).collect(); + + let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?; + let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?; + + let expected_permuted = vec![5u32, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14]; + + assert_eq!(permuted_optimized, expected_permuted); + assert_eq!(permuted_naive, expected_permuted); + assert_eq!(permuted_optimized, permuted_naive); + + Ok(()) + } + + #[test] + fn decode_lehmer_compare_same_length() -> Result<(), Box> { + // Lehmer code: [2, 3, 0, 0, 0] + let code = vec![2u32, 3, 0, 0, 0]; + let n = code.len(); + let permutation_slice: Vec = (0..n as u32).collect(); + + let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?; + let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?; + + let expected_permutation = vec![2u32, 4, 0, 1, 3]; + + assert_eq!(permuted_optimized, expected_permutation); + assert_eq!(permuted_naive, expected_permutation); + assert_eq!(permuted_optimized, permuted_naive); + + Ok(()) + } + + #[test] + fn lehmer_out_of_bounds() { + let code = vec![4]; + let permutation_slice: Vec = (4..8).collect(); + + let result = decode_lehmer_code(&code, &permutation_slice); + assert!(result.is_err()); + } +} diff --git a/jxl/src/util.rs b/jxl/src/util.rs index 6e879ff..1245160 100644 --- a/jxl/src/util.rs +++ b/jxl/src/util.rs @@ -3,12 +3,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +mod bits; #[allow(unused)] mod concat_slice; mod log2; mod shift_right_ceil; pub mod tracing; +pub use bits::*; #[allow(unused)] pub use concat_slice::*; pub use log2::*; diff --git a/jxl/src/util/bits.rs b/jxl/src/util/bits.rs new file mode 100644 index 0000000..fe578ea --- /dev/null +++ b/jxl/src/util/bits.rs @@ -0,0 +1,23 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +pub fn value_of_lowest_1_bit(t: u32) -> u32 { + t & t.wrapping_neg() +} +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_value_of_lowest_1_bit() { + assert_eq!(value_of_lowest_1_bit(0b0001), 1); + assert_eq!(value_of_lowest_1_bit(0b1111), 1); + assert_eq!(value_of_lowest_1_bit(0b0010), 2); + assert_eq!(value_of_lowest_1_bit(0b0100), 4); + assert_eq!(value_of_lowest_1_bit(0b1010), 2); + assert_eq!(value_of_lowest_1_bit(0b1000_0000), 128); + assert_eq!(value_of_lowest_1_bit(0), 0); + } +}