Skip to content

Commit

Permalink
consolidate pack() into packer cls
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium committed Jan 28, 2025
1 parent 1a8d17a commit 71f1f8b
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 154 deletions.
72 changes: 72 additions & 0 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,81 @@
import numpy as np
import torch as t # conflict with torch.py
import torch.nn as nn
import transformers

from ...models._const import DEVICE, PLATFORM

class Packer():
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()

self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().to(dtype=t.float16)
if linear.bias is not None:
self.bias = linear.bias.clone().to(dtype=t.float16)

intweight = t.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(t.int32)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(self.pack_np_math_dtype)

qweight = np.zeros((intweight.shape[0] // self.pack_dtype_bits * self.bits, intweight.shape[1]),
dtype=self.pack_np_dtype)
if self.bits in [2, 4, 8]:
for row in range(qweight.shape[0]):
for j in range(self.pack_factor):
qweight[row] |= intweight[row * self.pack_factor + j] << (self.bits * j)
elif self.bits == 3:
for row in range(qweight.shape[0]):
row_offset = row * 10 # Cache row * 10
row_offset_plus_10 = row_offset + 10 # Cache row * 10 + 10
for j in range(10):
qweight[row] |= intweight[row_offset + j] << (3 * j)
qweight[row] |= intweight[row_offset_plus_10] << 30
row += 1
qweight[row] |= (intweight[row_offset_plus_10] >> 2) & 1
for j in range(10):
qweight[row] |= intweight[row_offset + j] << (3 * j + 1)
qweight[row] |= intweight[row_offset_plus_10] << 31
row += 1
qweight[row] |= (intweight[row_offset_plus_10] >> 1) & 0x3
for j in range(10):
qweight[row] |= intweight[row_offset + j] << (3 * j + 2)

self.qweight = t.from_numpy(qweight.astype(self.pack_np_dtype))

zeros = zeros.numpy().astype(self.pack_np_math_dtype)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // self.pack_dtype_bits * self.bits),
dtype=self.pack_np_math_dtype)
if self.bits in [2, 4, 8]:
for col in range(qzeros.shape[1]):
for j in range(self.pack_factor):
qzeros[:, col] |= zeros[:, col * self.pack_factor + j] << (self.bits * j)
elif self.bits == 3:
for col in range(qzeros.shape[1]):
col_offset = col * 10 # Cache col * 10
col_offset_plus_10 = col_offset + 10 # Cache col * 10 + 10
for j in range(10):
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j)
qzeros[:, col] |= zeros[:, col_offset_plus_10] << 30
col += 1
qzeros[:, col] |= (zeros[:, col_offset_plus_10] >> 2) & 1
for j in range(10):
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j + 1)
qzeros[:, col] |= zeros[:, col_offset_plus_10] << 31
col += 1
qzeros[:, col] |= (zeros[:, col_offset_plus_10] >> 1) & 0x3
for j in range(10):
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j + 2)

self.qzeros = t.from_numpy(qzeros.astype(self.pack_np_dtype))

class BaseQuantLinear(nn.Module):
SUPPORTS_BITS: List[int] = None
Expand Down
42 changes: 2 additions & 40 deletions gptqmodel/nn_modules/qlinear/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn.functional as F
import transformers

from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.nn_modules.qlinear import BaseQuantLinear, Packer

from ...models._const import DEVICE, PLATFORM

Expand Down Expand Up @@ -59,7 +59,7 @@ def ext_q4_matmul(x, q4, q4_width):
return output.view(outshape)


class ExllamaQuantLinear(BaseQuantLinear):
class ExllamaQuantLinear(BaseQuantLinear, Packer):
SUPPORTS_BITS = [4]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_DESC_ACT = [True, False]
Expand Down Expand Up @@ -158,44 +158,6 @@ def post_init(self):
self.qweight.device.index,
)

def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()

self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()

intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)

qweight = np.zeros((intweight.shape[0] // self.pack_dtype_bits * self.bits, intweight.shape[1]), dtype=np.uint32)
for row in range(qweight.shape[0]):
i = row * (self.pack_factor)
for j in range(self.pack_factor):
qweight[row] |= intweight[i + j] << (self.bits * j)

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)

zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // self.pack_dtype_bits * self.bits), dtype=np.uint32)
for col in range(qzeros.shape[1]):
i = col * (self.pack_factor)
for j in range(self.pack_factor):
qzeros[:, col] |= zeros[:, i + j] << (self.bits * j)

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)

def forward(self, x):
if x.dtype != torch.float16:
Expand Down
71 changes: 2 additions & 69 deletions gptqmodel/nn_modules/qlinear/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
import torch.nn.functional as F
import transformers

from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.nn_modules.qlinear import BaseQuantLinear, Packer
from gptqmodel.utils.logger import setup_logger

from ...models._const import DEVICE, PLATFORM


logger = setup_logger()

class TorchQuantLinear(BaseQuantLinear):
class TorchQuantLinear(BaseQuantLinear, Packer):
SUPPORTS_BITS = [2, 3, 4, 8]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_DESC_ACT = [True, False]
Expand Down Expand Up @@ -119,74 +119,7 @@ def post_init(self):
self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32,
device=self.g_idx.device)

