From abe9ae4c3ff8a67d54c8b686355140d5ca46294a Mon Sep 17 00:00:00 2001 From: Trevor Gale Date: Tue, 25 Jul 2023 08:10:56 -0700 Subject: [PATCH] Add compiler hints for Triton permutation kernels. --- megablocks/backend/kernels.py | 56 +++++++++++++++++------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index 6e54db44..d2e38997 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -185,18 +185,18 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): triton.Config({'BLOCK_X': 128}, num_warps=4), triton.Config({'BLOCK_X': 256}, num_warps=4), ], - key=['num_columns'], + key=['NUM_COLUMNS'], ) @triton.jit def _padded_copy_wgrad( x, grad, wgrad, - num_columns, indices, bin_ids, bins, padded_bins, + NUM_COLUMNS : tl.constexpr, TOP_K : tl.constexpr, BLOCK_X : tl.constexpr): # Our index into 'tokens * top_k'. @@ -220,14 +220,14 @@ def _padded_copy_wgrad( # Offset the input and output pointers. wgrad += index_out - grad += (index_out // TOP_K) * num_columns - x += index_x * num_columns - offsets = tl.arange(0, BLOCK_X) + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(num_columns, BLOCK_X) - for i in range(tl.cdiv(num_columns, BLOCK_X)): - mask = offsets < num_columns + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) acc += data * scale @@ -258,11 +258,11 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): x, grad, out, - x.shape[1], indices, bin_ids, bins, padded_bins, + NUM_COLUMNS=x.shape[1], TOP_K=top_k) return out @@ -280,7 +280,7 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): triton.Config({'BLOCK_X': 128}, num_warps=4), triton.Config({'BLOCK_X': 256}, num_warps=4), ], - key=['num_columns'], + key=['NUM_COLUMNS'], ) @triton.jit def _binned_copy( @@ -288,10 +288,10 @@ def _binned_copy( b, num_experts, expert_capacity, - num_columns, indices, weights, bins, + NUM_COLUMNS : tl.constexpr, TOP_K : tl.constexpr, BLOCK_X : tl.constexpr, A_TO_B : tl.constexpr, @@ -324,9 +324,9 @@ def _binned_copy( # need to reduce the result. Using atomics is slow, so we # do the reduce step in a second kernel. offset = index_a // TOP_K if A_TO_B else index_a - a += offset * num_columns - b += index_b * num_columns - offsets = tl.arange(0, BLOCK_X) + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) # Load the scale, if requested. scale = tl.load(weights + index_a) if SCALE else 1 @@ -337,9 +337,9 @@ def _binned_copy( iptr = a if A_TO_B else b optr = b if A_TO_B else a - iterations = tl.cdiv(num_columns, BLOCK_X) - for i in range(tl.cdiv(num_columns, BLOCK_X)): - mask = offsets < num_columns + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -369,10 +369,10 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): out, num_experts, expert_capacity, - x.shape[1], indices, weights, bins, + NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, SCALE=weights is not None) @@ -400,10 +400,10 @@ def binned_scatter(x, indices, weights, bins, top_k): x, num_experts, expert_capacity, - hidden_size, indices, weights, bins, + NUM_COLUMNS=hidden_size, A_TO_B=False, TOP_K=top_k, SCALE=weights is not None) @@ -425,7 +425,7 @@ def binned_scatter(x, indices, weights, bins, top_k): triton.Config({'BLOCK_X': 128}, num_warps=4), triton.Config({'BLOCK_X': 256}, num_warps=4), ], - key=['num_columns'], + key=['NUM_COLUMNS'], ) @triton.jit def _binned_copy_wgrad( @@ -434,9 +434,9 @@ def _binned_copy_wgrad( wgrad, num_experts, expert_capacity, - num_columns, indices, bins, + NUM_COLUMNS : tl.constexpr, TOP_K : tl.constexpr, BLOCK_X : tl.constexpr): # Load our indices into the output. @@ -462,14 +462,14 @@ def _binned_copy_wgrad( # Offset the input and output pointers. wgrad += index_out - grad += (index_out // TOP_K) * num_columns - x += index_x * num_columns - offsets = tl.arange(0, BLOCK_X) + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) acc = tl.zeros((BLOCK_X,), dtype=tl.float32) - iterations = tl.cdiv(num_columns, BLOCK_X) - for i in range(tl.cdiv(num_columns, BLOCK_X)): - mask = offsets < num_columns + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) acc += data * scale @@ -500,8 +500,8 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): out, num_experts, expert_capacity, - hidden_size, indices, bins, + NUM_COLUMNS=hidden_size, TOP_K=top_k) return out