diff --git a/src/permutations.rs b/src/permutations.rs index b6ff5c4fa..5dd3dbf9a 100644 --- a/src/permutations.rs +++ b/src/permutations.rs @@ -1,6 +1,7 @@ use alloc::vec::Vec; use std::fmt; use std::iter::once; +use std::iter::FusedIterator; use super::lazy_buffer::LazyBuffer; use crate::size_hint::{self, SizeHint}; @@ -26,22 +27,17 @@ where #[derive(Clone, Debug)] enum PermutationState { - StartUnknownLen { k: usize }, - OngoingUnknownLen { k: usize, min_n: usize }, - Complete(CompleteState), - Empty, -} - -#[derive(Clone, Debug)] -enum CompleteState { - Start { - n: usize, - k: usize, - }, - Ongoing { + /// No permutation generated yet. + Start { k: usize }, + /// Values from the iterator are not fully loaded yet so `n` is still unknown. + Buffered { k: usize, min_n: usize }, + /// All values from the iterator are known so `n` is known. + Loaded { indices: Vec, cycles: Vec, }, + /// No permutation left to generate. + End, } impl fmt::Debug for Permutations @@ -55,20 +51,13 @@ where pub fn permutations(iter: I, k: usize) -> Permutations { let mut vals = LazyBuffer::new(iter); - if k == 0 { - // Special case, yields single empty vec; `n` is irrelevant - let state = PermutationState::Complete(CompleteState::Start { n: 0, k: 0 }); - - return Permutations { vals, state }; - } - vals.prefill(k); let enough_vals = vals.len() == k; let state = if enough_vals { - PermutationState::StartUnknownLen { k } + PermutationState::Start { k } } else { - PermutationState::Empty + PermutationState::End }; Permutations { vals, state } @@ -82,169 +71,117 @@ where type Item = Vec; fn next(&mut self) -> Option { - { - let &mut Permutations { - ref mut vals, - ref mut state, - } = self; - match *state { - PermutationState::StartUnknownLen { k } => { - *state = PermutationState::OngoingUnknownLen { k, min_n: k }; - } - PermutationState::OngoingUnknownLen { k, min_n } => { - if vals.get_next() { - *state = PermutationState::OngoingUnknownLen { - k, - min_n: min_n + 1, - }; - } else { - let n = min_n; - let prev_iteration_count = n - k + 1; - let mut complete_state = CompleteState::Start { n, k }; - - // Advance the complete-state iterator to the correct point - for _ in 0..(prev_iteration_count + 1) { - complete_state.advance(); + let Self { vals, state } = self; + match state { + PermutationState::Start { k: 0 } => { + *state = PermutationState::End; + Some(Vec::new()) + } + &mut PermutationState::Start { k } => { + *state = PermutationState::Buffered { k, min_n: k }; + Some(vals[0..k].to_vec()) + } + PermutationState::Buffered { ref k, min_n } => { + if vals.get_next() { + let item = (0..*k - 1) + .chain(once(*min_n)) + .map(|i| vals[i].clone()) + .collect(); + *min_n += 1; + Some(item) + } else { + let n = *min_n; + let prev_iteration_count = n - *k + 1; + let mut indices: Vec<_> = (0..n).collect(); + let mut cycles: Vec<_> = (n - k..n).rev().collect(); + // Advance the state to the correct point. + for _ in 0..prev_iteration_count { + if advance(&mut indices, &mut cycles) { + *state = PermutationState::End; + return None; } - - *state = PermutationState::Complete(complete_state); } + let item = indices[0..*k].iter().map(|&i| vals[i].clone()).collect(); + *state = PermutationState::Loaded { indices, cycles }; + Some(item) } - PermutationState::Complete(ref mut state) => { - state.advance(); - } - PermutationState::Empty => {} - }; - } - let &mut Permutations { - ref vals, - ref state, - } = self; - match *state { - PermutationState::StartUnknownLen { .. } => panic!("unexpected iterator state"), - PermutationState::OngoingUnknownLen { k, min_n } => { - let latest_idx = min_n - 1; - let indices = (0..(k - 1)).chain(once(latest_idx)); - - Some(indices.map(|i| vals[i].clone()).collect()) } - PermutationState::Complete(CompleteState::Ongoing { - ref indices, - ref cycles, - }) => { + PermutationState::Loaded { indices, cycles } => { + if advance(indices, cycles) { + *state = PermutationState::End; + return None; + } let k = cycles.len(); Some(indices[0..k].iter().map(|&i| vals[i].clone()).collect()) } - PermutationState::Complete(CompleteState::Start { .. }) | PermutationState::Empty => { - None - } + PermutationState::End => None, } } fn count(self) -> usize { - fn from_complete(complete_state: CompleteState) -> usize { - complete_state - .remaining() - .expect("Iterator count greater than usize::MAX") - } - - let Permutations { vals, state } = self; - match state { - PermutationState::StartUnknownLen { k } => { - let n = vals.count(); - let complete_state = CompleteState::Start { n, k }; - - from_complete(complete_state) - } - PermutationState::OngoingUnknownLen { k, min_n } => { - let prev_iteration_count = min_n - k + 1; - let n = vals.count(); - let complete_state = CompleteState::Start { n, k }; - - from_complete(complete_state) - prev_iteration_count - } - PermutationState::Complete(state) => from_complete(state), - PermutationState::Empty => 0, - } + let Self { vals, state } = self; + let n = vals.count(); + state.size_hint_for(n).1.unwrap() } fn size_hint(&self) -> SizeHint { - let at_start = |k| { - // At the beginning, there are `n!/(n-k)!` items to come (see `remaining`) but `n` might be unknown. - let (mut low, mut upp) = self.vals.size_hint(); - low = CompleteState::Start { n: low, k } - .remaining() - .unwrap_or(usize::MAX); - upp = upp.and_then(|n| CompleteState::Start { n, k }.remaining()); - (low, upp) - }; - match self.state { - PermutationState::StartUnknownLen { k } => at_start(k), - PermutationState::OngoingUnknownLen { k, min_n } => { - // Same as `StartUnknownLen` minus the previously generated items. - size_hint::sub_scalar(at_start(k), min_n - k + 1) - } - PermutationState::Complete(ref state) => match state.remaining() { - Some(count) => (count, Some(count)), - None => (::std::usize::MAX, None), - }, - PermutationState::Empty => (0, Some(0)), - } + let (mut low, mut upp) = self.vals.size_hint(); + low = self.state.size_hint_for(low).0; + upp = upp.and_then(|n| self.state.size_hint_for(n).1); + (low, upp) } } -impl CompleteState { - fn advance(&mut self) { - *self = match *self { - CompleteState::Start { n, k } => { - let indices = (0..n).collect(); - let cycles = ((n - k)..n).rev().collect(); - - CompleteState::Ongoing { cycles, indices } - } - CompleteState::Ongoing { - ref mut indices, - ref mut cycles, - } => { - let n = indices.len(); - let k = cycles.len(); - - for i in (0..k).rev() { - if cycles[i] == 0 { - cycles[i] = n - i - 1; - - let to_push = indices.remove(i); - indices.push(to_push); - } else { - let swap_index = n - cycles[i]; - indices.swap(i, swap_index); - - cycles[i] -= 1; - return; - } - } +impl FusedIterator for Permutations +where + I: Iterator, + I::Item: Clone, +{ +} - CompleteState::Start { n, k } - } +fn advance(indices: &mut [usize], cycles: &mut [usize]) -> bool { + let n = indices.len(); + let k = cycles.len(); + // NOTE: if `cycles` are only zeros, then we reached the last permutation. + for i in (0..k).rev() { + if cycles[i] == 0 { + cycles[i] = n - i - 1; + indices[i..].rotate_left(1); + } else { + let swap_index = n - cycles[i]; + indices.swap(i, swap_index); + cycles[i] -= 1; + return false; } } + true +} - /// Returns the count of remaining permutations, or None if it would overflow. - fn remaining(&self) -> Option { +impl PermutationState { + fn size_hint_for(&self, n: usize) -> SizeHint { + // At the beginning, there are `n!/(n-k)!` items to come. + let at_start = |n, k| { + debug_assert!(n >= k); + let total = (n - k + 1..=n).try_fold(1usize, |acc, i| acc.checked_mul(i)); + (total.unwrap_or(usize::MAX), total) + }; match *self { - CompleteState::Start { n, k } => { - if n < k { - return Some(0); - } - (n - k + 1..=n).try_fold(1usize, |acc, i| acc.checked_mul(i)) + Self::Start { k } => at_start(n, k), + Self::Buffered { k, min_n } => { + // Same as `Start` minus the previously generated items. + size_hint::sub_scalar(at_start(n, k), min_n - k + 1) } - CompleteState::Ongoing { + Self::Loaded { ref indices, ref cycles, - } => cycles.iter().enumerate().try_fold(0usize, |acc, (i, &c)| { - acc.checked_mul(indices.len() - i) - .and_then(|count| count.checked_add(c)) - }), + } => { + let count = cycles.iter().enumerate().try_fold(0usize, |acc, (i, &c)| { + acc.checked_mul(indices.len() - i) + .and_then(|count| count.checked_add(c)) + }); + (count.unwrap_or(usize::MAX), count) + } + Self::End => (0, Some(0)), } } }