diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 22391128..11404971 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -38,6 +38,7 @@ class Arguments: # Compute arguments. memory_optimized_mlp : bool = False + mlp_type : str = 'mlp' grouped_mlp : bool = False quantize_inputs_num_bits: int = -1 # -1 = no quantization quantize_rematerialize_num_bits: int = -1 diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py new file mode 100644 index 00000000..02a95bc9 --- /dev/null +++ b/megablocks/layers/dmlp_registry.py @@ -0,0 +1,35 @@ +from typing import Union +from megablocks.layers import mlp +from megablocks.layers import glu +from megablocks.layers.arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': {'grouped': mlp.GroupedMLP, 'sparse' : mlp.SparseMLP}, + 'glu': {'grouped': glu.GroupedGLU, 'sparse': glu.SparseGLU}, +} + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + mlp_impl = 'grouped' if args.grouped_mlp else 'sparse' + + if mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {mlp_impl} backend.') + + return _REGISTRY[args.mlp_type][mlp_impl](args) \ No newline at end of file diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 4968aa63..837e31ed 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,6 +1,6 @@ from megablocks.layers import common -from megablocks.layers import mlp from megablocks.layers import moe +from megablocks.layers import dmlp_registry from megablocks.layers import mpu from megablocks.layers import router from megablocks.layers.arguments import Arguments @@ -9,11 +9,9 @@ import stk import torch - def promote_scalar(x): return x.view(1) if not len(x.size()) else x - class ParallelDroplessMLP(moe.ParallelMLP): def __init__(self, args : Arguments): @@ -21,13 +19,7 @@ def __init__(self, args : Arguments): self.hidden_size = args.hidden_size self.ffn_hidden_size = mpu.features_per_rank(args) self.blocking = 128 - - # Grouped or sparse MLP. - self.mlp = ( - mlp.GroupedMLP(args) - if args.grouped_mlp - else mlp.SparseMLP(args) - ) + self.mlp = dmlp_registry.get(args) # Calculate the number of bits needed to represent the column indices # in the intermediate sparse matrix. diff --git a/megablocks/layers/dmoe_test.py b/megablocks/layers/dmoe_test.py index b932ef46..f0ef8953 100644 --- a/megablocks/layers/dmoe_test.py +++ b/megablocks/layers/dmoe_test.py @@ -30,6 +30,7 @@ def test_modules( memory_optimized_mlp=True, quantize_inputs_num_bits=num_input_bits, quantize_rematerialize_num_bits=num_remat_bits, + mlp_type='mlp', grouped_mlp=grouped_mlp, fp16=False, bf16=True) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py new file mode 100644 index 00000000..0a1fe6b5 --- /dev/null +++ b/megablocks/layers/glu.py @@ -0,0 +1,60 @@ +from megablocks.layers import common +from megablocks.layers import gelu +from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights +from megablocks.layers import mpu +from megablocks.layers.arguments import Arguments, InitFn +from megablocks import grouped_gemm_util as gg +import stk +import torch +import torch.nn.functional as F + + +class SparseGLU(SparseMLP): + + def __init__(self, args : Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter(torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args))) + with torch.no_grad(): + self.v1.copy_(create_dmoe_expert_weights( + args, args.moe_num_experts, args.ffn_hidden_size, + args.hidden_size, args.init_method)) + + mpu.set_expert_model_parallel_attributes( + self.v1, self._should_set_parallelism_attribute) + + if self.args.moe_weight_parallelism: + raise NotImplementedError("Weight parallelism not yet supported with GLU.") + elif self.args.memory_optimized_mlp: + raise NotImplementedError("Memory optimized implementation not yet supported with GLU.") + + def forward(self, x, topo): + w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2)) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + x1 = stk.ops.mul(gelu.gelu(x1), x2) + + return stk.ops.dsd(x1, w2) + +class GroupedGLU(SparseGLU): + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + # Compute the MLP. + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = F.gelu(x1, approximate="tanh") * x2 + return gg.ops.gmm(x1, w2, batch_sizes) diff --git a/megablocks/layers/glu_test.py b/megablocks/layers/glu_test.py new file mode 100644 index 00000000..5d7319da --- /dev/null +++ b/megablocks/layers/glu_test.py @@ -0,0 +1,88 @@ +import unittest +from functools import partial + +from absl.testing import parameterized +from megablocks.layers.arguments import Arguments +from megablocks.layers.glu import SparseGLU, GroupedGLU +from megablocks.layers import testing + +import torch +import stk +import numpy as np + +def test_modules( + hidden_size, + ffn_hidden_size, + grouped_mlp=False): + init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) + args = Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=1, + moe_top_k=1, + init_method=init_method, + memory_optimized_mlp=False, + mlp_type='glu', + grouped_mlp=grouped_mlp, + fp16=False, + bf16=True) + + glu = testing.GLU(args) + dmoe_glu = GroupedGLU(args) if grouped_mlp else SparseGLU(args) + + dmoe_glu.cuda(torch.cuda.current_device()).to(torch.bfloat16) + glu.cuda(torch.cuda.current_device()).to(torch.bfloat16) + + with torch.no_grad(): + glu.w1.copy_(dmoe_glu.w1.T) + glu.v1.copy_(dmoe_glu.v1.T) + glu.w2.copy_(dmoe_glu.w2) + + return args, glu, dmoe_glu + +_DENSE_TESTS = ( + (16, 1024, 512), + (8, 2048, 512), +) + +class GLUTest(parameterized.TestCase): + + @parameterized.parameters(*_DENSE_TESTS) + def testGLU_ForwardGroupedMLP(self, bs, sl, hs): + x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() + + _, glu, dmoe_glu = test_modules( + hidden_size=hs, + ffn_hidden_size=hs * 2, + grouped_mlp=True) + + expected_out = glu(x) + tokens_per_expert = torch.tensor([bs * sl]).cuda() + out = dmoe_glu(x.view(bs * sl, hs), tokens_per_expert) + out = out.view(sl, bs, hs) + + self.assertSequenceEqual(out.shape, x.shape) + self.assertSequenceEqual(expected_out.shape, x.shape) + self.assertTrue(testing.allclose(out, expected_out)) + + @parameterized.parameters(*_DENSE_TESTS) + def testGLU_ForwardSparseMLP(self, bs, sl, hs): + x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() + + _, glu, dmoe_glu = test_modules( + hidden_size=hs, + ffn_hidden_size=hs * 2, + grouped_mlp=False) + + expected_out = glu(x) + with torch.no_grad(): + topo = stk.random.mask(bs * sl, hs * 2, 0, blocking=128).cuda() + out = dmoe_glu(x.view(bs * sl, hs), topo) + out = out.view(sl, bs, hs) + + self.assertSequenceEqual(out.shape, x.shape) + self.assertSequenceEqual(expected_out.shape, x.shape) + self.assertTrue(testing.allclose(out, expected_out)) + +if __name__ == '__main__': + unittest.main() diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index abe1cbdd..7894b43c 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -304,18 +304,18 @@ class SparseMLP(torch.nn.Module): def __init__(self, args : Arguments): super().__init__() self.args = args - num_rows_per_rank = ( + self._num_rows_per_rank = ( (mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // mpu.get_weight_parallel_world_size(args) ) self.w1 = torch.nn.Parameter(torch.empty( - num_rows_per_rank, + self._num_rows_per_rank, args.hidden_size, device=args.device, dtype=common.dtype(args))) self.w2 = torch.nn.Parameter(torch.empty( - num_rows_per_rank, + self._num_rows_per_rank, args.hidden_size, device=args.device, dtype=common.dtype(args))) @@ -336,12 +336,12 @@ def __init__(self, args : Arguments): args, args.moe_num_experts, args.ffn_hidden_size, args.hidden_size, args.output_layer_init_method)) - should_set_attribute = ( + self._should_set_parallelism_attribute = ( args.moe_expert_model_parallelism or args.moe_weight_parallelism) mpu.set_expert_model_parallel_attributes( - self.w1, should_set_attribute) + self.w1, self._should_set_parallelism_attribute) mpu.set_expert_model_parallel_attributes( - self.w2, should_set_attribute) + self.w2, self._should_set_parallelism_attribute) self.gradient_scale = None if self.args.moe_expert_model_parallelism: diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index 764f6dab..530026ec 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -30,3 +30,17 @@ def __init__(self, args : Arguments): def forward(self, x): return torch.matmul(F.gelu( torch.matmul(x, self.w1), approximate="tanh"), self.w2) + +class GLU(FFN): + + def __init__(self, args : Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter(torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32)) + + def forward(self, x): + x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1) + return torch.matmul(x1, self.w2)