Skip to content

Commit

Permalink
Merge branch 'main' into safetensors-default-1
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Nov 8, 2023
2 parents 7879980 + 76de60d commit 3a88675
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ If you use 🤗 Accelerate in your publication, please cite it by using the foll
```bibtex
@Misc{accelerate,
title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.},
author = {Sylvain Gugger, Lysandre Debut, Thomas Wolf, Philipp Schmid, Zachary Mueller, Sourab Mangrulkar, Marc Sun, Benjamin Bossan},
author = {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan},
howpublished = {\url{https://github.com/huggingface/accelerate}},
year = {2022}
}
Expand Down
6 changes: 6 additions & 0 deletions docs/source/package_reference/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ These are basic dataclasses used throughout 🤗 Accelerate and they can be pass

[[autodoc]] utils.ProjectConfiguration

## Environmental Variables

These are environmental variables that can be enabled for different use cases

* `ACCELERATE_DEBUG_MODE` (`str`): Whether to run accelerate in debug mode. More info available [here](../usage_guides/debug.md).

## Plugins

These are plugins that can be passed to the [`Accelerator`] object. While they are defined elsewhere in the documentation,
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_cluster_input():

use_mps = not use_cpu and is_mps_available()
deepspeed_config = {}
if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO] and not use_mps:
if distributed_type in [DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.NO] and not use_mps:
use_deepspeed = _ask_field(
"Do you want to use DeepSpeed? [yes/NO]: ",
_convert_yes_no_to_bool,
Expand Down
5 changes: 4 additions & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,15 @@ class SeedableRandomSampler(RandomSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.epoch = 0
self.seed = torch.random.initial_seed()

def __iter__(self):
if self.generator is None:
self.generator = torch.Generator()
else:
self.seed = self.generator.initial_seed()
# Allow `self.epoch` to modify the seed of the generator
seed = self.epoch + self.generator.initial_seed()
seed = self.epoch + self.seed
self.generator.manual_seed(seed)
yield from super().__iter__()
self.set_epoch(self.epoch + 1)
Expand Down
6 changes: 6 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def __init__(self, cpu: bool = False, **kwargs):
if is_xpu_available and is_ccl_available():
# Set DeepSpeed backend to ccl for xpu
self.backend = "ccl"
elif is_npu_available():
self.backend = "hccl"
else:
self.backend = "nccl"
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
Expand All @@ -187,6 +189,10 @@ def __init__(self, cpu: bool = False, **kwargs):
self.device = torch.device("xpu", self.local_process_index)
if self.device is not None:
torch.xpu.set_device(self.device)
elif is_npu_available():
self.device = torch.device("npu", self.local_process_index)
if self.device is not None:
torch.npu.set_device(self.device)
else:
self.device = torch.device("cuda", self.local_process_index)
if self.device is not None:
Expand Down
13 changes: 11 additions & 2 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_STATE_DICT_TYPE
from .environment import str_to_bool
from .imports import is_xpu_available
from .imports import is_cuda_available, is_npu_available, is_xpu_available
from .versions import compare_versions


Expand Down Expand Up @@ -932,7 +932,16 @@ def __post_init__(self):
self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1

if self.sync_module_states:
device = torch.cuda.current_device() if not is_xpu_available() else torch.xpu.current_device()
if is_npu_available():
device = torch.npu.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
device = torch.xpu.current_device()
else:
raise RuntimeError(
"There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'."
)
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict
current_env["ACCELERATE_DEBUG_MODE"] = "true"
gpu_ids = getattr(args, "gpu_ids", "all")
if gpu_ids != "all" and args.gpu_ids is not None:
if not is_xpu_available():
current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
else:
if is_xpu_available():
current_env["ZE_AFFINITY_MASK"] = gpu_ids
elif is_npu_available():
current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids
else:
current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids
try:
mixed_precision = PrecisionType(args.mixed_precision.lower())
except ValueError:
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def find_batch_size(data):
Returns:
`int`: The batch size.
"""
if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0):
raise ValueError(f"Cannot find the batch size from empty {type(data)}.")

if isinstance(data, (tuple, list)):
return find_batch_size(data[0])
elif isinstance(data, Mapping):
Expand Down
4 changes: 2 additions & 2 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..state import PartialState
from .constants import FSDP_PYTORCH_VERSION
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_tpu_available
from .imports import is_deepspeed_available, is_torch_distributed_available, is_tpu_availabl
from .transformer_engine import convert_model
from .versions import is_torch_version

Expand Down Expand Up @@ -75,7 +75,7 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True):

options += (DeepSpeedEngine,)

if is_torch_version(">=", FSDP_PYTORCH_VERSION):
if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

options += (FSDP,)
Expand Down

0 comments on commit 3a88675

Please sign in to comment.