Skip to content

Commit

Permalink
fix3
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Apr 1, 2024
1 parent 03c9784 commit fa68616
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,19 @@ def __init__(
'triton_rms_norm requires Flash Attention to be installed. ' +
'Please pip install flash-attn.')

if not isinstance(normalized_shape, int):
raise ValueError('TritonRMSNorm only supports 1D tensors')

self.rms_norm_fn = rms_norm_fn

self.weight = torch.nn.Parameter(
torch.ones(normalized_shape, device=device, dtype=dtype))

def forward(self, x: torch.Tensor):
# Flash Attention expect a flat tensor
weight = self.weight.flatten()
return self.rms_norm_fn(
x,
weight,
self.weight,
None, # no bias
residual=None,
eps=self.eps,
Expand Down
14 changes: 9 additions & 5 deletions tests/models/test_rmsnorm_triton_vs_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@pytest.mark.gpu
@pytest.mark.parametrize('normalized_shape', [32, 128])
@pytest.mark.parametrize('normalized_shape', [32, 128, 4096])
def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]],
device: str = 'cuda'):
# Compare Triton and PyTorch Eager implementations of RMSNorm
Expand Down Expand Up @@ -53,20 +53,24 @@ def test_rmsnorm_triton_vs_eager(normalized_shape: Union[int, List[int]],
loss0.backward()
loss1.backward()

torch.testing.assert_close(y0, y1, rtol=1e-2, atol=1e-2)
rtol = 1e-6
atol = 1e-6

torch.testing.assert_close(y0, y1, rtol=rtol, atol=atol)

p0 = eager_rmsnorm.weight
p1 = triton_rmsnorm.weight

# weight check
torch.testing.assert_close(p0, p1, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(p0, p1, rtol=rtol, atol=atol)
# weight gradient check
assert p0.grad is not None
assert p1.grad is not None
assert torch.norm(p0.grad - p1.grad) <= 1e-2 + 1e-2 * torch.norm(p0.grad)
assert torch.norm(p0.grad - p1.grad) <= atol + rtol * torch.norm(p0.grad)

# input gradient check
assert x0.grad is not None
assert x1.grad is not None
# Relaxed to a l2-norm based check.
assert torch.norm(x0.grad - x1.grad) <= 1e-2 + 1e-2 * torch.norm(x0.grad)
assert torch.norm(x0.grad - x1.grad) <= atol + rtol * torch.norm(x0.grad)

0 comments on commit fa68616

Please sign in to comment.