Skip to content

Commit

Permalink
fix: fail when there is no tensor parallel plan with model
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Nov 15, 2024
1 parent 913f8d1 commit 189e202
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not model.has_tp_plan:
raise NotImplementedError("Provided model does not support tensor parallelism")
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
Expand Down

0 comments on commit 189e202

Please sign in to comment.