Skip to content

Commit

Permalink
broadcast scale zp
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed May 14, 2024
1 parent b64d8a2 commit 28d6b50
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
18 changes: 17 additions & 1 deletion test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ def test_per_channel_qdq(self):
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

def test_per_channel_qdq_scalar_scale(self):
device = xm.xla_device()
x = torch.randn(2, 3, 4, 5).to(device)
scale = torch.tensor([3.2]).to(device)
zero_point = torch.tensor([-1], dtype=torch.int64).to(device)
x = torch.ops.quantized_decomposed.quantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertEqual(
stablehlo_txt.count(
'tensor<2x3x4x5x!quant.uniform<i8:f32:2, {3.200000e+00:-1,3.200000e+00:-1,3.200000e+00:-1,3.200000e+00:-1}>>'
), 2)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand Down Expand Up @@ -116,7 +133,6 @@ def test_resnet18(self):
save_torch_module_as_tf_saved_model(m, args, tmp_path)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

@unittest.skip
def test_resnet18_per_channel(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/experimental/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def _xla_quantize(input: torch.Tensor,
dtype: torch.dtype,
axis: int = -1):
_check_scale_zp(input, scale, zero_point, axis, dtype)
if axis != -1:
# PT2E generate scalar for per-channel quant, which shouldn't be expected
# https://github.com/pytorch/pytorch/issues/126189
# StableHLO uniform_quantize requires the scalar/zero_point size to be the
# same as the size of the axis that is quantized along.
# We will broadcast the scalar to the size of the axis for now.
if input.shape[axis] != scale.numel() and scale.numel() == 1:
scale = scale.cpu().broadcast_to((input.shape[axis],))
if input.shape[axis] != zero_point.numel() and zero_point.numel() == 1:
zero_point = zero_point.cpu().broadcast_to((input.shape[axis],))
# Scale and zero_point need to be unpacked(materialized before enter LTC),
# because the quant param will be attached to tensor Shape in HLO/StableHLO.
scale_np = _unpack_tensor_to_list(scale)
Expand Down

0 comments on commit 28d6b50

Please sign in to comment.