Skip to content

Commit

Permalink
breaking(gptq): quantizing only 1 layer yield high perplexity
Browse files Browse the repository at this point in the history
  • Loading branch information
3outeille committed May 7, 2023
1 parent cf14124 commit c2bbe64
Show file tree
Hide file tree
Showing 7 changed files with 892 additions and 302 deletions.
239 changes: 0 additions & 239 deletions quantize/compress_rwkv.py

This file was deleted.

2 changes: 1 addition & 1 deletion quantize/gptq/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import tokenizers
import random
from rwkv.model import RWKV
from myRWKV import RWKV

from datasets import load_dataset

Expand Down
5 changes: 3 additions & 2 deletions quantize/gptq/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def pack(self, weight, bias, scales, zeros, g_idx = None):

intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
#OLD: intweight.append(torch.round((weight.data[:,idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
intweight.append(torch.round((weight.data[idx, :] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:,None])
intweight = torch.cat(intweight,dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
Expand Down Expand Up @@ -411,7 +412,7 @@ def pack(self, linear, scales, zeros, g_idx = None):
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros -= 1;
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
Expand Down
15 changes: 9 additions & 6 deletions quantize/measure_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from typing import List
from rwkv.model import RWKV
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0'

def parse_args():
parser = argparse.ArgumentParser(description='Measure perplexity and per-token latency of an RWKV model on a given text file')
Expand Down Expand Up @@ -56,9 +58,10 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:

# ---
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device=torch.device('cpu')
# device = torch.device('cpu')

model = RWKV(model=args.model_path, strategy='cuda fp16i8')
#TODO: Why is PERPLEXITY SO DAMN HIGH ?
model = RWKV(model=args.model_path, strategy='cuda fp16')

logits, state = None, None
loss_sum: torch.Tensor = torch.tensor([0.0], device=device)
Expand All @@ -72,7 +75,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
for i in range(run_count):
token: int = test_tokens[i]
target: int = test_tokens[i + 1]

logits, state = model.forward([token], None if i == 0 else state)

if ignore_first_n_tokens == 0 or i + 1 >= ignore_first_n_tokens:
Expand Down Expand Up @@ -105,7 +108,7 @@ def format_loss_with_perplexity(loss: torch.Tensor) -> str:
print(f'Average latency: {int((time.time() - start) * 1000 / run_count)} ms per token')

print()
print(f'Model: {os.path.basename(args.model_path)}, '
f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens, '
f'Ignored first {ignore_first_n_tokens} tokens, '
print(f'Model: {os.path.basename(args.model_path)}\n'
f'data: {os.path.basename(args.dataset_path)} with {token_count} tokens\n'
f'Ignored first {ignore_first_n_tokens} tokens\n'
f'averages: {format_loss_with_perplexity(loss_sum / loss_count)}')
Loading

0 comments on commit c2bbe64

Please sign in to comment.