From 1a6af0bd6dc125db287bbe7cf8577a45ebe252ec Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 20 Aug 2024 10:40:21 -0400 Subject: [PATCH] Improve config handling and add a zoo (#3029) * Improve config handling and add a zoo * Docs * rm comment * Tweak doc --- README.md | 4 +- docs/source/quicktour.md | 2 + examples/README.md | 32 ++++++++------- examples/config_yaml_templates/README.md | 10 +++++ examples/config_yaml_templates/deepspeed.yaml | 15 +++++++ examples/config_yaml_templates/fp8.yaml | 18 +++++++++ examples/config_yaml_templates/fsdp.yaml | 18 +++++++++ examples/config_yaml_templates/multi_gpu.yaml | 6 +++ .../config_yaml_templates/multi_node.yaml | 16 ++++++++ examples/config_yaml_templates/run_me.py | 26 ++++++++++++ .../config_yaml_templates/single_gpu.yaml | 4 ++ src/accelerate/commands/config/config_args.py | 40 +++++++++---------- src/accelerate/utils/dataclasses.py | 5 ++- 13 files changed, 157 insertions(+), 39 deletions(-) create mode 100644 examples/config_yaml_templates/README.md create mode 100644 examples/config_yaml_templates/deepspeed.yaml create mode 100644 examples/config_yaml_templates/fp8.yaml create mode 100644 examples/config_yaml_templates/fsdp.yaml create mode 100644 examples/config_yaml_templates/multi_gpu.yaml create mode 100644 examples/config_yaml_templates/multi_node.yaml create mode 100644 examples/config_yaml_templates/run_me.py create mode 100644 examples/config_yaml_templates/single_gpu.yaml diff --git a/README.md b/README.md index 3f443983383..fc64f02e4fe 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,8 @@ accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli). +Or view the configuration zoo [here](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates/) + ## Launching multi-CPU run using MPI 🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well. @@ -256,7 +258,7 @@ pip install accelerate - multi-GPU on several nodes (machines) - TPU - FP16/BFloat16 mixed precision -- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) +- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) or [MS-AMP](https://github.com/Azure/MS-AMP/) - DeepSpeed support (Experimental) - PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental) - Megatron-LM support (Experimental) diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 38929ec8031..fd571c861b3 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -53,6 +53,8 @@ accelerate launch path_to_script.py --args_for_the_script To learn more, check out the [Launch distributed code](basic_tutorials/launch) tutorial for more information about launching your scripts. +We also have a [configuration zoo](https://github.com/huggingface/accelerate/blob/main/examples/config_yaml_templates) which showcases a number of premade **minimal** example configurations for a variety of setups you can run. + ## Adapt training code The next main feature of Accelerate is the [`Accelerator`] class which adapts your PyTorch code to run on different distributed setups. diff --git a/examples/README.md b/examples/README.md index 5060c9ad733..27938cd7f97 100644 --- a/examples/README.md +++ b/examples/README.md @@ -208,23 +208,13 @@ To run it in each of these various modes, use the following commands: - [huggan project](https://github.com/huggingface/community-events/tree/main/huggan) + ### Using AWS SageMaker integration - [Examples showcasing AWS SageMaker integration of 🤗 Accelerate.](https://github.com/pacman100/accelerate-aws-sagemaker) - -## Simple Multi-GPU Hardware Launcher - -[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate -on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can -easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then -run the script to automatically launch multi GPU training on remote hardware. - -This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own -cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed -with `pip install runhouse`, and you can refer to -[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup) -for hardware setup instructions, or this -[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough. +## Configuration zoo +In [/config_yaml_templates](./config_yaml_templates/) we have a variety of *minimal* `config.yaml` templates and examples to help you learn +how to create your own configuration files depending on the scenario. ## SLURM Scripts In [/slurm/submit_multigpu.sh](./slurm/submit_multigpu.sh) and [/slurm/submit_multinode.sh](./slurm/submit_multinode.sh) we present two scripts for running the examples on a machine with [SLURM](https://slurm.schedmd.com/documentation.html) workload manager. @@ -251,6 +241,20 @@ export PYTHONPATH=/home/nct01/nct01328/transformers-in-supercomputers:$PYTHONPAT export GPUS_PER_NODE=4 ``` +## Simple Multi-GPU Hardware Launcher (using an external platform) + +[multigpu_remote_launcher.py](./multigpu_remote_launcher.py) is a minimal script that demonstrates launching accelerate +on multiple remote GPUs, and with automatic hardware environment and dependency setup for reproducibility. You can +easily customize the training function used, training arguments, hyperparameters, and type of compute hardware, and then +run the script to automatically launch multi GPU training on remote hardware. + +This script uses [Runhouse](https://github.com/run-house/runhouse) to launch on self-hosted hardware (e.g. in your own +cloud account or on-premise cluster) but there are other options for running remotely as well. Runhouse can be installed +with `pip install runhouse`, and you can refer to +[hardware setup](https://runhouse-docs.readthedocs-hosted.com/en/latest/api/python/cluster.html#hardware-setup) +for hardware setup instructions, or this +[Colab tutorial](https://colab.research.google.com/drive/1qVwYyLTCPYPSdz9ZX7BZl9Qm0A3j7RJe) for a more in-depth walkthrough. + ## Finer Examples While the first two scripts are extremely barebones when it comes to what you can do with accelerate, more advanced features are documented in two other locations. diff --git a/examples/config_yaml_templates/README.md b/examples/config_yaml_templates/README.md new file mode 100644 index 00000000000..4cfd56f2ac5 --- /dev/null +++ b/examples/config_yaml_templates/README.md @@ -0,0 +1,10 @@ +# Config Zoo + +This folder contains a variety of minimal configurations for `Accelerate` achieving certain goals. You can use these +direct config YAML's, or build off of them for your own YAML's. + +These are highly annoted versions, aiming to teach you what each section does. + +Each config can be run via `accelerate launch --config_file {file} run_me.py` + +`run_me.py` will then print out how the current environment is setup (the contents of the `AcceleratorState`) \ No newline at end of file diff --git a/examples/config_yaml_templates/deepspeed.yaml b/examples/config_yaml_templates/deepspeed.yaml new file mode 100644 index 00000000000..5efddd05245 --- /dev/null +++ b/examples/config_yaml_templates/deepspeed.yaml @@ -0,0 +1,15 @@ +# Similar to FSDP, we set the distributed type as DEEPSPEED +distributed_type: DEEPSPEED +# With DeepSpeed, we utilize a deepspeed config file for the entire configuration +deepspeed_config: + # Can also be any of the config json's in accelerate/examples/deepspeed_config_templates + deepspeed_config_file: ../deepspeed_config_templates/zero_stage1_config.json + # If using ZeRO-3 and wanting to load big models in, this should be set to `true` so + # `transformers` uses the right `init` function + zero3_init_flag: false # true + +# Finally we need to specify the number of GPUs to use +num_processes: 2 +# Optionally we can set the mixed precision now instead of in the deepspeed config file, +# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file +# mixed_precision: "bf16" diff --git a/examples/config_yaml_templates/fp8.yaml b/examples/config_yaml_templates/fp8.yaml new file mode 100644 index 00000000000..4e81ac8e9fb --- /dev/null +++ b/examples/config_yaml_templates/fp8.yaml @@ -0,0 +1,18 @@ +# This config template simply setups up the TransformersEngine config (and a config for a single GPU), +# this can interop with the other configs in this folder +distributed_type: "NO" +mixed_precision: "fp8" +# Then we specify the fp8 configuration: +fp8_config: + backend: TE # Can be TE | MS-AMP + # The following are TE specific arguments. + # See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details + amax_history_length: 1024 + fp8_format: E4M3 + interval: 1 + margin: 0 + override_linear_precision: false + # Generally this should always be set to `false` to have the most realistic fp8 eval performance + use_autocast_during_eval: false + # If using MS-AMP, we ignore all of the prior and set a opt_level + #opt_level: O1 \ No newline at end of file diff --git a/examples/config_yaml_templates/fsdp.yaml b/examples/config_yaml_templates/fsdp.yaml new file mode 100644 index 00000000000..07c3e1b83e0 --- /dev/null +++ b/examples/config_yaml_templates/fsdp.yaml @@ -0,0 +1,18 @@ +# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP +distributed_type: FSDP +# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well) +mixed_precision: 'bf16' +# Specify the number of GPUs to use +num_processes: 2 +# Then we can specify the FSDP config +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true diff --git a/examples/config_yaml_templates/multi_gpu.yaml b/examples/config_yaml_templates/multi_gpu.yaml new file mode 100644 index 00000000000..9d6c12bf667 --- /dev/null +++ b/examples/config_yaml_templates/multi_gpu.yaml @@ -0,0 +1,6 @@ +# Specify distributed_type as `MULTI_GPU` for DDP +distributed_type: "MULTI_GPU" +# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`) +mixed_precision: "bf16" +# Specify the number of GPUs to use +num_processes: 2 \ No newline at end of file diff --git a/examples/config_yaml_templates/multi_node.yaml b/examples/config_yaml_templates/multi_node.yaml new file mode 100644 index 00000000000..b76699f849c --- /dev/null +++ b/examples/config_yaml_templates/multi_node.yaml @@ -0,0 +1,16 @@ +# This config template is for a multi-node setup. This assumes DDP, but can be interop'd with the other configs in this folder +# Generally it's recommended to look at the SLURM config template for a more robust multi-node setup +distributed_type: MULTI_GPU +# We need to specify the current machine's rank +machine_rank: 0 +# We then need to specify the IP address and port of the main process +main_process_ip: '1234' +main_process_port: 9999 +# We need to specify the number of machines +num_machines: 2 +# We need to specify the *total* number of processes +num_processes: 8 +# And then we need to specify how rdvz comms will be handled +rdzv_backend: static # or c10d +# If the compute nodes are on the same network (cloud will more than likely be false) +same_network: false diff --git a/examples/config_yaml_templates/run_me.py b/examples/config_yaml_templates/run_me.py new file mode 100644 index 00000000000..70bed48cc0f --- /dev/null +++ b/examples/config_yaml_templates/run_me.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A base script which outputs the accelerate config for the given environment +""" +from accelerate import Accelerator + + +accelerator = Accelerator() + +accelerator.print(f"Accelerator state from the current environment:\n{accelerator.state}") +if accelerator.fp8_recipe_handler is not None: + accelerator.print(f"FP8 config:\n{accelerator.fp8_recipe_handler}") +accelerator.end_training() diff --git a/examples/config_yaml_templates/single_gpu.yaml b/examples/config_yaml_templates/single_gpu.yaml new file mode 100644 index 00000000000..3d1a81cedaf --- /dev/null +++ b/examples/config_yaml_templates/single_gpu.yaml @@ -0,0 +1,4 @@ +# Since this is single GPU, we don't need distributed training +distributed_type: "NO" +# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`) +mixed_precision: "bf16" \ No newline at end of file diff --git a/src/accelerate/commands/config/config_args.py b/src/accelerate/commands/config/config_args.py index f6ee4deb534..3039e45054b 100644 --- a/src/accelerate/commands/config/config_args.py +++ b/src/accelerate/commands/config/config_args.py @@ -99,13 +99,17 @@ def _convert_enums(value): result = {k: v for k, v in result.items() if v is not None} return result - @classmethod - def from_json_file(cls, json_file=None): - json_file = default_json_config_file if json_file is None else json_file - with open(json_file, encoding="utf-8") as f: - config_dict = json.load(f) + @staticmethod + def process_config(config_dict): + """ + Processes `config_dict` and sets default values for any missing keys + """ if "compute_environment" not in config_dict: config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE + if "distributed_type" not in config_dict: + raise ValueError("A `distributed_type` must be specified in the config file.") + if "num_processes" not in config_dict and config_dict["distributed_type"] == DistributedType.NO: + config_dict["num_processes"] = 1 if "mixed_precision" not in config_dict: config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None if "fp16" in config_dict: # Convert the config to the new format. @@ -119,6 +123,14 @@ def from_json_file(cls, json_file=None): config_dict["debug"] = False if "enable_cpu_affinity" not in config_dict: config_dict["enable_cpu_affinity"] = False + return config_dict + + @classmethod + def from_json_file(cls, json_file=None): + json_file = default_json_config_file if json_file is None else json_file + with open(json_file, encoding="utf-8") as f: + config_dict = json.load(f) + config_dict = cls.process_config(config_dict) extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) if len(extra_keys) > 0: raise ValueError( @@ -138,23 +150,7 @@ def from_yaml_file(cls, yaml_file=None): yaml_file = default_yaml_config_file if yaml_file is None else yaml_file with open(yaml_file, encoding="utf-8") as f: config_dict = yaml.safe_load(f) - if "compute_environment" not in config_dict: - config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE - if "mixed_precision" not in config_dict: - config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None - if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]: - config_dict["mixed_precision"] = "no" - if "fp16" in config_dict: # Convert the config to the new format. - del config_dict["fp16"] - if "dynamo_backend" in config_dict: # Convert the config to the new format. - dynamo_backend = config_dict.pop("dynamo_backend") - config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} - if "use_cpu" not in config_dict: - config_dict["use_cpu"] = False - if "debug" not in config_dict: - config_dict["debug"] = False - if "enable_cpu_affinity" not in config_dict: - config_dict["enable_cpu_affinity"] = False + config_dict = cls.process_config(config_dict) extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) if len(extra_keys) > 0: raise ValueError( diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 1151cd73fda..0f35a294736 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1338,10 +1338,11 @@ class FullyShardedDataParallelPlugin: }, ) sync_module_states: bool = field( - default=False, + default=None, metadata={ "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " - "to ensure they are the same across all ranks after initialization. Defaults to `True`" + "to ensure they are the same across all ranks after initialization. Defaults to `False` unless " + "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled." }, ) forward_prefetch: bool = field(