Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a unit test for MoE layer. #7069

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
260 changes: 260 additions & 0 deletions experimental/torch_xla2/test/moe/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)

@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
num_experts: int = 8
num_activated_experts: int = 2

def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])


transformer_configs = {
"Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
}

class KVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val

return k_out, v_out

class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config

self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
return logits

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))


class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.block_sparse_moe = MOEFeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.block_sparse_moe(self.ffn_norm(h))
return out


class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0

total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)

def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)

q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))

if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

y = self.wo(y)
return y


class ConditionalFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))

def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
w1_weights = self.w1[expert_indices] # [T, A, D, D]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memory consumption here seems pretty high? It's like you have num_token duplicated weights? So this will not introduce extra storage? Or maybe you are just considering inference where T is just 1?

w3_weights = self.w3[expert_indices] # [T, A, D, D]
w2_weights = self.w2[expert_indices] # [T, A, D, D]
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
return expert_outs


class MOEFeedForward(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
self.cond_ffn = ConditionalFeedForward(config)
self.dim = config.dim
self.num_activated_experts = config.num_activated_experts
def forward(self, x: Tensor) -> Tensor:
x = x.view(-1, self.dim)
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights = F.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
expert_outs = self.cond_ffn(x, expert_indices)
return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)


class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight


def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000
) -> Tensor:
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16)


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)

x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
75 changes: 75 additions & 0 deletions experimental/torch_xla2/test/moe/moe_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch_xla2
import torch_xla2.interop
import torch
import unittest
import jax


from test.moe import model


class TestMoe(unittest.TestCase):

def _make_tiny_config(self):
return model.ModelArgs(
block_size = 128,
vocab_size = 32000,
n_layer = 4,
n_head = 4,
dim = 128,
intermediate_size = None,
n_local_heads = -1,
head_dim = 32,
rope_base = 10000,
norm_eps = 1e-5,
num_experts = 8,
num_activated_experts = 2,
)

def _random_init(self, model):
new_state_dict = {}

for k, v in model.state_dict().items():
new_state_dict[k] = torch.randn_like(v)

model.load_state_dict(new_state_dict, assign=True)
return model



def test_moe_layer(self):
model_args = self._make_tiny_config()

moe_layer = model.MOEFeedForward(model_args)
moe_layer = self._random_init(moe_layer)
seqlen = 32
x = torch.randn((seqlen, model_args.dim))
res = moe_layer(x)

env = torch_xla2.default_env()
model_xla = env.to_xla(moe_layer)
x_xla = env.to_xla(x)
with jax.default_matmul_precision('float32'):
res_xla = model_xla(x_xla)
res2 = torch_xla2.tensor.j2t(res_xla._elem)
print('max diff', torch.max((res - res2).abs()))

self.assertTrue(
torch.allclose(res2, res, atol=1e-2))

# test can jit

def f(weights, x):
return torch.func.functional_call(moe_layer, weights, (x, ))

fjitted = torch_xla2.interop.jax_jit(f)
weights_xla = env.to_xla(moe_layer.state_dict())

print(fjitted(weights_xla, x_xla))





if __name__ == '__main__':
unittest.main()
8 changes: 5 additions & 3 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ def _aten_index_put(self, indexes, values, accumulate=False):
@op(torch.ops.aten._unsafe_index)
@op(torch.ops.aten.index.Tensor)
def _aten_index(self, indexes):
print(indexes)
indexes = [slice(None, None, None) if i is None else i for i in indexes]
indexes = tuple(indexes)
return self[indexes]
Expand Down Expand Up @@ -1671,8 +1670,11 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None):
- indices: The indices of the top k values in the original array.
"""
if dim is None:
input = input.flatten()
dim = 0
# last dim is chosen
dim = input.ndim - 1

if dim < 0:
dim = dim + input.ndim

if not largest:
input = -input # Find top-k of negated input if we want the smallest
Expand Down
Loading