-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example for a decoder only model (#7082)
- Loading branch information
Showing
2 changed files
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
import math | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
||
@dataclass | ||
class DecoderOnlyConfig: | ||
hidden_size: int = 1024 | ||
num_hidden_layers: int = 2 | ||
num_attention_heads: int = 8 | ||
num_key_value_heads: int = 4 | ||
intermediate_size = 64 * 1024 | ||
vocab_size = 3200 | ||
|
||
|
||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
""" | ||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | ||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | ||
""" | ||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||
if n_rep == 1: | ||
return hidden_states | ||
hidden_states = hidden_states[:, :, | ||
None, :, :].expand(batch, num_key_value_heads, | ||
n_rep, slen, head_dim) | ||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, | ||
head_dim) | ||
|
||
|
||
class RMSNorm(nn.Module): | ||
|
||
def __init__(self, hidden_size, eps=1e-6): | ||
""" | ||
RMSNorm is equivalent to LlamaRMSNorm | ||
""" | ||
super().__init__() | ||
self.weight = nn.Parameter(torch.ones(hidden_size)) | ||
self.variance_epsilon = eps | ||
|
||
def forward(self, hidden_states): | ||
input_dtype = hidden_states.dtype | ||
hidden_states = hidden_states.to(torch.float32) | ||
variance = hidden_states.pow(2).mean(-1, keepdim=True) | ||
hidden_states = hidden_states * torch.rsqrt(variance + | ||
self.variance_epsilon) | ||
return self.weight * hidden_states.to(input_dtype) | ||
|
||
|
||
# 1. no kv_chche | ||
# 2. no rotary embedding | ||
# 3. no attention_mask | ||
class GroupQueryAttention(nn.Module): | ||
"""Stripped-down version of the LlamaAttention""" | ||
|
||
def __init__(self, config: DecoderOnlyConfig): | ||
super().__init__() | ||
self.config = config | ||
|
||
self.hidden_size = config.hidden_size | ||
self.num_heads = config.num_attention_heads | ||
self.head_dim = self.hidden_size // self.num_heads | ||
self.num_key_value_heads = config.num_key_value_heads | ||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | ||
|
||
self.q_proj = nn.Linear( | ||
self.hidden_size, self.num_heads * self.head_dim, bias=False) | ||
self.k_proj = nn.Linear( | ||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | ||
self.v_proj = nn.Linear( | ||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) | ||
self.o_proj = nn.Linear( | ||
self.num_heads * self.head_dim, self.hidden_size, bias=False) | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
) -> torch.Tensor: | ||
|
||
bsz, q_len, _ = hidden_states.size() | ||
# [B, S, H] -> [B, S, n_head * head_dim] | ||
query_states = self.q_proj(hidden_states) | ||
# [B, S, H] -> [B, S, n_kv_head * head_dim] | ||
key_states = self.k_proj(hidden_states) | ||
# [B, S, H] -> [B, S, n_kv_head * head_dim] | ||
value_states = self.v_proj(hidden_states) | ||
|
||
# [B, S, n_head * head_dim] -> [B, n_head, S, head_dim] | ||
query_states = query_states.view(bsz, q_len, self.num_heads, | ||
self.head_dim).transpose(1, 2) | ||
# [B, S, n_kv_head * head_dim] -> [B, n_kv_head, S, head_dim] | ||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | ||
self.head_dim).transpose(1, 2) | ||
# [B, S, n_kv_head * head_dim] -> [B, n_kv_head, S, head_dim] | ||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | ||
self.head_dim).transpose(1, 2) | ||
|
||
# [B, n_kv_head, S, head_dim] -> [B, n_head, S, head_dim] | ||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
# [B, n_kv_head, S, head_dim] -> [B, n_head, S, head_dim] | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
||
# [B, n_head, S, head_dim] @ T([B, n_head, S, head_dim]) -> [B, n_head, S, S] | ||
attn_weights = torch.einsum('bnsh,bnkh->bnsk', query_states, | ||
key_states) / math.sqrt(self.head_dim) | ||
|
||
# upcast attention to fp32 | ||
attn_weights = nn.functional.softmax( | ||
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||
|
||
# [B, n_head, S, S] @ T([B, n_head, S, head_dim]) -> [B, n_head, S, head_dim] | ||
attn_output = torch.einsum('bnsk,bnkh->bnsh', attn_weights, value_states) | ||
|
||
# [B, n_head, S, head_dim] -> [B * S * n_head * head_dim] | ||
attn_output = attn_output.transpose(1, 2).contiguous() | ||
# [B * S * n_head * head_dim] -> [B, S, H] | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
# [B, S, H] -> [B, S, H] | ||
attn_output = self.o_proj(attn_output) | ||
|
||
return attn_output | ||
|
||
|
||
class MLP(nn.Module): | ||
"""Stripped-down version of the LlamaMLP""" | ||
|
||
def __init__(self, config: DecoderOnlyConfig): | ||
super().__init__() | ||
self.config = config | ||
|
||
self.hidden_size = config.hidden_size | ||
self.intermediate_size = config.intermediate_size | ||
self.gate_proj = nn.Linear( | ||
self.hidden_size, self.intermediate_size, bias=False) | ||
self.up_proj = nn.Linear( | ||
self.hidden_size, self.intermediate_size, bias=False) | ||
self.down_proj = nn.Linear( | ||
self.intermediate_size, self.hidden_size, bias=False) | ||
self.act_fn = F.silu | ||
|
||
def forward(self, x): | ||
# [B, S, H] -> [B, S, I] | ||
up_proj = self.up_proj(x) | ||
# [B, S, H] -> [B, S, I] | ||
gate_proj = self.act_fn(self.gate_proj(x)) | ||
# ([B, S, I] * [B, S, I]) -> [B, S, H] | ||
down_proj = self.down_proj(gate_proj * up_proj) | ||
return down_proj | ||
|
||
|
||
class DecoderLayer(nn.Module): | ||
|
||
def __init__(self, config: DecoderOnlyConfig): | ||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
self.self_attn = (GroupQueryAttention(config=config)) | ||
self.mlp = MLP(config) | ||
self.input_layernorm = RMSNorm(config.hidden_size) | ||
self.post_attention_layernorm = RMSNorm(config.hidden_size) | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
**kwargs, | ||
) -> torch.FloatTensor: | ||
residual = hidden_states | ||
|
||
hidden_states = self.input_layernorm(hidden_states) | ||
|
||
# Self Attention | ||
hidden_states = self.self_attn(hidden_states=hidden_states,) | ||
hidden_states = residual + hidden_states | ||
|
||
# Fully Connected | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = residual + hidden_states | ||
|
||
return hidden_states | ||
|
||
|
||
# 1. no gradient_checkpointing | ||
# 2. no padding_idx | ||
# 3. no kv cache | ||
class DecoderOnlyModel(nn.Module): | ||
|
||
def __init__(self, config: DecoderOnlyConfig): | ||
super(DecoderOnlyModel, self).__init__() | ||
self.vocab_size = config.vocab_size | ||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) | ||
self.layers = nn.ModuleList( | ||
[DecoderLayer(config) for _ in range(config.num_hidden_layers)]) | ||
self.norm = RMSNorm(config.hidden_size) | ||
self.output = nn.Linear(config.hidden_size, self.vocab_size, bias=False) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
) -> torch.Tensor: | ||
inputs_embeds = self.embed_tokens(input_ids) | ||
|
||
# embed positions | ||
hidden_states = inputs_embeds | ||
|
||
# decoder layers | ||
for idx, decoder_layer in enumerate(self.layers): | ||
layer_outputs = decoder_layer(hidden_states,) | ||
hidden_states = layer_outputs | ||
|
||
hidden_states = self.norm(hidden_states) | ||
# [B, S, H] -> [B, S, V] | ||
return self.output(hidden_states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel | ||
|
||
from torch_xla import runtime as xr | ||
import torch_xla.utils.utils as xu | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.parallel_loader as pl | ||
|
||
import time | ||
import itertools | ||
|
||
import torch | ||
import torch_xla | ||
import torch.optim as optim | ||
import torch.nn as nn | ||
|
||
|
||
class TrainDecoderOnlyBase(): | ||
|
||
def __init__(self): | ||
self.config = DecoderOnlyConfig() | ||
self.batch_size = 16 | ||
self.seq_len = 512 | ||
self.num_steps = 300 | ||
self.num_epochs = 1 | ||
self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. | ||
# For the purpose of this example, we are going to use fake data. | ||
train_loader = xu.SampleGenerator( | ||
data=(torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64), | ||
torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)), | ||
sample_count=self.train_dataset_len // self.batch_size) | ||
|
||
self.device = torch_xla.device() | ||
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device) | ||
self.model = DecoderOnlyModel(self.config).to(self.device) | ||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001) | ||
self.loss_fn = nn.CrossEntropyLoss() | ||
|
||
def _train_update(self, step, loss, tracker, epoch): | ||
print(f'epoch: {epoch}, step: {step}, loss: {loss}, rate: {tracker.rate()}') | ||
|
||
def run_optimizer(self): | ||
self.optimizer.step() | ||
|
||
def train_loop_fn(self, loader, epoch): | ||
tracker = xm.RateTracker() | ||
self.model.train() | ||
loader = itertools.islice(loader, self.num_steps) | ||
for step, (data, target) in enumerate(loader): | ||
self.optimizer.zero_grad() | ||
logits = self.model(data) | ||
loss = self.loss_fn( | ||
logits.view(-1, self.config.vocab_size), target.view(-1)) | ||
loss.backward() | ||
self.run_optimizer() | ||
tracker.add(self.batch_size) | ||
if step % 10 == 0: | ||
xm.add_step_closure( | ||
self._train_update, args=(step, loss, tracker, epoch)) | ||
|
||
def start_training(self): | ||
|
||
for epoch in range(1, self.num_epochs + 1): | ||
xm.master_print('Epoch {} train begin {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
self.train_loop_fn(self.train_device_loader, epoch) | ||
xm.master_print('Epoch {} train end {}'.format( | ||
epoch, time.strftime('%l:%M%p %Z on %b %d, %Y'))) | ||
xm.wait_device_ops() | ||
|
||
|
||
if __name__ == '__main__': | ||
base = TrainDecoderOnlyBase() | ||
base.start_training() |