Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permutation decoding #36

Merged
merged 14 commits into from
Oct 28, 2024
4 changes: 4 additions & 0 deletions jxl/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ pub enum Error {
InvalidContextMap(u32),
#[error("Invalid context map: number of histogram {0}, number of distinct histograms {1}")]
InvalidContextMapHole(u32, u32),
#[error("Invalid permutation: skipped elements {skip} and encoded elements {end} don't fit in permutation of size {size}")]
InvalidPermutationSize { size: u32, skip: u32, end: u32 },
#[error("Invalid permutation: Lehmer code {lehmer} out of bounds in permutation of size {size} at index {idx}")]
InvalidPermutationLehmerCode { size: u32, idx: u32, lehmer: u32 },
// FrameHeader format errors
#[error("Invalid extra channel upsampling: upsampling: {0} dim_shift: {1} ec_upsampling: {2}")]
InvalidEcUpsampling(u32, u32, u32),
Expand Down
1 change: 1 addition & 0 deletions jxl/src/headers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod encodings;
pub mod extra_channels;
pub mod frame_header;
pub mod image_metadata;
pub mod permutation;
pub mod size;
pub mod transform_data;

Expand Down
330 changes: 330 additions & 0 deletions jxl/src/headers/permutation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use crate::bit_reader::BitReader;
use crate::entropy_coding::decode::Reader;
use crate::error::{Error, Result};
use crate::util::{tracing::instrument, value_of_lowest_1_bit, CeilLog2};

pub struct Permutation(Vec<u32>);

impl std::ops::Deref for Permutation {
type Target = [u32];

fn deref(&self) -> &[u32] {
&self.0
}
}

impl Permutation {
/// Decode a permutation from entropy-coded stream.
pub fn decode(
size: u32,
skip: u32,
br: &mut BitReader,
entropy_reader: &mut Reader,
) -> Result<Self> {
let end = entropy_reader.read(br, get_context(size))?;
Self::decode_inner(size, skip, end, |ctx| entropy_reader.read(br, ctx))
}

fn decode_inner(
size: u32,
skip: u32,
end: u32,
mut read: impl FnMut(usize) -> Result<u32>,
) -> Result<Self> {
if end > size - skip {
return Err(Error::InvalidPermutationSize { size, skip, end });
}

let mut lehmer = Vec::new();
lehmer.try_reserve(end as usize)?;

let mut prev_val = 0u32;
for idx in skip..(skip + end) {
let val = read(get_context(prev_val))?;
if val >= size - idx {
return Err(Error::InvalidPermutationLehmerCode {
size,
idx,
lehmer: val,
});
}
lehmer.push(val);
prev_val = val;
}

// Initialize the full permutation vector with skipped elements intact
let mut permutation: Vec<u32> = Vec::new();
permutation.try_reserve((size - skip) as usize)?;
permutation.extend(0..size);

// Decode the Lehmer code into the slice starting at `skip`
let permuted_slice = decode_lehmer_code(&lehmer, &permutation[skip as usize..])?;

// Replace the target slice in `permutation`
permutation[skip as usize..].copy_from_slice(&permuted_slice);

// Ensure the permutation has the correct size
assert_eq!(permutation.len(), size as usize);

Ok(Self(permutation))
}
}

// Decodes the Lehmer code in `code` and returns the permuted slice.
#[instrument(ret, err)]
fn decode_lehmer_code(code: &[u32], permutation_slice: &[u32]) -> Result<Vec<u32>> {
let n = permutation_slice.len();
if n == 0 {
return Err(Error::InvalidPermutationLehmerCode {
size: 0,
idx: 0,
lehmer: 0,
});
}

let mut permuted = vec![];
permuted.try_reserve(n)?;
permuted.extend_from_slice(permutation_slice);

let padded_n = (n as u32).next_power_of_two() as usize;

// Allocate temp array inside the function
let mut temp = vec![];
temp.try_reserve(padded_n)?;
temp.extend((0..padded_n as u32).map(|x| value_of_lowest_1_bit(x + 1)));

for (i, permuted_item) in permuted.iter_mut().enumerate() {
let code_i = *code.get(i).unwrap_or(&0);

// Adjust the maximum allowed value for code_i
if code_i as usize > n - i - 1 {
return Err(Error::InvalidPermutationLehmerCode {
size: n as u32,
idx: i as u32,
lehmer: code_i,
});
}

let mut rank = code_i + 1;

// Extract i-th unused element via implicit order-statistics tree.
let mut bit = padded_n;
let mut next = 0usize;
while bit != 0 {
let cand = next + bit;
if cand == 0 || cand > padded_n {
return Err(Error::InvalidPermutationLehmerCode {
size: n as u32,
idx: i as u32,
lehmer: code_i,
});
}
bit >>= 1;
if temp[cand - 1] < rank {
next = cand;
rank -= temp[cand - 1];
}
}

*permuted_item = permutation_slice[next];

next += 1;
while next <= padded_n {
temp[next - 1] -= 1;
next += value_of_lowest_1_bit(next as u32) as usize;
}
}

Ok(permuted)
}

// Decodes the Lehmer code in `code` and returns the permuted vector.
#[cfg(test)]
fn decode_lehmer_code_naive(code: &[u32], permutation_slice: &[u32]) -> Result<Vec<u32>> {
let n = code.len();
if n == 0 {
return Err(Error::InvalidPermutationLehmerCode {
size: 0,
idx: 0,
lehmer: 0,
});
}

// Ensure permutation_slice has sufficient length
if permutation_slice.len() < n {
return Err(Error::InvalidPermutationLehmerCode {
size: n as u32,
idx: 0,
lehmer: 0,
});
}

// Create temp array with values from permutation_slice
let mut temp = permutation_slice.to_vec();
let mut permuted = Vec::with_capacity(n);

// Iterate over the Lehmer code
for (i, &idx) in code.iter().enumerate() {
if idx as usize >= temp.len() {
return Err(Error::InvalidPermutationLehmerCode {
size: n as u32,
idx: i as u32,
lehmer: idx,
});
}

// Assign temp[idx] to permuted vector
permuted.push(temp.remove(idx as usize));
}

// Append any remaining elements from temp to permuted
permuted.extend(temp);

Ok(permuted)
}

fn get_context(x: u32) -> usize {
(x + 1).ceil_log2().min(7) as usize
}

#[cfg(test)]
mod test {
use super::*;
use arbtest::arbitrary::{self, Arbitrary, Unstructured};
use core::assert_eq;
use test_log::test;

#[test]
fn generate_permutation_arbtest() {
arbtest::arbtest(|u| {
let input = PermutationInput::arbitrary(u)?;

let permutation_slice = input.permutation.as_slice();

let perm1 = decode_lehmer_code(&input.code, permutation_slice);
let perm2 = decode_lehmer_code_naive(&input.code, permutation_slice);

assert_eq!(
perm1.map_err(|x| x.to_string()),
perm2.map_err(|x| x.to_string())
);
Ok(())
});
}

#[derive(Debug)]
struct PermutationInput {
code: Vec<u32>,
permutation: Vec<u32>,
}

impl<'a> Arbitrary<'a> for PermutationInput {
fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self, arbitrary::Error> {
// Generate a reasonable size to prevent tests from taking too long
let size_lehmer = u.int_in_range(1..=1000)?;

let mut lehmer: Vec<u32> = Vec::with_capacity(size_lehmer as usize);
for i in 0..size_lehmer {
let max_val = size_lehmer - i - 1;
let val = if max_val > 0 {
u.int_in_range(0..=max_val)?
} else {
0
};
lehmer.push(val);
}

let mut permutation = Vec::new();
let size_permutation = u.int_in_range(size_lehmer..=1000)?;
permutation.extend(0..size_permutation);

let num_of_swaps = u.int_in_range(0..=100)?;
for _ in 0..num_of_swaps {
// Randomly swap two positions
let pos1 = u.int_in_range(0..=size_permutation - 1)?;
let pos2 = u.int_in_range(0..=size_permutation - 1)?;
permutation.swap(pos1 as usize, pos2 as usize);
}

Ok(PermutationInput {
code: lehmer,
permutation,
})
}
}

#[test]
fn simple() {
// Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1]
let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1];
let skip = 4;
let size = 16;

let permutation_slice: Vec<u32> = (skip..size).collect();

let permuted = decode_lehmer_code(&code, &permutation_slice).unwrap();
let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice).unwrap();

let mut permutation = Vec::with_capacity(size as usize);
permutation.extend(0..skip); // Add skipped elements
permutation.extend(permuted.iter());
let expected_permutation = vec![0, 1, 2, 3, 5, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14];

assert_eq!(permutation, expected_permutation);
assert_eq!(permuted, permuted_naive);
}

