Skip to content

Commit

Permalink
#13364: fixed fp32 accumulation error in sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
caixunshiren committed Oct 18, 2024
1 parent c8374e4 commit 22f0c90
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def fa_rand(*shape):
return normal_1 + normal_2 * bernoulli


def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype):
def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype, use_high_precision_compute=False):
torch.manual_seed(1234)

program_config = ttnn.SDPAProgramConfig(
Expand All @@ -43,12 +43,20 @@ def run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype
exp_approx_mode=False,
)

compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
if use_high_precision_compute:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=False,
fp32_dest_acc_en=True,
packer_l1_acc=False,
)
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)

Q = fa_rand(b, nh, s, d)
K = fa_rand(b, nkv, s, d)
Expand Down Expand Up @@ -118,7 +126,7 @@ def test_sdpa_tt_large_seq(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size,
if nh == 8 and q_chunk_size == 128 and k_chunk_size == 128:
pytest.skip("Can cause OOM if profiling is enabled.")
ttnn.device.DisablePersistentKernelCache()
run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype)
run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype, use_high_precision_compute=True)


@pytest.mark.skip(reason="Skip perf test in CI")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ operation::ProgramWithCallbacks sdpa_multi_core(
: tt::DataFormat::Float16_b;
tt::DataFormat out_df = tt::tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype());
tt::DataFormat scalar_df = tt::DataFormat::Float16_b;
tt::DataFormat im_df = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat im_df = tt::DataFormat::Float16_b; // need to disable fp32 cbs (Issue #13364) fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat stats_df = im_df;

uint32_t q_tile_size = tt::tt_metal::detail::TileSize(q_df);
Expand Down

0 comments on commit 22f0c90

Please sign in to comment.