diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index c9e5f04af65..2d2173901bd 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -84,6 +84,23 @@ def test_per_channel_qdq(self): self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) + def test_per_channel_qdq_scalar_scale(self): + device = xm.xla_device() + x = torch.randn(2, 3, 4, 5).to(device) + scale = torch.tensor([3.2]).to(device) + zero_point = torch.tensor([-1], dtype=torch.int64).to(device) + x = torch.ops.quantized_decomposed.quantize_per_channel( + x, scale, zero_point, 2, -128, 127, torch.int8) + x = torch.ops.quantized_decomposed.dequantize_per_channel( + x, scale, zero_point, 2, -128, 127, torch.int8) + stablehlo_txt = xm.get_stablehlo([x]) + self.assertEqual( + stablehlo_txt.count( + 'tensor<2x3x4x5x!quant.uniform>' + ), 2) + self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1) + self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1) + def test_resnet18(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) @@ -116,7 +133,6 @@ def test_resnet18(self): save_torch_module_as_tf_saved_model(m, args, tmp_path) self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb'))) - @unittest.skip def test_resnet18_per_channel(self): # Step 1: export resnet18 args = (torch.randn(1, 3, 224, 224),) diff --git a/torch_xla/experimental/quantized.py b/torch_xla/experimental/quantized.py index 05d5e7bb815..c52efe6c26b 100644 --- a/torch_xla/experimental/quantized.py +++ b/torch_xla/experimental/quantized.py @@ -77,6 +77,16 @@ def _xla_quantize(input: torch.Tensor, dtype: torch.dtype, axis: int = -1): _check_scale_zp(input, scale, zero_point, axis, dtype) + if axis != -1: + # PT2E generate scalar for per-channel quant, which shouldn't be expected + # https://github.com/pytorch/pytorch/issues/126189 + # StableHLO uniform_quantize requires the scalar/zero_point size to be the + # same as the size of the axis that is quantized along. + # We will broadcast the scalar to the size of the axis for now. + if input.shape[axis] != scale.numel() and scale.numel() == 1: + scale = scale.cpu().broadcast_to((input.shape[axis],)) + if input.shape[axis] != zero_point.numel() and zero_point.numel() == 1: + zero_point = zero_point.cpu().broadcast_to((input.shape[axis],)) # Scale and zero_point need to be unpacked(materialized before enter LTC), # because the quant param will be attached to tensor Shape in HLO/StableHLO. scale_np = _unpack_tensor_to_list(scale)