Skip to content

Commit

Permalink
Do RMSNorm in native type.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 27, 2024
1 parent ca4b8f3 commit ab13000
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ def __init__(self, dim: int, dtype=None, device=None, operations=None):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))

def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)


class QKNorm(torch.nn.Module):
Expand Down

0 comments on commit ab13000

Please sign in to comment.