Skip to content

Commit

Permalink
Type Checking (#141)
Browse files Browse the repository at this point in the history
* add type hints

* more type checks

* tyoe check router

* more type checking

* restore sum

* more tests

* more type checking

* more updates

* add py.typed

* git rid of stk type errors

* remove icecream package

* fix matrix import

* add type hints

* fix all torch.distibuted type errors

* fix more torch.distibuted type errors

* fix all gmm type errors

* more type checking

* comment out type checking

* update
  • Loading branch information
eitanturok authored Aug 28, 2024
1 parent bce5d7b commit 5b2650a
Show file tree
Hide file tree
Showing 31 changed files with 324 additions and 146 deletions.
12 changes: 11 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
# Copyright 2024 MosaicML MegaBlocks authors
# Copyright 2024 Databricks authors
# SPDX-License-Identifier: Apache-2.0

default_language_version:
python: python3
repos:
# - repo: local
# hooks:
# - id: pyright
# name: pyright
# entry: pyright
# language: node
# types: [python]
# pass_filenames: false
# args: [--warnings]
# additional_dependencies: ["[email protected]"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
Expand Down
153 changes: 106 additions & 47 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional

import torch
import triton
import triton.language as tl


def assert_is_tensor(x, ndim):
def assert_is_tensor(x: torch.Tensor, ndim: int):
if x.ndim != ndim:
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')


def assert_is_matrix(x):
def assert_is_matrix(x: torch.Tensor):
assert_is_tensor(x, 2)


def assert_is_vector(x):
def assert_is_vector(x: torch.Tensor):
if x.ndim != 1:
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')


def assert_equal(a, b):
def assert_equal(a: Any, b: Any):
if a != b:
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)

Expand All @@ -43,13 +44,13 @@ def assert_equal(a, b):
)
@triton.jit
def _padded_copy(
a,
b,
indices,
bin_ids,
weights,
bins,
padded_bins,
a: torch.Tensor,
b: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Any,
bins: torch.Tensor,
padded_bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -93,7 +94,8 @@ def _padded_copy(
iptr = a if A_TO_B else b
optr = b if A_TO_B else a

for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)):
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
x = tl.load(iptr + offsets, mask=mask)
x = x.to(tl.float32) * scale.to(tl.float32)
Expand All @@ -103,7 +105,15 @@ def _padded_copy(
offsets += BLOCK_X


def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
def padded_gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -119,7 +129,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):

# NOTE: Because of the padding, the output size is dynamic.
# We load the final padded bin bound to get the output rows.
output_rows = padded_bins[-1].cpu().item()
output_rows = int(padded_bins[-1].cpu().item())
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
_padded_copy[(indices.shape[0],)](
x,
Expand All @@ -137,7 +147,14 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
return out


def gather(x, indices, bin_ids, weights, bins, top_k):
def gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -169,7 +186,15 @@ def gather(x, indices, bin_ids, weights, bins, top_k):
return out


def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -202,7 +227,14 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])


def scatter(x, indices, bin_ids, weights, bins, top_k):
def scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)


Expand All @@ -225,13 +257,13 @@ def scatter(x, indices, bin_ids, weights, bins, top_k):
)
@triton.jit
def _padded_copy_wgrad(
x,
grad,
wgrad,
indices,
bin_ids,
bins,
padded_bins,
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -263,7 +295,7 @@ def _padded_copy_wgrad(

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(iterations):
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
data = tl.load(x + offsets, mask=mask).to(tl.float32)
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
Expand All @@ -275,7 +307,15 @@ def _padded_copy_wgrad(
tl.store(wgrad, out)


def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
def padded_scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_matrix(grad)
Expand All @@ -302,7 +342,14 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
return out


def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
def scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
top_k: int,
):
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)


Expand All @@ -323,13 +370,13 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
)
@triton.jit
def _binned_copy(
a,
b,
num_experts,
expert_capacity,
indices,
weights,
bins,
a: torch.Tensor,
b: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
weights, #: Optional[torch.Tensor],
bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -378,7 +425,7 @@ def _binned_copy(
optr = b if A_TO_B else a

iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(iterations):
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
x = tl.load(iptr + offsets, mask=mask)
x = x.to(tl.float32) * scale.to(tl.float32)
Expand All @@ -388,7 +435,14 @@ def _binned_copy(
offsets += BLOCK_X


def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
def binned_gather(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
expert_capacity: int,
top_k: int,
):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -400,7 +454,6 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):

num_experts = bins.shape[0]
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)

_binned_copy[(num_experts, expert_capacity)](
x,
out,
Expand All @@ -417,7 +470,13 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
return out


def binned_scatter(x, indices, weights, bins, top_k):
def binned_scatter(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_vector(indices)
Expand Down Expand Up @@ -465,13 +524,13 @@ def binned_scatter(x, indices, weights, bins, top_k):
)
@triton.jit
def _binned_copy_wgrad(
x,
grad,
wgrad,
num_experts,
expert_capacity,
indices,
bins,
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
bins: torch.Tensor,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -505,7 +564,7 @@ def _binned_copy_wgrad(

acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
for i in range(iterations):
for _ in range(iterations):
mask = offsets < NUM_COLUMNS
data = tl.load(x + offsets, mask=mask).to(tl.float32)
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
Expand All @@ -517,7 +576,7 @@ def _binned_copy_wgrad(
tl.store(wgrad, out)


def binned_scatter_wgrad(x, grad, indices, bins, top_k):
def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_matrix(grad)
Expand Down
17 changes: 11 additions & 6 deletions megablocks/grouped_gemm_util.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import warnings

_grouped_gemm_is_available: bool = False
try:
import grouped_gemm
except ImportError:
grouped_gemm = None
_grouped_gemm_is_available = True
except ImportError as error:
warnings.warn('Grouped GEMM not available.')


def grouped_gemm_is_available():
return grouped_gemm is not None
return _grouped_gemm_is_available


def assert_grouped_gemm_is_available():
assert grouped_gemm_is_available(
), ('Grouped GEMM not available. Please run '
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.')
msg = (
'Grouped GEMM not available. Please run '
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
)
assert _grouped_gemm_is_available, msg


backend = grouped_gemm.backend if grouped_gemm_is_available() else None
Expand Down
12 changes: 6 additions & 6 deletions megablocks/layers/activation_fn.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0

from typing import Callable
from typing import Any, Callable, Union

import stk
import torch
from stk import Matrix


def act_fn(
x: stk.Matrix,
x: Matrix,
function: Callable,
return_grad_fn: bool = False,
**kwargs,
):
assert isinstance(x, stk.Matrix)
) -> Union[tuple[Matrix, Any] | Matrix]:
assert isinstance(x, Matrix)
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
if return_grad_fn:
x.data.requires_grad = True
out = function(x.data, **kwargs)
y = stk.Matrix(
y = Matrix(
x.size(),
out,
x.row_indices,
Expand Down
Loading

0 comments on commit 5b2650a

Please sign in to comment.