diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index ca0120b..b584cee 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,27 +1,26 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional import torch import triton import triton.language as tl -def assert_is_tensor(x: torch.Tensor, ndim: int): +def assert_is_tensor(x, ndim): if x.ndim != ndim: raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') -def assert_is_matrix(x: torch.Tensor): +def assert_is_matrix(x): assert_is_tensor(x, 2) -def assert_is_vector(x: torch.Tensor): +def assert_is_vector(x): if x.ndim != 1: raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') -def assert_equal(a: Any, b: Any): +def assert_equal(a, b): if a != b: raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) @@ -44,13 +43,13 @@ def assert_equal(a: Any, b: Any): ) @triton.jit def _padded_copy( - a: torch.Tensor, - b: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Any, - bins: torch.Tensor, - padded_bins: torch.Tensor, + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -105,15 +104,7 @@ def _padded_copy( offsets += BLOCK_X -def padded_gather( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -129,7 +120,7 @@ def padded_gather( # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. - output_rows = int(padded_bins[-1].cpu().item()) + output_rows = padded_bins[-1].cpu().item() out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, @@ -147,14 +138,7 @@ def padded_gather( return out -def gather( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -): +def gather(x, indices, bin_ids, weights, bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -186,15 +170,7 @@ def gather( return out -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -) -> torch.Tensor: +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -227,14 +203,7 @@ def padded_scatter( return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -) -> torch.Tensor: +def scatter(x, indices, bin_ids, weights, bins, top_k): return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -257,13 +226,13 @@ def scatter( ) @triton.jit def _padded_copy_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - wgrad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -307,15 +276,7 @@ def _padded_copy_wgrad( tl.store(wgrad, out) -def padded_scatter_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_matrix(grad) @@ -342,14 +303,7 @@ def padded_scatter_wgrad( return out -def scatter_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, -): +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -370,13 +324,13 @@ def scatter_wgrad( ) @triton.jit def _binned_copy( - a: torch.Tensor, - b: torch.Tensor, - num_experts: int, - expert_capacity: int, - indices: torch.Tensor, - weights, #: Optional[torch.Tensor], - bins: torch.Tensor, + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -435,14 +389,7 @@ def _binned_copy( offsets += BLOCK_X -def binned_gather( - x: torch.Tensor, - indices: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - expert_capacity: int, - top_k: int, -): +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -454,6 +401,7 @@ def binned_gather( num_experts = bins.shape[0] out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( x, out, @@ -470,13 +418,7 @@ def binned_gather( return out -def binned_scatter( - x: torch.Tensor, - indices: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -): +def binned_scatter(x, indices, weights, bins, top_k): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_vector(indices) @@ -524,13 +466,13 @@ def binned_scatter( ) @triton.jit def _binned_copy_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - wgrad: torch.Tensor, - num_experts: int, - expert_capacity: int, - indices: torch.Tensor, - bins: torch.Tensor, + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -576,7 +518,7 @@ def _binned_copy_wgrad( tl.store(wgrad, out) -def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int): +def binned_scatter_wgrad(x, grad, indices, bins, top_k): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_matrix(grad)