diff --git a/src/accelerate/local_sgd.py b/src/accelerate/local_sgd.py index 382c34bc91c..47b8ed7c450 100644 --- a/src/accelerate/local_sgd.py +++ b/src/accelerate/local_sgd.py @@ -69,6 +69,7 @@ def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_s DistributedType.NO, DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, + DistributedType.MULTI_XPU, DistributedType.MULTI_MLU, DistributedType.MULTI_MUSA, DistributedType.MULTI_NPU,