Skip to content

Commit

Permalink
PR comments: opcheck
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Dec 11, 2024
1 parent 0a9a96f commit 720d537
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 32 deletions.
47 changes: 29 additions & 18 deletions tests/kernels/test_fused_quant_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm

DTYPES = [torch.bfloat16, torch.float]
Expand All @@ -27,11 +28,11 @@
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')


def ref_rms_norm(rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:

-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual)
Expand All @@ -40,13 +41,13 @@ def ref_rms_norm(rms_norm_layer: RMSNorm,

return out, residual

def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:

def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn

Expand All @@ -64,35 +65,37 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,

return torch_out, scales, residual


def ref_impl(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
residual, scale_ub)

def ops_dynamic_per_token_quant(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:

def ops_dynamic_per_token_quant(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
quant_dtype, scale_ub,
residual)
return out, scales, residual


def ops_impl(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
scale_ub)

Expand Down Expand Up @@ -139,9 +142,9 @@ def test_rms_norm(
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda')

ref_out, ref_scales, ref_residual = \
ref_impl(layer, x, quant_dtype, residual, scale_ub)
ref_impl(layer, x, quant_dtype, residual, scale_ub)
ops_out, ops_scales, ops_residual = \
ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)
ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)

assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
Expand All @@ -154,3 +157,11 @@ def test_rms_norm(
ops_out.to(dtype=torch.float32))
if add_residual:
assert torch.allclose(ref_residual, ops_residual)

output = torch.empty_like(x, dtype=quant_dtype)
scales = torch.empty((x.numel() // x.shape[-1], 1),
device=x.device,
dtype=torch.float32)

opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual))
14 changes: 0 additions & 14 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,6 @@ def rms_norm_dynamic_per_token_quant(
return output, scales


if hasattr(torch.ops._C, "rms_norm_dynamic_per_token_quant"):

@register_fake("_C::rms_norm_dynamic_per_token_quant")
def _rms_norm_dynamic_per_token_quant_fake(
output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scales: torch.Tensor,
epsilon: float,
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> None:
return None


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
Expand Down

0 comments on commit 720d537

Please sign in to comment.