Skip to content

Commit

Permalink
Implement fp8 quant for layernorm and rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
ruanjm committed Jan 15, 2025
1 parent 04dd314 commit d7d27a6
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 20 deletions.
2 changes: 1 addition & 1 deletion example/ck_tile/02_layernorm2d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)

# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)

target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS})

Expand Down
8 changes: 5 additions & 3 deletions example/ck_tile/02_layernorm2d/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_if_str(idx, total, lase_else = True):
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'}
'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}

def BOOL_MAP(b_) -> str:
if b_:
Expand Down Expand Up @@ -504,12 +505,13 @@ def get_blobs(self, args):
h_traits = layernorm_fwd_codegen.h_traits
h_instance = layernorm_fwd_codegen.h_instance

dynamic_quant_out_dtype = ['int8']
dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out
types_8bit = ('int8', 'fp8')
types_16bit = ('int16', 'fp16', 'bf16')
#fused_add_list = [0, 1, 2]
Expand Down
32 changes: 28 additions & 4 deletions example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ auto get_elimit<ck_tile::bf16_t>()
return ck_tile::make_tuple(rtol, atol);
}

template <>
auto get_elimit<ck_tile::int8_t>()
{
double rtol = 1e-2;
double atol = 1.0;
return ck_tile::make_tuple(rtol, atol);
}

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
Expand Down Expand Up @@ -97,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
int xbias = arg_parser.get_int("xbias");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
if(fused_quant == 1 && prec_o != "int8")
if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8")
{
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false;
}

Expand Down Expand Up @@ -291,7 +301,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
Expand Down Expand Up @@ -334,7 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data());
}

auto [rtol, atol] = get_elimit<InDataType>();
auto [rtol, atol] = get_elimit<OutDataType>();

if(x_stride == n)
{
Expand Down Expand Up @@ -452,6 +466,16 @@ int main(int argc, char* argv[])
{
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_mv)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}

return -3;
}
2 changes: 1 addition & 1 deletion example/ck_tile/02_layernorm2d/script/smoke_test.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"

for fquant in "" "-fquant=1 -prec_o=int8"; do
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/10_rmsnorm2d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)

# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)

target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})

Expand Down
8 changes: 5 additions & 3 deletions example/ck_tile/10_rmsnorm2d/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def get_if_str(idx, total, lase_else = True):
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::fp16_t',
'bf16' : 'ck_tile::bf16_t',
'int8' : 'ck_tile::int8_t'}
'int8' : 'ck_tile::int8_t',
'fp8' : 'ck_tile::fp8_t'}

def BOOL_MAP(b_) -> str:
if b_:
Expand Down Expand Up @@ -477,12 +478,13 @@ def get_blobs(self):
h_traits = rmsnorm_fwd_codegen.h_traits
h_instance = rmsnorm_fwd_codegen.h_instance

dynamic_quant_out_dtype = ['int8']
dynamic_quant_out_dtype = ['int8', 'fp8']
# some predefined support range
# (prec_i,prec_o) for simplicity this string will be used as key for dict
scale_list = [('fp32,fp32')]
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out
('fp16,int8'), ('bf16,int8'),
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
#fused_add_list = [0, 1, 2]
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
fused_add_list = [0, 1]
Expand Down
22 changes: 19 additions & 3 deletions example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sy = "fp32";
}

if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8")
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8")
{
std::cout << "if fused_quant is 1, only support \"-prec_o=int8\" case" << std::endl;
std::cout
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
<< std::endl;
return false;
}

Expand Down Expand Up @@ -248,7 +250,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
constexpr ComputeDataType kMaxY =
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
: 0.0;
ComputeDataType y_scale = absmax / kMaxY;
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
Expand Down Expand Up @@ -400,6 +406,16 @@ int main(int argc, char* argv[])
{
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
!save_rms)
{
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
}

return -3;
}
4 changes: 2 additions & 2 deletions example/ck_tile/10_rmsnorm2d/script/smoke_test.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"

for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
Expand All @@ -27,7 +27,7 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
Expand Down
8 changes: 6 additions & 2 deletions include/ck_tile/host/check_err.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -337,7 +337,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
Expand Down

0 comments on commit d7d27a6

Please sign in to comment.