diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 32d51bb..b227f70 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Tuple import numpy as np import torch @@ -146,7 +147,7 @@ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: to expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): + def indices_and_bins(self, top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Sort the expert ids to produce the scatter/gather # indices for the permutation. # @@ -167,6 +168,12 @@ def indices_and_bins(self, top_expert: torch.Tensor) -> (torch.Tensor, torch.Ten # Calculate the bin bounds for the sorted tokens. bins = ops.inclusive_cumsum(tokens_per_expert, 0) bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 35e3e4f..41291d9 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -8,6 +8,11 @@ from megablocks.layers.arguments import Arguments +class MoeParam(torch.Tensor): + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') @@ -27,8 +32,7 @@ def set_expert_model_parallel_attributes( assert not hasattr(tensor, 'expert_model_parallel') setattr(tensor, 'expert_model_parallel', is_parallel) - -def param_is_expert_model_parallel(param: torch.Tensor) -> bool: +def param_is_expert_model_parallel(param: MoeParam) -> bool: return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 709290e..b9dc286 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -13,6 +13,7 @@ from megablocks.ops.round_up import round_up from megablocks.ops.scatter import scatter from megablocks.ops.sort import sort +from megablocks.ops.sum import sum from megablocks.ops.topology import topology __all__ = [ @@ -29,5 +30,6 @@ 'round_up', 'scatter', 'sort', + 'sum', 'topology', ] diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index 0cb1392..f5ce0d6 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -14,7 +14,12 @@ class BinnedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, weights: torch.Tensor, bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, ): assert len(x.size()) == 3 ctx.bin_size = x.size(1) diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index a048308..41b09a1 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -14,7 +14,12 @@ class GatherOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, ): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 815791c..f272a77 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -14,8 +14,13 @@ class PaddedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, ): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 9546522..9ff81dd 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -14,8 +14,14 @@ class PaddedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx: Any, x: torch.Tensor, indices: torch.Tensor, bin_ids: torch.Tensor, weights: torch.Tensor, - bins: torch.Tensor, padded_bins: torch.Tensor, top_k: int + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 99d7e9b..a5aaafc 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,7 +1,7 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Any +from typing import Any, Optional import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -68,5 +68,5 @@ def scatter( weights: torch.Tensor, bins: torch.Tensor, top_k: int, -) -> torch.Tensor: +) -> Optional[torch.Tensor]: return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)