diff --git a/experimental/torch_xla2/test/moe/__init__.py b/experimental/torch_xla2/test/moe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/experimental/torch_xla2/test/moe/model.py b/experimental/torch_xla2/test/moe/model.py new file mode 100644 index 00000000000..9249ac9dce0 --- /dev/null +++ b/experimental/torch_xla2/test/moe/model.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + num_experts: int = 8 + num_activated_experts: int = 2 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + @classmethod + def from_name(cls, name: str): + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] + assert len(config) == 1, name + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "Mixtral-8x7B-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.block_sparse_moe = MOEFeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.block_sparse_moe(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class ConditionalFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) + + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: + w1_weights = self.w1[expert_indices] # [T, A, D, D] + w3_weights = self.w3[expert_indices] # [T, A, D, D] + w2_weights = self.w2[expert_indices] # [T, A, D, D] + x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) + x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) + expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) + return expert_outs + + +class MOEFeedForward(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) + self.cond_ffn = ConditionalFeedForward(config) + self.dim = config.dim + self.num_activated_experts = config.num_activated_experts + def forward(self, x: Tensor) -> Tensor: + x = x.view(-1, self.dim) + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts + # x: [T, D] + scores = self.gate(x) # [T, E] + expert_weights = F.softmax(scores, dim=-1) + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_outs = self.cond_ffn(x, expert_indices) + return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/experimental/torch_xla2/test/moe/moe_test.py b/experimental/torch_xla2/test/moe/moe_test.py new file mode 100644 index 00000000000..f8d4a22e3f2 --- /dev/null +++ b/experimental/torch_xla2/test/moe/moe_test.py @@ -0,0 +1,75 @@ +import torch_xla2 +import torch_xla2.interop +import torch +import unittest +import jax + + +from test.moe import model + + +class TestMoe(unittest.TestCase): + + def _make_tiny_config(self): + return model.ModelArgs( + block_size = 128, + vocab_size = 32000, + n_layer = 4, + n_head = 4, + dim = 128, + intermediate_size = None, + n_local_heads = -1, + head_dim = 32, + rope_base = 10000, + norm_eps = 1e-5, + num_experts = 8, + num_activated_experts = 2, + ) + + def _random_init(self, model): + new_state_dict = {} + + for k, v in model.state_dict().items(): + new_state_dict[k] = torch.randn_like(v) + + model.load_state_dict(new_state_dict, assign=True) + return model + + + + def test_moe_layer(self): + model_args = self._make_tiny_config() + + moe_layer = model.MOEFeedForward(model_args) + moe_layer = self._random_init(moe_layer) + seqlen = 32 + x = torch.randn((seqlen, model_args.dim)) + res = moe_layer(x) + + env = torch_xla2.default_env() + model_xla = env.to_xla(moe_layer) + x_xla = env.to_xla(x) + with jax.default_matmul_precision('float32'): + res_xla = model_xla(x_xla) + res2 = torch_xla2.tensor.j2t(res_xla._elem) + print('max diff', torch.max((res - res2).abs())) + + self.assertTrue( + torch.allclose(res2, res, atol=1e-2)) + + # test can jit + + def f(weights, x): + return torch.func.functional_call(moe_layer, weights, (x, )) + + fjitted = torch_xla2.interop.jax_jit(f) + weights_xla = env.to_xla(moe_layer.state_dict()) + + print(fjitted(weights_xla, x_xla)) + + + + + +if __name__ == '__main__': + unittest.main() diff --git a/experimental/torch_xla2/torch_xla2/ops/jaten.py b/experimental/torch_xla2/torch_xla2/ops/jaten.py index 3c92e5e290d..a628a648441 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jaten.py +++ b/experimental/torch_xla2/torch_xla2/ops/jaten.py @@ -318,7 +318,6 @@ def _aten_index_put(self, indexes, values, accumulate=False): @op(torch.ops.aten._unsafe_index) @op(torch.ops.aten.index.Tensor) def _aten_index(self, indexes): - print(indexes) indexes = [slice(None, None, None) if i is None else i for i in indexes] indexes = tuple(indexes) return self[indexes] @@ -1671,8 +1670,11 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): - indices: The indices of the top k values in the original array. """ if dim is None: - input = input.flatten() - dim = 0 + # last dim is chosen + dim = input.ndim - 1 + + if dim < 0: + dim = dim + input.ndim if not largest: input = -input # Find top-k of negated input if we want the smallest