From 81284808f1a73268199c7c427feda0de84c8ce56 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 6 Feb 2024 13:18:18 -0800 Subject: [PATCH] Remove assert on zero_point dtype of quantize/dequantize op (#6475) --- torch_xla/experimental/quantized.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch_xla/experimental/quantized.py b/torch_xla/experimental/quantized.py index 7fed638895a..df7d4645ba2 100644 --- a/torch_xla/experimental/quantized.py +++ b/torch_xla/experimental/quantized.py @@ -50,9 +50,15 @@ def _check_scale_zp(input, scale, zero_point, axis, dtype): # The followings are checked: # 1. scale, zp are 1D tensor. # 2. Lenghth of scale, zp matched the (de)quant dim. - # 3. zp dtype is the same as the quantized integer type. + # 3. dtype must be integer type + # 4. zero_point values must be within the range of dtype. assert len(scale.shape) == 1 and len(zero_point.shape) == 1 - assert zero_point.dtype == dtype + assert 'int' in str(dtype) + assert torch.equal( + zero_point, + torch.clamp(zero_point, + torch.iinfo(dtype).min, + torch.iinfo(dtype).max)) if axis == -1: assert scale.numel() == 1 and zero_point.numel() == 1 else: