Skip to content

Commit

Permalink
Ruff lint (#1646)
Browse files Browse the repository at this point in the history
lint
  • Loading branch information
metascroy authored Jan 30, 2025
1 parent 7815262 commit 48fdd31
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
)
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
)

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)
Expand All @@ -41,12 +40,14 @@
handler.setFormatter(formatter)
logger.addHandler(handler)


class Target(Enum):
"""Enum that indicates the backend target"""

NATIVE = auto()
ATEN = auto()


def target_from_str(target: str) -> Target:
if target.lower() == "native":
return Target.NATIVE
Expand All @@ -55,6 +56,7 @@ def target_from_str(target: str) -> Target:
else:
raise ValueError(f"Invalid target: {target}")


class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
bit_width: Optional[int]
group_size: Optional[int]
Expand Down Expand Up @@ -157,7 +159,10 @@ def from_plain(
):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"
assert layout.target in {
Target.NATIVE,
Target.ATEN,
}, f"Unexpected target: {layout.target}"

# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
# when AOTI supports int
Expand All @@ -167,10 +172,14 @@ def from_plain(
k_tensor = torch.empty(0, k, dtype=torch.int8)

if layout.target == Target.ATEN:
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
assert (
TORCH_VERSION_AT_LEAST_2_6
), "aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(
int_data, scale, bias, layout.group_size, k, n
)
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)

if layout.has_weight_zeros:
Expand Down Expand Up @@ -248,12 +257,11 @@ def __tensor_unflatten__(
def _linear_check(input_tensor, weight_tensor, bias):
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
bias is None or layout.target == Target.ATEN # Aten target allows bias
bias is None or layout.target == Target.ATEN # Aten target allows bias
)


def _linear_impl(input_tensor, weight_tensor, bias):

def _impl_2d_native(input_tensor, weight_tensor):
assert input_tensor.dim() == 2
assert weight_tensor.dim() == 2
Expand Down Expand Up @@ -299,14 +307,13 @@ def _impl_2d_aten(input_tensor, weight_tensor):
group_size = weight_tensor.tensor_impl.get_layout().group_size
packed_weight = weight_tensor.tensor_impl.packed_weight
return torch.ops.aten._dyn_quant_matmul_4bit(
input_tensor, packed_weight, group_size, k_, n)
input_tensor, packed_weight, group_size, k_, n
)

target = weight_tensor.tensor_impl.get_layout().target

if target == Target.ATEN:
assert (
TORCH_VERSION_AT_LEAST_2_6 == 1
), "Target.ATEN requires torch >= 2.6.0"
assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0"
_impl_2d = _impl_2d_aten
elif target == Target.NATIVE:
_impl_2d = _impl_2d_native
Expand All @@ -327,6 +334,7 @@ def _impl_2d_aten(input_tensor, weight_tensor):
res = res.reshape(*lead_shape, m, n)
return res


register_aqt_quantized_linear_dispatch(
_linear_check,
_linear_impl,
Expand Down Expand Up @@ -354,12 +362,17 @@ def from_hp_to_intx(
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
use_hqq: bool = False,
bias: Optional[torch.Tensor] = None
bias: Optional[torch.Tensor] = None,
):
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
assert (
use_hqq == False
), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
assert isinstance(
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout
), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
assert (
_layout.target == Target.ATEN
), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)

Expand Down Expand Up @@ -405,4 +418,7 @@ def from_hp_to_intx(
dtype=input_float.dtype,
)

to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx

to_packedlinearint8dynamicactivationintxweight_quantized_intx = (
PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx
)
89 changes: 50 additions & 39 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import sys
import logging
import sys
from typing import Optional, Union

import torch
Expand All @@ -15,22 +15,21 @@
quantize_per_channel_group,
)

from torchao.dtypes import PlainLayout
from torchao.quantization.granularity import (
PerGroup,
PerRow,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_6,
)
from torchao.dtypes import PlainLayout

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s")
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

Expand Down Expand Up @@ -494,8 +493,8 @@ def quantize(self, model: nn.Module) -> nn.Module:

from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
PackedLinearInt8DynamicActivationIntxWeightLayout,
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
Target,
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
)
from torchao.quantization.linear_activation_quantized_tensor import (
to_linear_activation_quantized,
Expand All @@ -515,7 +514,9 @@ def int8_dynamic_activation_intx_weight(
has_weight_zeros: bool = False,
weight_mapping_type=MappingType.ASYMMETRIC,
act_mapping_type=MappingType.ASYMMETRIC,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
target="native"
), # PlainLayout() also works, but will be slow
):
"""
Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers.
Expand All @@ -540,13 +541,10 @@ def int8_dynamic_activation_intx_weight(
"""

def is_torchao_op_skippable(layout):
return (
isinstance(layout, PlainLayout) or
(
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and
layout.target == Target.ATEN
)
)
return isinstance(layout, PlainLayout) or (
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
and layout.target == Target.ATEN
)

if not is_torchao_op_skippable(layout):
try:
Expand Down Expand Up @@ -574,7 +572,10 @@ def is_torchao_op_skippable(layout):
)
bit_width = dtype_to_bit_width[weight_dtype]
layout_arg = layout
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN
propagate_bias = (
isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout)
and layout_arg.target == Target.ATEN
)

def apply(weight, bias: Optional[torch.Tensor] = None):
if isinstance(granularity, PerGroup):
Expand Down Expand Up @@ -612,35 +613,45 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
target="aten" if layout.target == Target.ATEN else "native",
)
if layout.target == Target.ATEN:
if weight_dtype != torch.int4 or \
has_weight_zeros != True or \
weight_mapping_type == MappingType.ASYMMETRIC:
if (
weight_dtype != torch.int4
or has_weight_zeros != True
or weight_mapping_type == MappingType.ASYMMETRIC
):
raise NotImplementedError(
f"target 'aten' requires:\n"
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
f"- has_weight_zeros to be True,\n"
f"- weight_dtype to be torch.int4,\n"
f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR"
"target 'aten' requires:\n"
"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
"- has_weight_zeros to be True,\n"
"- weight_dtype to be torch.int4,\n"
"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR"
)
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
assert (
TORCH_VERSION_AT_LEAST_2_6
), "aten target is requires torch version > 2.6.0"
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx

quantizer_args = [weight,
weight_mapping_type,
(1, group_size),
torch.int32,
quant_min,
quant_max,
torch.finfo(torch.float32).eps,
scale_dtype,
torch.int8,
has_weight_zeros,
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE,
layout,
False] + ([bias] if propagate_bias else [])
scale_dtype = (
torch.bfloat16
) # KleidiAI kernel requires bfloat16 scale_dtype
tensor_quantizer = (
to_packedlinearint8dynamicactivationintxweight_quantized_intx
)

quantizer_args = [
weight,
weight_mapping_type,
(1, group_size),
torch.int32,
quant_min,
quant_max,
torch.finfo(torch.float32).eps,
scale_dtype,
torch.int8,
has_weight_zeros,
ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE,
layout,
False,
] + ([bias] if propagate_bias else [])

weight = tensor_quantizer(*quantizer_args)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
int8_dynamic_activation_intx_weight,
)
from torchao.quantization.granularity import (
PerGroup,
PerRow,
)
from torchao.quantization.quant_api import quantize_
from torchao.utils import unwrap_tensor_subclass
from torchao.quantization.quant_primitives import MappingType


Expand Down Expand Up @@ -57,7 +55,8 @@ def test_accuracy(self):
has_weight_zeros=has_weight_zeros,
weight_mapping_type=weight_mapping_type,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
target="aten"), # default
target="aten"
), # default
),
)

Expand Down

0 comments on commit 48fdd31

Please sign in to comment.