diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index e510723..cbe0c91 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -67,7 +67,7 @@ class MemoryOptimizedGroupedGLU(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -102,7 +102,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index e8f2d7b..6e6f4d8 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -18,13 +18,13 @@ class ScaleGradient(torch.autograd.Function): @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None @@ -188,7 +188,7 @@ class MemoryOptimizedMLP(torch.autograd.Function): """Sparse MLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, topo, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -230,7 +230,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') @@ -398,7 +398,7 @@ class MemoryOptimizedGroupedMLP(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -431,7 +431,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx: Any, ddsd_out: torch.Tensor): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.')