-
Notifications
You must be signed in to change notification settings - Fork 989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support tensor parallel & Data loader #3173
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks great to me. We do still need to update this to work with accelerate config
however, whcih happens in commands/config
and commands/launch
. Would you like to do so?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@kmehant if you rebase from |
@muellerzr Appreciate your response. I would like to bring to your notice the below two points.
For point (1) I can keep this PR simple and allow only for the paradigm 1 and address the paradigm 2 in another PR. WDYT? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR, this looks nice. I have a few smaller comments, please take a look.
Also, please ensure that make quality
passes.
src/accelerate/accelerator.py
Outdated
@@ -1457,6 +1463,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |||
) | |||
if self.ddp_handler is not None: | |||
self.ddp_handler.register_comm_hook(model) | |||
elif self.distributed_type == DistributedType.TP: | |||
model.apply_tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply_tensor_parallel
will be implemented in huggingface/transformers#34194 but only for select model architectures, right? Should we check this and if not present, raise an appropriate error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tensor_parallel()
interface will be implemented here - https://github.com/huggingface/transformers/pull/34184/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaR5017
I have raised a comment on providing a way to know if tensor_parallel succeeded or not. Once that PR is ready, we can handle it here. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's see what the final result will be. But we could also check hasattr(model, "apply_tensor_parallel")
or would that not work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan
The function tensor_parallel
is being added to the parent class PretrainedModel
so all the model classes would have this function irrespective of it being available or not for a model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, in that case it is crucial to add a method or attribute to check the support for TP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has_tp_plan
property is added, so updated the code here to fail when the model has no support thank you.
) | ||
|
||
def __post_init__(self): | ||
from torch.distributed.device_mesh import init_device_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we perform a check on the minimum PyTorch and transformers versions? Not sure if here is the best place or somewhere else, Zach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure there, because ideally we'd have this API work with custom models and transformer ones. If we decide just transformers, yes we should guard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, good point. Still, torch could be checked, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added torch version check thanks
da67cba
to
c096d40
Compare
@muellerzr can I work on this #3173 (review) in a separate PR? I have fetched and rebased my PR and addressed all the review comments thank you. |
This feature is really useful, thank you @kmehant. I wonder if it is possible to combine tensor parallel with data parallel after this PR, say, TP for same-node parallelism and DP for multi-node parallelism. |
Hi @HoangCongDuc, support for that is in my TODOs but not covered in this PR, should be coming soon after discussing with HF. Thank you. |
src/accelerate/accelerator.py
Outdated
@@ -1461,6 +1467,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |||
) | |||
if self.ddp_handler is not None: | |||
self.ddp_handler.register_comm_hook(model) | |||
elif self.distributed_type == DistributedType.TP: | |||
if not model.has_tp_plan: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It appears that the attribute was renamed to supports_tp_plan
? Maybe let's wait until that other PR is merged so that this one does not need to be adapted constantly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenjaminBossan
Yes, it got modified. I have updated this PR again and also that PR to transformers is now merged :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Overall the code looks sound, what I'd appreciate however is if we could bring this the last 10% of the way through:
- Actually implementing this in the CLI and setting the env variable up properly
- Writing some tests (
src/accelerate/test_utils/scripts/test_tensor_parallel.py
IMO)
cf95d34
to
c80c030
Compare
Signed-off-by: Mehant Kammakomati <[email protected]>
Signed-off-by: Mehant Kammakomati <[email protected]>
e08c364
to
0fe41dd
Compare
8f1071b
to
d7ed517
Compare
Signed-off-by: Mehant Kammakomati <[email protected]>
3d8eeca
to
d5cc290
Compare
Signed-off-by: Mehant Kammakomati <[email protected]>
Let me know if I have missed out something. Thank you. |
What does this PR do?
TorchTensorParallelPlugin
to support TP with Pytorch 2.0. This work should be seen along with the PR feat: add support for tensor parallel using Pytorch transformers#34194.Please review in conjunction with huggingface/transformers#34194
Results
See significant improvement in both memory and throughput compared against single gpu training, and FSDP across different settings (checkpointing on/off) and context lengths.
Done on two models
Tables below show the max cuda memory and throughput for various configurations showing the potential of TP contributed in this PR. There is gains in both memory and throughput.
Note: Please be aware that the effective TPS for FSDP would be multiplicative of the parallel factor (number of GPUs/devices engaged in distributed training) whereas that is not the case with TP. Therefore, when effective throughput is considered we can find FSDP is better than TP in terms of throughput. However, that may be compensated by increasing the batch size utilizing the memory gains etc.
Fixes # (issue)
huggingface/transformers#32470
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
I have cycles to bring in more improvements over this PR to bring in Pytorch TP support to HF. Looking forward. Thank you