-
-
Notifications
You must be signed in to change notification settings - Fork 924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to parallel FFD bin packing algorithm (closes #1492) #1516
Conversation
e77a87a
to
739dd5f
Compare
I need to do some checking, but the estimates exist due to different processes getting different splits of data, so the actual count of packed samples can vary from process to process. When this happens, you get one process thinking it needs to run another step, but another process thinking it's done and they get out of sync. The estimate was the most sane way I could come up with having each process come up with a deterministic length. I'm open to other ideas to working around this. |
67f1504
to
8c233a0
Compare
Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync. |
Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though |
Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance? |
Gotcha, for now I'll keep this PR simple by leaving the packing estimates in. Ready for another look.
Yeah definitely, once the code is greenlit/finalized I'll rent an instance to test it in a distributed setup. |
|
||
return distributed_state.use_distributed and distributed_state.initialized | ||
global accelerate # pylint: disable=global-statement | ||
if not accelerate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @dsesclei, Sorry for the delay in getting back to this PR. Is there a particular reason Accelerator was added back rather than using PartialState. It's best to not explicitly load up Accelerator until the last possible moment, and I believe we can get everything about the distributed state from PartialState.
Thanks for getting this in Wing! No handles to give, but I appreciate it |
Thanks @dsesclei, I ended up having to revert the change b/c the loss was off by an order of magnitude. I need to dig into what the multipack sampler is outputting another time to see if there is something obvious that it is doing differently |
Oh gotcha, I'll look into it |
Description
Replace the existing sample packing algorithm with a parallel implementation of first-fit-decreasing.
Motivation and Context
I noticed recently that we could get denser sample packing with a different algorithm. Looking into it more, FFD performs just as well and is much faster than the heuristic I had 😅.
We can run FFD in parallel without losing much performance by packing samples in groups rather than all at once. On an i9-14900k, it takes 2.2s to pack 1M samples with 99.7% efficiency (current
multipack.py
is 91.7% in 0.32s.)I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in. Two new config options are added:
sample_packing_group_size
controls the the number of samples packed by each process, andsample_packing_bin_size
sets the number of samples that can be placed in one pack (may need to be increased for large context lengths.)How has this been tested?
Tests have been updated to verify that packing is correct. Training appears to run the same, just with fewer steps.
It seems reasonable that sorting the items in FFD would interfere with shuffling between epochs, but I haven't been able to find any evidence of that being the case. Testing against a few similarity metrics shows that even when we do the packing at once in one group, shuffling still generates a mostly new set of packs.
Screenshots
Some performance checks below for 1M items.