From 7554fb04b05aac8749664e9ca6299ae11c1583bf Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 16:10:14 +0530 Subject: [PATCH 01/16] Update arguments.py --- megatron/arguments.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 49b3d8e4c6..9228da6ee9 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -399,7 +399,8 @@ def validate_args(args, defaults={}): args.async_tensor_model_parallel_allreduce = False if not args.use_dataset_only: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if deepspeed.accelerator.get_accelerator().device_name() == "cuda" \ + and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if args.sequence_parallel: raise RuntimeError( "Using sequence parallelism requires setting the environment variable " From 46dabf9d7867435d45d373357908e490fa7d872e Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 16:12:07 +0530 Subject: [PATCH 02/16] Update layers.py --- megatron/core/tensor_parallel/layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 2245113c9c..0f42b42d09 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -450,7 +450,8 @@ def linear_with_grad_accumulation_and_async_allreduce( ] if not linear_with_grad_accumulation_and_async_allreduce.warned: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if get_accelerator().device_name() == "cuda" \ + and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if sequence_parallel: warnings.warn( "When using sequence parallelism it is recommended to set the " From dbc64a3e741eeea8dbec85b2d737ecba1a5c7068 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:01:33 +0530 Subject: [PATCH 03/16] Update fused_layer_norm.py --- megatron/model/fused_layer_norm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 2f3b89014b..f0a91a7a54 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -4,6 +4,7 @@ https://github.com/NVIDIA/apex with some changes. """ +from deepspeed.accelerator.real_accelerator import get_accelerator import numbers import torch from torch.nn.parameter import Parameter @@ -13,6 +14,7 @@ import inspect from megatron.core.utils import make_viewless_tensor +from megatron import get_args try: from apex.contrib.layer_norm.layer_norm import FastLayerNormFN @@ -56,8 +58,12 @@ def __init__(self, normalized_shape, eps=1e-5, normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.weight = Parameter(torch.empty(*normalized_shape, + device=get_accelerator().current_device_name(), + dtype=get_args().params_dtype)) + self.bias = Parameter(torch.empty(*normalized_shape, + device=get_accelerator().current_device_name(), + dtype=get_args().params_dtype)) self.reset_parameters() self.no_persist_layer_norm = no_persist_layer_norm self.sequence_parallel = sequence_parallel From 5de04d3a550a75bfb7bcf64e4d864228a087f8f9 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:03:34 +0530 Subject: [PATCH 04/16] Update gpt_model.py --- megatron/model/gpt_model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 8968c96655..6017562316 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -393,9 +393,12 @@ def _to_float16(inputs): if args.normalization == 'layernorm': self.specs.append(LayerSpec(LayerNorm, args.hidden_size, - eps=args.layernorm_epsilon)) + eps=args.layernorm_epsilon, + sequence_parallel=args.sequence_parallel)) else: - self.specs.append(LayerSpec(RMSNorm, args.hidden_size, args.layernorm_epsilon)) + self.specs.append(LayerSpec(RMSNorm, args.hidden_size, + args.layernorm_epsilon, + sequence_parallel=args.sequence_parallel)) def _logits_helper(embedding, lm_output): """A wrapper to massage inputs/outputs from pipeline. """ From 54b13d837b2c59ef90a8fdbb12be60a5aa62eb88 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:04:08 +0530 Subject: [PATCH 05/16] Update language_model.py --- megatron/model/language_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index ec2ae1877a..74a7148f74 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -256,8 +256,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. if self.sequence_parallel: - # already partition sequence, do not need scatter_to_sequence_parallel_region - # embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + # already partition sequence, do not need scatter_to_sequence_parallel_region ? + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: From 9618f214faec6870acbc07f17632d1673978cb4c Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:07:37 +0530 Subject: [PATCH 06/16] Update rmsnorm.py --- megatron/model/rmsnorm.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/megatron/model/rmsnorm.py b/megatron/model/rmsnorm.py index 60e8978171..f25f61abc4 100644 --- a/megatron/model/rmsnorm.py +++ b/megatron/model/rmsnorm.py @@ -1,4 +1,10 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +from deepspeed.accelerator import get_accelerator +from megatron import get_args + import torch +from torch.nn import init from torch.nn.parameter import Parameter # Taken from facebookresearch/llama @@ -6,11 +12,15 @@ class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps - self.weight = Parameter(torch.ones(dim)) + self.weight = Parameter(torch.empty(dim, + device=get_accelerator().current_device_name(), + dtype=get_args().params_dtype)) + init.ones_(self.weight) + setattr(self.weight, 'sequence_parallel', sequence_parallel) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) - return output * self.weight \ No newline at end of file + return output * self.weight From de5adcf510fd7f6e58f74135ba09708875658a0d Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:12:11 +0530 Subject: [PATCH 07/16] Update transformer.py --- megatron/model/transformer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e79abea3cf..711b725060 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -922,7 +922,8 @@ def __init__(self, config, config.hidden_size, eps=config.layernorm_epsilon) else: - self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) # Self attention. self.self_attention = ParallelAttention( config, @@ -948,7 +949,8 @@ def __init__(self, config, config.hidden_size, eps=config.layernorm_epsilon) else: - self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) # Cross attention. if self.layer_type in (LayerType.decoder, LayerType.retro_decoder, @@ -968,7 +970,9 @@ def __init__(self, config, apply_layernorm_1p=args.apply_layernorm_1p, mem_efficient_ln=args.mem_efficient_ln) else: - self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, + config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) # MLP self.num_experts = num_experts @@ -1771,7 +1775,8 @@ def build_layer(layer_number, n_e): config.hidden_size, eps=config.layernorm_epsilon) else: - self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) def _get_layer(self, layer_number): return self.layers[layer_number] From 7bcd36138961df75dd758c2a10a9862c1f6684ec Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:13:30 +0530 Subject: [PATCH 08/16] Update utils.py --- megatron/model/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 6c94921c95..19875f089f 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -9,6 +9,7 @@ from megatron import get_args from deepspeed.runtime.zero import GatheredParameters +from deepspeed.accelerator import get_accelerator def init_method_normal(sigma): """Init method based on N(0, sigma).""" @@ -49,7 +50,9 @@ def attention_mask_func(attention_scores, attention_mask): def get_linear_layer(rows, columns, init_method, gather_params_on_init=False): """Simple linear layer with weight initialization.""" - layer = torch.nn.Linear(rows, columns) + layer = torch.nn.Linear(rows, columns, + device=get_accelerator().current_device_name(), + dtype=get_args().params_dtype) if get_args().perform_initialization: with GatheredParameters(layer.weight, modifier_rank=0, enabled=gather_params_on_init): init_method(layer.weight) From 6a167e5d007b5f7ef6292f33917d70c31b36a3e2 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:35:29 +0530 Subject: [PATCH 09/16] Update layers.py --- megatron/core/tensor_parallel/layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 0f42b42d09..67a78853aa 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Parts of the code here are adapted from PyTorch From 520ac3aa95525dc8d44043072b54983b281b0032 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:36:23 +0530 Subject: [PATCH 10/16] Update fused_layer_norm.py --- megatron/model/fused_layer_norm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index f0a91a7a54..f914c9e994 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """This code is copied fron NVIDIA apex: From 04c190948f01f696add5d01e4dc0017918db62f1 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:36:54 +0530 Subject: [PATCH 11/16] Update gpt_model.py --- megatron/model/gpt_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 6017562316..e5e60c43ee 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """GPT-2 model.""" From c9aec0462d0edda42572e0c4360d155dee4db0a1 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:37:11 +0530 Subject: [PATCH 12/16] Update language_model.py --- megatron/model/language_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 74a7148f74..3b8e4e0da1 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Transformer based language model.""" From 2947fa2c1c3fd02dc5ff761da668e7424e4fb6a0 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 23 Jul 2024 18:37:59 +0530 Subject: [PATCH 13/16] Update utils.py --- megatron/model/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 19875f089f..ec269c4e29 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Utilities for models.""" From b436b900966ce8b0a4034cb2fd34e09593dc872d Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Thu, 29 Aug 2024 18:17:53 +0530 Subject: [PATCH 14/16] Update utils.py --- megatron/model/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index ec269c4e29..6c94921c95 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -1,4 +1,3 @@ -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Utilities for models.""" @@ -10,7 +9,6 @@ from megatron import get_args from deepspeed.runtime.zero import GatheredParameters -from deepspeed.accelerator import get_accelerator def init_method_normal(sigma): """Init method based on N(0, sigma).""" @@ -51,9 +49,7 @@ def attention_mask_func(attention_scores, attention_mask): def get_linear_layer(rows, columns, init_method, gather_params_on_init=False): """Simple linear layer with weight initialization.""" - layer = torch.nn.Linear(rows, columns, - device=get_accelerator().current_device_name(), - dtype=get_args().params_dtype) + layer = torch.nn.Linear(rows, columns) if get_args().perform_initialization: with GatheredParameters(layer.weight, modifier_rank=0, enabled=gather_params_on_init): init_method(layer.weight) From faa0d74233c156210c7e0aaafcc2dc09c486289b Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 3 Sep 2024 21:19:01 +0530 Subject: [PATCH 15/16] Update fused_layer_norm.py --- megatron/model/fused_layer_norm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index f914c9e994..d1ef034397 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -59,11 +59,14 @@ def __init__(self, normalized_shape, eps=1e-5, normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps + init_device = None + if get_accelerator().device_name() == 'hpu': + init_device = get_accelerator().current_device_name() self.weight = Parameter(torch.empty(*normalized_shape, - device=get_accelerator().current_device_name(), + device=init_device, dtype=get_args().params_dtype)) self.bias = Parameter(torch.empty(*normalized_shape, - device=get_accelerator().current_device_name(), + device=init_device, dtype=get_args().params_dtype)) self.reset_parameters() self.no_persist_layer_norm = no_persist_layer_norm From 003fb7b1fe0ce65e936c4357fcc777b59110bcee Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Tue, 3 Sep 2024 21:20:27 +0530 Subject: [PATCH 16/16] Update rmsnorm.py --- megatron/model/rmsnorm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/model/rmsnorm.py b/megatron/model/rmsnorm.py index f25f61abc4..4860d81716 100644 --- a/megatron/model/rmsnorm.py +++ b/megatron/model/rmsnorm.py @@ -12,8 +12,11 @@ class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps + init_device = None + if get_accelerator().device_name() == 'hpu': + init_device = get_accelerator().current_device_name() self.weight = Parameter(torch.empty(dim, - device=get_accelerator().current_device_name(), + device=init_device, dtype=get_args().params_dtype)) init.ones_(self.weight) setattr(self.weight, 'sequence_parallel', sequence_parallel)