Skip to content

Commit

Permalink
change
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 5, 2024
1 parent d516545 commit 8119e43
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class MemoryOptimizedGroupedGLU(torch.autograd.Function):
"""GroupedMLP with manually scheduled memory reuse."""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
# Cast inputs using ctx dtype from AMP
if ctx._fwd_used_autocast:
Expand Down Expand Up @@ -102,7 +102,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
return dsd_out

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, ddsd_out):
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
raise ValueError('Expected all MLP inputs to need grad.')
Expand Down
12 changes: 6 additions & 6 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
class ScaleGradient(torch.autograd.Function):

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx: Any, x: torch.Tensor, scale: float):
ctx.scale = scale
return x

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx: torch.Tensor, grad: torch.Tensor):
return grad * ctx.scale, None

Expand Down Expand Up @@ -188,7 +188,7 @@ class MemoryOptimizedMLP(torch.autograd.Function):
"""Sparse MLP with manually scheduled memory reuse."""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, x, w1, w2, topo, activation_fn):
# Cast inputs using ctx dtype from AMP
if ctx._fwd_used_autocast:
Expand Down Expand Up @@ -230,7 +230,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn):
return dsd_out

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx, ddsd_out):
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
raise ValueError('Expected all MLP inputs to need grad.')
Expand Down Expand Up @@ -398,7 +398,7 @@ class MemoryOptimizedGroupedMLP(torch.autograd.Function):
"""GroupedMLP with manually scheduled memory reuse."""

@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
# Cast inputs using ctx dtype from AMP
if ctx._fwd_used_autocast:
Expand Down Expand Up @@ -431,7 +431,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
return dsd_out

@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type='cuda')
def backward(ctx: Any, ddsd_out: torch.Tensor):
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
raise ValueError('Expected all MLP inputs to need grad.')
Expand Down

0 comments on commit 8119e43

Please sign in to comment.