diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 1454f496..abe1cbdd 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -216,18 +216,19 @@ def backward(ctx, ddsd_out): # unpack saved tensors; ugly because quantizing changes tensor count # dtype = ctx.dtype - w1, w2 = ctx.saved_tensors[:2] - topo_tensors = ctx.saved_tensors[2:8] + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] # either 1 or 2 tensors for MLP input after the always-present tensors if ctx.num_input_bits == -1: - x = ctx.saved_tensors[8] + x = saved_tensors[8] else: - x_q, x_scales = ctx.saved_tensors[8:10] + x_q, x_scales = saved_tensors[8:10] # either 1 or 2 tensors at the end for saved GELU input / sdd output if ctx.num_remat_bits == -1: - sdd_out_data = ctx.saved_tensors[-1] + sdd_out_data = saved_tensors[-1] else: - hidden_q, hidden_scales = ctx.saved_tensors[-2:] + hidden_q, hidden_scales = saved_tensors[-2:] # rematerialize gelu output if ctx.num_remat_bits == -1: @@ -434,20 +435,21 @@ def backward(ctx, ddsd_out): # Unpack saved tensors; ugly because quantizing changes tensor count # dtype = ctx.dtype - w1, w2 = ctx.saved_tensors[:2] - batch_sizes = ctx.saved_tensors[2] + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] # Either 1 or 2 tensors for MLP input after the always-present tensors if ctx.num_input_bits == -1: - x = ctx.saved_tensors[3] + x = saved_tensors[3] else: - x_q, x_scales = ctx.saved_tensors[3:5] + x_q, x_scales = saved_tensors[3:5] # Either 1 or 2 tensors at the end for saved GELU input / sdd output if ctx.num_remat_bits == -1: - sdd_out = ctx.saved_tensors[-1] + sdd_out = saved_tensors[-1] else: - hidden_q, hidden_scales = ctx.saved_tensors[-2:] + hidden_q, hidden_scales = saved_tensors[-2:] # Rematerialize gelu output. if ctx.num_remat_bits == -1: diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index e973fb86..6e127cfa 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -32,8 +32,9 @@ def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k, @custom_bwd def backward(ctx, grad): grad = grad.contiguous() + saved_tensors = ctx.saved_tensors - indices, bin_ids, weights, bins, padded_bins = ctx.saved_tensors[:5] + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] dgrad = None if ctx.needs_input_grad[0]: dgrad = kernels.padded_gather( @@ -48,9 +49,9 @@ def backward(ctx, grad): wgrad = None if ctx.needs_input_grad[3]: # need wgrad if ctx.num_bits == -1: # input saved without quantization - x = ctx.saved_tensors[-1] + x = saved_tensors[-1] else: # dequantize input - x_q, x_scales = ctx.saved_tensors[-2:] + x_q, x_scales = saved_tensors[-2:] x = turbo.dequantize_signed( x_q, x_scales, num_bits=ctx.num_bits, out_shape=ctx.x_shape) diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index c7f7ca32..f65b6dbd 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -32,8 +32,9 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k, @custom_bwd def backward(ctx, grad): grad = grad.contiguous() + saved_tensors = ctx.saved_tensors - indices, bin_ids, weights, bins = ctx.saved_tensors[:4] + indices, bin_ids, weights, bins = saved_tensors[:4] dgrad = None if ctx.needs_input_grad[0]: dgrad = kernels.gather( @@ -47,9 +48,9 @@ def backward(ctx, grad): wgrad = None if ctx.needs_input_grad[3]: # need wgrad if ctx.num_bits == -1: # input saved without quantization - x = ctx.saved_tensors[-1] + x = saved_tensors[-1] else: # dequantize input - x_q, x_scales = ctx.saved_tensors[-2:] + x_q, x_scales = saved_tensors[-2:] x = turbo.dequantize_signed( x_q, x_scales, num_bits=ctx.num_bits, out_shape=ctx.x_shape)