Skip to content

Commit

Permalink
tenstorrent#13364: enabled fp32 accumulate in sdpa decode
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren authored and Christopher Taylor committed Nov 9, 2024
1 parent 42d7b28 commit c4569c1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def run_test_sdpa_decode_multi_pos(
min_pcc = 0.99
if q_dtype == ttnn.bfloat8_b:
min_pcc = 0.98
min_pcc = 0.93 if dtype == ttnn.bfloat4_b else min_pcc
min_pcc = 0.91 if dtype == ttnn.bfloat4_b else min_pcc

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
Expand Down Expand Up @@ -330,7 +330,7 @@ def run_test_sdpa_decode_single_iter(
min_pcc = 0.99
if q_dtype == ttnn.bfloat8_b:
min_pcc = 0.98
min_pcc = 0.93 if dtype == ttnn.bfloat4_b else min_pcc
min_pcc = 0.91 if dtype == ttnn.bfloat4_b else min_pcc

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
Expand Down Expand Up @@ -463,6 +463,7 @@ def run_test_sdpa_decode_single_iter(
[4, 32, 8, 8192, 128, (8, 8), True, True], # llama 3.1 8b
[32, 32, 8, 8192, 128, (8, 8), True, False], # llama 3.1 8b
# [4, 16, 4, 32768, 128, (8, 8), False, False], # llama 3.1 8b
# [1, 8, 1, 8192*16, 128, (1, 1), False, True], # llama2-70B long seqlen
),
)
def test_sdpa_decode(
Expand Down Expand Up @@ -595,7 +596,7 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc
min_pcc = 0.99
if q_dtype == ttnn.bfloat8_b:
min_pcc = 0.98
min_pcc = 0.93 if kv_dtype == ttnn.bfloat4_b else min_pcc
min_pcc = 0.91 if kv_dtype == ttnn.bfloat4_b else min_pcc

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,6 @@ void ScaledDotProductAttentionDecode::validate(const std::vector<Tensor>& input_
uint32_t num_heads_per_kv = q_shape_unpadded[2]/k_shape[1];
TT_FATAL(q_shape_unpadded[2]%k_shape[1] == 0, "GQA expects Q to have a multiple of K heads, but got {} and {}", q_shape_unpadded[2], k_shape[1]);
}

// Check compute kernel config
std::visit(
[&](auto&& compute_kernel_config) {
using T = std::decay_t<decltype(compute_kernel_config)>;
if constexpr (std::is_same_v<T, WormholeComputeKernelConfig>) {
TT_FATAL(
compute_kernel_config.fp32_dest_acc_en == false,
"FP32 dest acc disabled due to nd pcc and unpacker hang issue.");
}
},
this->compute_kernel_config);
}

std::vector<tt::tt_metal::LegacyShape> ScaledDotProductAttentionDecode::compute_output_shapes(
Expand Down

0 comments on commit c4569c1

Please sign in to comment.