Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Sequence Parallelism #429

Merged
merged 16 commits into from
Sep 4, 2024
3 changes: 2 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def validate_args(args, defaults={}):
args.async_tensor_model_parallel_allreduce = False

if not args.use_dataset_only:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if deepspeed.accelerator.get_accelerator().device_name() == "cuda" \
and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
Expand Down
4 changes: 3 additions & 1 deletion megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

# Parts of the code here are adapted from PyTorch
Expand Down Expand Up @@ -450,7 +451,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
]

if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if get_accelerator().device_name() == "cuda" \
and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
Expand Down
11 changes: 9 additions & 2 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """

from deepspeed.accelerator.real_accelerator import get_accelerator
import numbers
import torch
from torch.nn.parameter import Parameter
Expand All @@ -13,6 +15,7 @@
import inspect

from megatron.core.utils import make_viewless_tensor
from megatron import get_args

try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
Expand Down Expand Up @@ -56,8 +59,12 @@ def __init__(self, normalized_shape, eps=1e-5,
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.empty(*normalized_shape,
polisettyvarma marked this conversation as resolved.
Show resolved Hide resolved
device=get_accelerator().current_device_name(),
dtype=get_args().params_dtype))
self.bias = Parameter(torch.empty(*normalized_shape,
device=get_accelerator().current_device_name(),
dtype=get_args().params_dtype))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
Expand Down
8 changes: 6 additions & 2 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""GPT-2 model."""
Expand Down Expand Up @@ -393,9 +394,12 @@ def _to_float16(inputs):
if args.normalization == 'layernorm':
self.specs.append(LayerSpec(LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon))
eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel))
else:
self.specs.append(LayerSpec(RMSNorm, args.hidden_size, args.layernorm_epsilon))
self.specs.append(LayerSpec(RMSNorm, args.hidden_size,
args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel))

def _logits_helper(embedding, lm_output):
"""A wrapper to massage inputs/outputs from pipeline. """
Expand Down
5 changes: 3 additions & 2 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Transformer based language model."""
Expand Down Expand Up @@ -256,8 +257,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):

# Dropout.
if self.sequence_parallel:
# already partition sequence, do not need scatter_to_sequence_parallel_region
# embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# already partition sequence, do not need scatter_to_sequence_parallel_region ?
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
polisettyvarma marked this conversation as resolved.
Show resolved Hide resolved
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
Expand Down
14 changes: 12 additions & 2 deletions megatron/model/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.

from deepspeed.accelerator import get_accelerator
from megatron import get_args

import torch
from torch.nn import init
from torch.nn.parameter import Parameter

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = Parameter(torch.ones(dim))
self.weight = Parameter(torch.empty(dim,
polisettyvarma marked this conversation as resolved.
Show resolved Hide resolved
device=get_accelerator().current_device_name(),
dtype=get_args().params_dtype))
init.ones_(self.weight)
setattr(self.weight, 'sequence_parallel', sequence_parallel)

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return output * self.weight
13 changes: 9 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,8 @@ def __init__(self, config,
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# Self attention.
self.self_attention = ParallelAttention(
config,
Expand All @@ -948,7 +949,8 @@ def __init__(self, config,
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# Cross attention.
if self.layer_type in (LayerType.decoder,
LayerType.retro_decoder,
Expand All @@ -968,7 +970,9 @@ def __init__(self, config,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.post_inter_attention_layernorm = RMSNorm(config.hidden_size,
config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)

# MLP
self.num_experts = num_experts
Expand Down Expand Up @@ -1771,7 +1775,8 @@ def build_layer(layer_number, n_e):
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)

def _get_layer(self, layer_number):
return self.layers[layer_number]
Expand Down
6 changes: 5 additions & 1 deletion megatron/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Utilities for models."""
Expand All @@ -9,6 +10,7 @@
from megatron import get_args

from deepspeed.runtime.zero import GatheredParameters
from deepspeed.accelerator import get_accelerator

def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
Expand Down Expand Up @@ -49,7 +51,9 @@ def attention_mask_func(attention_scores, attention_mask):

def get_linear_layer(rows, columns, init_method, gather_params_on_init=False):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
layer = torch.nn.Linear(rows, columns,
polisettyvarma marked this conversation as resolved.
Show resolved Hide resolved
device=get_accelerator().current_device_name(),
dtype=get_args().params_dtype)
if get_args().perform_initialization:
with GatheredParameters(layer.weight, modifier_rank=0, enabled=gather_params_on_init):
init_method(layer.weight)
Expand Down