#[test]
fn decode_lehmer_compare_different_length() -> Result<(), Box<dyn std::error::Error>> {
// Lehmer code: [1, 1, 2, 3, 3, 6, 0, 1]
let code = vec![1u32, 1, 2, 3, 3, 6, 0, 1];
let skip = 4;
let size = 16;

let permutation_slice: Vec<u32> = (skip..size).collect();

let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?;
let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?;

let expected_permuted = vec![5u32, 6, 8, 10, 11, 15, 4, 9, 7, 12, 13, 14];

assert_eq!(permuted_optimized, expected_permuted);
assert_eq!(permuted_naive, expected_permuted);
assert_eq!(permuted_optimized, permuted_naive);

Ok(())
}

#[test]
fn decode_lehmer_compare_same_length() -> Result<(), Box<dyn std::error::Error>> {
// Lehmer code: [2, 3, 0, 0, 0]
let code = vec![2u32, 3, 0, 0, 0];
let n = code.len();
let permutation_slice: Vec<u32> = (0..n as u32).collect();

let permuted_optimized = decode_lehmer_code(&code, &permutation_slice)?;
let permuted_naive = decode_lehmer_code_naive(&code, &permutation_slice)?;

let expected_permutation = vec![2u32, 4, 0, 1, 3];

assert_eq!(permuted_optimized, expected_permutation);
assert_eq!(permuted_naive, expected_permutation);
assert_eq!(permuted_optimized, permuted_naive);

Ok(())
}

