Skip to content

Commit

Permalink
Merge pull request #38 from sashaDoubov/sasha/glu
Browse files Browse the repository at this point in the history
Add GLU support
  • Loading branch information
tgale96 authored Dec 1, 2023
2 parents 8b959f2 + ee5ff20 commit 059ae20
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 16 deletions.
1 change: 1 addition & 0 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Arguments:

# Compute arguments.
memory_optimized_mlp : bool = False
mlp_type : str = 'mlp'
grouped_mlp : bool = False
quantize_inputs_num_bits: int = -1 # -1 = no quantization
quantize_rematerialize_num_bits: int = -1
Expand Down
35 changes: 35 additions & 0 deletions megablocks/layers/dmlp_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Union
from megablocks.layers import mlp
from megablocks.layers import glu
from megablocks.layers.arguments import Arguments

MlpType = Union[mlp.SparseMLP, glu.SparseGLU]

_REGISTRY = {
'mlp': {'grouped': mlp.GroupedMLP, 'sparse' : mlp.SparseMLP},
'glu': {'grouped': glu.GroupedGLU, 'sparse': glu.SparseGLU},
}

def get(args: Arguments) -> MlpType:
"""Returns an MLP for use in a dMoE instance.
Uses the provided arguments to instantiate the appropriate
MLP instance. This only contains MLPs for use in dMoEs
(ie. only for the dropless versions of MoEs).
Args:
args: propagated Arguments dataclass.
Returns:
An instantiated MLP constructed using the input args.
"""
if args.mlp_type not in _REGISTRY:
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')

mlp_impl = 'grouped' if args.grouped_mlp else 'sparse'

if mlp_impl not in _REGISTRY[args.mlp_type]:
raise ValueError(f'{args.mlp_type} does not support {mlp_impl} backend.')

return _REGISTRY[args.mlp_type][mlp_impl](args)
12 changes: 2 additions & 10 deletions megablocks/layers/dmoe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from megablocks.layers import common
from megablocks.layers import mlp
from megablocks.layers import moe
from megablocks.layers import dmlp_registry
from megablocks.layers import mpu
from megablocks.layers import router
from megablocks.layers.arguments import Arguments
Expand All @@ -9,25 +9,17 @@
import stk
import torch


def promote_scalar(x):
return x.view(1) if not len(x.size()) else x


class ParallelDroplessMLP(moe.ParallelMLP):

def __init__(self, args : Arguments):
super(ParallelDroplessMLP, self).__init__(args)
self.hidden_size = args.hidden_size
self.ffn_hidden_size = mpu.features_per_rank(args)
self.blocking = 128

# Grouped or sparse MLP.
self.mlp = (
mlp.GroupedMLP(args)
if args.grouped_mlp
else mlp.SparseMLP(args)
)
self.mlp = dmlp_registry.get(args)

# Calculate the number of bits needed to represent the column indices
# in the intermediate sparse matrix.
Expand Down
1 change: 1 addition & 0 deletions megablocks/layers/dmoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_modules(
memory_optimized_mlp=True,
quantize_inputs_num_bits=num_input_bits,
quantize_rematerialize_num_bits=num_remat_bits,
mlp_type='mlp',
grouped_mlp=grouped_mlp,
fp16=False,
bf16=True)
Expand Down
60 changes: 60 additions & 0 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from megablocks.layers import common
from megablocks.layers import gelu
from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights
from megablocks.layers import mpu
from megablocks.layers.arguments import Arguments, InitFn
from megablocks import grouped_gemm_util as gg
import stk
import torch
import torch.nn.functional as F


class SparseGLU(SparseMLP):

def __init__(self, args : Arguments):
super().__init__(args)
self.v1 = torch.nn.Parameter(torch.empty(
self._num_rows_per_rank,
args.hidden_size,
device=args.device,
dtype=common.dtype(args)))
with torch.no_grad():
self.v1.copy_(create_dmoe_expert_weights(
args, args.moe_num_experts, args.ffn_hidden_size,
args.hidden_size, args.init_method))

mpu.set_expert_model_parallel_attributes(
self.v1, self._should_set_parallelism_attribute)

if self.args.moe_weight_parallelism:
raise NotImplementedError("Weight parallelism not yet supported with GLU.")
elif self.args.memory_optimized_mlp:
raise NotImplementedError("Memory optimized implementation not yet supported with GLU.")

def forward(self, x, topo):
w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2))

# Compute the GLU.
x1 = stk.ops.sdd(x, w1.t(), topo)
x2 = stk.ops.sdd(x, v1.t(), topo)

x1 = stk.ops.mul(gelu.gelu(x1), x2)

return stk.ops.dsd(x1, w2)

class GroupedGLU(SparseGLU):
def forward(self, x, tokens_per_expert):
batch_sizes = tokens_per_expert.cpu().to(torch.long)
w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2))

# Re-shape the weights for the grouped GEMMs.
ne = mpu.experts_per_rank(self.args)
w1 = w1.view(ne, -1, self.args.hidden_size)
v1 = v1.view(ne, -1, self.args.hidden_size)
w2 = w2.view(ne, -1, self.args.hidden_size)

