Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 committed Nov 7, 2023
1 parent f1d8181 commit 3f1dfde
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions test/stablehlo/test_pt2e_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@ class PT2EExportTest(unittest.TestCase):

def test_per_tensor_qdq(self):
device = xm.xla_device()
x = torch.randn(2,3,4,5).to(device)
x = torch.ops.quantized_decomposed.quantize_per_tensor(x, 0.4, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(x, 0.4, 2, -128, 127, torch.int8)
x = torch.randn(2, 3, 4, 5).to(device)
x = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 0.4, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 0.4, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt)
self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt)

@unittest.skip("Currently Failing")
def test_per_channel_qdq(self):
device = xm.xla_device()
x = torch.randn(2,3,4,5).to(device)
x = torch.randn(2, 3, 4, 5).to(device)
scale = torch.tensor([3.2, 5.3, -0.1, 10])
zero_point = torch.tensor([1, 2, -1, -2])
x = torch.ops.quantized_decomposed.quantize_per_channel(x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_channel(x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.quantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
x = torch.ops.quantized_decomposed.dequantize_per_channel(
x, scale, zero_point, 2, -128, 127, torch.int8)
stablehlo_txt = xm.get_stablehlo([x])
self.assertTrue("stablehlo.uniform_quantize" in stablehlo_txt)
self.assertTrue("stablehlo.uniform_dequantize" in stablehlo_txt)
Expand Down

0 comments on commit 3f1dfde

Please sign in to comment.