diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d315f5..c754b29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,19 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks authors # SPDX-License-Identifier: Apache-2.0 default_language_version: python: python3 repos: +# - repo: local +# hooks: +# - id: pyright +# name: pyright +# entry: pyright +# language: node +# types: [python] +# pass_filenames: false +# args: [--warnings] +# additional_dependencies: ["pyright@1.1.310"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.2 hooks: diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index b831826..ca0120b 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,26 +1,27 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional import torch import triton import triton.language as tl -def assert_is_tensor(x, ndim): +def assert_is_tensor(x: torch.Tensor, ndim: int): if x.ndim != ndim: raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') -def assert_is_matrix(x): +def assert_is_matrix(x: torch.Tensor): assert_is_tensor(x, 2) -def assert_is_vector(x): +def assert_is_vector(x: torch.Tensor): if x.ndim != 1: raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') -def assert_equal(a, b): +def assert_equal(a: Any, b: Any): if a != b: raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) @@ -43,13 +44,13 @@ def assert_equal(a, b): ) @triton.jit def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, + a: torch.Tensor, + b: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Any, + bins: torch.Tensor, + padded_bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -93,7 +94,8 @@ def _padded_copy( iptr = a if A_TO_B else b optr = b if A_TO_B else a - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -103,7 +105,15 @@ def _padded_copy( offsets += BLOCK_X -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): +def padded_gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -119,7 +129,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() + output_rows = int(padded_bins[-1].cpu().item()) out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, @@ -137,7 +147,14 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): return out -def gather(x, indices, bin_ids, weights, bins, top_k): +def gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -169,7 +186,15 @@ def gather(x, indices, bin_ids, weights, bins, top_k): return out -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -202,7 +227,14 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) -def scatter(x, indices, bin_ids, weights, bins, top_k): +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -225,13 +257,13 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): ) @triton.jit def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, + x: torch.Tensor, + grad: torch.Tensor, + wgrad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -263,7 +295,7 @@ def _padded_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -275,7 +307,15 @@ def _padded_copy_wgrad( tl.store(wgrad, out) -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): +def padded_scatter_wgrad( + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_matrix(grad) @@ -302,7 +342,14 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): return out -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): +def scatter_wgrad( + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, +): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -323,13 +370,13 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): ) @triton.jit def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, + a: torch.Tensor, + b: torch.Tensor, + num_experts: int, + expert_capacity: int, + indices: torch.Tensor, + weights, #: Optional[torch.Tensor], + bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -378,7 +425,7 @@ def _binned_copy( optr = b if A_TO_B else a iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -388,7 +435,14 @@ def _binned_copy( offsets += BLOCK_X -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): +def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + expert_capacity: int, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -400,7 +454,6 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): num_experts = bins.shape[0] out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( x, out, @@ -417,7 +470,13 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): return out -def binned_scatter(x, indices, weights, bins, top_k): +def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_vector(indices) @@ -465,13 +524,13 @@ def binned_scatter(x, indices, weights, bins, top_k): ) @triton.jit def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, + x: torch.Tensor, + grad: torch.Tensor, + wgrad: torch.Tensor, + num_experts: int, + expert_capacity: int, + indices: torch.Tensor, + bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -505,7 +564,7 @@ def _binned_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -517,7 +576,7 @@ def _binned_copy_wgrad( tl.store(wgrad, out) -def binned_scatter_wgrad(x, grad, indices, bins, top_k): +def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_matrix(grad) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 07dbc04..6d3f977 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,20 +1,25 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +import warnings +_grouped_gemm_is_available: bool = False try: import grouped_gemm -except ImportError: - grouped_gemm = None + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') def grouped_gemm_is_available(): - return grouped_gemm is not None + return _grouped_gemm_is_available def assert_grouped_gemm_is_available(): - assert grouped_gemm_is_available( - ), ('Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.') + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg backend = grouped_gemm.backend if grouped_gemm_is_available() else None diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 736d311..a31770b 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,24 +1,24 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Any, Callable, Union -import stk import torch +from stk import Matrix def act_fn( - x: stk.Matrix, + x: Matrix, function: Callable, return_grad_fn: bool = False, **kwargs, -): - assert isinstance(x, stk.Matrix) +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): if return_grad_fn: x.data.requires_grad = True out = function(x.data, **kwargs) - y = stk.Matrix( + y = Matrix( x.size(), out, x.row_indices, diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 82a6f40..5ac7067 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.distributed as dist class AllToAllOp(torch.autograd.Function): @@ -14,7 +15,7 @@ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes ctx.group = group - handle = torch.distributed.all_to_all_single( + handle = dist.all_to_all_single( out, x, output_split_sizes=output_split_sizes, @@ -32,7 +33,7 @@ def backward(ctx, grad, _): device=grad.device, dtype=grad.dtype, ) - torch.distributed.all_to_all_single( + dist.all_to_all_single( out, grad, output_split_sizes=ctx.input_split_sizes, diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index ddbe2b7..892cb91 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -6,12 +6,13 @@ from typing import Any, Callable, Optional, Union import torch +import torch.distributed as dist import torch.nn.functional as F import megablocks.grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. -InitFn = Callable[[torch.Tensor], None] +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] _ALLOWED_BITWIDTHS = (-1, 4, 8) @@ -39,7 +40,7 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None + expert_parallel_group: Optional[dist.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None @@ -51,7 +52,7 @@ class Arguments: # Initialization arguments. fp16: bool = True bf16: bool = False - device: torch.device = torch.cuda.current_device() + device: Union[int, torch.device] = torch.cuda.current_device() init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) output_layer_init_method: InitFn = init_method @@ -60,7 +61,7 @@ class Arguments: # shared expert arguments shared_expert: bool = False # enable using shared expert - fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored shared_expert_hidden_size: Optional[ @@ -75,7 +76,7 @@ def __post_init__(self): self.shared_expert_hidden_size = self.ffn_hidden_size -def from_megatron(megatron_args): +def from_megatron(megatron_args: Any): args = Arguments() for field in dataclasses.fields(args): if hasattr(megatron_args, field.name): diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index e683f8a..377b77f 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk +import stk.ops import torch +from stk import Matrix import megablocks.ops as ops from megablocks.layers import common, dmlp_registry, moe, mpu diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 4654576..e510723 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,7 +1,7 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +import stk.ops import torch from megablocks import grouped_gemm_util as gg @@ -80,6 +80,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) @@ -123,6 +124,7 @@ def backward(ctx, ddsd_out): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( activation_fn_out, ddsd_out, @@ -196,6 +198,7 @@ def forward(self, x, tokens_per_expert): ) # Compute the MLP. + assert gg.ops is not None x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) x1 = self.args.activation_fn(x1) * x2 diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 809e317..4acbd94 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -4,6 +4,7 @@ import gc import torch +import torch.distributed as dist from megablocks.layers import arguments, dmoe @@ -92,9 +93,9 @@ def grad_numel(x): if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _TESTS: diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index f7cb782..e8f2d7b 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -4,6 +4,8 @@ from typing import Any import stk +import stk.backend.triton_kernels +import stk.ops import torch from packaging import version @@ -17,20 +19,20 @@ class ScaleGradient(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, x, scale): + def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod @torch.cuda.amp.custom_bwd - def backward(ctx, grad): + def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None scale_gradient = ScaleGradient.apply -def resolve_dtensor(weight): +def resolve_dtensor(weight: torch.Tensor): if version.parse(torch.__version__) >= version.parse('2.0.0'): from torch.distributed._tensor import DTensor if isinstance(weight, DTensor): @@ -408,6 +410,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) # activation_fn @@ -429,7 +432,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): + def backward(ctx: Any, ddsd_out: torch.Tensor): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') @@ -449,6 +452,7 @@ def backward(ctx, ddsd_out): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( activation_fn_out, ddsd_out, @@ -513,6 +517,7 @@ def forward(self, x, tokens_per_expert): ) # Compute the MLP. + assert gg.ops is not None x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x = self.args.activation_fn(x) return gg.ops.gmm(x, w2, batch_sizes) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index e5eaaa8..9ba5edb 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,8 +1,10 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple import numpy as np import torch +import torch.distributed as dist import megablocks.ops as ops from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry @@ -110,6 +112,7 @@ def __init__(self, args: Arguments): # Expert MLP. self.mlp = mlp.MLP(args) + self.bias: Optional[torch.Tensor] if self.args.bias: # Note that the output bias is not parallelized with expert # model parallelism. @@ -127,12 +130,12 @@ def __init__(self, args: Arguments): # Select the forward function for the operating mode. self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - def expert_capacity(self, tokens): + def expert_capacity(self, tokens: int) -> int: world_size = mpu.get_expert_parallel_world_size(self.args) tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) return int(self.args.moe_capacity_factor * tokens_per_expert) - def load_balancing_loss(self, tokens_per_expert, expert_scores): + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): """Calculate the load balancing loss contribution.""" assert len(expert_scores.size()) == 2 tokens, num_experts = expert_scores.size() @@ -146,7 +149,8 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert): + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Sort the expert ids to produce the scatter/gather # indices for the permutation. # @@ -154,7 +158,9 @@ def indices_and_bins(self, top_expert): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output # Histogram the expert ids to identify the number of # tokens routed to each expert. @@ -166,23 +172,32 @@ def indices_and_bins(self, top_expert): # Calculate the bin bounds for the sorted tokens. bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( self, - x, - tokens_per_expert, # unused - indices, - bin_ids, # unused - expert_weights, - bins, - expert_capacity, - top_k, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output # Perform the expert computation. Note that we don't # use biases for these linear operations. @@ -191,7 +206,7 @@ def permute_and_compute( # Un-route the data for the MoE output. return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - def forward_once(self, x, expert_weights, top_experts): + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] # expert_weights: [sl * bs, top-k] # top_experts: [sl * bs, top-k] @@ -202,7 +217,7 @@ def forward_once(self, x, expert_weights, top_experts): # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - sl, bs, hs = x.size() + sl, bs, _ = x.size() expert_capacity = self.expert_capacity(sl * bs) if expert_capacity == 0: expert_capacity = torch.max(tokens_per_expert).item() @@ -219,7 +234,7 @@ def forward_once(self, x, expert_weights, top_experts): ) return x, tokens_per_expert - def parallel_forward_once(self, x, expert_weights, top_experts): + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # NOTE: This function implements the same computation as forward_once # but with expert model parallelism. # @@ -257,7 +272,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Pass token count information to the device on which the # target expert resides. parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = torch.distributed.all_to_all_single( + tpe_handle = dist.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, group=self.args.expert_parallel_group, @@ -270,7 +285,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output # Compute the number of tokens that will be received from each # device and permute the input data across the devices. @@ -356,7 +373,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - tokens, hs = x.size() + tokens, _ = x.size() expert_capacity = self.expert_capacity(tokens) if expert_capacity == 0: expert_capacity = torch.max(parallel_tokens_per_expert).item() @@ -405,7 +422,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() - def forward(self, x, scores, expert_weights, top_experts): + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): in_shape = x.size() # Compute the experts. @@ -439,7 +456,7 @@ def __init__(self, args: Arguments): def _init_experts_mlp(self, args: Arguments): return ParallelMLP(args) - def forward(self, x): + def forward(self, x: torch.Tensor): # NOTE: If we're going to cast the activations to lower precision # do it before we permute the tokens to save bandwidth. x = common.cast_if_autocast_enabled(x) diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 239f75f..b232139 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,21 +1,31 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + import torch +import torch.distributed as dist from megablocks.layers.arguments import Arguments +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') def get_expert_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) def get_expert_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) def set_expert_model_parallel_attributes( @@ -26,7 +36,7 @@ def set_expert_model_parallel_attributes( setattr(tensor, 'expert_model_parallel', is_parallel) -def param_is_expert_model_parallel(param: torch.Tensor) -> bool: +def param_is_expert_model_parallel(param: MoeParam) -> bool: return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) @@ -42,11 +52,11 @@ def copy_expert_model_parallel_attributes( ) -def synchronized_print(group, *x): - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) for i in range(world_size): - torch.distributed.barrier(group) + dist.barrier(group) if i == rank: print(f'rank = {rank}', *x) @@ -70,9 +80,7 @@ def hidden_sharding_degree(args: Arguments) -> int: raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) if (esd * hsd) != world_size: raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' " - f'({esd}) * hidden_sharding_degree ' - f'({hsd}) != world_size ({world_size}).', + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", ) return hsd diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 42cfbe1..9499870 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch @@ -14,7 +15,7 @@ class _UniformExpertAssignment(torch.autograd.Function): @staticmethod - def forward(ctx, x, num_experts): + def forward(ctx: Any, x: torch.Tensor, num_experts: int): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) @@ -43,18 +44,19 @@ def __init__(self, args: Arguments): ) args.init_method(self.layer.weight) - def jitter(self, x): + def jitter(self, x: torch.Tensor): + assert isinstance(self.args.moe_jitter_eps, float) low = 1.0 - self.args.moe_jitter_eps high = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low) - def _top_k(self, scores): + def _top_k(self, scores: torch.Tensor): if self.args.moe_top_k == 1: return scores.max(dim=-1, keepdim=True) return torch.topk(scores, self.args.moe_top_k, dim=-1) - def forward(self, x): + def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index b3a8537..47b9530 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.distributed as dist from megablocks import benchmark_util from megablocks.layers.all_to_all import all_to_all @@ -29,7 +30,7 @@ def benchmark_all_to_all(group, sl, hs): - world_size = torch.distributed.get_world_size(group) + world_size = dist.get_world_size(group) assert (sl % world_size) == 0 send_recv_sizes = [sl // world_size] * world_size @@ -45,14 +46,14 @@ def benchmark(): time, std = benchmark_util.benchmark_function(benchmark) - if torch.distributed.get_rank(group) == 0: + if dist.get_rank(group) == 0: benchmark_util.log_benchmark('All-To-All', details, time, std) if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _ALL_TO_ALL_BENCHMARK: diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 8a22317..89cce1b 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,14 +13,21 @@ class BinnedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bins, bin_size, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): ctx.save_for_backward(indices, bins) ctx.top_k = top_k return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bins = ctx.saved_tensors out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index f65fbe8..f5ce0d6 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,14 @@ class BinnedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, weights, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): assert len(x.size()) == 3 ctx.bin_size = x.size(1) ctx.top_k = top_k @@ -24,7 +32,7 @@ def forward(ctx, x, indices, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() x, indices, weights, bins = ctx.saved_tensors out = kernels.binned_gather( diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 09b23ab..bf0482a 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -18,7 +20,7 @@ class ExclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int): if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) @@ -35,7 +37,7 @@ def forward(ctx, x, dim): class InclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index a335273..41b09a1 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,14 +13,21 @@ class GatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k return kernels.gather(x, indices, bin_ids, None, bins, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 7660e82..7855233 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -18,7 +20,7 @@ class HistogramOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, max_val): + def forward(ctx: Any, x: torch.Tensor, max_val: float): return ops.histogram(x, max_val) diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index b57a518..f272a77 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,15 @@ class PaddedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k return kernels.padded_gather( @@ -27,7 +36,7 @@ def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins, padded_bins = ctx.saved_tensors diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 1ca1605..9ff81dd 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,16 @@ class PaddedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( indices, @@ -36,7 +46,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index 61bb04b..7e9e09d 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,11 +1,10 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 +import torch -def repeat(x, tiling): +def repeat(x: torch.Tensor, tiling: torch.Size): if all((t == 1 for t in tiling)): return x return x.repeat(*tiling) diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index b7cb9c3..2dbec35 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -17,14 +19,14 @@ class ReplicateOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, bins, num_outputs): + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): ctx.save_for_backward(bins) out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): bins, = ctx.saved_tensors out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) ops.replicate_backward(grad, bins, out) diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index 2c59a78..6cf6bc8 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -4,7 +4,7 @@ import torch -def round_up(x, value): +def round_up(x: torch.Tensor, value: int): assert isinstance(value, int) assert x.dtype == torch.int32 diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 33f051c..a5aaafc 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional + import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +14,15 @@ class ScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k @@ -21,7 +31,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors @@ -58,5 +68,5 @@ def scatter( weights: torch.Tensor, bins: torch.Tensor, top_k: int, -): +) -> Optional[torch.Tensor]: return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 12ec8f3..4fb0aab 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional, Tuple + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -24,7 +26,7 @@ class SortOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, end_bit=None): + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: if end_bit is None: end_bit = _BITS_FOR_DTYPE[x.dtype] x_out = torch.empty_like(x) diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index aa81334..e00c1aa 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,8 +1,9 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +import torch -def sum(x, dim=0): +def sum(x: torch.Tensor, dim: int = 0): if x.shape[dim] == 1: return x.squeeze(dim=dim) return x.sum(dim=dim) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index ba4ade0..b41b5fa 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -19,11 +21,11 @@ class TopologyOp(torch.autograd.Function): @staticmethod def forward( - ctx, - padded_bins, - block_size, - output_block_rows, - output_block_columns, + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, ): out = torch.empty( output_block_rows * output_block_columns, diff --git a/megablocks/py.typed b/megablocks/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index c72dbdf..17e1b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks authors # SPDX-License-Identifier: Apache-2.0 # build requirements diff --git a/setup.py b/setup.py index fa15ee4..202e3da 100644 --- a/setup.py +++ b/setup.py @@ -143,4 +143,5 @@ install_requires=install_requires, extras_require=extra_deps, python_requires='>=3.9', + package_data={_PACKAGE_NAME: ['py.typed']}, )