#[test]
fn lehmer_out_of_bounds() {
let code = vec![4];
let permutation_slice: Vec<u32> = (4..8).collect();

let result = decode_lehmer_code(&code, &permutation_slice);
assert!(result.is_err());
}
}
2 changes: 2 additions & 0 deletions jxl/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

mod bits;
#[allow(unused)]
mod concat_slice;
mod log2;
mod shift_right_ceil;
pub mod tracing;

pub use bits::*;
#[allow(unused)]
pub use concat_slice::*;
pub use log2::*;
Expand Down
23 changes: 23 additions & 0 deletions jxl/src/util/bits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

pub fn value_of_lowest_1_bit(t: u32) -> u32 {
t & t.wrapping_neg()
}
#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_value_of_lowest_1_bit() {
assert_eq!(value_of_lowest_1_bit(0b0001), 1);
assert_eq!(value_of_lowest_1_bit(0b1111), 1);
assert_eq!(value_of_lowest_1_bit(0b0010), 2);
assert_eq!(value_of_lowest_1_bit(0b0100), 4);
assert_eq!(value_of_lowest_1_bit(0b1010), 2);
assert_eq!(value_of_lowest_1_bit(0b1000_0000), 128);
assert_eq!(value_of_lowest_1_bit(0), 0);
}
}