diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index 3997f4e9b8fe9..ff8e807ecb600 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -4,6 +4,7 @@ import torch import vllm._custom_ops as ops +from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm DTYPES = [torch.bfloat16, torch.float] @@ -27,11 +28,11 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') + def ref_rms_norm(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - + -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, residual = rms_norm_layer.forward_native(x, residual) @@ -40,13 +41,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm, return out, residual -def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if scale_ub is not None: assert quant_dtype == torch.float8_e4m3fn @@ -64,22 +65,23 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, return torch_out, scales, residual + def ref_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, residual, scale_ub) -def ops_dynamic_per_token_quant(weight: torch.Tensor, - x: torch.Tensor, - quant_dtype: torch.dtype, - residual: Optional[torch.Tensor], - scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +def ops_dynamic_per_token_quant(weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if residual is not None: residual = residual.clone() out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, @@ -87,12 +89,13 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor, residual) return out, scales, residual + def ops_impl(weight: torch.Tensor, x: torch.Tensor, quant_dtype: torch.dtype, residual: Optional[torch.Tensor], scale_ub: Optional[torch.Tensor]) \ - -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, scale_ub) @@ -139,9 +142,9 @@ def test_rms_norm( scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') ref_out, ref_scales, ref_residual = \ - ref_impl(layer, x, quant_dtype, residual, scale_ub) + ref_impl(layer, x, quant_dtype, residual, scale_ub) ops_out, ops_scales, ops_residual = \ - ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) assert ref_out.dtype == quant_dtype assert ops_out.dtype == quant_dtype @@ -154,3 +157,11 @@ def test_rms_norm( ops_out.to(dtype=torch.float32)) if add_residual: assert torch.allclose(ref_residual, ops_residual) + + output = torch.empty_like(x, dtype=quant_dtype) + scales = torch.empty((x.numel() // x.shape[-1], 1), + device=x.device, + dtype=torch.float32) + + opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual)) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 1964b934e1986..d6002630ee02c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -269,20 +269,6 @@ def rms_norm_dynamic_per_token_quant( return output, scales -if hasattr(torch.ops._C, "rms_norm_dynamic_per_token_quant"): - - @register_fake("_C::rms_norm_dynamic_per_token_quant") - def _rms_norm_dynamic_per_token_quant_fake( - output: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - epsilon: float, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None) -> None: - return None - - # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,