-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: Reduce TASO hashtable size #133
Conversation
After some more testing (and one very crucial improvement, see latest commit), improvements are very clear. These are run with Command: On main
This PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
The changes to the priority channel will also be useful as a base for the sharding idea.
src/optimiser/taso.rs
Outdated
if (pq.len() > PRIORITY_QUEUE_CAPACITY / 2 | ||
&& new_circ_cost > *pq.max_cost().unwrap()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check this before computing the hash, so we may skip that computation on some cases.
src/optimiser/taso.rs
Outdated
@@ -130,18 +131,23 @@ where | |||
let rewrites = self.rewriter.get_rewrites(&circ); | |||
for new_circ in self.strategy.apply_rewrites(rewrites, &circ) { | |||
let new_circ_hash = new_circ.circuit_hash(); | |||
let new_circ_cost = (self.cost)(&new_circ); | |||
circ_cnt += 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this; do we want to count repeated hashes as seen multiple times? Otherwise this should go after the branch.
src/optimiser/taso/hugr_pchannel.rs
Outdated
self.log | ||
.send(PriorityChannelLog::CircuitCount( | ||
self.circ_cnt, | ||
self.seen_hashes.len(), | ||
)) | ||
.unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this out of the loop, so other break
s also trigger a last log.
self.circ_cnt += 1; | ||
if self.circ_cnt % 1000 == 0 { | ||
// TODO: Add a minimum time between logs | ||
self.log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could log directly from this thread, but currently TasoLogger
is non-copyable so we cannot share it.
LGTM for now, but we'll probably want to simplify it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree.
self.buckets.push_front([hash].into_iter().collect()); | ||
return true; | ||
}; | ||
while cost < *min_cost { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while cost < *min_cost { | |
self.buckets.reserve(min_cost.saturating_sub(cost)); | |
while cost < *min_cost { |
src/optimiser/taso/hugr_hash_set.rs
Outdated
*min_cost -= 1; | ||
} | ||
let bucket_index = cost - *min_cost; | ||
while bucket_index >= self.buckets.len() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while bucket_index >= self.buckets.len() { | |
let missing_back = (bucket_index+1).saturating_sub(self.buckets.len()); | |
self.buckets.reserve(missing_back); | |
while bucket_index >= self.buckets.len() { |
Or alternatively
let missing_back = (bucket_index+1).saturating_sub(self.buckets.len());
self.buckets.extend(iter::repeat_with(|| FxHashSet::default()).take(missing_back));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or even
if bucket_index >= self.buckets.len() {
self.buckets.resize_with(bucket_index + 1, FxHashSet::default);
}
src/optimiser/taso/hugr_hash_set.rs
Outdated
self.buckets[bucket_index].insert(hash) | ||
} | ||
|
||
// /// Returns whether the given hash is present in the set. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// /// Returns whether the given hash is present in the set. | |
/// Returns whether the given hash is present in the set. |
Co-authored-by: Agustín Borgna <[email protected]>
The idea of this PR is to store hashes of seen circuits in buckets given by the gate count of the circuit. That way hashes above a certain gate count can be cleared.
@aborgna-q I had to remove the call to
tracing::trace_span
. Where should I re-introduce it in the new code?EDIT: I had run tests to compare memory usage, but the variation between runs was too big for it to mean anything. Not sure how to track memory usage well enough.