forked from qwopqwop200/GPTQ-for-LLaMa
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for upstream gptq cuda version
Co-authored-by: qwopqwop200 <[email protected]>
- Loading branch information
1 parent
de567bd
commit cbf8ad0
Showing
11 changed files
with
845 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import functional as F | ||
from torch.cuda.amp import custom_bwd, custom_fwd | ||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb | ||
|
||
from .quant_v3 import * | ||
|
||
|
||
class QuantLlamaAttention(nn.Module): | ||
"""Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size, | ||
num_heads, | ||
qkv_proj, | ||
o_proj, | ||
rotary_emb, | ||
): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
self.num_heads = num_heads | ||
self.head_dim = hidden_size // num_heads | ||
|
||
if (self.head_dim * num_heads) != self.hidden_size: | ||
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | ||
f" and `num_heads`: {num_heads}).") | ||
self.qkv_proj = qkv_proj | ||
self.o_proj = o_proj | ||
self.rotary_emb = rotary_emb | ||
|
||
def _shape(self, tensor, seq_len, bsz): | ||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | ||
|
||
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): | ||
"""Input shape: Batch x Time x Channel""" | ||
|
||
bsz, q_len, _ = hidden_states.size() | ||
|
||
qkv_states = self.qkv_proj(hidden_states) | ||
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) | ||
|
||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
# [bsz, nh, t, hd] | ||
|
||
is_causal = past_key_value is None | ||
if past_key_value is not None: | ||
# reuse k, v, self_attention | ||
key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
|
||
if use_cache: | ||
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor | ||
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. | ||
query_states = query_states.contiguous() | ||
key_states = key_states.contiguous() | ||
value_states = value_states.contiguous() | ||
|
||
|
||
past_key_value = (key_states, value_states) if use_cache else None | ||
|
||
with torch.backends.cuda.sdp_kernel(enable_math=False): | ||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) | ||
|
||
attn_output = attn_output.transpose(1, 2) | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
attn_output = self.o_proj(attn_output) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, past_key_value | ||
|
||
|
||
def make_quant_attn(model): | ||
""" | ||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. | ||
""" | ||
for name, m in model.named_modules(): | ||
if not isinstance(m, LlamaAttention): | ||
continue | ||
|
||
q_proj = m.q_proj | ||
k_proj = m.k_proj | ||
v_proj = m.v_proj | ||
|
||
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) | ||
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) | ||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) | ||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) | ||
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None | ||
|
||
qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False) | ||
qkv_layer.qweight = qweights | ||
qkv_layer.qzeros = qzeros | ||
qkv_layer.scales = scales | ||
qkv_layer.g_idx = g_idx | ||
qkv_layer.bias = bias | ||
|
||
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) | ||
|
||
if '.' in name: | ||
parent_name = name.rsplit('.', 1)[0] | ||
child_name = name[len(parent_name) + 1:] | ||
parent = model.get_submodule(parent_name) | ||
else: | ||
parent_name = '' | ||
parent = model | ||
child_name = name | ||
|
||
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") | ||
|
||
setattr(parent, child_name, attn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.