def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()

self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().to(dtype=torch.float16)
if linear.bias is not None:
self.bias = linear.bias.clone().to(dtype=torch.float16)

intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int32)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(self.pack_np_math_dtype)

qweight = np.zeros((intweight.shape[0] // self.pack_dtype_bits * self.bits, intweight.shape[1]), dtype=self.pack_np_dtype)
if self.bits in [2, 4, 8]:
for row in range(qweight.shape[0]):
for j in range(self.pack_factor):
qweight[row] |= intweight[row * self.pack_factor + j] << (self.bits * j)
elif self.bits == 3:
for row in range(qweight.shape[0]):
row_offset = row * 10 # Cache row * 10
row_offset_plus_10 = row_offset + 10 # Cache row * 10 + 10
for j in range(10):
qweight[row] |= intweight[row_offset + j] << (3 * j)
qweight[row] |= intweight[row_offset_plus_10] << 30
row += 1
qweight[row] |= (intweight[row_offset_plus_10] >> 2) & 1
for j in range(10):
qweight[row] |= intweight[row_offset + j] << (3 * j + 1)
qweight[row] |= intweight[row_offset_plus_10] << 31
row += 1
qweight[row] |= (intweight[row_offset_plus_10] >> 1) & 0x3
for j in range(10):
qweight[row] |= intweight[row_offset + j] << (3 * j + 2)

self.qweight = torch.from_numpy(qweight.astype(self.pack_np_dtype))

zeros = zeros.numpy().astype(self.pack_np_math_dtype)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // self.pack_dtype_bits * self.bits), dtype=self.pack_np_math_dtype)
if self.bits in [2, 4, 8]:
for col in range(qzeros.shape[1]):
for j in range(self.pack_factor):
qzeros[:, col] |= zeros[:, col * self.pack_factor + j] << (self.bits * j)
elif self.bits == 3:
for col in range(qzeros.shape[1]):
col_offset = col * 10 # Cache col * 10
col_offset_plus_10 = col_offset + 10 # Cache col * 10 + 10
for j in range(10):
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j)
qzeros[:, col] |= zeros[:, col_offset_plus_10] << 30
col += 1
qzeros[:, col] |= (zeros[:, col_offset_plus_10] >> 2) & 1
for j in range(10):
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j + 1)
qzeros[:, col] |= zeros[:, col_offset_plus_10] << 31
col += 1
qzeros[:, col] |= (zeros[:, col_offset_plus_10] >> 1) & 0x3
for j in range(10):
qzeros[:, col] |= zeros[:, col_offset + j] << (3 * j + 2)

self.qzeros = torch.from_numpy(qzeros.astype(self.pack_np_dtype))

def forward(self, x: torch.Tensor):
if x.size(-1) != self.padded_infeatures:
Expand Down
44 changes: 2 additions & 42 deletions gptqmodel/nn_modules/qlinear/tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from packaging import version

from ...models._const import DEVICE, PLATFORM
from ...utils.logger import setup_logger
from . import BaseQuantLinear

from . import BaseQuantLinear, Packer

try:
import triton
Expand All @@ -49,7 +46,7 @@ class TritonModuleMixin:
logger = setup_logger()


class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin, Packer):
SUPPORTS_BITS = [2, 4, 8]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_DESC_ACT = [True, False]
Expand Down Expand Up @@ -138,43 +135,6 @@ def post_init(self):
self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32,
device=self.g_idx.device)

def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()

self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()

intweight = torch.round((W + scale_zeros[self.g_idx].T) / scales[self.g_idx].T).to(torch.int32)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(self.pack_np_math_dtype)

qweight = np.zeros((intweight.shape[0] // self.pack_factor, intweight.shape[1]), dtype=self.pack_np_dtype)
for row in range(qweight.shape[0]):
i = row * self.pack_factor
for j in range(self.pack_factor):
qweight[row] |= intweight[i + j] << (self.bits * j)

self.qweight = torch.from_numpy(qweight.astype(self.pack_np_dtype))

zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // self.pack_factor), dtype=self.self.pack_np_math_dtype)
for col in range(qzeros.shape[1]):
i = col * self.pack_factor
for j in range(self.pack_factor):
qzeros[:, col] |= zeros[:, i + j] << (self.bits * j)

self.qzeros = torch.from_numpy(qzeros.astype(self.pack_np_dtype))

def forward(self, x):
# if infeatures is padded, we need to pad the input as well
if x.size(-1) != self.padded_infeatures:
Expand Down
4 changes: 1 addition & 3 deletions gptqmodel/nn_modules/qlinear/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
# Unpack 4-bit values and interpret them as signed integers
unpacked_weights = torch.zeros(
Expand All @@ -27,7 +26,6 @@ def unpack_4bit_to_32bit_signed(qweight, qzeros):


# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def dequantize_4bits_weight(layer):
qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
Expand All @@ -36,4 +34,4 @@ def dequantize_4bits_weight(layer):
scales = scales.repeat_interleave(group_size, dim=0)
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
return unpacked_qweight.T, unpacked_qzeros
return unpacked_qweight.T, unpacked_qzeros

0 comments on commit 71f1f8b

Please sign in to comment.