Skip to content

Commit

Permalink
Add support for upstream gptq cuda version
Browse files Browse the repository at this point in the history
Co-authored-by: qwopqwop200 <[email protected]>
  • Loading branch information
0cc4m and qwopqwop200 committed May 9, 2023
1 parent de567bd commit cbf8ad0
Show file tree
Hide file tree
Showing 11 changed files with 845 additions and 220 deletions.
4 changes: 2 additions & 2 deletions gptq/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def get_ptb(nsamples, seed, seqlen, model, use_fast=False):
def get_c4(nsamples, seed, seqlen, model, use_fast=False):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=True
'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False
)
valdata = load_dataset(
'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation',use_auth_token=True
'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation',use_auth_token=False
)

from transformers import AutoTokenizer
Expand Down
124 changes: 124 additions & 0 deletions gptq/fused_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb

from .quant_v3 import *


class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(
self,
hidden_size,
num_heads,
qkv_proj,
o_proj,
rotary_emb,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads

if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = rotary_emb

def _shape(self, tensor, seq_len, bsz):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""

bsz, q_len, _ = hidden_states.size()

qkv_states = self.qkv_proj(hidden_states)
query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]

is_causal = past_key_value is None
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()


past_key_value = (key_states, value_states) if use_cache else None

with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value


def make_quant_attn(model):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue

q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj

qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None

qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.g_idx = g_idx
qkv_layer.bias = bias

attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb)

if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name

#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")

setattr(parent, child_name, attn)
68 changes: 46 additions & 22 deletions gptq/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

import transformers

from .modelutils import find_layers, make_quant
from .quant_v2 import quantize, Quantizer, QuantLinear
from .gptq import GPTQ
from .modelutils import DEV, find_layers, GPTQVERSION, make_quant

if GPTQVERSION == 1:
from .quant_v2 import quantize, Quantizer, QuantLinear
elif GPTQVERSION == 2:
from .quant_v3 import quantize, Quantizer, QuantLinear


def get_gptj(model):
Expand Down Expand Up @@ -118,18 +123,32 @@ def tmp(_, inp, out):
h.remove()

for name in subset:
print(i, name)
print("Quantizing ...")
scale, zero = gptq[name].fasterquant(
percdamp=args.percdamp,
groupsize=args.groupsize,
actorder=args.act_order,
)
quantizers["transformer.h.%d.%s" % (i, name)] = (
gptq[name].quantizer,
scale,
zero,
)
print(f"Quantizing {name} in layer {i+1}/{len(layers)}...")
if GPTQVERSION == 1:
scale, zero = gptq[name].fasterquant(
percdamp=args.percdamp,
groupsize=args.groupsize,
actorder=args.act_order,
)
quantizers["transformer.h.%d.%s" % (i, name)] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
)
elif GPTQVERSION == 2:
scale, zero, g_idx = gptq[name].fasterquant(
percdamp=args.percdamp,
groupsize=args.groupsize,
actorder=args.act_order,
)
quantizers["transformer.h.%d.%s" % (i, name)] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
g_idx.cpu(),
)
else:
raise NotImplementedError("Unsupported GPTQVERSION")
gptq[name].free()

