Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]Fix init issue for rms_norm in sequence_parallel. #448

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions megatron/model/fused_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from megatron import get_args

import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import intel_extension_for_pytorch as ipex # noqa

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel=False):
super().__init__()
self.eps = eps
self.weight = Parameter(torch.ones(dim))
self.weight = Parameter(torch.ones(dim,
dtype=get_args().params_dtype))
self.sequence_parallel = sequence_parallel
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)

def forward(self, x):
output = torch.xpu.IpexRmsNorm(x, self.weight.shape, self.weight, self.eps)
Expand Down
5 changes: 3 additions & 2 deletions megatron/model/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, sequence_parallel=False):
super().__init__()
self.eps = eps
init_device = None
Expand All @@ -19,7 +19,8 @@ def __init__(self, dim: int, eps: float = 1e-6):
device=init_device,
dtype=get_args().params_dtype))
init.ones_(self.weight)
setattr(self.weight, 'sequence_parallel', sequence_parallel)
self.sequence_parallel = sequence_parallel
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
Expand Down