Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: more efficient intersection #157

Merged
merged 2 commits into from
Dec 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,41 +285,69 @@ impl<V: Ord + Clone> Range<V> {

/// Computes the intersection of two sets of versions.
pub fn intersection(&self, other: &Self) -> Self {
let mut segments: SmallVec<Interval<V>> = SmallVec::empty();
let mut output: SmallVec<Interval<V>> = SmallVec::empty();
let mut left_iter = self.segments.iter().peekable();
let mut right_iter = other.segments.iter().peekable();

while let (Some((left_start, left_end)), Some((right_start, right_end))) =
(left_iter.peek(), right_iter.peek())
// By the definition of intersection any point that is matched by the output
// must have a segment in each of the inputs that it matches.
// Therefore, every segment in the output must be the intersection of a segment from each of the inputs.
// It would be correct to do the "O(n^2)" thing, by computing the intersection of every segment from one input
// with every segment of the other input, and sorting the result.
// We can avoid the sorting by generating our candidate segments with an increasing `end` value.
while let Some(((left_start, left_end), (right_start, right_end))) =
left_iter.peek().zip(right_iter.peek())
{
// The next smallest `end` value is going to come from one of the inputs.
let left_end_is_smaller = match (left_end, right_end) {
(Included(l), Included(r))
| (Excluded(l), Excluded(r))
| (Excluded(l), Included(r)) => l <= r,

(Included(l), Excluded(r)) => l < r,
(_, Unbounded) => true,
(Unbounded, _) => false,
};
// Now that we are processing `end` we will never have to process any segment smaller than that.
// We can ensure that the input that `end` came from is larger than `end` by advancing it one step.
// `end` is the smaller available input, so we know the other input is already larger than `end`.
// Note: We can call `other_iter.next_if( == end)`, but the ends lining up is rare enough that
// it does not end up being faster in practice.
let (other_start, end) = if left_end_is_smaller {
left_iter.next();
(right_start, left_end)
} else {
right_iter.next();
(left_start, right_end)
};
// `start` will either come from the input `end` came from or the other input, whichever one is larger.
// The intersection is invalid if `start` > `end`.
// But, we already know that the segments in our input are valid.
// So we do not need to check if the `start` from the input `end` came from is smaller then `end`.
// If the `other_start` is larger than end, then the intersection will be invalid.
if !valid_segment(other_start, end) {
// Note: We can call `this_iter.next_if(!valid_segment(other_start, this_end))` in a loop.
// But the checks make it slower for the benchmarked inputs.
continue;
}
let start = match (left_start, right_start) {
(Included(l), Included(r)) => Included(std::cmp::max(l, r)),
(Excluded(l), Excluded(r)) => Excluded(std::cmp::max(l, r)),

(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if i <= e => Excluded(e),
(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if e < i => Included(i),
(s, Unbounded) | (Unbounded, s) => s.as_ref(),
_ => unreachable!(),
}
.cloned();
let end = match (left_end, right_end) {
(Included(l), Included(r)) => Included(std::cmp::min(l, r)),
(Excluded(l), Excluded(r)) => Excluded(std::cmp::min(l, r)),

(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if i >= e => Excluded(e),
(Included(i), Excluded(e)) | (Excluded(e), Included(i)) if e > i => Included(i),
(Included(i), Excluded(e)) | (Excluded(e), Included(i)) => {
if i <= e {
Excluded(e)
} else {
Included(i)
}
}
(s, Unbounded) | (Unbounded, s) => s.as_ref(),
_ => unreachable!(),
}
.cloned();
left_iter.next_if(|(_, e)| e == &end);
right_iter.next_if(|(_, e)| e == &end);
if valid_segment(&start, &end) {
segments.push((start, end))
}
};
// Now we clone and push a new segment.
// By dealing with references until now we ensure that NO cloning happens when we reject the segment.
output.push((start.cloned(), end.clone()))
}

Self { segments }.check_invariants()
Self { segments: output }.check_invariants()
}
}

Expand Down
Loading