Skip to content

Commit

Permalink
[Feat]: Add support for kleidiai quantization schemes (#1447)
Browse files Browse the repository at this point in the history
  • Loading branch information
ng-05 authored Jan 30, 2025
1 parent 463a872 commit 7815262
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 38 deletions.
31 changes: 31 additions & 0 deletions torchao/experimental/docs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,37 @@ quantize_(
)
```

KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:

```python
from torchao.dtypes import PlainLayout
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
)
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.granularity import (
PerGroup,
PerRow,
)
from torchao.quantization.quant_api import quantize_
from torchao.quantization.quant_primitives import MappingType

my_model = Model()

quantize_(
my_model,
int8_dynamic_activation_intx_weight(
weight_dtype=torch.int4,
granularity=PerGroup(32), # PerRow() is also supported
has_weight_zeros=True, # Should be True
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
),
)
```

If you get stuck, consult
`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py`
for a working example.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
# LICENSE file in the root directory of this source tree.

import logging
from enum import Enum, auto
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.affine_quantized_tensor_ops import (
Expand All @@ -19,6 +22,13 @@
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
MappingType,
choose_qparams_affine,
quantize_affine,
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)

logger = logging.getLogger(__name__)
Expand All @@ -31,17 +41,33 @@
handler.setFormatter(formatter)
logger.addHandler(handler)

class Target(Enum):
"""Enum that indicates the backend target"""

NATIVE = auto()
ATEN = auto()

def target_from_str(target: str) -> Target:
if target.lower() == "native":
return Target.NATIVE
elif target.lower() == "aten":
return Target.ATEN
else:
raise ValueError(f"Invalid target: {target}")

class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
has_weight_zeros: Optional[bool]
# The target platform for the layout, 'native' or 'aten'
target: Optional[Target]

def __init__(
self,
bit_width: Optional[int] = None,
group_size: Optional[int] = None,
has_weight_zeros: Optional[bool] = None,
target: Optional[str] = "native",
):
if bit_width is not None:
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
Expand All @@ -51,6 +77,7 @@ def __init__(
self.bit_width = bit_width
self.group_size = group_size
self.has_weight_zeros = has_weight_zeros
self.target = target_from_str(target)

if not self.has_params_set():
assert (
Expand All @@ -60,13 +87,14 @@ def __init__(
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"

def extra_repr(self):
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}"
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"

def has_params_set(self) -> bool:
return (
(self.bit_width is not None)
and (self.group_size is not None)
and (self.has_weight_zeros is not None)
and (self.target is not None)
)


Expand Down Expand Up @@ -125,9 +153,11 @@ def from_plain(
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
layout: Layout,
bias: Optional[torch.Tensor] = None,
):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"

# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
# when AOTI supports int
Expand All @@ -136,6 +166,13 @@ def from_plain(
n_tensor = torch.empty(0, n, dtype=torch.int8)
k_tensor = torch.empty(0, k, dtype=torch.int8)

if layout.target == Target.ATEN:
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)

if layout.has_weight_zeros:
args = [
int_data.to(torch.int8),
Expand Down Expand Up @@ -211,16 +248,13 @@ def __tensor_unflatten__(
def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
bias is None
bias is None or layout.target == Target.ATEN # Aten target allows bias
)


def _linear_impl(input_tensor, weight_tensor, bias):
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"

def _impl_2d(input_tensor, weight_tensor):
def _impl_2d_native(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

Expand Down Expand Up @@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor):
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight"
)(*args)

def _impl_2d_aten(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2

m, k = input_tensor.shape
n, k_ = weight_tensor.shape
assert k_ == k
group_size = weight_tensor.tensor_impl.get_layout().group_size
packed_weight = weight_tensor.tensor_impl.packed_weight
return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)

target = weight_tensor.tensor_impl.get_layout().target

if target == Target.ATEN:
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "Target.ATEN requires torch >= 2.6.0"
_impl_2d = _impl_2d_aten
elif target == Target.NATIVE:
_impl_2d = _impl_2d_native
assert (
bias is None
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "

if input_tensor.dim() == 2:
return _impl_2d(input_tensor, weight_tensor)

Expand All @@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor):
res = res.reshape(*lead_shape, m, n)
return res


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
)


class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
"""
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
"""

@classmethod
def from_hp_to_intx(
cls,
input_float: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
):
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
assert isinstance(
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

scale, zero_point = choose_qparams_affine(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias)
return cls(
tensor_impl,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)

to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
Loading

0 comments on commit 7815262

Please sign in to comment.