From a8291410b75956ca0899e36cd139f3cee164fd21 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 13 Aug 2024 16:11:44 +0000 Subject: [PATCH 01/19] add type hints --- .pre-commit-config.yaml | 12 +++++++++++- megablocks/ops/__init__.py | 2 -- megablocks/ops/histogram.py | 4 +++- megablocks/ops/padded_gather.py | 8 ++++++-- megablocks/ops/padded_scatter.py | 8 ++++++-- megablocks/ops/repeat.py | 5 ++--- megablocks/ops/replicate.py | 6 ++++-- megablocks/ops/round_up.py | 2 +- megablocks/ops/scatter.py | 9 +++++++-- megablocks/ops/sort.py | 4 +++- megablocks/ops/sum.py | 8 -------- megablocks/ops/topology.py | 12 +++++++----- pyproject.toml | 2 +- 13 files changed, 51 insertions(+), 31 deletions(-) delete mode 100644 megablocks/ops/sum.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d315f5..3db2e74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,19 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks authors # SPDX-License-Identifier: Apache-2.0 default_language_version: python: python3 repos: +- repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: node + types: [python] + pass_filenames: false + args: [--warnings] + additional_dependencies: ["pyright@1.1.310"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.2 hooks: diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index b9dc286..709290e 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -13,7 +13,6 @@ from megablocks.ops.round_up import round_up from megablocks.ops.scatter import scatter from megablocks.ops.sort import sort -from megablocks.ops.sum import sum from megablocks.ops.topology import topology __all__ = [ @@ -30,6 +29,5 @@ 'round_up', 'scatter', 'sort', - 'sum', 'topology', ] diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 7660e82..7855233 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -18,7 +20,7 @@ class HistogramOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, max_val): + def forward(ctx: Any, x: torch.Tensor, max_val: float): return ops.histogram(x, max_val) diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index b57a518..815791c 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,10 @@ class PaddedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): + def forward( + ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, + padded_bins: torch.Tensor, top_k: int + ): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k return kernels.padded_gather( @@ -27,7 +31,7 @@ def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins, padded_bins = ctx.saved_tensors diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 1ca1605..9546522 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,10 @@ class PaddedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): + def forward( + ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, + bins: torch.Tensor, padded_bins: torch.Tensor, top_k: int + ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( indices, @@ -36,7 +40,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index 61bb04b..7e9e09d 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,11 +1,10 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 +import torch -def repeat(x, tiling): +def repeat(x: torch.Tensor, tiling: torch.Size): if all((t == 1 for t in tiling)): return x return x.repeat(*tiling) diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index b7cb9c3..2dbec35 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -17,14 +19,14 @@ class ReplicateOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, bins, num_outputs): + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): ctx.save_for_backward(bins) out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): bins, = ctx.saved_tensors out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) ops.replicate_backward(grad, bins, out) diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index 2c59a78..6cf6bc8 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -4,7 +4,7 @@ import torch -def round_up(x, value): +def round_up(x: torch.Tensor, value: int): assert isinstance(value, int) assert x.dtype == torch.int32 diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 33f051c..b7ef24d 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +14,10 @@ class ScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, top_k): + def forward( + ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, + bins: torch.Tensor, top_k: int + ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k @@ -21,7 +26,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 12ec8f3..4b88afd 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -24,7 +26,7 @@ class SortOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, end_bit=None): + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None): if end_bit is None: end_bit = _BITS_FOR_DTYPE[x.dtype] x_out = torch.empty_like(x) diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py deleted file mode 100644 index aa81334..0000000 --- a/megablocks/ops/sum.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - - -def sum(x, dim=0): - if x.shape[dim] == 1: - return x.squeeze(dim=dim) - return x.sum(dim=dim) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index ba4ade0..b41b5fa 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -19,11 +21,11 @@ class TopologyOp(torch.autograd.Function): @staticmethod def forward( - ctx, - padded_bins, - block_size, - output_block_rows, - output_block_columns, + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, ): out = torch.empty( output_block_rows * output_block_columns, diff --git a/pyproject.toml b/pyproject.toml index c72dbdf..17e1b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks authors # SPDX-License-Identifier: Apache-2.0 # build requirements From 9eb8a20152aed5c03103656adc1f684d20202401 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 13 Aug 2024 16:48:18 +0000 Subject: [PATCH 02/19] more type checks --- megablocks/ops/binned_gather.py | 7 +++++-- megablocks/ops/binned_scatter.py | 7 +++++-- megablocks/ops/cumsum.py | 6 ++++-- megablocks/ops/gather.py | 7 +++++-- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 8a22317..40c399f 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,14 +13,16 @@ class BinnedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bins, bin_size, top_k): + def forward( + ctx: Any, x: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, bin_size: torch.Tensor, top_k: int + ): ctx.save_for_backward(indices, bins) ctx.top_k = top_k return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bins = ctx.saved_tensors out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index f65fbe8..0cb1392 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,9 @@ class BinnedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, weights, bins, top_k): + def forward( + ctx: Any, x: torch.Tensor, indices: torch.Tensor, weights: torch.Tensor, bins: torch.Tensor, top_k: int + ): assert len(x.size()) == 3 ctx.bin_size = x.size(1) ctx.top_k = top_k @@ -24,7 +27,7 @@ def forward(ctx, x, indices, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() x, indices, weights, bins = ctx.saved_tensors out = kernels.binned_gather( diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 09b23ab..3646969 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -18,7 +20,7 @@ class ExclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int): if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) @@ -35,7 +37,7 @@ def forward(ctx, x, dim): class InclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int): if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index a335273..a048308 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,14 +13,16 @@ class GatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, top_k): + def forward( + ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, top_k: int + ): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k return kernels.gather(x, indices, bin_ids, None, bins, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors From c902a0e468d4f81f52ca6bec1c65400a1fa2d3ed Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 13 Aug 2024 16:54:09 +0000 Subject: [PATCH 03/19] tyoe check router --- megablocks/layers/router.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 42cfbe1..9499870 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch @@ -14,7 +15,7 @@ class _UniformExpertAssignment(torch.autograd.Function): @staticmethod - def forward(ctx, x, num_experts): + def forward(ctx: Any, x: torch.Tensor, num_experts: int): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) @@ -43,18 +44,19 @@ def __init__(self, args: Arguments): ) args.init_method(self.layer.weight) - def jitter(self, x): + def jitter(self, x: torch.Tensor): + assert isinstance(self.args.moe_jitter_eps, float) low = 1.0 - self.args.moe_jitter_eps high = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low) - def _top_k(self, scores): + def _top_k(self, scores: torch.Tensor): if self.args.moe_top_k == 1: return scores.max(dim=-1, keepdim=True) return torch.topk(scores, self.args.moe_top_k, dim=-1) - def forward(self, x): + def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) From 3eb88190c21d996019f6ac7156d6ae403fca87e7 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 13 Aug 2024 20:52:37 +0000 Subject: [PATCH 04/19] more type checking --- megablocks/backend/kernels.py | 115 ++++++++++++++++++++-------------- megablocks/layers/moe.py | 34 +++++----- megablocks/layers/mpu.py | 4 +- 3 files changed, 89 insertions(+), 64 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index b831826..f6a4386 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,26 +1,27 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional import torch import triton import triton.language as tl -def assert_is_tensor(x, ndim): +def assert_is_tensor(x: torch.Tensor, ndim: int): if x.ndim != ndim: raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') -def assert_is_matrix(x): +def assert_is_matrix(x: torch.Tensor): assert_is_tensor(x, 2) -def assert_is_vector(x): +def assert_is_vector(x: torch.Tensor): if x.ndim != 1: raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') -def assert_equal(a, b): +def assert_equal(a: Any, b: Any): if a != b: raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) @@ -43,13 +44,13 @@ def assert_equal(a, b): ) @triton.jit def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, + a: torch.Tensor, + b: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Any, + bins: torch.Tensor, + padded_bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -93,7 +94,8 @@ def _padded_copy( iptr = a if A_TO_B else b optr = b if A_TO_B else a - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -103,7 +105,10 @@ def _padded_copy( offsets += BLOCK_X -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): +def padded_gather( + x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, + padded_bins: torch.Tensor, top_k: int +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -119,7 +124,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() + output_rows = int(padded_bins[-1].cpu().item()) out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, @@ -137,7 +142,10 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): return out -def gather(x, indices, bin_ids, weights, bins, top_k): +def gather( + x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, + top_k: int +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -169,7 +177,10 @@ def gather(x, indices, bin_ids, weights, bins, top_k): return out -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): +def padded_scatter( + x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, + padded_bins: torch.Tensor, top_k: int +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -202,7 +213,9 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) -def scatter(x, indices, bin_ids, weights, bins, top_k): +def scatter( + x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, bins: torch.Tensor, top_k: int +): return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -225,13 +238,13 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): ) @triton.jit def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, + x: torch.Tensor, + grad: torch.Tensor, + wgrad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -263,7 +276,7 @@ def _padded_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -275,7 +288,10 @@ def _padded_copy_wgrad( tl.store(wgrad, out) -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): +def padded_scatter_wgrad( + x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, + padded_bins: torch.Tensor, top_k: int +): # Validate the input shapes. assert_is_matrix(x) assert_is_matrix(grad) @@ -302,7 +318,9 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): return out -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): +def scatter_wgrad( + x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, top_k: int +): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -323,13 +341,13 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): ) @triton.jit def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, + a: torch.Tensor, + b: torch.Tensor, + num_experts: int, + expert_capacity: int, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -378,7 +396,7 @@ def _binned_copy( optr = b if A_TO_B else a iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -388,7 +406,10 @@ def _binned_copy( offsets += BLOCK_X -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): +def binned_gather( + x: torch.Tensor, indices: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, expert_capacity: int, + top_k: int +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -417,7 +438,9 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): return out -def binned_scatter(x, indices, weights, bins, top_k): +def binned_scatter( + x: torch.Tensor, indices: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, top_k: int +): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_vector(indices) @@ -465,13 +488,13 @@ def binned_scatter(x, indices, weights, bins, top_k): ) @triton.jit def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, + x: torch.Tensor, + grad: torch.Tensor, + wgrad: torch.Tensor, + num_experts: int, + expert_capacity: int, + indices: torch.Tensor, + bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -505,7 +528,7 @@ def _binned_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -517,7 +540,7 @@ def _binned_copy_wgrad( tl.store(wgrad, out) -def binned_scatter_wgrad(x, grad, indices, bins, top_k): +def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_matrix(grad) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index e5eaaa8..875e47b 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -127,12 +127,12 @@ def __init__(self, args: Arguments): # Select the forward function for the operating mode. self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - def expert_capacity(self, tokens): + def expert_capacity(self, tokens: int) -> int: world_size = mpu.get_expert_parallel_world_size(self.args) tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) return int(self.args.moe_capacity_factor * tokens_per_expert) - def load_balancing_loss(self, tokens_per_expert, expert_scores): + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): """Calculate the load balancing loss contribution.""" assert len(expert_scores.size()) == 2 tokens, num_experts = expert_scores.size() @@ -146,7 +146,7 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert): + def indices_and_bins(self, top_expert: torch.Tensor): # Sort the expert ids to produce the scatter/gather # indices for the permutation. # @@ -171,14 +171,14 @@ def indices_and_bins(self, top_expert): def permute_and_compute( self, - x, - tokens_per_expert, # unused - indices, - bin_ids, # unused - expert_weights, - bins, - expert_capacity, - top_k, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) @@ -191,7 +191,7 @@ def permute_and_compute( # Un-route the data for the MoE output. return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - def forward_once(self, x, expert_weights, top_experts): + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] # expert_weights: [sl * bs, top-k] # top_experts: [sl * bs, top-k] @@ -202,7 +202,7 @@ def forward_once(self, x, expert_weights, top_experts): # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - sl, bs, hs = x.size() + sl, bs, _ = x.size() expert_capacity = self.expert_capacity(sl * bs) if expert_capacity == 0: expert_capacity = torch.max(tokens_per_expert).item() @@ -219,7 +219,7 @@ def forward_once(self, x, expert_weights, top_experts): ) return x, tokens_per_expert - def parallel_forward_once(self, x, expert_weights, top_experts): + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # NOTE: This function implements the same computation as forward_once # but with expert model parallelism. # @@ -356,7 +356,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - tokens, hs = x.size() + tokens, _ = x.size() expert_capacity = self.expert_capacity(tokens) if expert_capacity == 0: expert_capacity = torch.max(parallel_tokens_per_expert).item() @@ -405,7 +405,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() - def forward(self, x, scores, expert_weights, top_experts): + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): in_shape = x.size() # Compute the experts. @@ -439,7 +439,7 @@ def __init__(self, args: Arguments): def _init_experts_mlp(self, args: Arguments): return ParallelMLP(args) - def forward(self, x): + def forward(self, x: torch.Tensor): # NOTE: If we're going to cast the activations to lower precision # do it before we permute the tokens to save bandwidth. x = common.cast_if_autocast_enabled(x) diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 239f75f..8926acf 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + import torch from megablocks.layers.arguments import Arguments @@ -42,7 +44,7 @@ def copy_expert_model_parallel_attributes( ) -def synchronized_print(group, *x): +def synchronized_print(group: Optional[torch.distributed.ProcessGroup], *x: torch.Tensor): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) for i in range(world_size): From e3f693aff9966bd09f3fcee091f44117ae787325 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 14 Aug 2024 15:44:55 +0000 Subject: [PATCH 05/19] restore sum --- megablocks/ops/sum.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 megablocks/ops/sum.py diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py new file mode 100644 index 0000000..e00c1aa --- /dev/null +++ b/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) From b2842c52ce7008c7503e922b4a0f1b97c62068e0 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 14 Aug 2024 18:49:54 +0000 Subject: [PATCH 06/19] more tests --- megablocks/backend/kernels.py | 67 +++++++++++++++++++++++++-------- megablocks/layers/arguments.py | 8 ++-- megablocks/layers/mlp.py | 8 ++-- megablocks/layers/moe.py | 2 +- megablocks/layers/mpu.py | 4 +- megablocks/ops/binned_gather.py | 7 +++- megablocks/ops/cumsum.py | 2 +- megablocks/ops/scatter.py | 13 +++++-- 8 files changed, 78 insertions(+), 33 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index f6a4386..7727abb 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -106,8 +106,13 @@ def _padded_copy( def padded_gather( - x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, ): # Validate the input shapes. assert_is_matrix(x) @@ -143,8 +148,12 @@ def padded_gather( def gather( - x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, - top_k: int + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, ): # Validate the input shapes. assert_is_matrix(x) @@ -178,9 +187,14 @@ def gather( def padded_scatter( - x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int -): + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -214,8 +228,13 @@ def padded_scatter( def scatter( - x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, bins: torch.Tensor, top_k: int -): + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -289,8 +308,13 @@ def _padded_copy_wgrad( def padded_scatter_wgrad( - x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, ): # Validate the input shapes. assert_is_matrix(x) @@ -319,7 +343,12 @@ def padded_scatter_wgrad( def scatter_wgrad( - x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, top_k: int + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, ): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -407,8 +436,12 @@ def _binned_copy( def binned_gather( - x: torch.Tensor, indices: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, expert_capacity: int, - top_k: int + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + expert_capacity: int, + top_k: int, ): # Validate the input shapes. assert_is_matrix(x) @@ -439,7 +472,11 @@ def binned_gather( def binned_scatter( - x: torch.Tensor, indices: torch.Tensor, weights: Optional[torch.Tensor], bins: torch.Tensor, top_k: int + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, ): # Validate the input shapes. assert_is_tensor(x, 3) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index ddbe2b7..23146b5 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -11,7 +11,7 @@ import megablocks.grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. -InitFn = Callable[[torch.Tensor], None] +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] _ALLOWED_BITWIDTHS = (-1, 4, 8) @@ -51,7 +51,7 @@ class Arguments: # Initialization arguments. fp16: bool = True bf16: bool = False - device: torch.device = torch.cuda.current_device() + device: Union[int, torch.device] = torch.cuda.current_device() init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) output_layer_init_method: InitFn = init_method @@ -60,7 +60,7 @@ class Arguments: # shared expert arguments shared_expert: bool = False # enable using shared expert - fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored shared_expert_hidden_size: Optional[ @@ -75,7 +75,7 @@ def __post_init__(self): self.shared_expert_hidden_size = self.ffn_hidden_size -def from_megatron(megatron_args): +def from_megatron(megatron_args: Any): args = Arguments() for field in dataclasses.fields(args): if hasattr(megatron_args, field.name): diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index f7cb782..98b820b 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -17,20 +17,20 @@ class ScaleGradient(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, x, scale): + def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod @torch.cuda.amp.custom_bwd - def backward(ctx, grad): + def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None scale_gradient = ScaleGradient.apply -def resolve_dtensor(weight): +def resolve_dtensor(weight: torch.Tensor): if version.parse(torch.__version__) >= version.parse('2.0.0'): from torch.distributed._tensor import DTensor if isinstance(weight, DTensor): @@ -429,7 +429,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): + def backward(ctx: Any, ddsd_out: torch.Tensor): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 875e47b..32d51bb 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -146,7 +146,7 @@ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: to expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert: torch.Tensor): + def indices_and_bins(self, top_expert: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): # Sort the expert ids to produce the scatter/gather # indices for the permutation. # diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 8926acf..35e3e4f 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -72,9 +72,7 @@ def hidden_sharding_degree(args: Arguments) -> int: raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) if (esd * hsd) != world_size: raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' " - f'({esd}) * hidden_sharding_degree ' - f'({hsd}) != world_size ({world_size}).', + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", ) return hsd diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 40c399f..89cce1b 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -14,7 +14,12 @@ class BinnedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, bin_size: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, ): ctx.save_for_backward(indices, bins) ctx.top_k = top_k diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 3646969..bf0482a 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -37,7 +37,7 @@ def forward(ctx: Any, x: torch.Tensor, dim: int): class InclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx: Any, x: torch.Tensor, dim: int): + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index b7ef24d..99d7e9b 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -15,9 +15,14 @@ class ScatterOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, - bins: torch.Tensor, top_k: int - ): + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k @@ -63,5 +68,5 @@ def scatter( weights: torch.Tensor, bins: torch.Tensor, top_k: int, -): +) -> torch.Tensor: return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) From e42a92b7a9bde252e1e9a38567d11042e3bebe67 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 14 Aug 2024 19:34:43 +0000 Subject: [PATCH 07/19] more type checking --- megablocks/backend/kernels.py | 3 +-- tests/conftest.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index 7727abb..ca0120b 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -375,7 +375,7 @@ def _binned_copy( num_experts: int, expert_capacity: int, indices: torch.Tensor, - weights: Optional[torch.Tensor], + weights, #: Optional[torch.Tensor], bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, @@ -454,7 +454,6 @@ def binned_gather( num_experts = bins.shape[0] out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( x, out, diff --git a/tests/conftest.py b/tests/conftest.py index 663bda3..151fb54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,9 @@ import pytest from composer.utils import reproducibility +from icecream import install + +install() # Allowed options for pytest.mark.world_size() WORLD_SIZE_OPTIONS = (1, 2) From 4ebf3d5c1dff4be4801b554181723414bfae8b31 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 21 Aug 2024 15:17:43 +0000 Subject: [PATCH 08/19] more updates --- megablocks/layers/moe.py | 9 ++++++++- megablocks/layers/mpu.py | 8 ++++++-- megablocks/ops/__init__.py | 2 ++ megablocks/ops/binned_scatter.py | 7 ++++++- megablocks/ops/gather.py | 7 ++++++- megablocks/ops/padded_gather.py | 9 +++++++-- megablocks/ops/padded_scatter.py | 10 ++++++++-- megablocks/ops/scatter.py | 4 ++-- 8 files changed, 45 insertions(+), 11 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 32d51bb..b227f70 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Tuple import numpy as np import torch @@ -146,7 +147,7 @@ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: to expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + def indices_and_bins(self, top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Sort the expert ids to produce the scatter/gather # indices for the permutation. # @@ -167,6 +168,12 @@ def indices_and_bins(self, top_expert: torch.Tensor) -> (torch.Tensor, torch.Ten # Calculate the bin bounds for the sorted tokens. bins = ops.inclusive_cumsum(tokens_per_expert, 0) bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 35e3e4f..41291d9 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -8,6 +8,11 @@ from megablocks.layers.arguments import Arguments +class MoeParam(torch.Tensor): + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') @@ -27,8 +32,7 @@ def set_expert_model_parallel_attributes( assert not hasattr(tensor, 'expert_model_parallel') setattr(tensor, 'expert_model_parallel', is_parallel) - -def param_is_expert_model_parallel(param: torch.Tensor) -> bool: +def param_is_expert_model_parallel(param: MoeParam) -> bool: return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 709290e..b9dc286 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -13,6 +13,7 @@ from megablocks.ops.round_up import round_up from megablocks.ops.scatter import scatter from megablocks.ops.sort import sort +from megablocks.ops.sum import sum from megablocks.ops.topology import topology __all__ = [ @@ -29,5 +30,6 @@ 'round_up', 'scatter', 'sort', + 'sum', 'topology', ] diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index 0cb1392..f5ce0d6 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -14,7 +14,12 @@ class BinnedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, weights: torch.Tensor, bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, ): assert len(x.size()) == 3 ctx.bin_size = x.size(1) diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index a048308..41b09a1 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -14,7 +14,12 @@ class GatherOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, ): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 815791c..f272a77 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -14,8 +14,13 @@ class PaddedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, ): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 9546522..9ff81dd 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -14,8 +14,14 @@ class PaddedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, - bins: torch.Tensor, padded_bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 99d7e9b..a5aaafc 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,7 +1,7 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -68,5 +68,5 @@ def scatter( weights: torch.Tensor, bins: torch.Tensor, top_k: int, -) -> torch.Tensor: +) -> Optional[torch.Tensor]: return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) From 78e4dd52d284ce848feb8b7db2688f20c41f7961 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 01:57:44 +0000 Subject: [PATCH 09/19] add py.typed --- megablocks/py.typed | 0 setup.py | 1 + 2 files changed, 1 insertion(+) create mode 100644 megablocks/py.typed diff --git a/megablocks/py.typed b/megablocks/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index fa15ee4..202e3da 100644 --- a/setup.py +++ b/setup.py @@ -143,4 +143,5 @@ install_requires=install_requires, extras_require=extra_deps, python_requires='>=3.9', + package_data={_PACKAGE_NAME: ['py.typed']}, ) From 34c302470e377c44f0bb7e09666c16032dc39526 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:00:18 +0000 Subject: [PATCH 10/19] git rid of stk type errors --- megablocks/layers/dmoe.py | 3 ++- megablocks/layers/glu.py | 2 +- megablocks/layers/mlp.py | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index e683f8a..961e8af 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk +import stk.ops +import stk.Matrix import torch import megablocks.ops as ops diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 4654576..1352919 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,7 +1,7 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +import stk.ops import torch from megablocks import grouped_gemm_util as gg diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 98b820b..173bfea 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -5,6 +5,8 @@ import stk import torch +import stk.ops +import stk.backend.triton_kernels from packaging import version from megablocks import grouped_gemm_util as gg From 0cc49f8a2fa650e8e5d2a0dca93c5bf7f57116d6 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:00:52 +0000 Subject: [PATCH 11/19] remove icecream package --- tests/conftest.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 151fb54..663bda3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,9 +6,6 @@ import pytest from composer.utils import reproducibility -from icecream import install - -install() # Allowed options for pytest.mark.world_size() WORLD_SIZE_OPTIONS = (1, 2) From 4095c1b6a3b0351fa7104bda6164be2280c5556b Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:04:25 +0000 Subject: [PATCH 12/19] fix matrix import --- megablocks/layers/dmoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 961e8af..4406695 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -3,7 +3,7 @@ import numpy as np import stk.ops -import stk.Matrix +from stk import Matrix import torch import megablocks.ops as ops From 8617b394ac48edaab47b2ecbddc940afe8237751 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:14:48 +0000 Subject: [PATCH 13/19] add type hints --- megablocks/layers/activation_fn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 736d311..6093cc9 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,24 +1,24 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Callable, Union, Any -import stk +from stk import Matrix import torch def act_fn( - x: stk.Matrix, + x: Matrix, function: Callable, return_grad_fn: bool = False, **kwargs, -): - assert isinstance(x, stk.Matrix) +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): if return_grad_fn: x.data.requires_grad = True out = function(x.data, **kwargs) - y = stk.Matrix( + y = Matrix( x.size(), out, x.row_indices, From 0a122a0bfa20728ee4328c4ef7e49bc204bf93a4 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:20:00 +0000 Subject: [PATCH 14/19] fix all torch.distibuted type errors --- megablocks/layers/all_to_all.py | 6 +++--- megablocks/layers/arguments.py | 3 ++- megablocks/layers/memory_test.py | 7 ++++--- megablocks/layers/moe.py | 3 ++- megablocks/layers/mpu.py | 13 +++++++------ 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 82a6f40..7fdb644 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch - +import torch.distributed as dist class AllToAllOp(torch.autograd.Function): @@ -14,7 +14,7 @@ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes ctx.group = group - handle = torch.distributed.all_to_all_single( + handle = dist.all_to_all_single( out, x, output_split_sizes=output_split_sizes, @@ -32,7 +32,7 @@ def backward(ctx, grad, _): device=grad.device, dtype=grad.dtype, ) - torch.distributed.all_to_all_single( + dist.all_to_all_single( out, grad, output_split_sizes=ctx.input_split_sizes, diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 23146b5..bc344e5 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F +import torch.distributed as dist import megablocks.grouped_gemm_util as grouped_gemm @@ -39,7 +40,7 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None + expert_parallel_group: Optional[dist.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 809e317..4acbd94 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -4,6 +4,7 @@ import gc import torch +import torch.distributed as dist from megablocks.layers import arguments, dmoe @@ -92,9 +93,9 @@ def grad_numel(x): if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _TESTS: diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index b227f70..a59d91b 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.distributed as dist import megablocks.ops as ops from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry @@ -264,7 +265,7 @@ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, t # Pass token count information to the device on which the # target expert resides. parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = torch.distributed.all_to_all_single( + tpe_handle = dist.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, group=self.args.expert_parallel_group, diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 41291d9..9d35df2 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -4,6 +4,7 @@ from typing import Optional import torch +import torch.distributed as dist from megablocks.layers.arguments import Arguments @@ -18,11 +19,11 @@ def is_moe_param(tensor: torch.Tensor) -> bool: def get_expert_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) def get_expert_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) def set_expert_model_parallel_attributes( @@ -48,11 +49,11 @@ def copy_expert_model_parallel_attributes( ) -def synchronized_print(group: Optional[torch.distributed.ProcessGroup], *x: torch.Tensor): - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) for i in range(world_size): - torch.distributed.barrier(group) + dist.barrier(group) if i == rank: print(f'rank = {rank}', *x) From 3140c8a5ec069f5acd720aaad898dfd8ab0ca1e8 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:21:00 +0000 Subject: [PATCH 15/19] fix more torch.distibuted type errors --- megablocks/ops/all_to_all_benchmark.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index b3a8537..47b9530 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.distributed as dist from megablocks import benchmark_util from megablocks.layers.all_to_all import all_to_all @@ -29,7 +30,7 @@ def benchmark_all_to_all(group, sl, hs): - world_size = torch.distributed.get_world_size(group) + world_size = dist.get_world_size(group) assert (sl % world_size) == 0 send_recv_sizes = [sl // world_size] * world_size @@ -45,14 +46,14 @@ def benchmark(): time, std = benchmark_util.benchmark_function(benchmark) - if torch.distributed.get_rank(group) == 0: + if dist.get_rank(group) == 0: benchmark_util.log_benchmark('All-To-All', details, time, std) if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _ALL_TO_ALL_BENCHMARK: From a1176e486fce545b50695efa5a516b0ddcf0436b Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 02:42:34 +0000 Subject: [PATCH 16/19] fix all gmm type errors --- megablocks/grouped_gemm_util.py | 18 +++++++++++------- megablocks/layers/glu.py | 3 +++ megablocks/layers/mlp.py | 3 +++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 07dbc04..9bbadad 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,20 +1,24 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +import warnings +_grouped_gemm_is_available: bool = False try: import grouped_gemm -except ImportError: - grouped_gemm = None - + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn("Grouped GEMM not available.") def grouped_gemm_is_available(): - return grouped_gemm is not None + return _grouped_gemm_is_available def assert_grouped_gemm_is_available(): - assert grouped_gemm_is_available( - ), ('Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.') + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg backend = grouped_gemm.backend if grouped_gemm_is_available() else None diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 1352919..e510723 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -80,6 +80,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) @@ -123,6 +124,7 @@ def backward(ctx, ddsd_out): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( activation_fn_out, ddsd_out, @@ -196,6 +198,7 @@ def forward(self, x, tokens_per_expert): ) # Compute the MLP. + assert gg.ops is not None x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) x1 = self.args.activation_fn(x1) * x2 diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 173bfea..2dc1224 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -410,6 +410,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) # activation_fn @@ -451,6 +452,7 @@ def backward(ctx: Any, ddsd_out: torch.Tensor): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( activation_fn_out, ddsd_out, @@ -515,6 +517,7 @@ def forward(self, x, tokens_per_expert): ) # Compute the MLP. + assert gg.ops is not None x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x = self.args.activation_fn(x) return gg.ops.gmm(x, w2, batch_sizes) From 53d2224664adfca7a05b32992d4d946da73ab71f Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 03:02:14 +0000 Subject: [PATCH 17/19] more type checking --- megablocks/layers/moe.py | 16 ++++++++++++---- megablocks/ops/sort.py | 5 +++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index a59d91b..1040519 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,6 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Tuple +from typing import Tuple, Optional import numpy as np import torch @@ -112,6 +112,7 @@ def __init__(self, args: Arguments): # Expert MLP. self.mlp = mlp.MLP(args) + self.bias: Optional[torch.Tensor] if self.args.bias: # Note that the output bias is not parallelized with expert # model parallelism. @@ -156,7 +157,9 @@ def indices_and_bins(self, top_expert: torch.Tensor) -> Tuple[torch.Tensor, torc # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output # Histogram the expert ids to identify the number of # tokens routed to each expert. @@ -168,6 +171,7 @@ def indices_and_bins(self, top_expert: torch.Tensor) -> Tuple[torch.Tensor, torc # Calculate the bin bounds for the sorted tokens. bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins assert isinstance(indices, torch.Tensor) @@ -190,7 +194,9 @@ def permute_and_compute( ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output # Perform the expert computation. Note that we don't # use biases for these linear operations. @@ -278,7 +284,9 @@ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, t # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output # Compute the number of tokens that will be received from each # device and permute the input data across the devices. diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 4b88afd..a551557 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,7 +1,7 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional +from typing import Any, Optional, Tuple # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. @@ -26,7 +26,7 @@ class SortOp(torch.autograd.Function): @staticmethod - def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None): + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: if end_bit is None: end_bit = _BITS_FOR_DTYPE[x.dtype] x_out = torch.empty_like(x) @@ -36,3 +36,4 @@ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None): sort = SortOp.apply + From 1155ac627af3df65b47f1c6d2ed1c1691963afbc Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 03:02:56 +0000 Subject: [PATCH 18/19] comment out type checking --- .pre-commit-config.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3db2e74..c754b29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,16 +4,16 @@ default_language_version: python: python3 repos: -- repo: local - hooks: - - id: pyright - name: pyright - entry: pyright - language: node - types: [python] - pass_filenames: false - args: [--warnings] - additional_dependencies: ["pyright@1.1.310"] +# - repo: local +# hooks: +# - id: pyright +# name: pyright +# entry: pyright +# language: node +# types: [python] +# pass_filenames: false +# args: [--warnings] +# additional_dependencies: ["pyright@1.1.310"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.2 hooks: From ec2fd2bb3b24000eec64896c99292a011c6e100b Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Tue, 27 Aug 2024 03:03:29 +0000 Subject: [PATCH 19/19] update --- megablocks/grouped_gemm_util.py | 5 +++-- megablocks/layers/activation_fn.py | 4 ++-- megablocks/layers/all_to_all.py | 1 + megablocks/layers/arguments.py | 2 +- megablocks/layers/dmoe.py | 2 +- megablocks/layers/mlp.py | 4 ++-- megablocks/layers/moe.py | 5 +++-- megablocks/layers/mpu.py | 3 +++ megablocks/ops/sort.py | 1 - 9 files changed, 16 insertions(+), 11 deletions(-) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 9bbadad..6d3f977 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -7,7 +7,8 @@ import grouped_gemm _grouped_gemm_is_available = True except ImportError as error: - warnings.warn("Grouped GEMM not available.") + warnings.warn('Grouped GEMM not available.') + def grouped_gemm_is_available(): return _grouped_gemm_is_available @@ -17,7 +18,7 @@ def assert_grouped_gemm_is_available(): msg = ( 'Grouped GEMM not available. Please run ' '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', - ) + ) assert _grouped_gemm_is_available, msg diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 6093cc9..a31770b 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,10 +1,10 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Union, Any +from typing import Any, Callable, Union -from stk import Matrix import torch +from stk import Matrix def act_fn( diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 7fdb644..5ac7067 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist + class AllToAllOp(torch.autograd.Function): @staticmethod diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index bc344e5..892cb91 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -6,8 +6,8 @@ from typing import Any, Callable, Optional, Union import torch -import torch.nn.functional as F import torch.distributed as dist +import torch.nn.functional as F import megablocks.grouped_gemm_util as grouped_gemm diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 4406695..377b77f 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -3,8 +3,8 @@ import numpy as np import stk.ops -from stk import Matrix import torch +from stk import Matrix import megablocks.ops as ops from megablocks.layers import common, dmlp_registry, moe, mpu diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 2dc1224..e8f2d7b 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -4,9 +4,9 @@ from typing import Any import stk -import torch -import stk.ops import stk.backend.triton_kernels +import stk.ops +import torch from packaging import version from megablocks import grouped_gemm_util as gg diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 1040519..9ba5edb 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,6 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Tuple, Optional +from typing import Optional, Tuple import numpy as np import torch @@ -149,7 +149,8 @@ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: to expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Sort the expert ids to produce the scatter/gather # indices for the permutation. # diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 9d35df2..b232139 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -10,10 +10,12 @@ class MoeParam(torch.Tensor): + def __init__(self): super().__init__(self) self.expert_model_parallel: bool + def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') @@ -33,6 +35,7 @@ def set_expert_model_parallel_attributes( assert not hasattr(tensor, 'expert_model_parallel') setattr(tensor, 'expert_model_parallel', is_parallel) + def param_is_expert_model_parallel(param: MoeParam) -> bool: return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index a551557..4fb0aab 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -36,4 +36,3 @@ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[t sort = SortOp.apply -