# Compute the MLP.
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
x1 = F.gelu(x1, approximate="tanh") * x2
return gg.ops.gmm(x1, w2, batch_sizes)
88 changes: 88 additions & 0 deletions megablocks/layers/glu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest
from functools import partial

from absl.testing import parameterized
from megablocks.layers.arguments import Arguments
from megablocks.layers.glu import SparseGLU, GroupedGLU
from megablocks.layers import testing

import torch
import stk
import numpy as np

def test_modules(
hidden_size,
ffn_hidden_size,
grouped_mlp=False):
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
args = Arguments(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
moe_num_experts=1,
moe_top_k=1,
init_method=init_method,
memory_optimized_mlp=False,
mlp_type='glu',
grouped_mlp=grouped_mlp,
fp16=False,
bf16=True)

glu = testing.GLU(args)
dmoe_glu = GroupedGLU(args) if grouped_mlp else SparseGLU(args)

dmoe_glu.cuda(torch.cuda.current_device()).to(torch.bfloat16)
glu.cuda(torch.cuda.current_device()).to(torch.bfloat16)

with torch.no_grad():
glu.w1.copy_(dmoe_glu.w1.T)
glu.v1.copy_(dmoe_glu.v1.T)
glu.w2.copy_(dmoe_glu.w2)

return args, glu, dmoe_glu

_DENSE_TESTS = (
(16, 1024, 512),
(8, 2048, 512),
)

class GLUTest(parameterized.TestCase):

@parameterized.parameters(*_DENSE_TESTS)
def testGLU_ForwardGroupedMLP(self, bs, sl, hs):
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()

_, glu, dmoe_glu = test_modules(
hidden_size=hs,
ffn_hidden_size=hs * 2,
grouped_mlp=True)

expected_out = glu(x)
tokens_per_expert = torch.tensor([bs * sl]).cuda()
out = dmoe_glu(x.view(bs * sl, hs), tokens_per_expert)
out = out.view(sl, bs, hs)

self.assertSequenceEqual(out.shape, x.shape)
self.assertSequenceEqual(expected_out.shape, x.shape)
self.assertTrue(testing.allclose(out, expected_out))

@parameterized.parameters(*_DENSE_TESTS)
def testGLU_ForwardSparseMLP(self, bs, sl, hs):
x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda()

_, glu, dmoe_glu = test_modules(
hidden_size=hs,
ffn_hidden_size=hs * 2,
grouped_mlp=False)

expected_out = glu(x)
with torch.no_grad():
topo = stk.random.mask(bs * sl, hs * 2, 0, blocking=128).cuda()
out = dmoe_glu(x.view(bs * sl, hs), topo)
out = out.view(sl, bs, hs)

self.assertSequenceEqual(out.shape, x.shape)
self.assertSequenceEqual(expected_out.shape, x.shape)
self.assertTrue(testing.allclose(out, expected_out))

if __name__ == '__main__':
unittest.main()
12 changes: 6 additions & 6 deletions megablocks/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,18 @@ class SparseMLP(torch.nn.Module):
def __init__(self, args : Arguments):
super().__init__()
self.args = args
num_rows_per_rank = (
self._num_rows_per_rank = (
(mpu.experts_per_rank(args) * mpu.features_per_rank(args)) //
mpu.get_weight_parallel_world_size(args)
)

self.w1 = torch.nn.Parameter(torch.empty(
num_rows_per_rank,
self._num_rows_per_rank,
args.hidden_size,
device=args.device,
dtype=common.dtype(args)))
self.w2 = torch.nn.Parameter(torch.empty(
num_rows_per_rank,
self._num_rows_per_rank,
args.hidden_size,
device=args.device,
dtype=common.dtype(args)))
Expand All @@ -336,12 +336,12 @@ def __init__(self, args : Arguments):
args, args.moe_num_experts, args.ffn_hidden_size,
args.hidden_size, args.output_layer_init_method))

should_set_attribute = (
self._should_set_parallelism_attribute = (
args.moe_expert_model_parallelism or args.moe_weight_parallelism)
mpu.set_expert_model_parallel_attributes(
self.w1, should_set_attribute)
self.w1, self._should_set_parallelism_attribute)
mpu.set_expert_model_parallel_attributes(
self.w2, should_set_attribute)
self.w2, self._should_set_parallelism_attribute)

self.gradient_scale = None
if self.args.moe_expert_model_parallelism:
Expand Down
14 changes: 14 additions & 0 deletions megablocks/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,17 @@ def __init__(self, args : Arguments):
def forward(self, x):
return torch.matmul(F.gelu(
torch.matmul(x, self.w1), approximate="tanh"), self.w2)

class GLU(FFN):

def __init__(self, args : Arguments):
super().__init__(args)
self.v1 = torch.nn.Parameter(torch.empty(
args.hidden_size,
args.ffn_hidden_size,
device=args.device,
dtype=torch.float16 if args.fp16 else torch.float32))

def forward(self, x):
x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1)
return torch.matmul(x1, self.w2)

0 comments on commit 059ae20

Please sign in to comment.