From 7a8d2e89032e41c54f7729c337a7714f23d22c8f Mon Sep 17 00:00:00 2001 From: Saaketh Date: Wed, 4 Sep 2024 22:13:27 -0400 Subject: [PATCH] change --- megablocks/layers/glu.py | 4 ++-- megablocks/layers/mlp.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 1d9c82b..cbe0c91 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -67,7 +67,7 @@ class MemoryOptimizedGroupedGLU(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.amp.custom_fwd(device_type='cuda') + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -102,7 +102,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.amp.custom_bwd(device_type='cuda') + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 0959c44..6e6f4d8 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -18,13 +18,13 @@ class ScaleGradient(torch.autograd.Function): @staticmethod - @torch.amp.custom_fwd(device_type='cuda') + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod - @torch.amp.custom_bwd(device_type='cuda') + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None @@ -188,7 +188,7 @@ class MemoryOptimizedMLP(torch.autograd.Function): """Sparse MLP with manually scheduled memory reuse.""" @staticmethod - @torch.amp.custom_fwd(device_type='cuda') + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, topo, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -230,7 +230,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn): return dsd_out @staticmethod - @torch.amp.custom_bwd(device_type='cuda') + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') @@ -398,7 +398,7 @@ class MemoryOptimizedGroupedMLP(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.amp.custom_fwd(device_type='cuda') + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -431,7 +431,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.amp.custom_bwd(device_type='cuda') + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx: Any, ddsd_out: torch.Tensor): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.')