-
Notifications
You must be signed in to change notification settings - Fork 1
/
llama.py
326 lines (277 loc) · 14 KB
/
llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
from contextlib import nullcontext
from typing import Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from base_llama import LlamaPreTrainedModel, LlamaConfig
from rope import apply_rotary_emb
from utils import *
# Root Mean Square Layer Normalization (https://arxiv.org/abs/1910.07467)
# borrowed from the official Llama implementation:
# https://github.com/facebookresearch/llama/blob/main/llama/model.py
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Compute the root mean square normalization. Use Equation 4 under
Section 4 of https://arxiv.org/abs/1910.07467 as a reference. Add
the given epsilon value (self.eps) to the tensor's norm (i.e. inside
the square root in Equation 4) before normalizing the tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms
def forward(self, x):
"""
Apply the root mean square normalizer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
class Attention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.n_kv_heads = config.n_heads if config.n_kv_heads is None else config.n_kv_heads
assert config.n_heads % self.n_kv_heads == 0
model_parallel_size = 1
self.n_local_heads = config.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = config.dim // config.n_heads
self.max_seq_len = config.max_seq_len
self.compute_query = nn.Linear(config.dim, config.n_heads * self.head_dim, bias=False)
self.compute_key = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
self.compute_value = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
self.compute_output = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.dropout = config.dropout
def compute_query_key_value_scores(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor) -> torch.Tensor:
'''
Jointly compute Scaled Dot Product Attention (see Section 3.2.1 in
https://arxiv.org/abs/1706.03762 for details). The query, key, and
value tensors each have shape (bs, n_local_heads, seqlen, head_dim).
An optimal implemention will jointly computing attention for multiple
heads (n_local_heads of them) at once using matrix/tensor operations.
Make sure to use attention_dropout (self.attn_dropout) on the computed
attention matrix before applying it to the value tensor.
'''
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
attn = scores.softmax(dim=-1)
if self.attn_dropout is not None:
attn = self.attn_dropout(attn)
return torch.matmul(attn, value)
def forward(
self,
x: torch.Tensor
):
'''
Llama2 uses Grouped-Query Attention. The details of GQA are actually
not critical to solving this assignment; you are simply asked to
compute Scaled Dot Product Attention (see above for details). GQA is
a memory optimization to compute multi-head attention efficiently. See
Section 2.2 in https://arxiv.org/abs/2305.13245 or
https://ai.plainenglish.io/understanding-llama2-kv-cache-grouped-query-attention-rotary-embedding-and-more-c17e5f49a6d7
for details.
'''
batch_size, seqlen, _ = x.shape
query = self.compute_query(x)
key = self.compute_key(x)
value = self.compute_value(x)
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim)
key = key.view(batch_size, seqlen, self.n_local_kv_heads, self.head_dim)
value = value.view(batch_size, seqlen, self.n_local_kv_heads, self.head_dim)
# RoPE relative positional embeddings
query, key = apply_rotary_emb(query, key, self.head_dim, self.max_seq_len)
# Grouped multiquery attention: expand out keys and values.
# Convert both to:
# (bs, seqlen, n_local_heads, head_dim)
key = torch.repeat_interleave(key, dim=2, repeats=self.n_rep)
value = torch.repeat_interleave(value, dim=2, repeats=self.n_rep)
# make heads into a batch dimension
query = query.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = self.compute_query_key_value_scores(query, key, value)
# restore time as batch dimension and concat heads
output = output.transpose(1, 2).contiguous().view(batch_size, seqlen, -1)
# final projection into the residual stream
output = self.resid_dropout(self.compute_output(output))
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.dropout = nn.Dropout(dropout)
def SwiGLU(self, x: torch.Tensor) -> torch.Tensor:
'''
Compute the SwiGLU activation function (see Section 2 in
https://arxiv.org/abs/2204.02311
'''
return F.silu(self.w1(x)) * self.w3(x)
def forward(self, x):
return self.dropout(self.w2(self.SwiGLU(x)))
class LlamaLayer(nn.Module):
def __init__(self, layer_id: int, config: LlamaConfig):
super().__init__()
self.n_heads = config.n_heads
self.dim = config.dim
self.head_dim = config.dim // config.n_heads
self.attention = Attention(config)
self.feed_forward = FeedForward(
dim=config.dim,
hidden_dim=config.hidden_dim,
multiple_of=config.multiple_of,
dropout=config.dropout,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
def forward(self, x):
'''
This is the forward pass of the basic transformer building block. This is a
modernized version of the block shown on the left of Figure 1 on
https://arxiv.org/pdf/1706.03762.pdf.
The transformer block should consist of:
1) layer normalization of the input (via Root Mean Square layer normalization)
2) self-attention on the layer-normalized input
3) a residual connection (i.e., add the input to the output of the self-attention)
3) layer normalization on the output of the self-attention
4) a feed-forward network on the layer-normalized output of the self-attention
5) add a residual connection from the unnormalized self-attention output to the
output of the feed-forward network
'''
x = x + self.attention(self.attention_norm(x))
x = x + self.feed_forward(self.ffn_norm(x))
return x
class Llama(LlamaPreTrainedModel):
def __init__(self, config: LlamaConfig):
'''
You will probably never need to call this function, unless you decide
to pretrain a Llama model from scratch.
'''
super().__init__(config)
self.params = config
self.vocab_size = config.vocab_size
self.n_layers = config.n_layers
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.dropout = nn.Dropout(config.dropout)
self.layers = torch.nn.ModuleList()
for layer_id in range(config.n_layers):
self.layers.append(LlamaLayer(layer_id, config))
self.norm = RMSNorm(config.dim, eps=config.layer_norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
# share the unembedding parameters with the embedding parameters
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
# some useful precompute for the RoPE relative positional embeddings
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('compute_output.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layers))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_batch_size, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
for layer in self.layers:
h = layer(h)
h = self.norm(h)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
return logits, h
@torch.inference_mode()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
Also note this is a super inefficient version of sampling with no key/value cache.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
logits = logits[:, -1, :] # crop to just the final time step
if temperature == 0.0:
# select the single most likely index
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
else:
'''
Perform temperature sampling:
1) identify the logits at the final step.
2) scale (divide) these probabilities by the given temperature.
3) normalize the scaled logits with a softmax to obtain scaled probabilities.
4) sample from the scaled probability distribution.
'''
if top_k is not None:
values, indices = torch.topk(logits, top_k, dim=-1)
mask = torch.ones_like(logits).scatter_(-1, indices, 0.0)
logits.masked_fill_(mask.bool(), float('-inf'))
scaled_logits = logits / temperature
probs = F.softmax(scaled_logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
def load_pretrained(checkpoint):
device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
dtype = "float32"
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# init from a model saved in a specific directory
checkpoint_dict = torch.load(checkpoint, map_location=device)
config = LlamaConfig(**checkpoint_dict['model_args'])
model = Llama(config)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)
return model