From 9e86e7f74f01255379cab8248d065d5a44d14b83 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 15 Nov 2024 12:09:40 -0500 Subject: [PATCH] fix: fail when there is no tensor parallel plan with model Signed-off-by: Mehant Kammakomati --- src/accelerate/accelerator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 7d9953151b0..e3cea98a7f4 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1468,6 +1468,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e if self.ddp_handler is not None: self.ddp_handler.register_comm_hook(model) elif self.distributed_type == DistributedType.TP: + if not model.supports_tp_plan: + raise NotImplementedError("Provided model does not support tensor parallelism") model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model