From 3f1dfde34437b88932e273aa29b566e3c2f83d45 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 7 Nov 2023 18:38:15 +0000 Subject: [PATCH] format --- test/stablehlo/test_pt2e_qdq.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/stablehlo/test_pt2e_qdq.py b/test/stablehlo/test_pt2e_qdq.py index cd7313ced09..02733fc9d96 100644 --- a/test/stablehlo/test_pt2e_qdq.py +++ b/test/stablehlo/test_pt2e_qdq.py @@ -19,9 +19,11 @@ class PT2EExportTest(unittest.TestCase): def test_per_tensor_qdq(self): device = xm.xla_device() - x = torch.randn(2,3,4,5).to(device) - x = torch.ops.quantized_decomposed.quantize_per_tensor(x, 0.4, 2, -128, 127, torch.int8) - x = torch.ops.quantized_decomposed.dequantize_per_tensor(x, 0.4, 2, -128, 127, torch.int8) + x = torch.randn(2, 3, 4, 5).to(device) + x = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 0.4, 2, -128, 127, torch.int8) + x = torch.ops.quantized_decomposed.dequantize_per_tensor( + x, 0.4, 2, -128, 127, torch.int8) stablehlo_txt = xm.get_stablehlo([x]) self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt) self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt) @@ -29,11 +31,13 @@ def test_per_tensor_qdq(self): @unittest.skip("Currently Failing") def test_per_channel_qdq(self): device = xm.xla_device() - x = torch.randn(2,3,4,5).to(device) + x = torch.randn(2, 3, 4, 5).to(device) scale = torch.tensor([3.2, 5.3, -0.1, 10]) zero_point = torch.tensor([1, 2, -1, -2]) - x = torch.ops.quantized_decomposed.quantize_per_channel(x, scale, zero_point, 2, -128, 127, torch.int8) - x = torch.ops.quantized_decomposed.dequantize_per_channel(x, scale, zero_point, 2, -128, 127, torch.int8) + x = torch.ops.quantized_decomposed.quantize_per_channel( + x, scale, zero_point, 2, -128, 127, torch.int8) + x = torch.ops.quantized_decomposed.dequantize_per_channel( + x, scale, zero_point, 2, -128, 127, torch.int8) stablehlo_txt = xm.get_stablehlo([x]) self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt) self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt)