Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reenable disabled pt2e test #7059

Merged
merged 2 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ def test_per_channel_qdq(self):
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

def test_per_channel_qdq_scalar_scale(self):
device = xm.xla_device()
x = torch.randn(2, 3, 4, 5).to(device)
scale = torch.tensor([3.2]).to(device)
zero_point = torch.tensor([-1], dtype=torch.int64).to(device)
x = torch.ops.quantized_decomposed.quantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertEqual(
stablehlo_txt.count(
'tensor<2x3x4x5x!quant.uniform<i8:f32:2, {3.200000e+00:-1,3.200000e+00:-1,3.200000e+00:-1,3.200000e+00:-1}>>'
), 2)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_quantize"), 1)
self.assertEqual(stablehlo_txt.count("stablehlo.uniform_dequantize"), 1)

def test_resnet18(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand Down Expand Up @@ -116,7 +133,6 @@ def test_resnet18(self):
save_torch_module_as_tf_saved_model(m, args, tmp_path)
self.assertTrue(os.path.exists(os.path.join(tmp_path, 'saved_model.pb')))

@unittest.skip
def test_resnet18_per_channel(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/experimental/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ def _xla_quantize(input: torch.Tensor,
dtype: torch.dtype,
axis: int = -1):
_check_scale_zp(input, scale, zero_point, axis, dtype)
if axis != -1:
# PT2E generate scalar for per-channel quant, which shouldn't be expected
# https://github.com/pytorch/pytorch/issues/126189
# StableHLO uniform_quantize requires the scalar/zero_point size to be the
# same as the size of the axis that is quantized along.
# We will broadcast the scalar to the size of the axis for now.
if input.shape[axis] != scale.numel() and scale.numel() == 1:
scale = scale.cpu().broadcast_to((input.shape[axis],))
if input.shape[axis] != zero_point.numel() and zero_point.numel() == 1:
zero_point = zero_point.cpu().broadcast_to((input.shape[axis],))
# Scale and zero_point need to be unpacked(materialized before enter LTC),
# because the quant param will be attached to tensor Shape in HLO/StableHLO.
scale_np = _unpack_tensor_to_list(scale)
Expand Down
Loading