diff --git a/model.py b/model.py index 5f190ed..4df7c92 100644 --- a/model.py +++ b/model.py @@ -48,9 +48,12 @@ def __init__(self, config, idx_layer): self.idx_layer = idx_layer self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + if not self.flash: + print("Using slow attention. Flash Attention requires PyTorch >= 2.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) @@ -61,14 +64,20 @@ def forward(self, x): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if self.scale_attn_by_inverse_layer_idx: - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)) / float(self.idx_layer + 1)) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) else: - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + + if self.scale_attn_by_inverse_layer_idx: + y = y / float(self.idx_layer + 1) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection