From db6e5eebe0b5e0b116e1c74f16bf299bb502c9cb Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 7 Nov 2024 13:29:20 -0800 Subject: [PATCH] Move _tp_plan setting to post_init --- src/transformers/modeling_utils.py | 49 +++++++++++-------- .../models/gemma/modeling_gemma.py | 1 - .../models/gemma2/modeling_gemma2.py | 1 - src/transformers/models/glm/modeling_glm.py | 1 - .../models/llama/modeling_llama.py | 1 - 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ee2406e747a..bde262843d8 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1399,6 +1399,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Has support for a `QuantoQuantizedCache` instance as `past_key_values` _supports_quantized_cache = False + # A tensor parallel plan to be applied to the model when TP is enabled. For + # top-level models, this attribute is currently defined in respective model + # code. For base models, this attribute comes from + # `config.base_model_tp_plan` during `post_init`. + _tp_plan = None + @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: """ @@ -1443,6 +1449,9 @@ def post_init(self): """ self.init_weights() self._backward_compatibility_gradient_checkpointing() + # If current model is a base model, attach `base_model_tp_plan` from config + if self.base_model is self: + self._tp_plan = self.config.base_model_tp_plan def dequantize(self): """ @@ -3475,9 +3484,8 @@ def from_pretrained( tp_plan = kwargs.pop("tp_plan", None) if tp_plan is not None and tp_plan != "auto": - raise ValueError( - f"tp_plan supports 'auto' only for now but got {tp_plan}." - ) + # TODO: we can relax this check when we support taking tp_plan from a json file, for example. + raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -4095,9 +4103,7 @@ def from_pretrained( init_contexts.append(init_empty_weights()) elif tp_plan is not None: if not torch.distributed.is_initialized(): - raise ValueError( - "Tensor Parallel requires torch.distributed to be initialized first." - ) + raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") # Get device type (e.g. "cuda") device_type = torch.distributed.distributed_c10d._device_capability()[0] @@ -5063,21 +5069,22 @@ def tensor_parallel(self, device_mesh): # parallelize a model. def tplize(mod: torch.nn.Module) -> None: tp_plan = getattr(mod, "_tp_plan", None) - if tp_plan: - logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") - # In model configs, we use a neutral type (string) to specify - # parallel styles, here we translate them into torch TP types. - # Using tree_map because `tp_plan` is a dict. - tp_plan = torch.utils._pytree.tree_map( - translate_to_torch_parallel_style, - tp_plan, - ) - # Apply TP to current module. - torch.distributed.tensor.parallel.parallelize_module( - mod, - device_mesh=device_mesh, - parallelize_plan=tp_plan, - ) + if tp_plan is None: + return + logger.debug(f"Applying tensor parallel to {mod.__class__.__name__}: {tp_plan}") + # In model configs, we use a neutral type (string) to specify + # parallel styles, here we translate them into torch TP types. + # Using tree_map because `tp_plan` is a dict. + tp_plan = torch.utils._pytree.tree_map( + translate_to_torch_parallel_style, + tp_plan, + ) + # Apply TP to current module. + torch.distributed.tensor.parallel.parallelize_module( + mod, + device_mesh=device_mesh, + parallelize_plan=tp_plan, + ) # `apply` is a native method of `nn.Module` that recursively applies a # function to every submodule. diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 650e57de6d9..6fead73eced 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -722,7 +722,6 @@ def __init__(self, config: GemmaConfig): self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2e6168bfb70..6a3d8f27fb1 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -742,7 +742,6 @@ def __init__(self, config: Gemma2Config): self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index d108e645bf5..58a89d90b44 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -708,7 +708,6 @@ def __init__(self, config: GlmConfig): dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a967cfe685e..679296648a9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -812,7 +812,6 @@ def __init__(self, config: LlamaConfig): self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - self._tp_plan = config.base_model_tp_plan if getattr(config, "pretraining_tp", 1) != 1: logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")