Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to parallel FFD bin packing algorithm (closes #1492) #1516

Closed
wants to merge 3 commits into from

Conversation

dsesclei
Copy link
Contributor

@dsesclei dsesclei commented Apr 11, 2024

Description

Replace the existing sample packing algorithm with a parallel implementation of first-fit-decreasing.

Motivation and Context

I noticed recently that we could get denser sample packing with a different algorithm. Looking into it more, FFD performs just as well and is much faster than the heuristic I had 😅.

We can run FFD in parallel without losing much performance by packing samples in groups rather than all at once. On an i9-14900k, it takes 2.2s to pack 1M samples with 99.7% efficiency (current multipack.py is 91.7% in 0.32s.)

I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in. Two new config options are added: sample_packing_group_size controls the the number of samples packed by each process, and sample_packing_bin_size sets the number of samples that can be placed in one pack (may need to be increased for large context lengths.)

How has this been tested?

Tests have been updated to verify that packing is correct. Training appears to run the same, just with fewer steps.

It seems reasonable that sorting the items in FFD would interfere with shuffling between epochs, but I haven't been able to find any evidence of that being the case. Testing against a few similarity metrics shows that even when we do the packing at once in one group, shuffling still generates a mostly new set of packs.

Screenshots

Some performance checks below for 1M items.

group_size_vs_excess
bin_size_vs_excess

@dsesclei dsesclei force-pushed the ds-packing branch 2 times, most recently from e77a87a to 739dd5f Compare April 11, 2024 04:00
@winglian
Copy link
Collaborator

I removed the length estimates around packing in favor of just counting the batches, but let me know if I should add that back in.

I need to do some checking, but the estimates exist due to different processes getting different splits of data, so the actual count of packed samples can vary from process to process. When this happens, you get one process thinking it needs to run another step, but another process thinking it's done and they get out of sync. The estimate was the most sane way I could come up with having each process come up with a deterministic length. I'm open to other ideas to working around this.

@dsesclei dsesclei force-pushed the ds-packing branch 2 times, most recently from 67f1504 to 8c233a0 Compare April 11, 2024 18:38
@dsesclei
Copy link
Contributor Author

Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync.

@winglian
Copy link
Collaborator

Could we generate all the packs, and then evenly split those up (like in the updated multipack.py)? I think each rank should then get an exact number of batches and stay in sync.

Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though

@NanoCode012
Copy link
Collaborator

Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance?

@dsesclei
Copy link
Contributor Author

Perhaps we could do something like dispatch_batches=True to only run the packing on rank 0. I'm not 100% certain of the implications though

Gotcha, for now I'll keep this PR simple by leaving the packing estimates in. Ready for another look.

Hey, this is very interesting. Should there be some full run comparisons to make sure that there is no loss in performance?

Yeah definitely, once the code is greenlit/finalized I'll rent an instance to test it in a distributed setup.


return distributed_state.use_distributed and distributed_state.initialized
global accelerate # pylint: disable=global-statement
if not accelerate:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @dsesclei, Sorry for the delay in getting back to this PR. Is there a particular reason Accelerator was added back rather than using PartialState. It's best to not explicitly load up Accelerator until the last possible moment, and I believe we can get everything about the distributed state from PartialState.

@winglian
Copy link
Collaborator

Hey @dsesclei we cherry picked and merged your fixes in #1619. Thanks! Would love to give you a shoutout if you're on twitter or discord and could share your handle. thanks!

@winglian winglian closed this May 23, 2024
@dsesclei
Copy link
Contributor Author

Thanks for getting this in Wing! No handles to give, but I appreciate it

@dsesclei dsesclei deleted the ds-packing branch May 29, 2024 22:08
@winglian
Copy link
Collaborator

Thanks @dsesclei, I ended up having to revert the change b/c the loss was off by an order of magnitude. I need to dig into what the multipack sampler is outputting another time to see if there is something obvious that it is doing differently

@dsesclei
Copy link
Contributor Author

Oh gotcha, I'll look into it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants