-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor spanner to avoid creating large array #773
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this have tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok i think this makes sense, but can you add the following:
- tests comparing the
self.spans
object with old method vs new method - quick local perf benchmark for old vs. new spanner object creation -- want to make sure this doesn't increase init time excessively
Hi @XiaohanZhangCMU incredible job! I just tested and can confirm that this solves a problem with large number of shards. However, there may be another bottleneck with The following code takes ~5 minutes and CPU memory utilisation climbs to ~80GB before the first batch can be retrieved. As the dataloader is being used it consumes ~50GB of memory. If I end up using 8 GPUs there may be not enough CPU memory in H100s to be able to use it. dataset = StreamingDataset(
remote="/mnt/disks/raw/train",
local="./local",
batch_size=1024,
cache_limit="10GB",
)
dataloader = StreamingDataLoader(dataset, batch_size=1024)
for batch in dataloader:
print(batch)
break Any ideas what could be going wrong? |
@AugustDev yeah, I wouldn't be surprised with those profiling numbers. There are places that some large arrays of int32 are created, which may add up..... |
I see... So MosaicML currently is not suitable for datasets with large number of shards 😢. We're using litdata, but been thinking to migrate to MosaicML as has great features. Do you have any suggestions what could work for us? Currently we generate dataset using default |
Hey @AugustDev, in your case, you have a whole lot of samples and storing the sample partition array is taking up a good amount of space. However, according to your issue, you have ~27B samples, right? So the sample partition array, even in int64, will be ~216GB, which should be possible. The array is created but then saved to a shared memory file which all workers can access, so this should be a temporary memory cost. Regardless, most H100 systems have nodes with much more CPU ram than 216 GB so this should still be feasible... Reducing the number of shards can help but will probably not make a significant difference for memory savings. It's more about the sheer number of samples you have. |
@AugustDev I would still recommend give it a try using your current set up for 8 gpus (with or without this PR). And let us know if you hit any memory issue. To clarify, the number of shards is not the issue, having smaller shards is actually preferred when streaming from cloud. In your case, the partition array (a shared array of int) was causing the large mem allocation. This PR is to reduce the mem cost in initialization but the partition array alloc is not avoidable. But as @snarayan21 mentioned, the partition array is shared across processes so I think you should be good to scale up to 8 gpus. |
Thank you for the message. Will try on H100 - 8GPUs and report back the findings. |
Hi guys, for me to load the dataset very large dataset (2.4B rows) even with the spanner fix takes 2min 28sec. When it comes to loading first batch for batch in dataloader:
print(batch)
break it's been 15min (and still running). It seems I won't be able to train on this dataset using Mosaic. Any plans to support Mosaic Streaming on large datasets? I'm not sure where the problem with dataloading is, but happy to help fixing if you can guide where the performance issue might come from? Perhaps it's time to merge this? |
Hey @AugustDev, we've been able to train on datasets that have that many (or more) samples -- this is likely an issue particular to your dataset. Are you trying to retrieve a batch locally on your laptop or from the GPUs themselves? Have you tried model training while dataloading instead of dataloading alone? For performance tuning, you can use the streaming simulator to input your dataset characteristics and understand what performance you can expect. Happy to help further. |
@XiaohanZhangCMU Mind adding the tests mentioned above and we can get this one in? |
Description of changes:
Refactor spanner init to avoid creating a large array which may lead to OOM.
Issue #, if available:
OOM issue
Merge Checklist:
Put an
x
without space in the boxes that apply. If you are unsure about any checklist, please don't hesitate to ask. We are here to help! This is simply a reminder of what we are going to look for before merging your pull request.General
Tests
pre-commit
on my change. (check out thepre-commit
section of prerequisites)