Skip to content

Commit

Permalink
Merge pull request #33 from mvpatel2000/mvpatel2000/reformat
Browse files Browse the repository at this point in the history
Unpack saved context once
  • Loading branch information
tgale96 authored Oct 24, 2023
2 parents 52aa1b2 + 1a42b57 commit 30faffb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 18 deletions.
26 changes: 14 additions & 12 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,19 @@ def backward(ctx, ddsd_out):
# unpack saved tensors; ugly because quantizing changes tensor count
#
dtype = ctx.dtype
w1, w2 = ctx.saved_tensors[:2]
topo_tensors = ctx.saved_tensors[2:8]
saved_tensors = ctx.saved_tensors
w1, w2 = saved_tensors[:2]
topo_tensors = saved_tensors[2:8]
# either 1 or 2 tensors for MLP input after the always-present tensors
if ctx.num_input_bits == -1:
x = ctx.saved_tensors[8]
x = saved_tensors[8]
else:
x_q, x_scales = ctx.saved_tensors[8:10]
x_q, x_scales = saved_tensors[8:10]
# either 1 or 2 tensors at the end for saved GELU input / sdd output
if ctx.num_remat_bits == -1:
sdd_out_data = ctx.saved_tensors[-1]
sdd_out_data = saved_tensors[-1]
else:
hidden_q, hidden_scales = ctx.saved_tensors[-2:]
hidden_q, hidden_scales = saved_tensors[-2:]

# rematerialize gelu output
if ctx.num_remat_bits == -1:
Expand Down Expand Up @@ -434,20 +435,21 @@ def backward(ctx, ddsd_out):
# Unpack saved tensors; ugly because quantizing changes tensor count
#
dtype = ctx.dtype
w1, w2 = ctx.saved_tensors[:2]
batch_sizes = ctx.saved_tensors[2]
saved_tensors = ctx.saved_tensors
w1, w2 = saved_tensors[:2]
batch_sizes = saved_tensors[2]

# Either 1 or 2 tensors for MLP input after the always-present tensors
if ctx.num_input_bits == -1:
x = ctx.saved_tensors[3]
x = saved_tensors[3]
else:
x_q, x_scales = ctx.saved_tensors[3:5]
x_q, x_scales = saved_tensors[3:5]

# Either 1 or 2 tensors at the end for saved GELU input / sdd output
if ctx.num_remat_bits == -1:
sdd_out = ctx.saved_tensors[-1]
sdd_out = saved_tensors[-1]
else:
hidden_q, hidden_scales = ctx.saved_tensors[-2:]
hidden_q, hidden_scales = saved_tensors[-2:]

# Rematerialize gelu output.
if ctx.num_remat_bits == -1:
Expand Down
7 changes: 4 additions & 3 deletions megablocks/ops/padded_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k,
@custom_bwd
def backward(ctx, grad):
grad = grad.contiguous()
saved_tensors = ctx.saved_tensors

indices, bin_ids, weights, bins, padded_bins = ctx.saved_tensors[:5]
indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
dgrad = None
if ctx.needs_input_grad[0]:
dgrad = kernels.padded_gather(
Expand All @@ -48,9 +49,9 @@ def backward(ctx, grad):
wgrad = None
if ctx.needs_input_grad[3]: # need wgrad
if ctx.num_bits == -1: # input saved without quantization
x = ctx.saved_tensors[-1]
x = saved_tensors[-1]
else: # dequantize input
x_q, x_scales = ctx.saved_tensors[-2:]
x_q, x_scales = saved_tensors[-2:]
x = turbo.dequantize_signed(
x_q, x_scales, num_bits=ctx.num_bits, out_shape=ctx.x_shape)

Expand Down
7 changes: 4 additions & 3 deletions megablocks/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k,
@custom_bwd
def backward(ctx, grad):
grad = grad.contiguous()
saved_tensors = ctx.saved_tensors

indices, bin_ids, weights, bins = ctx.saved_tensors[:4]
indices, bin_ids, weights, bins = saved_tensors[:4]
dgrad = None
if ctx.needs_input_grad[0]:
dgrad = kernels.gather(
Expand All @@ -47,9 +48,9 @@ def backward(ctx, grad):
wgrad = None
if ctx.needs_input_grad[3]: # need wgrad
if ctx.num_bits == -1: # input saved without quantization
x = ctx.saved_tensors[-1]
x = saved_tensors[-1]
else: # dequantize input
x_q, x_scales = ctx.saved_tensors[-2:]
x_q, x_scales = saved_tensors[-2:]
x = turbo.dequantize_signed(
x_q, x_scales, num_bits=ctx.num_bits, out_shape=ctx.x_shape)

Expand Down

0 comments on commit 30faffb

Please sign in to comment.