for j in range(args.nsamples):
Expand Down Expand Up @@ -208,7 +227,9 @@ def forward(self, inp, **kwargs):
subset = find_layers(layer)
for name in subset:
quantizer = Quantizer()
quantizer.configure(args.wbits, perchannel=True, sym=False, mse=False)
quantizer.configure(
args.wbits, perchannel=True, sym=args.sym, mse=False
)
W = subset[name].weight.data
quantizer.find_params(W, weight=True)
subset[name].weight.data = quantize(
Expand Down Expand Up @@ -260,9 +281,14 @@ def gptj_pack(model, quantizers, wbits, groupsize):
print("Packing ...")
for name in qlayers:
print(name)
quantizers[name], scale, zero = quantizers[name]
quantizers[name], scale, zero = quantizers[name].cpu(), scale.cpu(), zero.cpu()
qlayers[name].pack(layers[name], scale, zero)
if GPTQVERSION == 1:
quantizers[name], scale, zero = quantizers[name]
qlayers[name].pack(layers[name], scale, zero)
elif GPTQVERSION == 2:
quantizers[name], scale, zero, g_idx = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
else:
raise NotImplementedError("Unsupported GPTQVERSION")
print("Done.")
return model

Expand Down Expand Up @@ -401,7 +427,7 @@ def sync():

if __name__ == "__main__":
import argparse
from datautils import *
from .datautils import *

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -490,9 +516,7 @@ def sync():
args.load = args.load.as_posix()

if args.load:
model = load_quant(
args.model, args.load, args.wbits, args.groupsize
)
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
else:
model = get_gptj(args.model)
model.eval()
Expand Down
91 changes: 59 additions & 32 deletions gptq/gptneox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@

import transformers

from .modelutils import find_layers, make_quant
from .quant_v2 import quantize, Quantizer, QuantLinear
from .gptq import GPTQ
from .modelutils import DEV, find_layers, GPTQVERSION, make_quant

if GPTQVERSION == 1:
from .quant_v2 import quantize, Quantizer, QuantLinear
elif GPTQVERSION == 2:
from .quant_v3 import quantize, Quantizer, QuantLinear


def get_gptneox(model):
Expand Down Expand Up @@ -118,18 +123,32 @@ def tmp(_, inp, out):
h.remove()

for name in subset:
print(i, name)
print("Quantizing ...")
scale, zero = gptq[name].fasterquant(
percdamp=args.percdamp,
groupsize=args.groupsize,
actorder=args.act_order,
)
quantizers["gpt_neox.layers.%d.%s" % (i, name)] = (
gptq[name].quantizer,
scale,
zero,
)
print(f"Quantizing {name} in layer {i+1}/{len(layers)}...")
if GPTQVERSION == 1:
scale, zero = gptq[name].fasterquant(
percdamp=args.percdamp,
groupsize=args.groupsize,
actorder=args.act_order,
)
quantizers["gpt_neox.layers.%d.%s" % (i, name)] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
)
elif GPTQVERSION == 2:
scale, zero, g_idx = gptq[name].fasterquant(
percdamp=args.percdamp,
groupsize=args.groupsize,
actorder=args.act_order,
)
quantizers["gpt_neox.layers.%d.%s" % (i, name)] = (
gptq[name].quantizer.cpu(),
scale.cpu(),
zero.cpu(),
g_idx.cpu(),
)
else:
raise NotImplementedError("Unsupported GPTQVERSION")
gptq[name].free()

for j in range(args.nsamples):
Expand Down Expand Up @@ -208,7 +227,9 @@ def forward(self, inp, **kwargs):
subset = find_layers(layer)
for name in subset:
quantizer = Quantizer()
quantizer.configure(args.wbits, perchannel=True, sym=False, mse=False)
quantizer.configure(
args.wbits, perchannel=True, sym=args.sym, mse=False
)
W = subset[name].weight.data
quantizer.find_params(W, weight=True)
subset[name].weight.data = quantize(
Expand Down Expand Up @@ -260,9 +281,14 @@ def gptneox_pack(model, quantizers, wbits, groupsize):
print("Packing ...")
for name in qlayers:
print(name)
quantizers[name], scale, zero = quantizers[name]
quantizers[name], scale, zero = quantizers[name].cpu(), scale.cpu(), zero.cpu()
qlayers[name].pack(layers[name], scale, zero)
if GPTQVERSION == 1:
quantizers[name], scale, zero = quantizers[name]
qlayers[name].pack(layers[name], scale, zero)
elif GPTQVERSION == 2:
quantizers[name], scale, zero, g_idx = quantizers[name]
qlayers[name].pack(layers[name], scale, zero, g_idx)
else:
raise NotImplementedError("Unsupported GPTQVERSION")
print("Done.")
return model

Expand Down Expand Up @@ -401,7 +427,7 @@ def sync():

if __name__ == "__main__":
import argparse
from datautils import *
from .datautils import *

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -490,9 +516,7 @@ def sync():
args.load = args.load.as_posix()

if args.load:
model = load_quant(
args.model, args.load, args.wbits, args.groupsize
)
model = load_quant(args.model, args.load, args.wbits, args.groupsize)
else:
model = get_gptneox(args.model)
model.eval()
Expand All @@ -511,6 +535,19 @@ def sync():
quantizers = gptneox_sequential(model, dataloader, DEV)
print(time.time() - tick)

if args.benchmark:
gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())]
if len(gpus) > 1:
llama_multigpu(model, gpus)
else:
model = model.to(DEV)
if args.benchmark:
input_ids = next(iter(dataloader))[0][:, : args.benchmark]
benchmark(model, input_ids, check=args.check)

if args.load:
exit()

if args.eval:
datasets = ["wikitext2", "ptb", "c4"]
if args.new_eval:
Expand All @@ -531,13 +568,3 @@ def sync():
from safetensors.torch import save_file as safe_save

safe_save(model.state_dict(), args.save_safetensors)

if args.benchmark:
gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())]
if len(gpus) > 1:
gptneox_multigpu(model, gpus)
else:
model = model.to(DEV)
if args.benchmark:
input_ids = next(iter(dataloader))[0][:, : args.benchmark]
benchmark(model, input_ids, check=args.check)
Loading

0 comments on commit cbf8ad0

Please sign in to comment.