diff --git a/src/combinations.rs b/src/combinations.rs index d8b5351ec..9e5a661c2 100644 --- a/src/combinations.rs +++ b/src/combinations.rs @@ -14,6 +14,7 @@ pub struct Combinations { indices: Vec, pool: LazyBuffer, first: bool, + done: bool, } impl Clone for Combinations @@ -21,7 +22,7 @@ where I: Clone + Iterator, I::Item: Clone, { - clone_fields!(indices, pool, first); + clone_fields!(indices, pool, first, done); } impl fmt::Debug for Combinations @@ -29,7 +30,7 @@ where I: Iterator + fmt::Debug, I::Item: fmt::Debug, { - debug_fmt_fields!(Combinations, indices, pool, first); + debug_fmt_fields!(Combinations, indices, pool, first, done); } /// Create a new `Combinations` from a clonable iterator. @@ -41,6 +42,7 @@ where indices: (0..k).collect(), pool: LazyBuffer::new(iter), first: true, + done: false, } } @@ -70,6 +72,7 @@ impl Combinations { /// elements. pub(crate) fn reset(&mut self, k: usize) { self.first = true; + self.done = false; if k < self.indices.len() { self.indices.truncate(k); @@ -90,10 +93,56 @@ impl Combinations { indices, pool, first, + done: _, } = self; let n = pool.count(); (n, remaining_for(n, first, &indices).unwrap()) } + + /// Initialises the iterator by filling a buffer with elements from the + /// iterator. + fn init(&mut self) { + self.pool.prefill(self.k()); + if self.k() > self.n() { + self.done = true; + } else { + self.first = false; + } + } + + /// Increments indices representing the combination to advance to the next + /// (in lexicographic order by increasing sequence) combination. For example + /// if we have n=3 & k=2 then [0, 1] -> [0, 2] -> [0, 3] -> [1, 2] -> ... + fn increment_indices(&mut self) { + if self.indices.is_empty() { + self.done = true; + return; + } + + // Scan from the end, looking for an index to increment + let mut i: usize = self.indices.len() - 1; + + // Check if we need to consume more from the iterator + if self.indices[i] == self.pool.len() - 1 { + self.pool.get_next(); // may change pool size + } + + while self.indices[i] == i + self.pool.len() - self.indices.len() { + if i > 0 { + i -= 1; + } else { + // Reached the last combination + self.done = true; + return; + } + } + + // Increment index, and reset the ones to its right + self.indices[i] += 1; + for j in i + 1..self.indices.len() { + self.indices[j] = self.indices[j - 1] + 1; + } + } } impl Iterator for Combinations @@ -104,40 +153,34 @@ where type Item = Vec; fn next(&mut self) -> Option { if self.first { - self.pool.prefill(self.k()); - if self.k() > self.n() { - return None; - } - self.first = false; - } else if self.indices.is_empty() { - return None; + self.init() } else { - // Scan from the end, looking for an index to increment - let mut i: usize = self.indices.len() - 1; + self.increment_indices() + } - // Check if we need to consume more from the iterator - if self.indices[i] == self.pool.len() - 1 { - self.pool.get_next(); // may change pool size - } + if self.done { + return None; + } - while self.indices[i] == i + self.pool.len() - self.indices.len() { - if i > 0 { - i -= 1; - } else { - // Reached the last combination - return None; - } - } + Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect()) + } + + fn nth(&mut self, n: usize) -> Option { + // Delegate initialisation work to next() + let first = self.next(); - // Increment index, and reset the ones to its right - self.indices[i] += 1; - for j in i + 1..self.indices.len() { - self.indices[j] = self.indices[j - 1] + 1; + if n == 0 { + return first; + } + + for _ in 0..(n - 1) { + self.increment_indices(); + if self.done { + return None; } } - // Create result vector based on the indices - Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect()) + self.next() } fn size_hint(&self) -> (usize, Option) {