Skip to content

Commit

Permalink
fixup! sdpa: add tests to directly call sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Jan 29, 2025
1 parent 51e3dc7 commit 46e8659
Showing 1 changed file with 27 additions and 44 deletions.
71 changes: 27 additions & 44 deletions tests/gtests/internals/test_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,12 @@ sdpa_tensors get_descriptors(dnnl::engine &eng, sdpa_dims_t p) {
fill_mask(mask_data, mask_md);
}

/// This section allows setting the values of the tensors using environment variables.
/// Syntax:
/// <Tensor Name>[<S for scales, Z for zero points>]<R for row C for column>
///
/// KR=3 KC=1 Set the value in the Key tensor at (3, 1) to 1 and all other values should be zero
/// VSR=1 VSC=2 Set the scale for the Value tensor at (1, 2) to 1 and all other values to zero
#if 1
auto &Q = query_data;
auto &K = key_quantized_data;
Expand Down Expand Up @@ -1231,7 +1237,7 @@ sdpa_tensors get_descriptors(dnnl::engine &eng, sdpa_dims_t p) {
}
class sdpa_test : public ::testing::TestWithParam<sdpa_dims_t> {
public:
virtual void SetUp() override {
void SetUp() override {
p = GetParam();
SKIP_IF_CUDA(true, "SDPA primitive tests do not support CUDA");
SKIP_IF_HIP(true, "SDPA primitive tests do not support HIP");
Expand Down Expand Up @@ -1422,10 +1428,6 @@ INSTANTIATE_TEST_SUITE_P(phi3_mini_4k_instruct,
sdpa_dims_t{ 1, 32, 32, 2049, 1, 96, 96, 96, mdt::f16, mdt::f16, mdt::s8, mdt::f16, mdt::s8, mdt::s8, mdt::f16, mdt::s8, mdt::f16, quantize_type::per_token_with_groups, with_key_transposed, mask_type::causal }
), &PrintToString);


//sdpa_dims_t{ 1, 259, 32, 1, 96, 1, 96, 96, mdt::f16, mdt::s8, mdt::f16, mdt::s8, mdt::s8, mdt::f16, mdt::s8, quantize_type::per_token_with_groups, no_key_transposed, with_causal_mask },
//sdpa_dims_t{ 1, 384, 32, 1, 96, 384, 96, 96, mdt::f16, mdt::s8, mdt::f16, mdt::s8, mdt::s8, mdt::f16, mdt::s8, quantize_type::per_token_with_groups, no_key_transposed, with_causal_mask }

// clang-format on

memory as(dnnl::stream &strm, memory &mem, memory::data_type dt) {
Expand Down Expand Up @@ -1650,8 +1652,8 @@ void prim_sdpa_quant(const sdpa_dims_t &p, const sdpa_tensors &t,

template <typename T>
void check_memory(memory &gold, memory &test) {
T *mapped_ptr_f16 = (T *)gold.map_data();
T *mapped_ptr_s8 = (T *)test.map_data();
T *mapped_ptr_gold = (T *)gold.map_data();
T *mapped_ptr_test = (T *)test.map_data();

auto dims = gold.get_desc().get_dims();
auto strides = gold.get_desc().get_strides();
Expand All @@ -1673,18 +1675,18 @@ void check_memory(memory &gold, memory &test) {
for (int i = 0; i < dims[3]; i++) {
auto offset = l * strides[0] + k * strides[1] + j * strides[2]
+ i * strides[3];
auto o_f16 = (float)mapped_ptr_f16[offset];
auto o_s8 = (float)mapped_ptr_s8[offset];
auto o_gold = (float)mapped_ptr_gold[offset];
auto o_test = (float)mapped_ptr_test[offset];
total++;

float abs_diff = abs(o_f16 - o_s8);
bool is_nan = isnan(o_f16) || isnan(o_s8);
float abs_diff = abs(o_gold - o_test);
bool is_nan = isnan(o_gold) || isnan(o_test);

bool is_mismatch = is_nan
|| (abs(o_f16) > 2.f ? abs_diff > abs(o_f16 * fthreshold)
: abs_diff > fthreshold);
|| (abs(o_gold) > 2.f ? abs_diff > abs(o_gold * fthreshold)
: abs_diff > fthreshold);
if (max_diff < abs_diff) {
printf("new max: f16: %f vs s8: %f diff: %f\n", o_f16, o_s8,
printf("new max: gold: %f vs test: %f diff: %f\n", o_gold, o_test,
abs_diff);
max_diff = abs_diff;
}
Expand All @@ -1696,15 +1698,16 @@ void check_memory(memory &gold, memory &test) {
}
if ((is_mismatch && mismatches++ < 32) || is_nan) {
fprintf(stderr,
"Mismatch at (%d,%d,%d,%d): computed %f "
"vs. %f (diff: %f thresh: %f)\n",
l, k, j, i, o_s8, o_f16, abs_diff,
(abs(o_f16) > 2.f ? abs(o_f16 * fthreshold) : fthreshold));
"Mismatch at (%d,%d,%d,%d): test %f "
"vs. gold %f (diff: %f thresh: %f)\n",
l, k, j, i, o_test, o_gold, abs_diff,
(abs(o_gold) > 2.f ? abs(o_gold * fthreshold)
: fthreshold));
}
}

gold.unmap_data(mapped_ptr_f16);
test.unmap_data(mapped_ptr_s8);
gold.unmap_data(mapped_ptr_gold);
test.unmap_data(mapped_ptr_test);

int threshold = total * 0.0004;
std::cout << "max diff: " << max_diff << std::endl;
Expand All @@ -1728,13 +1731,6 @@ GPU_TEST_P(sdpa_test, compare) {
case mask_type::twoD: mask_ptr = &mask; break;
}

//auto sdpaf16_pd = sdpa::primitive_desc(eng, t.m_query.get_desc(),
//t.m_key.get_desc(), t.m_value.get_desc(), mask_ptr, scale_dt,
//t.m_output.get_desc(), invert_scale, p.head_num, p.with_causal_mask, t.sdpa_attr);
//auto sdpaf16_p = sdpa(sdpaf16_pd);

//print_mem(t.m_query_test, "query_quantized");
//print_mem(t.m_key_quantized, "key_quantized");
auto sdpas8_pd = sdpa::primitive_desc(eng, t.m_query_test.get_desc(),
p.with_key_transposed ? t.m_key_t_quantized.get_desc()
: t.m_key_quantized.get_desc(),
Expand All @@ -1755,13 +1751,13 @@ GPU_TEST_P(sdpa_test, compare) {
}
if (scale_dt != mdt::undef) { s8_args[DNNL_ARG_SCALE] = t.m_scale; }

if (((p.kdt != mdt::f16) || (p.kdt != mdt::bf16))
&& p.qtype != quantize_type::no_quantization) {
bool k_is_16_bit_float = ((p.kdt != mdt::f16) || (p.kdt != mdt::bf16));
bool v_is_16_bit_float = ((p.vdt != mdt::f16) || (p.vdt != mdt::bf16));
if (k_is_16_bit_float && p.qtype != quantize_type::no_quantization) {
s8_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = t.m_key_scales;
s8_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS] = t.m_key_zp;
}
if (((p.vdt != mdt::f16) || (p.vdt != mdt::bf16))
&& p.qtype != quantize_type::no_quantization) {
if (v_is_16_bit_float && p.qtype != quantize_type::no_quantization) {
s8_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES] = t.m_value_scales;
s8_args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES] = t.m_value_zp;
}
Expand All @@ -1773,26 +1769,15 @@ GPU_TEST_P(sdpa_test, compare) {
if (scale_dt != mdt::undef) { f16_args[DNNL_ARG_SCALE] = t.m_scale; }
if (mask_ptr) { f16_args[DNNL_ARG_ATTN_MASK] = t.m_mask; }

//print_mem(t.m_key_scales, "key_scales");
//print_mem(t.m_key_zp, "key_zp");
auto loop_s8 = [&] { sdpas8_p.execute(strm, s8_args); };
loop_s8();
strm.wait();
//auto loop_f16 = [&] { sdpaf16_p.execute(strm, f16_args); };
//loop_f16();
//strm.wait();
prim_sdpa_quant(p, t, eng, strm, t.m_query,
p.with_key_transposed ? t.m_key_t_quantized : t.m_key_quantized,
t.m_key_scales, t.m_key_zp, scale_dt, t.m_scale, t.m_mask,
t.m_value_quantized, t.m_value_scales, t.m_value_zp, t.m_output,
invert_scale);
strm.wait();
//print_mem(t.m_key, "key");
//print_mem(t.m_query, "query");
//print_mem(t.m_value, "value");
//print_mem(t.m_mask, "mask");
//print_mem(t.m_output, "output");
//print_mem(t.m_output_quantized, "output_quantized");

strm.wait();

Expand Down Expand Up @@ -1871,8 +1856,6 @@ TEST_P(sdpa_test, perf) {
if (scale_dt != mdt::undef) { f16_args[DNNL_ARG_SCALE] = t.m_scale; }
if (mask_ptr) { f16_args[DNNL_ARG_ATTN_MASK] = t.m_mask; }

//print_mem(t.m_key_scales, "key_scales");
//print_mem(t.m_key_zp, "key_zp");
auto loop_s8 = [&] { sdpas8_p.execute(strm, s8_args); };

loop_s8();
Expand Down

0 comments on commit 46e8659

Please sign in to comment.