From cbf8ad0effce6e70da0876822af54499db795fee Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 9 May 2023 21:36:16 +0200 Subject: [PATCH] Add support for upstream gptq cuda version Co-authored-by: qwopqwop200 --- gptq/datautils.py | 4 +- gptq/fused_attn.py | 124 +++++++++++++++++ gptq/gptj.py | 68 ++++++---- gptq/gptneox.py | 91 ++++++++----- gptq/gptq.py | 21 ++- gptq/llama.py | 69 ++++++---- gptq/modelutils.py | 14 +- gptq/mpt.py | 310 ++++++++++++++++++++++++++----------------- gptq/opt.py | 37 ++++-- gptq/quant_v3.py | 322 +++++++++++++++++++++++++++++++++++++++++++++ setup.py | 5 +- 11 files changed, 845 insertions(+), 220 deletions(-) create mode 100644 gptq/fused_attn.py create mode 100644 gptq/quant_v3.py diff --git a/gptq/datautils.py b/gptq/datautils.py index 6d5f7296..327c901d 100644 --- a/gptq/datautils.py +++ b/gptq/datautils.py @@ -58,10 +58,10 @@ def get_ptb(nsamples, seed, seqlen, model, use_fast=False): def get_c4(nsamples, seed, seqlen, model, use_fast=False): from datasets import load_dataset traindata = load_dataset( - 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=True + 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False ) valdata = load_dataset( - 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation',use_auth_token=True + 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation',use_auth_token=False ) from transformers import AutoTokenizer diff --git a/gptq/fused_attn.py b/gptq/fused_attn.py new file mode 100644 index 00000000..e236b873 --- /dev/null +++ b/gptq/fused_attn.py @@ -0,0 +1,124 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb + +from .quant_v3 import * + + +class QuantLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size, + num_heads, + qkv_proj, + o_proj, + rotary_emb, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads}).") + self.qkv_proj = qkv_proj + self.o_proj = o_proj + self.rotary_emb = rotary_emb + + def _shape(self, tensor, seq_len, bsz): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): + """Input shape: Batch x Time x Channel""" + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + is_causal = past_key_value is None + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + if use_cache: + # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor + # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + + past_key_value = (key_states, value_states) if use_cache else None + + with torch.backends.cuda.sdp_kernel(enable_math=False): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def make_quant_attn(model): + """ + Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. + """ + for name, m in model.named_modules(): + if not isinstance(m, LlamaAttention): + continue + + q_proj = m.q_proj + k_proj = m.k_proj + v_proj = m.v_proj + + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False) + qkv_layer.qweight = qweights + qkv_layer.qzeros = qzeros + qkv_layer.scales = scales + qkv_layer.g_idx = g_idx + qkv_layer.bias = bias + + attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) + + if '.' in name: + parent_name = name.rsplit('.', 1)[0] + child_name = name[len(parent_name) + 1:] + parent = model.get_submodule(parent_name) + else: + parent_name = '' + parent = model + child_name = name + + #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + + setattr(parent, child_name, attn) diff --git a/gptq/gptj.py b/gptq/gptj.py index 6e338884..5ebf7b3e 100644 --- a/gptq/gptj.py +++ b/gptq/gptj.py @@ -6,8 +6,13 @@ import transformers -from .modelutils import find_layers, make_quant -from .quant_v2 import quantize, Quantizer, QuantLinear +from .gptq import GPTQ +from .modelutils import DEV, find_layers, GPTQVERSION, make_quant + +if GPTQVERSION == 1: + from .quant_v2 import quantize, Quantizer, QuantLinear +elif GPTQVERSION == 2: + from .quant_v3 import quantize, Quantizer, QuantLinear def get_gptj(model): @@ -118,18 +123,32 @@ def tmp(_, inp, out): h.remove() for name in subset: - print(i, name) - print("Quantizing ...") - scale, zero = gptq[name].fasterquant( - percdamp=args.percdamp, - groupsize=args.groupsize, - actorder=args.act_order, - ) - quantizers["transformer.h.%d.%s" % (i, name)] = ( - gptq[name].quantizer, - scale, - zero, - ) + print(f"Quantizing {name} in layer {i+1}/{len(layers)}...") + if GPTQVERSION == 1: + scale, zero = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + ) + quantizers["transformer.h.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + ) + elif GPTQVERSION == 2: + scale, zero, g_idx = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + ) + quantizers["transformer.h.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + ) + else: + raise NotImplementedError("Unsupported GPTQVERSION") gptq[name].free() for j in range(args.nsamples): @@ -208,7 +227,9 @@ def forward(self, inp, **kwargs): subset = find_layers(layer) for name in subset: quantizer = Quantizer() - quantizer.configure(args.wbits, perchannel=True, sym=False, mse=False) + quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False + ) W = subset[name].weight.data quantizer.find_params(W, weight=True) subset[name].weight.data = quantize( @@ -260,9 +281,14 @@ def gptj_pack(model, quantizers, wbits, groupsize): print("Packing ...") for name in qlayers: print(name) - quantizers[name], scale, zero = quantizers[name] - quantizers[name], scale, zero = quantizers[name].cpu(), scale.cpu(), zero.cpu() - qlayers[name].pack(layers[name], scale, zero) + if GPTQVERSION == 1: + quantizers[name], scale, zero = quantizers[name] + qlayers[name].pack(layers[name], scale, zero) + elif GPTQVERSION == 2: + quantizers[name], scale, zero, g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + else: + raise NotImplementedError("Unsupported GPTQVERSION") print("Done.") return model @@ -401,7 +427,7 @@ def sync(): if __name__ == "__main__": import argparse - from datautils import * + from .datautils import * parser = argparse.ArgumentParser() @@ -490,9 +516,7 @@ def sync(): args.load = args.load.as_posix() if args.load: - model = load_quant( - args.model, args.load, args.wbits, args.groupsize - ) + model = load_quant(args.model, args.load, args.wbits, args.groupsize) else: model = get_gptj(args.model) model.eval() diff --git a/gptq/gptneox.py b/gptq/gptneox.py index eb4e0f31..69192b19 100644 --- a/gptq/gptneox.py +++ b/gptq/gptneox.py @@ -6,8 +6,13 @@ import transformers -from .modelutils import find_layers, make_quant -from .quant_v2 import quantize, Quantizer, QuantLinear +from .gptq import GPTQ +from .modelutils import DEV, find_layers, GPTQVERSION, make_quant + +if GPTQVERSION == 1: + from .quant_v2 import quantize, Quantizer, QuantLinear +elif GPTQVERSION == 2: + from .quant_v3 import quantize, Quantizer, QuantLinear def get_gptneox(model): @@ -118,18 +123,32 @@ def tmp(_, inp, out): h.remove() for name in subset: - print(i, name) - print("Quantizing ...") - scale, zero = gptq[name].fasterquant( - percdamp=args.percdamp, - groupsize=args.groupsize, - actorder=args.act_order, - ) - quantizers["gpt_neox.layers.%d.%s" % (i, name)] = ( - gptq[name].quantizer, - scale, - zero, - ) + print(f"Quantizing {name} in layer {i+1}/{len(layers)}...") + if GPTQVERSION == 1: + scale, zero = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + ) + quantizers["gpt_neox.layers.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + ) + elif GPTQVERSION == 2: + scale, zero, g_idx = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + ) + quantizers["gpt_neox.layers.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + ) + else: + raise NotImplementedError("Unsupported GPTQVERSION") gptq[name].free() for j in range(args.nsamples): @@ -208,7 +227,9 @@ def forward(self, inp, **kwargs): subset = find_layers(layer) for name in subset: quantizer = Quantizer() - quantizer.configure(args.wbits, perchannel=True, sym=False, mse=False) + quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False + ) W = subset[name].weight.data quantizer.find_params(W, weight=True) subset[name].weight.data = quantize( @@ -260,9 +281,14 @@ def gptneox_pack(model, quantizers, wbits, groupsize): print("Packing ...") for name in qlayers: print(name) - quantizers[name], scale, zero = quantizers[name] - quantizers[name], scale, zero = quantizers[name].cpu(), scale.cpu(), zero.cpu() - qlayers[name].pack(layers[name], scale, zero) + if GPTQVERSION == 1: + quantizers[name], scale, zero = quantizers[name] + qlayers[name].pack(layers[name], scale, zero) + elif GPTQVERSION == 2: + quantizers[name], scale, zero, g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + else: + raise NotImplementedError("Unsupported GPTQVERSION") print("Done.") return model @@ -401,7 +427,7 @@ def sync(): if __name__ == "__main__": import argparse - from datautils import * + from .datautils import * parser = argparse.ArgumentParser() @@ -490,9 +516,7 @@ def sync(): args.load = args.load.as_posix() if args.load: - model = load_quant( - args.model, args.load, args.wbits, args.groupsize - ) + model = load_quant(args.model, args.load, args.wbits, args.groupsize) else: model = get_gptneox(args.model) model.eval() @@ -511,6 +535,19 @@ def sync(): quantizers = gptneox_sequential(model, dataloader, DEV) print(time.time() - tick) + if args.benchmark: + gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + llama_multigpu(model, gpus) + else: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, : args.benchmark] + benchmark(model, input_ids, check=args.check) + + if args.load: + exit() + if args.eval: datasets = ["wikitext2", "ptb", "c4"] if args.new_eval: @@ -531,13 +568,3 @@ def sync(): from safetensors.torch import save_file as safe_save safe_save(model.state_dict(), args.save_safetensors) - - if args.benchmark: - gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] - if len(gpus) > 1: - gptneox_multigpu(model, gpus) - else: - model = model.to(DEV) - if args.benchmark: - input_ids = next(iter(dataloader))[0][:, : args.benchmark] - benchmark(model, input_ids, check=args.check) diff --git a/gptq/gptq.py b/gptq/gptq.py index fa1ecd40..5269c734 100644 --- a/gptq/gptq.py +++ b/gptq/gptq.py @@ -5,7 +5,11 @@ import torch.nn as nn import transformers -from .quant_v2 import * +from .modelutils import GPTQVERSION +if GPTQVERSION == 1: + from .quant_v2 import quantize +elif GPTQVERSION == 2: + from .quant_v3 import quantize DEBUG = False @@ -93,6 +97,8 @@ def fasterquant( H = torch.linalg.cholesky(H, upper=True) Hinv = H + if GPTQVERSION == 2: + g_idx = [] scale = [] zero = [] now_idx = 1 @@ -144,7 +150,11 @@ def fasterquant( torch.cuda.synchronize() print('time %.2f' % (time.time() - tick)) print('error', torch.sum(Losses).item()) - + + if GPTQVERSION == 2: + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if actorder: invperm = torch.argsort(perm) Q = Q[:, invperm] @@ -160,7 +170,12 @@ def fasterquant( zero.append(self.quantizer.zero) scale = torch.cat(scale,dim=1) zero = torch.cat(zero,dim=1) - return scale,zero + if GPTQVERSION == 1: + return scale, zero + elif GPTQVERSION == 2: + return scale,zero,g_idx + else: + raise NotImplementedError("Unsupported GPTQVERSION") def free(self): if DEBUG: diff --git a/gptq/llama.py b/gptq/llama.py index 25a2028b..30a7dc50 100644 --- a/gptq/llama.py +++ b/gptq/llama.py @@ -5,8 +5,13 @@ import transformers -from .modelutils import find_layers, make_quant -from .quant_v2 import quantize, Quantizer, QuantLinear +from .gptq import GPTQ +from .modelutils import DEV, find_layers, GPTQVERSION, make_quant +if GPTQVERSION == 1: + from .quant_v2 import quantize, Quantizer, QuantLinear +elif GPTQVERSION == 2: + from .quant_v3 import quantize, Quantizer, QuantLinear +from .fused_attn import make_quant_attn def get_llama(model): @@ -105,10 +110,15 @@ def tmp(_, inp, out): h.remove() for name in subset: - print(i, name) - print('Quantizing ...') - scale,zero = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer,scale,zero) + print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') + if GPTQVERSION == 1: + scale,zero = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu()) + elif GPTQVERSION == 2: + scale,zero,g_idx = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + else: + raise NotImplementedError("Unsupported GPTQVERSION") gptq[name].free() for j in range(args.nsamples): @@ -181,7 +191,7 @@ def forward(self, inp, **kwargs): for name in subset: quantizer = Quantizer() quantizer.configure( - args.wbits, perchannel=True, sym=False, mse=False + args.wbits, perchannel=True, sym=args.sym, mse=False ) W = subset[name].weight.data quantizer.find_params(W, weight=True) @@ -229,9 +239,14 @@ def llama_pack(model, quantizers, wbits, groupsize): print('Packing ...') for name in qlayers: print(name) - quantizers[name],scale,zero = quantizers[name] - quantizers[name],scale,zero = quantizers[name].cpu(),scale.cpu(),zero.cpu() - qlayers[name].pack(layers[name], scale, zero) + if GPTQVERSION == 1: + quantizers[name],scale,zero = quantizers[name] + qlayers[name].pack(layers[name], scale, zero) + elif GPTQVERSION == 2: + quantizers[name],scale,zero,g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + else: + raise NotImplementedError("Unsupported GPTQVERSION") print('Done.') return model @@ -264,6 +279,9 @@ def noop(*args, **kwargs): model.load_state_dict(safe_load(checkpoint)) else: model.load_state_dict(torch.load(checkpoint)) + + if GPTQVERSION > 1: + make_quant_attn(model) model.seqlen = 2048 print('Done.') @@ -354,7 +372,7 @@ def sync(): if __name__ == '__main__': import argparse - from datautils import * + from .datautils import * parser = argparse.ArgumentParser() @@ -453,6 +471,19 @@ def sync(): tick = time.time() quantizers = llama_sequential(model, dataloader, DEV) print(time.time() - tick) + + if args.benchmark: + gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + llama_multigpu(model, gpus) + else: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, :args.benchmark] + benchmark(model, input_ids, check=args.check) + + if args.load: + exit() if args.eval: datasets = ['wikitext2', 'ptb', 'c4'] @@ -472,16 +503,6 @@ def sync(): if args.save_safetensors: llama_pack(model, quantizers, args.wbits, args.groupsize) from safetensors.torch import save_file as safe_save - safe_save(model.state_dict(), args.save_safetensors) - - if args.benchmark: - gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] - if len(gpus) > 1: - llama_multigpu(model, gpus) - else: - model = model.to(DEV) - if args.benchmark: - input_ids = next(iter(dataloader))[0][:, :args.benchmark] - benchmark(model, input_ids, check=args.check) - if args.load: - exit() + state_dict = model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + safe_save(state_dict, args.save_safetensors) diff --git a/gptq/modelutils.py b/gptq/modelutils.py index 6de84684..f2931910 100644 --- a/gptq/modelutils.py +++ b/gptq/modelutils.py @@ -1,12 +1,15 @@ +import os import torch import torch.nn as nn from . import quant_v1 from . import quant_v2 - -GPTQVERSION = 1 +from . import quant_v3 +GPTQVERSION = int(os.environ.get("GPTQVERSION", 1)) +if GPTQVERSION < 0 or GPTQVERSION > 2: + raise NotImplementedError(f"Unsupported gptq version: {GPTQVERSION}") DEV = torch.device('cuda:0') @@ -26,9 +29,14 @@ def set_gptq_version(version): GPTQVERSION = version +def get_gptq_version(): + return GPTQVERSION + + def make_quant(*args, **kwargs): if GPTQVERSION == 0: return quant_v1.make_quant(*args, **kwargs) - if GPTQVERSION == 1: return quant_v2.make_quant(*args, **kwargs) + if GPTQVERSION == 2: + return quant_v3.make_quant(*args, **kwargs) diff --git a/gptq/mpt.py b/gptq/mpt.py index 5cf5b9c8..541d67d5 100644 --- a/gptq/mpt.py +++ b/gptq/mpt.py @@ -5,27 +5,34 @@ import transformers -from .gptq import * -from .modelutils import * -from .quant_v2 import * +from .gptq import GPTQ +from .modelutils import DEV, find_layers, GPTQVERSION, make_quant + +if GPTQVERSION == 1: + from .quant_v2 import quantize, Quantizer, QuantLinear +elif GPTQVERSION == 2: + from .quant_v3 import quantize, Quantizer, QuantLinear from hf_bleeding_edge.mpt import MPTConfig, MPTForCausalLM def get_mpt(model): import torch + def skip(*args, **kwargs): pass + torch.nn.init.kaiming_uniform_ = skip torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip - model = MPTForCausalLM.from_pretrained(model, torch_dtype='auto') + model = MPTForCausalLM.from_pretrained(model, torch_dtype="auto") model.seqlen = 2048 return model + @torch.no_grad() def mpt_sequential(model, dataloader, dev): - print('Starting ...') + print("Starting ...") use_cache = model.config.use_cache model.config.use_cache = False @@ -39,17 +46,19 @@ def mpt_sequential(model, dataloader, dev): inps = torch.zeros( (args.nsamples, model.seqlen, model.config.d_model), dtype=dtype, device=dev ) - cache = {'i': 0, 'attention_mask': None} + cache = {"i": 0, "attention_mask": None} class Catcher(nn.Module): def __init__(self, module): super().__init__() self.module = module + def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] + inps[cache["i"]] = inp + cache["i"] += 1 + cache["attention_mask"] = kwargs["attention_mask"] raise ValueError + layers[0] = Catcher(layers[0]) for batch in dataloader: try: @@ -64,9 +73,9 @@ def forward(self, inp, **kwargs): torch.cuda.empty_cache() outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] + attention_mask = cache["attention_mask"] - print('Ready.') + print("Ready.") quantizers = {} for i in range(len(layers)): @@ -74,14 +83,14 @@ def forward(self, inp, **kwargs): full = find_layers(layer) if args.true_sequential: sequential = [ - ['attn.Wqkv'], - ['attn.out_proj'], - ['ffn.up_proj'], - ['ffn.down_proj'] + ["attn.Wqkv"], + ["attn.out_proj"], + ["ffn.up_proj"], + ["ffn.down_proj"], ] else: sequential = [list(full.keys())] - + for names in sequential: subset = {n: full[n] for n in names} gptq = {} @@ -91,11 +100,13 @@ def forward(self, inp, **kwargs): gptq[name].quantizer.configure( args.wbits, perchannel=True, sym=args.sym, mse=False ) - + def add_batch(name): def tmp(_, inp, out): gptq[name].add_batch(inp[0].data, out.data) + return tmp + handles = [] for name in subset: handles.append(subset[name].register_forward_hook(add_batch(name))) @@ -105,12 +116,34 @@ def tmp(_, inp, out): h.remove() for name in subset: - print(i, name) - print('Quantizing ...') - scale,zero = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['transformer.blocks.%d.%s' % (i, name)] = (gptq[name].quantizer,scale,zero) + print(f"Quantizing {name} in layer {i+1}/{len(layers)}...") + if GPTQVERSION == 1: + scale, zero = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + ) + quantizers["transformer.blocks.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + ) + elif GPTQVERSION == 2: + scale, zero, g_idx = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + ) + quantizers["transformer.blocks.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + ) + else: + raise NotImplementedError("Unsupported GPTQVERSION") gptq[name].free() - + for j in range(args.nsamples): outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0] @@ -125,9 +158,10 @@ def tmp(_, inp, out): return quantizers + @torch.no_grad() def mpt_eval(model, testenc, dev): - print('Evaluating ...') + print("Evaluating ...") testenc = testenc.input_ids nsamples = testenc.numel() // model.seqlen @@ -143,20 +177,22 @@ def mpt_eval(model, testenc, dev): inps = torch.zeros( (nsamples, model.seqlen, model.config.d_model), dtype=dtype, device=dev ) - cache = {'i': 0, 'attention_mask': None} + cache = {"i": 0, "attention_mask": None} class Catcher(nn.Module): def __init__(self, module): super().__init__() self.module = module + def forward(self, inp, **kwargs): - inps[cache['i']] = inp - cache['i'] += 1 - cache['attention_mask'] = kwargs['attention_mask'] + inps[cache["i"]] = inp + cache["i"] += 1 + cache["attention_mask"] = kwargs["attention_mask"] raise ValueError + layers[0] = Catcher(layers[0]) for i in range(nsamples): - batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev) + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) try: model(batch) except ValueError: @@ -168,8 +204,8 @@ def forward(self, inp, **kwargs): torch.cuda.empty_cache() outs = torch.zeros_like(inps) - attention_mask = cache['attention_mask'] - + attention_mask = cache["attention_mask"] + for i in range(len(layers)): print(i) layer = layers[i].to(dev) @@ -179,7 +215,7 @@ def forward(self, inp, **kwargs): for name in subset: quantizer = Quantizer() quantizer.configure( - args.wbits, perchannel=True, sym=False, mse=False + args.wbits, perchannel=True, sym=args.sym, mse=False ) W = subset[name].weight.data quantizer.find_params(W, weight=True) @@ -206,11 +242,11 @@ def forward(self, inp, **kwargs): hidden_states = model.transformer.norm_f(hidden_states) lm_logits = model.lm_head(hidden_states) shift_logits = lm_logits[:, :-1, :].contiguous() - shift_labels = testenc[ - :, (i * model.seqlen):((i + 1) * model.seqlen) - ][:, 1:] + shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) neg_log_likelihood = loss.float() * model.seqlen nlls.append(neg_log_likelihood) ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) @@ -218,25 +254,34 @@ def forward(self, inp, **kwargs): model.config.use_cache = use_cache + # TODO: perform packing on GPU def mpt_pack(model, quantizers, wbits, groupsize): layers = find_layers(model) layers = {n: layers[n] for n in quantizers} make_quant(model, quantizers, wbits, groupsize) qlayers = find_layers(model, [QuantLinear]) - print('Packing ...') + print("Packing ...") for name in qlayers: print(name) - quantizers[name],scale,zero = quantizers[name] - quantizers[name],scale,zero = quantizers[name].cpu(),scale.cpu(),zero.cpu() - qlayers[name].pack(layers[name], scale, zero) - print('Done.') + if GPTQVERSION == 1: + quantizers[name], scale, zero = quantizers[name] + qlayers[name].pack(layers[name], scale, zero) + elif GPTQVERSION == 2: + quantizers[name], scale, zero, g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + else: + raise NotImplementedError("Unsupported GPTQVERSION") + print("Done.") return model + def load_quant(model, checkpoint, wbits, groupsize=-1): config = MPTConfig.from_pretrained(model) + def noop(*args, **kwargs): pass + torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop @@ -248,45 +293,49 @@ def noop(*args, **kwargs): torch.set_default_dtype(torch.float) model = model.eval() layers = find_layers(model) - for name in ['lm_head']: + for name in ["lm_head"]: if name in layers: del layers[name] make_quant(model, layers, wbits, groupsize) del layers - - print('Loading model ...') - if checkpoint.endswith('.safetensors'): + + print("Loading model ...") + if checkpoint.endswith(".safetensors"): from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint)) else: model.load_state_dict(torch.load(checkpoint)) model.seqlen = 2048 - print('Done.') + print("Done.") return model + def mpt_multigpu(model, gpus): model.transformer.wte = model.transformer.wte.to(gpus[0]) - if hasattr(model.transformer, 'norm') and model.transformer.norm_f: + if hasattr(model.transformer, "norm") and model.transformer.norm_f: model.transformer.norm_f = model.transformer.norm_f.to(gpus[-1]) import copy + model.lm_head = copy.deepcopy(model.lm_head).to(gpus[-1]) - cache = {'mask': None} + cache = {"mask": None} class MoveModule(nn.Module): def __init__(self, module): super().__init__() self.module = module self.dev = next(iter(self.module.parameters())).device + def forward(self, *inp, **kwargs): inp = list(inp) if inp[0].device != self.dev: inp[0] = inp[0].to(self.dev) - if cache['mask'] is None or cache['mask'].device != self.dev: - cache['mask'] = kwargs['attention_mask'].to(self.dev) - kwargs['attention_mask'] = cache['mask'] + if cache["mask"] is None or cache["mask"].device != self.dev: + cache["mask"] = kwargs["attention_mask"].to(self.dev) + kwargs["attention_mask"] = cache["mask"] tmp = self.module(*inp, **kwargs) return tmp @@ -297,31 +346,36 @@ def forward(self, *inp, **kwargs): model.gpus = gpus + def benchmark(model, input_ids, check=False): - input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV) + input_ids = input_ids.to(model.gpus[0] if hasattr(model, "gpus") else DEV) torch.cuda.synchronize() - cache = {'past': None} + cache = {"past": None} + def clear_past(i): def tmp(layer, inp, out): - if cache['past']: - cache['past'][i] = None + if cache["past"]: + cache["past"][i] = None + return tmp + for i, layer in enumerate(model.transformer.blocks): layer.register_forward_hook(clear_past(i)) - print('Benchmarking ...') + print("Benchmarking ...") if check: loss = nn.CrossEntropyLoss() - tot = 0. + tot = 0.0 def sync(): - if hasattr(model, 'gpus'): + if hasattr(model, "gpus"): for gpu in model.gpus: torch.cuda.synchronize(gpu) else: torch.cuda.synchronize() + max_memory = 0 with torch.no_grad(): attention_mask = torch.ones((1, input_ids.numel()), device=DEV) @@ -329,113 +383,119 @@ def sync(): for i in range(input_ids.numel()): tick = time.time() out = model( - input_ids[:, i:i+1], - past_key_values=cache['past'], - attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)) + input_ids[:, i : i + 1], + past_key_values=cache["past"], + attention_mask=attention_mask[:, : (i + 1)].reshape((1, -1)), ) sync() times.append(time.time() - tick) print(i, times[-1]) - max_memory = max(max_memory,torch.cuda.memory_allocated() / 1024 /1024) + max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024) if check and i != input_ids.numel() - 1: - tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float() - cache['past'] = list(out.past_key_values) + tot += loss( + out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV) + ).float() + cache["past"] = list(out.past_key_values) del out sync() import numpy as np - print('Median:', np.median(times)) + + print("Median:", np.median(times)) if check: - print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item()) - print('max memory(MiB):',max_memory) + print("PPL:", torch.exp(tot / (input_ids.numel() - 1)).item()) + print("max memory(MiB):", max_memory) -if __name__ == '__main__': +if __name__ == "__main__": import argparse - from datautils import * + from .datautils import * parser = argparse.ArgumentParser() + parser.add_argument("model", type=str, help="mpt model to load") parser.add_argument( - 'model', type=str, - help='mpt model to load' + "dataset", + type=str, + choices=["wikitext2", "ptb", "c4"], + help="Where to extract calibration data from.", ) parser.add_argument( - 'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], - help='Where to extract calibration data from.' + "--seed", type=int, default=0, help="Seed for sampling the calibration data." ) parser.add_argument( - '--seed', - type=int, default=0, help='Seed for sampling the calibration data.' + "--nsamples", type=int, default=128, help="Number of calibration data samples." ) parser.add_argument( - '--nsamples', type=int, default=128, - help='Number of calibration data samples.' + "--percdamp", + type=float, + default=0.01, + help="Percent of the average Hessian diagonal to use for dampening.", ) parser.add_argument( - '--percdamp', type=float, default=.01, - help='Percent of the average Hessian diagonal to use for dampening.' + "--nearest", action="store_true", help="Whether to run the RTN baseline." ) parser.add_argument( - '--nearest', action='store_true', - help='Whether to run the RTN baseline.' + "--wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="#bits to use for quantization; use 16 for evaluating base model.", ) parser.add_argument( - '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], - help='#bits to use for quantization; use 16 for evaluating base model.' + "--trits", action="store_true", help="Whether to use trits for quantization." ) parser.add_argument( - '--trits', action='store_true', - help='Whether to use trits for quantization.' + "--groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", ) + parser.add_argument("--eval", action="store_true", help="evaluate quantized model.") parser.add_argument( - '--groupsize', type=int, default=-1, - help='Groupsize to use for quantization; default uses full row.' + "--save", + type=str, + default="", + help="Save quantized checkpoint under this name.", ) parser.add_argument( - '--eval', action='store_true', - help='evaluate quantized model.' + "--save_safetensors", + type=str, + default="", + help="Save quantized `.safetensors` checkpoint under this name.", ) + parser.add_argument("--load", type=str, default="", help="Load quantized model.") parser.add_argument( - '--save', type=str, default='', - help='Save quantized checkpoint under this name.' + "--benchmark", + type=int, + default=0, + help="Number of tokens to use for benchmarking.", ) parser.add_argument( - '--save_safetensors', type=str, default='', - help='Save quantized `.safetensors` checkpoint under this name.' + "--check", + action="store_true", + help="Whether to compute perplexity during benchmarking for verification.", ) parser.add_argument( - '--load', type=str, default='', - help='Load quantized model.' + "--sym", action="store_true", help="Whether to perform symmetric quantization." ) parser.add_argument( - '--benchmark', type=int, default=0, - help='Number of tokens to use for benchmarking.' + "--act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", ) parser.add_argument( - '--check', action='store_true', - help='Whether to compute perplexity during benchmarking for verification.' + "--true-sequential", + action="store_true", + help="Whether to run in true sequential model.", ) parser.add_argument( - '--sym', action='store_true', - help='Whether to perform symmetric quantization.' - ) - parser.add_argument( - '--act-order', action='store_true', - help='Whether to apply the activation order GPTQ heuristic' - ) - parser.add_argument( - '--true-sequential', action='store_true', - help='Whether to run in true sequential model.' - ) - parser.add_argument( - '--new-eval', action='store_true', - help='Whether to use the new PTB and C4 eval' + "--new-eval", action="store_true", help="Whether to use the new PTB and C4 eval" ) args = parser.parse_args() if type(args.load) is not str: args.load = args.load.as_posix() - + if args.load: model = load_quant(args.model, args.load, args.wbits, args.groupsize) else: @@ -443,19 +503,24 @@ def sync(): model.eval() dataloader, testloader = get_loaders( - args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen, use_fast=True + args.dataset, + nsamples=args.nsamples, + seed=args.seed, + model=args.model, + seqlen=model.seqlen, + use_fast=True, ) if not args.load and args.wbits < 16 and not args.nearest: tick = time.time() quantizers = mpt_sequential(model, dataloader, DEV) print(time.time() - tick) - + if args.eval: - datasets = ['wikitext2', 'ptb', 'c4'] + datasets = ["wikitext2", "ptb", "c4"] if args.new_eval: - datasets = ['wikitext2', 'ptb-new', 'c4-new'] - for dataset in datasets: + datasets = ["wikitext2", "ptb-new", "c4-new"] + for dataset in datasets: dataloader, testloader = get_loaders( dataset, seed=args.seed, model=args.model, seqlen=model.seqlen ) @@ -464,21 +529,22 @@ def sync(): if args.save: mpt_pack(model, quantizers, args.wbits, args.groupsize) - torch.save(model.state_dict(), args.save) + torch.save(model.state_dict(), args.save) if args.save_safetensors: mpt_pack(model, quantizers, args.wbits, args.groupsize) from safetensors.torch import save_file as safe_save + safe_save(model.state_dict(), args.save_safetensors) - + if args.benchmark: - gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())] + gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] if len(gpus) > 1: mpt_multigpu(model, gpus) else: model = model.to(DEV) if args.benchmark: - input_ids = next(iter(dataloader))[0][:, :args.benchmark] + input_ids = next(iter(dataloader))[0][:, : args.benchmark] benchmark(model, input_ids, check=args.check) if args.load: exit() diff --git a/gptq/opt.py b/gptq/opt.py index e200dbfc..492b5be8 100644 --- a/gptq/opt.py +++ b/gptq/opt.py @@ -3,10 +3,12 @@ import torch import torch.nn as nn -import transformers - -from .modelutils import find_layers, make_quant -from .quant_v2 import quantize, Quantizer, QuantLinear +from .gptq import GPTQ +from .modelutils import DEV, find_layers, GPTQVERSION, make_quant +if GPTQVERSION == 1: + from .quant_v2 import quantize, Quantizer, QuantLinear +elif GPTQVERSION == 2: + from .quant_v3 import quantize, Quantizer, QuantLinear def get_opt(model): @@ -102,8 +104,14 @@ def tmp(_, inp, out): for name in subset: print(f'Quantizing {name} in layer {i+1}/{len(layers)}...') - scale,zero = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) - quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu()) + if GPTQVERSION == 1: + scale,zero = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu()) + elif GPTQVERSION == 2: + scale,zero,g_idx = gptq[name].fasterquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order) + quantizers['model.decoder.layers.%d.%s' % (i, name)] = (gptq[name].quantizer.cpu(),scale.cpu(),zero.cpu(),g_idx.cpu()) + else: + raise NotImplementedError("Unsupported GPTQVERSION") gptq[name].free() for j in range(args.nsamples): @@ -236,12 +244,18 @@ def opt_pack(model, quantizers, wbits, groupsize): print('Packing ...') for name in qlayers: print(name) - quantizers[name],scale,zero = quantizers[name] - qlayers[name].pack(layers[name], scale, zero) + if GPTQVERSION == 1: + quantizers[name],scale,zero = quantizers[name] + qlayers[name].pack(layers[name], scale, zero) + elif GPTQVERSION == 2: + quantizers[name],scale,zero,g_idx = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + else: + raise NotImplementedError("Unsupported GPTQVERSION") print('Done.') return model -def load_quant(model, checkpoint, wbits, groupsize=-1): +def load_quant(model, checkpoint, wbits, groupsize): from transformers import OPTConfig, OPTForCausalLM config = OPTConfig.from_pretrained(model) def noop(*args, **kwargs): @@ -403,7 +417,7 @@ def sync(): ) parser.add_argument( '--eval', action='store_true', - help='Evaluate quantized model.' + help='evaluate quantized model.' ) parser.add_argument( '--save', type=str, default='', @@ -490,5 +504,6 @@ def sync(): opt_pack(model, quantizers, args.wbits, args.groupsize) from safetensors.torch import save_file as safe_save state_dict = model.state_dict() - state_dict['lm_head.weight'] = state_dict['lm_head.weight'].clone() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} safe_save(state_dict, args.save_safetensors) + diff --git a/gptq/quant_v3.py b/gptq/quant_v3.py new file mode 100644 index 00000000..e3cf5bed --- /dev/null +++ b/gptq/quant_v3.py @@ -0,0 +1,322 @@ +import numpy as np +import torch +import torch.nn as nn +import math + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure( + self, + bits, perchannel=False, sym=True, + mse=False, norm=2.4, grid=100, maxshrink=.8, + trits=False + ): + + self.maxq = torch.tensor(2 ** bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize(x, self.scale, self.zero, self.maxq) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + +import quant_cuda_v3 as quant_cuda + +def make_quant(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + +class QuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128): + super().__init__() + if bits not in [2,3,4,8]: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype = torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures),dtype=torch.float16)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2,4,8]: + self.register_buffer('wf',torch.tensor(list(range(0,32,self.bits)), dtype=torch.int32).unsqueeze(0),persistent=False) + elif self.bits == 3: + self.register_buffer('wf', torch.tensor([[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0],], dtype=torch.int32).reshape(1,3,12), persistent=False) + + self.kernel_switch_threshold = kernel_switch_threshold + + def pack(self, linear, scales, zeros, g_idx = None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None]) + intweight = torch.cat(intweight,dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros( + (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32 + ) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32//self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1; + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2,4,8]: + for j in range(i, i + (32//self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32//self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + else: + raise NotImplementedError("Only 2,3,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + x = x.reshape(-1,x.shape[-1]) + if self.kernel_switch_threshold is False or x.shape[0] < self.kernel_switch_threshold: + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + quant_cuda.vecquant2matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 3: + quant_cuda.vecquant3matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 4: + quant_cuda.vecquant4matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + elif self.bits == 8: + quant_cuda.vecquant8matmul(x.float(), self.qweight, out, self.scales.float(), self.qzeros, self.g_idx) + out = out.half() + else: + if self.bits in [2,4,8]: + zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8) + torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:,:,0,10] = (zeros[:,:,0,10]&0x3) | ((zeros[:,:,1,0] << 2)&0x4) + zeros[:,:,1,11] = (zeros[:,:,1,11]&0x1) | ((zeros[:,:,2,0] << 1)&0x6) + zeros = zeros & 0x7 + zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0]//3, 3, 1, self.qweight.shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1))&0x7 + weight[:,0,10] = (weight[:,0,10]&0x3) | ((weight[:,1,0] << 2)&0x4) + weight[:,1,11] = (weight[:,1,11]&0x1) | ((weight[:,2,0] << 1)&0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:,0,:11], weight[:,1,1:12], weight[:,2,1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + num_itr = self.g_idx.shape[0]//x.shape[-1] + if num_itr == 1: + weights = (self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()])) + else: + num_dim = self.g_idx.shape[0]//num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:,i*num_dim:(i+1)*num_dim] + weight_i = weight[:,i*num_dim:(i+1)*num_dim] + zeros_i = zeros[:,i*num_dim:(i+1)*num_dim] + g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights,dim=1) + out = torch.matmul(x.half(), weights) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out diff --git a/setup.py b/setup.py index b2c36ccf..d238c5ac 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="gptq-koboldai", - version="0.0.1", + version="0.0.2", install_requires=[ "hf_bleeding_edge", "torch", @@ -17,6 +17,9 @@ cpp_extension.CUDAExtension( "quant_cuda_v2", ["quant_cuda_v2/quant_cuda.cpp", "quant_cuda_v2/quant_cuda_kernel.cu"], ), + cpp_extension.CUDAExtension( + "quant_cuda_v3", ["quant_cuda_v3/quant_cuda.cpp", "quant_cuda_v3/quant_cuda_kernel.cu"], + ), ], cmdclass={"build_ext": cpp_extension.BuildExtension}, )