From 7739c0aca47fa8a20682ddb1208c137bffbba2cc Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 17 Jan 2024 00:32:35 +0800 Subject: [PATCH 1/8] [docs] Add new autotp supported model in tutorial (#4960) This PR refresh the list of models supported by AutoTP. Newly added models are: - baichuan - codellama - falcon - llama2 - mistral - qwen - starcode --- docs/_tutorials/automatic-tensor-parallelism.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/_tutorials/automatic-tensor-parallelism.md b/docs/_tutorials/automatic-tensor-parallelism.md index 5d182b2a4532..aea98ad9b22a 100644 --- a/docs/_tutorials/automatic-tensor-parallelism.md +++ b/docs/_tutorials/automatic-tensor-parallelism.md @@ -121,15 +121,18 @@ The following results were collected using V100 SXM2 32GB GPUs. The following model families have been successfully tested with automatic tensor parallelism. Other models may work but have not been tested yet. - albert +- baichuan - bert - bigbird_pegasus - bloom - camembert - codegen +- codellama - deberta_v2 - electra - ernie - esm +- falcon - glm - gpt-j - gpt-neo @@ -137,8 +140,10 @@ The following model families have been successfully tested with automatic tensor - longt5 - luke - llama +- llama2 - m2m_100 - marian +- mistral - mpt - mvp - nezha @@ -147,10 +152,12 @@ The following model families have been successfully tested with automatic tensor - pegasus - perceiver - plbart +- qwen - reformer - roberta - roformer - splinter +- starcode - t5 - xglm - xlm_roberta From e278076495078cb0450f911dd86d32fbc4f726a2 Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Wed, 17 Jan 2024 20:51:12 +0200 Subject: [PATCH 2/8] Add missing op_builder.hpu component for HPU accelerator (#4963) --- op_builder/hpu/__init__.py | 10 ++++++++++ op_builder/hpu/builder.py | 35 +++++++++++++++++++++++++++++++++++ op_builder/hpu/cpu_adam.py | 33 +++++++++++++++++++++++++++++++++ op_builder/hpu/fused_adam.py | 29 +++++++++++++++++++++++++++++ op_builder/hpu/no_impl.py | 24 ++++++++++++++++++++++++ 5 files changed, 131 insertions(+) create mode 100644 op_builder/hpu/__init__.py create mode 100644 op_builder/hpu/builder.py create mode 100644 op_builder/hpu/cpu_adam.py create mode 100644 op_builder/hpu/fused_adam.py create mode 100644 op_builder/hpu/no_impl.py diff --git a/op_builder/hpu/__init__.py b/op_builder/hpu/__init__.py new file mode 100644 index 000000000000..6527ace087b5 --- /dev/null +++ b/op_builder/hpu/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +'''Copyright The Microsoft DeepSpeed Team''' + +from .cpu_adam import CPUAdamBuilder +from .fused_adam import FusedAdamBuilder +from .no_impl import NotImplementedBuilder diff --git a/op_builder/hpu/builder.py b/op_builder/hpu/builder.py new file mode 100644 index 000000000000..5a538c84040c --- /dev/null +++ b/op_builder/hpu/builder.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class CPUOpBuilder(OpBuilder): + + def builder(self): + from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + + compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} + + cpp_ext = ExtensionBuilder(name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + libraries=self.strip_empty_entries(self.libraries_args()), + extra_compile_args=compile_args) + + return cpp_ext + + def cxx_args(self): + return ['-O3', '-g', '-Wno-reorder'] + + def libraries_args(self): + return [] diff --git a/op_builder/hpu/cpu_adam.py b/op_builder/hpu/cpu_adam.py new file mode 100644 index 000000000000..2f3b7aefe705 --- /dev/null +++ b/op_builder/hpu/cpu_adam.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class CPUAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/adam/cpu_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def cxx_args(self): + args = super().cxx_args() + args += ['-DENABLE_BFLOAT16'] + return args + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py new file mode 100644 index 000000000000..d77228317ddb --- /dev/null +++ b/op_builder/hpu/fused_adam.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class FusedAdamBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] + + def cxx_args(self): + args = super().cxx_args() + args += ['-DENABLE_BFLOAT16'] + return args + + def include_paths(self): + return ['csrc/includes'] diff --git a/op_builder/hpu/no_impl.py b/op_builder/hpu/no_impl.py new file mode 100644 index 000000000000..140d65b48def --- /dev/null +++ b/op_builder/hpu/no_impl.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import CPUOpBuilder + + +class NotImplementedBuilder(CPUOpBuilder): + BUILD_VAR = "DS_BUILD_NOT_IMPLEMENTED" + NAME = "deepspeed_not_implemented" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.comm.{self.NAME}_op' + + def load(self, verbose=True): + raise ValueError("This op had not been implemented on HPU backend.") + + def sources(self): + return [] From 69a459887f461eddb0e2e62e92edb87836e7cf1e Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Thu, 18 Jan 2024 00:28:35 +0200 Subject: [PATCH 3/8] Stage_1_and_2.py: fix assert for reduce_scatter configurations combinations (#4964) today it does not allow to set reduce_scatter with: - invalid comm data type - gradient_predivide_factor != 1.0 - postscale_gradients is False but reduce scatter is relevant only for Zero2, also reduce_scatter is set to True by default. So setting prescale_gradients=True in json will end up with assert on Zero1. Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 703e62d13bef..3aa7e6f5e0a2 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -232,7 +232,7 @@ def __init__(self, f"Currently only supported using ZeRO-Offload with DeepSpeedCPUAdam. But current setting is ZeRO-Offload:{self.cpu_offload} and optimizer type {type(self.optimizer)}." \ f"Either disable fp16_master_weights_and_gradients or enable {self.zero_stage_string} Offload with DeepSpeedCPUAdam." - if self.reduce_scatter: + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" From 752a50acddaf122f86e9b3f6f60d7d233a76402e Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Thu, 18 Jan 2024 06:29:45 +0800 Subject: [PATCH 4/8] [MiCS]Add the path to support sequence_data_parallel on MiCS (#4926) This pr is to match the latest update on sequence_data_parallel for MiCS, sequence_data_parallel_group is added in the latest Megatron-DeepSpeed init method https://github.com/microsoft/Megatron-DeepSpeed/blob/main/pretrain_gpt.py#L39, if we want to enable zero3+MiCS on GPT training it will be unsupported, add the path to support the sequence_data_parallel can fix this issue. Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/mics.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index 6c7003085685..40d0ea977e43 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -64,6 +64,7 @@ class MiCS_Init(Init): def __init__(self, module=None, data_parallel_group=None, + sequence_data_parallel_group=None, mem_efficient_linear=True, remote_device=None, pin_memory=False, @@ -145,9 +146,21 @@ def __init__(self, if not dist.is_initialized(): dist.init_distributed() assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" + + if data_parallel_group is None and sequence_data_parallel_group is None: + ds_process_group = dist.get_world_group() + elif sequence_data_parallel_group is not None: + ds_process_group = sequence_data_parallel_group + elif data_parallel_group is not None: + ds_process_group = data_parallel_group + else: # both given + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + self.mics_comm_groups = create_mics_comm_groups( _ds_config.mics_shard_size, - data_parallel_group, + ds_process_group, hierarchical_allgather=_ds_config.mics_hierarchial_params_gather, mpu=mpu) From 1b34a4d3053db1d3e3fac0481aed796fa04b9c76 Mon Sep 17 00:00:00 2001 From: Arash Bakhtiari Date: Wed, 17 Jan 2024 15:20:26 -0800 Subject: [PATCH 5/8] Update the DeepSpeed Phi-2 impl. to work with the HF latest changes (#4950) The latest changes in Huggingface Phi-2 implementation (https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869#d2h-025836) have broken the DeepSpeed implementation. This PR address the related issues. --------- Co-authored-by: Michael Wyatt --- deepspeed/inference/v2/engine_factory.py | 2 +- .../model_implementations/phi/containers.py | 75 +++++++++---------- .../v2/model_implementations/phi/model.py | 14 ++-- .../v2/model_implementations/phi/policy.py | 2 +- 4 files changed, 43 insertions(+), 50 deletions(-) diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 41cb47729237..9281640f844a 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -116,7 +116,7 @@ def build_hf_engine(path: str, policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "falcon": policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) - elif model_config.model_type == "phi-msft": + elif model_config.model_type == "phi": policy = PhiPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "qwen": policy = QwenPolicy(model_config, checkpoint_engine=checkpoint_engine) diff --git a/deepspeed/inference/v2/model_implementations/phi/containers.py b/deepspeed/inference/v2/model_implementations/phi/containers.py index ab6d0181611c..21f07eb8c99a 100644 --- a/deepspeed/inference/v2/model_implementations/phi/containers.py +++ b/deepspeed/inference/v2/model_implementations/phi/containers.py @@ -11,41 +11,30 @@ # HF Phi-2 model looks like this: PhiForCausalLM( - (transformer): PhiModel( - (embd): Embedding( - (wte): Embedding(51200, 2560) - (drop): Dropout(p=0.0, inplace=False) - ) - (h): ModuleList( - (0-31): 32 x ParallelBlock( - (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) - (resid_dropout): Dropout(p=0.1, inplace=False) - (mixer): MHA( - (rotary_emb): RotaryEmbedding() - (Wqkv): Linear(in_features=2560, out_features=7680, bias=True) - (out_proj): Linear(in_features=2560, out_features=2560, bias=True) - (inner_attn): SelfAttention( - (drop): Dropout(p=0.0, inplace=False) - ) - (inner_cross_attn): CrossAttention( - (drop): Dropout(p=0.0, inplace=False) - ) + (model): PhiModel( + (embed_tokens): Embedding(51200, 2560) + (embed_dropout): Dropout(p=0.0, inplace=False) + (layers): ModuleList( + (0-31): 32 x PhiDecoderLayer( + (self_attn): PhiAttention( + (q_proj): Linear(in_features=2560, out_features=2560, bias=True) + (k_proj): Linear(in_features=2560, out_features=2560, bias=True) + (v_proj): Linear(in_features=2560, out_features=2560, bias=True) + (dense): Linear(in_features=2560, out_features=2560, bias=True) + (rotary_emb): PhiRotaryEmbedding() ) - (mlp): MLP( + (mlp): PhiMLP( + (activation_fn): NewGELUActivation() (fc1): Linear(in_features=2560, out_features=10240, bias=True) (fc2): Linear(in_features=10240, out_features=2560, bias=True) - (act): NewGELUActivation() ) + (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) + (resid_dropout): Dropout(p=0.1, inplace=False) ) ) + (final_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) ) - (lm_head): CausalLMHead( - (ln): LayerNorm((2560,), eps=1e-05, elementwise_affine=True) - (linear): Linear(in_features=2560, out_features=51200, bias=True) - ) - (loss): CausalLMLoss( - (loss_fct): CrossEntropyLoss() - ) + (lm_head): Linear(in_features=2560, out_features=51200, bias=True) ) ''' @@ -54,8 +43,8 @@ class PhiTransformerContainer(LayerContainer): """ Transformer layer container for the Phi model. """ - qkv_w: FusedQKVParameter - qkv_b: FusedQKVParameter + qkv_w: UnfusedQKVParameter + qkv_b: UnfusedQKVParameter attn_out_w: AttentionOutputParameter attn_out_b: AttentionOutputParameter mlp_1_w: MLP1Parameter @@ -66,16 +55,20 @@ class PhiTransformerContainer(LayerContainer): ln_beta: NormParameter PARAM_MAPPING = { - "mixer.Wqkv.weight": "qkv_w.params", - "mixer.Wqkv.bias": "qkv_b.params", - "mixer.out_proj.weight": "attn_out_w.params", - "mixer.out_proj.bias": "attn_out_b.params", + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.q_proj.bias": "qkv_b.q_params", + "self_attn.k_proj.bias": "qkv_b.k_params", + "self_attn.v_proj.bias": "qkv_b.v_params", + "self_attn.dense.weight": "attn_out_w.params", + "self_attn.dense.bias": "attn_out_b.params", "mlp.fc1.weight": "mlp_1_w.params", "mlp.fc1.bias": "mlp_1_b.params", "mlp.fc2.weight": "mlp_2_w.params", "mlp.fc2.bias": "mlp_2_b.params", - "ln.weight": "ln_gamma.params", - "ln.bias": "ln_beta.params", + "input_layernorm.weight": "ln_gamma.params", + "input_layernorm.bias": "ln_beta.params", } @@ -90,9 +83,9 @@ class PhiNonTransformerContainer(LayerContainer): final_norm_beta: NormParameter PARAM_MAPPING = { - "transformer.embd.wte.weight": "word_emb.params", - "lm_head.ln.weight": "final_norm_gamma.params", - "lm_head.ln.bias": "final_norm_beta.params", - "lm_head.linear.weight": "word_unembed_w.params", - "lm_head.linear.bias": "word_unembed_b.params", + "model.embed_tokens.weight": "word_emb.params", + "model.final_layernorm.weight": "final_norm_gamma.params", + "model.final_layernorm.bias": "final_norm_beta.params", + "lm_head.weight": "word_unembed_w.params", + "lm_head.bias": "word_unembed_b.params", } diff --git a/deepspeed/inference/v2/model_implementations/phi/model.py b/deepspeed/inference/v2/model_implementations/phi/model.py index 0127c87c7bff..2d5826810cb5 100644 --- a/deepspeed/inference/v2/model_implementations/phi/model.py +++ b/deepspeed/inference/v2/model_implementations/phi/model.py @@ -47,11 +47,11 @@ def max_sequence_length(self) -> int: @property def num_layers(self) -> int: - return self._config.n_layer + return self._config.num_hidden_layers @property def model_dim(self) -> int: - return self._config.n_embd + return self._config.hidden_size @property def vocab_size(self) -> int: @@ -63,16 +63,15 @@ def head_size(self) -> int: @property def n_heads(self) -> int: - return self._config.n_head + return self._config.num_attention_heads @property def intermediate_dim(self) -> int: - n_inner = getattr(self._config, "n_inner", None) - return n_inner if n_inner is not None else 4 * self.model_dim + return self._config.intermediate_size @property def n_heads_kv(self) -> int: - return getattr(self._config, "n_head_kv", None) or self.n_heads + return self._config.num_key_value_heads @property def activation_dtype(self) -> DtypeEnum: @@ -97,7 +96,8 @@ def positional_embedding_type(self) -> PositionalEmbeddingType: @property def positional_embedding_config(self) -> Optional[RotateHalfConfig]: - return RotateHalfConfig(rotate_dim=self._config.rotary_dim) + rotary_dim = int(self._config.partial_rotary_factor * self.head_size) + return RotateHalfConfig(rotate_dim=rotary_dim, theta_base=self._config.rope_theta) """ Forward implementations diff --git a/deepspeed/inference/v2/model_implementations/phi/policy.py b/deepspeed/inference/v2/model_implementations/phi/policy.py index 1e9db24022f5..4b081a8e61bd 100644 --- a/deepspeed/inference/v2/model_implementations/phi/policy.py +++ b/deepspeed/inference/v2/model_implementations/phi/policy.py @@ -22,7 +22,7 @@ def build_container_map(self) -> ContainerMap: trans_container_cls = PhiTransformerContainer transformer_containers = [trans_container_cls(self.model) for _ in range(self.model.num_layers)] - map.set_transformer_params(['transformer.h'], transformer_containers) + map.set_transformer_params(['model.layers'], transformer_containers) map.set_non_transformer_params(PhiNonTransformerContainer(self.model)) From 740080c050e700f970deb73acc70615aaadfdecd Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Wed, 17 Jan 2024 15:20:35 -0800 Subject: [PATCH 6/8] Prevent infinite recursion when DS_ACCELERATOR is set to cuda (#4962) When DS_ACCELERATOR is overriden to CUDA, `get_accelerator` attempts to check if `is_current_accelerator_supported`. But since that calls `get_accelerator` again and `ds_accelerator` has not been initialized, DeepSpeed runs into infinite recursion. ``` elif is_current_accelerator_supported(): File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 48, in is_current_accelerator_supported return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 101, in get_accelerator elif is_current_accelerator_supported(): File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 48, in is_current_accelerator_supported return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 101, in get_accelerator elif is_current_accelerator_supported(): File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 48, in is_current_accelerator_supported return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 101, in get_accelerator elif is_current_accelerator_supported(): File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 48, in is_current_accelerator_supported return get_accelerator().device_name() in SUPPORTED_ACCELERATOR_LIST File "/usr/local/lib/python3.8/dist-packages/deepspeed/accelerator/real_accelerator.py", line 59, in get_accelerator if "DS_ACCELERATOR" in os.environ.keys(): File "/usr/lib/python3.8/_collections_abc.py", line 717, in __contains__ return key in self._mapping File "/usr/lib/python3.8/_collections_abc.py", line 666, in __contains__ self[key] File "/usr/lib/python3.8/os.py", line 672, in __getitem__ value = self._data[self.encodekey(key)] RecursionError: maximum recursion depth exceeded ``` This change fixes that by comparing the accelerator directly with the supported list of accelerators. --- accelerator/real_accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index ee421ed4cca7..1090a61681d9 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -98,7 +98,7 @@ def get_accelerator(): except ImportError as e: raise ValueError( f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") - elif is_current_accelerator_supported(): + elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST: raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. ' f'Value "{accelerator_name}" is not supported') ds_set_method = "override" From 870ae041d42190be8139afc12bef51d6ed7719f3 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Wed, 17 Jan 2024 19:43:59 -0500 Subject: [PATCH 7/8] Fixes for training models with bf16 + freshly initialized optimizer via `load_module_only` (#4141) This PR makes some fixes to the case where we want to resume training from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while not using the old optimizer in the checkpoint or relying on its existence at all. in this situation, despite passing `load_module_only=True` and `load_optimizer_states=False` to `load_checkpoint()`, the previous behavior was that: - `self._load_zero_checkpoint` would still be called, which attempts to load from the (in this case, nonexistent) checkpoint files. This PR stops this function from being called if using `load_module_only=True` and `load_optimizer_states=False`. Alternatively, calling this function may be alright if `"load_from_fp32_weights": true` is set in the DeepSpeed ZeRO config (reference: https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733) but this parameter does not seem to be documented in the docs for ZeRO config dicts. - in `_load_checkpoint`, the following codeblock: ``` if self.optimizer is not None and self.fp16_enabled(): self.optimizer.refresh_fp32_params() ``` results in `self.optimizer.refresh_fp32_params()` being called only if using FP16. As a result, the FP32 optimizer state is never initialized from the 16-bit model weights. This PR removes the fp16-specific condition. Previously reported in: https://github.com/EleutherAI/gpt-neox/issues/947 https://github.com/EleutherAI/gpt-neox/issues/843 Should also close: https://github.com/microsoft/DeepSpeed/issues/4017 Fixes: #4944 and #4017 This caused problems for a freshly-converted LLama checkpoint, which did not contain optimizer states, when trying to train with this model as initialization. I have confirmed the following fixes prevent this behavior. cc @Quentin-Anthony @zhangir-azerbayev --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 00346317ebf1..9c0da4c2d406 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2745,10 +2745,12 @@ def load_checkpoint(self, load_module_only=load_module_only, custom_load_fn=custom_load_fn) - load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization() - or self.bfloat16_enabled()) + load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: - success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + if load_optimizer_states and not load_module_only: + success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + else: + success = False if not success: self.optimizer._restore_from_bit16_weights() @@ -2830,7 +2832,7 @@ def _load_checkpoint(self, optim_checkpoint = None if load_module_only: deepspeed_states = ['module'] - if self.optimizer is not None and self.fp16_enabled(): + if self.optimizer is not None: self.optimizer.refresh_fp32_params() else: has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() From 3110c38852ceb8d531f9577cbf6b74db5cbe5838 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 19 Jan 2024 01:42:18 +0800 Subject: [PATCH 8/8] params partition for skip_init (#4722) Some models use ```skip_init``` to initialize weights. ```skip_init``` first initializes on a meta device in ```__init__``` of a module and then uses ```to_empty()```. This conflicts with the deepspeed hook ```module.__init__``` mechanism. it's necessary to wait for ```skip_init``` to finish before executing ```_post_init_method```. However, the ```from ... import skip_init``` behavior typically occurs outside the context, there seems to be no good way to directly hook into ```skip_init```. Hence, the approach here is to delay the execution of ```_post_init_method``` to resolve this issue. Known affected models include HuggingFace models like chatglm2 and chatglm3." --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> --- .../runtime/zero/partition_parameters.py | 60 ++++++++++++++- tests/unit/runtime/zero/test_zero.py | 77 +++++++++++++++++++ 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 992dcd446ad6..030a050b88e2 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -312,6 +312,7 @@ def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtyp torch.half, torch.bfloat16, torch.float ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" self.wrapped_cls = set() + self.skip_init_depth = 0 self.quantized_initialization = None if ds_config is not None and ds_config.weight_quantization_config and ds_config.weight_quantization_config.quantized_initialization: @@ -435,6 +436,51 @@ def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: return wrapped_apply + def hook_for_skip_init(module): + # this function is intended for handling the logic of torch.nn.utils.skip_init + # skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta' + # the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device). + def partition_after_empty_init(f): + + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + _module = f(module, *args, **kwargs) + # here is the post-hook for module.apply(empty_like...) + # after module.apply(empty_like...), the module has completed its empty init on real device + # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init + self._post_init_method(_module) + return _module + + return wrapper + + def post_wrapper_to_empty(f): + # append some wrapper restoration after to_empty() call + @functools.wraps(f) + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + # restore _apply hook + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class_apply(subclass) + # self restore + module.to_empty = f + return res + + return wrapper + + def _enable_class_apply(cls): + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) + + def _disable_class_apply(cls): + cls._apply = cls._old_apply_of_skip_init_hook + + # add hooks for to_empty: apply_(empty_like) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class_apply(subclass) + + # add a restore hook when exiting skip_init + module.to_empty = post_wrapper_to_empty(module.to_empty) + def partition_after(f): @functools.wraps(f) @@ -456,16 +502,25 @@ def wrapper(module, *args, **kwargs): is_child_module = True setattr(module, "_ds_child_entered", True) - f(module, *args, **kwargs) + init_on_meta = 'device' in kwargs and kwargs['device'] == 'meta' + if init_on_meta: + self.skip_init_depth += 1 + f(module, *args, **kwargs) + if init_on_meta and self.skip_init_depth == 1: + # check and handle the logic of empty_init + hook_for_skip_init(module) if is_child_module: # child's __init__ is done, now we can run a single post_init on the child object delattr(module, "_ds_child_entered") print_rank_0(f'Running post_init for {module.__class__.__name__}', force=False) - self._post_init_method(module) + if self.skip_init_depth == 0: + self._post_init_method(module) print_rank_0(f'After initializing followed by post init for {module.__class__.__name__}', force=False) + if init_on_meta: + self.skip_init_depth -= 1 return wrapper @@ -512,7 +567,6 @@ def _init_subclass(cls, **kwargs): self.patched = True def unpatch_init_and_builtins(self): - if self.patched: def _disable_class(cls): diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 6d66ff704416..bc31e3b9a968 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -14,6 +14,7 @@ from torch.nn.modules.container import ModuleList from torch.nn.modules.loss import L1Loss from torch.nn.parameter import Parameter +from torch.nn.utils import skip_init from unit.common import DistributedTest from unit.simple_model import SimpleModel, random_dataloader @@ -1193,6 +1194,82 @@ def create_tensor(vals): _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) +class TestParamPartitioningSkipInit(DistributedTest): + world_size = 2 + + def test(self): + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 3 + }, + } + hidden_dim = 10 + + class SubModel(torch.nn.Module): + + def __init__(self, input_size, output_size, dropout_prob=0.5, device=None): + super(SubModel, self).__init__() + self.linear = torch.nn.Linear(input_size, output_size, device=device) + self.dropout = torch.nn.Dropout(dropout_prob) + self.module_list = torch.nn.ModuleList([torch.nn.Linear(input_size, output_size, device=device)]) + + def forward(self, x): + x = self.linear(x) + x = self.dropout(x) + x = self.module_list[0](x) + return x + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = skip_init(Linear, hidden_dim, hidden_dim) + self.l2 = skip_init(SubModel, hidden_dim, hidden_dim) + self.l3 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cel = torch.nn.CrossEntropyLoss() + self.l4 = skip_init(SubModel, hidden_dim, hidden_dim) + + def forward(self, x, y): + x = self.l1(x) + x = self.l2(x) + x = self.l3(x) + x = self.l4(x) + loss = self.cel(x, y) + val = [x, loss] + return val + + with deepspeed.zero.Init(config=config_dict): + model = MyModel(hidden_dim) + world_size = dist.get_world_size() + ds_tensor_numel = math.ceil(hidden_dim * hidden_dim / world_size) + assert model.l1.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l2.linear.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l2.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel + assert model.l3.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l4.linear.weight.ds_tensor.numel() == ds_tensor_numel + assert model.l4.module_list[0].weight.ds_tensor.numel() == ds_tensor_numel + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + dist.barrier() + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + loss = loss[1] + model.backward(loss) + model.step() + + class TestZeroOffloadStage1(DistributedTest): world_size = 2