Skip to content

Commit

Permalink
Add compiler hints for Triton permutation kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Jul 25, 2023
1 parent 2501ec4 commit abe9ae4
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,18 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
triton.Config({'BLOCK_X': 128}, num_warps=4),
triton.Config({'BLOCK_X': 256}, num_warps=4),
],
key=['num_columns'],
key=['NUM_COLUMNS'],
)
@triton.jit
def _padded_copy_wgrad(
x,
grad,
wgrad,
num_columns,
indices,
bin_ids,
bins,
padded_bins,
NUM_COLUMNS : tl.constexpr,
TOP_K : tl.constexpr,
BLOCK_X : tl.constexpr):
# Our index into 'tokens * top_k'.
Expand All @@ -220,14 +220,14 @@ def _padded_copy_wgrad(

# Offset the input and output pointers.
wgrad += index_out
grad += (index_out // TOP_K) * num_columns
x += index_x * num_columns
offsets = tl.arange(0, BLOCK_X)
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
iterations = tl.cdiv(num_columns, BLOCK_X)
for i in range(tl.cdiv(num_columns, BLOCK_X)):
mask = offsets < num_columns
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)):
mask = offsets < NUM_COLUMNS
data = tl.load(x + offsets, mask=mask).to(tl.float32)
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
acc += data * scale
Expand Down Expand Up @@ -258,11 +258,11 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
x,
grad,
out,
x.shape[1],
indices,
bin_ids,
bins,
padded_bins,
NUM_COLUMNS=x.shape[1],
TOP_K=top_k)
return out

Expand All @@ -280,18 +280,18 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
triton.Config({'BLOCK_X': 128}, num_warps=4),
triton.Config({'BLOCK_X': 256}, num_warps=4),
],
key=['num_columns'],
key=['NUM_COLUMNS'],
)
@triton.jit
def _binned_copy(
a,
b,
num_experts,
expert_capacity,
num_columns,
indices,
weights,
bins,
NUM_COLUMNS : tl.constexpr,
TOP_K : tl.constexpr,
BLOCK_X : tl.constexpr,
A_TO_B : tl.constexpr,
Expand Down Expand Up @@ -324,9 +324,9 @@ def _binned_copy(
# need to reduce the result. Using atomics is slow, so we
# do the reduce step in a second kernel.
offset = index_a // TOP_K if A_TO_B else index_a
a += offset * num_columns
b += index_b * num_columns
offsets = tl.arange(0, BLOCK_X)
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

# Load the scale, if requested.
scale = tl.load(weights + index_a) if SCALE else 1
Expand All @@ -337,9 +337,9 @@ def _binned_copy(
iptr = a if A_TO_B else b
optr = b if A_TO_B else a

iterations = tl.cdiv(num_columns, BLOCK_X)
for i in range(tl.cdiv(num_columns, BLOCK_X)):
mask = offsets < num_columns
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)):
mask = offsets < NUM_COLUMNS
x = tl.load(iptr + offsets, mask=mask)
x = x.to(tl.float32) * scale.to(tl.float32)

Expand Down Expand Up @@ -369,10 +369,10 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
out,
num_experts,
expert_capacity,
x.shape[1],
indices,
weights,
bins,
NUM_COLUMNS=x.shape[1],
A_TO_B=True,
TOP_K=top_k,
SCALE=weights is not None)
Expand Down Expand Up @@ -400,10 +400,10 @@ def binned_scatter(x, indices, weights, bins, top_k):
x,
num_experts,
expert_capacity,
hidden_size,
indices,
weights,
bins,
NUM_COLUMNS=hidden_size,
A_TO_B=False,
TOP_K=top_k,
SCALE=weights is not None)
Expand All @@ -425,7 +425,7 @@ def binned_scatter(x, indices, weights, bins, top_k):
triton.Config({'BLOCK_X': 128}, num_warps=4),
triton.Config({'BLOCK_X': 256}, num_warps=4),
],
key=['num_columns'],
key=['NUM_COLUMNS'],
)
@triton.jit
def _binned_copy_wgrad(
Expand All @@ -434,9 +434,9 @@ def _binned_copy_wgrad(
wgrad,
num_experts,
expert_capacity,
num_columns,
indices,
bins,
NUM_COLUMNS : tl.constexpr,
TOP_K : tl.constexpr,
BLOCK_X : tl.constexpr):
# Load our indices into the output.
Expand All @@ -462,14 +462,14 @@ def _binned_copy_wgrad(

# Offset the input and output pointers.
wgrad += index_out
grad += (index_out // TOP_K) * num_columns
x += index_x * num_columns
offsets = tl.arange(0, BLOCK_X)
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
iterations = tl.cdiv(num_columns, BLOCK_X)
for i in range(tl.cdiv(num_columns, BLOCK_X)):
mask = offsets < num_columns
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)):
mask = offsets < NUM_COLUMNS
data = tl.load(x + offsets, mask=mask).to(tl.float32)
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
acc += data * scale
Expand Down Expand Up @@ -500,8 +500,8 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
out,
num_experts,
expert_capacity,
hidden_size,
indices,
bins,
NUM_COLUMNS=hidden_size,
TOP_K=top_k)
return out

0 comments on commit abe9ae4

Please sign in to comment.