Skip to content

Commit

Permalink
no type checking in kernel.py (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok authored Aug 30, 2024
1 parent bc88977 commit 9b77d16
Showing 1 changed file with 43 additions and 101 deletions.
144 changes: 43 additions & 101 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional

import torch
import triton
import triton.language as tl


def assert_is_tensor(x: torch.Tensor, ndim: int):
def assert_is_tensor(x, ndim):
if x.ndim != ndim:
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')


def assert_is_matrix(x: torch.Tensor):
def assert_is_matrix(x):
assert_is_tensor(x, 2)


def assert_is_vector(x: torch.Tensor):
def assert_is_vector(x):
if x.ndim != 1:
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')


def assert_equal(a: Any, b: Any):
def assert_equal(a, b):
if a != b:
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)

Expand All @@ -44,13 +43,13 @@ def assert_equal(a: Any, b: Any):
)
@triton.jit
def _padded_copy(
a: torch.Tensor,
b: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Any,
bins: torch.Tensor,
padded_bins: torch.Tensor,
a,
b,
indices,
bin_ids,
weights,
bins,
padded_bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -105,15 +104,7 @@ def _padded_copy(
offsets += BLOCK_X


def padded_gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -129,7 +120,7 @@ def padded_gather(

# NOTE: Because of the padding, the output size is dynamic.
# We load the final padded bin bound to get the output rows.
output_rows = int(padded_bins[-1].cpu().item())
output_rows = padded_bins[-1].cpu().item()
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
_padded_copy[(indices.shape[0],)](
x,
Expand All @@ -147,14 +138,7 @@ def padded_gather(
return out


def gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
def gather(x, indices, bin_ids, weights, bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -186,15 +170,7 @@ def gather(
return out


def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -227,14 +203,7 @@ def padded_scatter(
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])


def scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
def scatter(x, indices, bin_ids, weights, bins, top_k):
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)


Expand All @@ -257,13 +226,13 @@ def scatter(
)
@triton.jit
def _padded_copy_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
x,
grad,
wgrad,
indices,
bin_ids,
bins,
padded_bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -307,15 +276,7 @@ def _padded_copy_wgrad(
tl.store(wgrad, out)


def padded_scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_matrix(grad)
Expand All @@ -342,14 +303,7 @@ def padded_scatter_wgrad(
return out


def scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
top_k: int,
):
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)


Expand All @@ -370,13 +324,13 @@ def scatter_wgrad(
)
@triton.jit
def _binned_copy(
a: torch.Tensor,
b: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
weights, #: Optional[torch.Tensor],
bins: torch.Tensor,
a,
b,
num_experts,
expert_capacity,
indices,
weights,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -435,14 +389,7 @@ def _binned_copy(
offsets += BLOCK_X


def binned_gather(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
expert_capacity: int,
top_k: int,
):
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -454,6 +401,7 @@ def binned_gather(

num_experts = bins.shape[0]
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)

_binned_copy[(num_experts, expert_capacity)](
x,
out,
Expand All @@ -470,13 +418,7 @@ def binned_gather(
return out


def binned_scatter(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
def binned_scatter(x, indices, weights, bins, top_k):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_vector(indices)
Expand Down Expand Up @@ -524,13 +466,13 @@ def binned_scatter(
)
@triton.jit
def _binned_copy_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
bins: torch.Tensor,
x,
grad,
wgrad,
num_experts,
expert_capacity,
indices,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -576,7 +518,7 @@ def _binned_copy_wgrad(
tl.store(wgrad, out)


def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int):
def binned_scatter_wgrad(x, grad, indices, bins, top_k):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_matrix(grad)
Expand Down

0 comments on commit 9b77d16

Please sign in to comment.