diff --git a/tests/ttnn/unit_tests/operations/test_softmax.py b/tests/ttnn/unit_tests/operations/test_softmax.py index d3a165df8dc..bf8e285cd5d 100644 --- a/tests/ttnn/unit_tests/operations/test_softmax.py +++ b/tests/ttnn/unit_tests/operations/test_softmax.py @@ -9,10 +9,164 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import skip_for_wormhole_b0 +from models.utility_functions import skip_for_wormhole_b0, is_grayskull from models.utility_functions import torch_random +@pytest.mark.parametrize( + "input_vector", + [[100.0, 101.0], [100.0, 1000.0], [-100.0, -101.0], [-1000.0, -100.0], [-100, -108, -99, -100, -101, -98]], +) +def test_softmax_stable_neg_values(device, input_vector): + torch.manual_seed(0) + + torch_input_tensor = torch.tensor([[[input_vector]]], dtype=torch.bfloat16) + torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.softmax(input_tensor, dim=-1, numeric_stable=True) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) + + +def run_softmax_stable_with_program_cache(device, batch_size, h, w, skip_scale_mask, math_approx): + torch.manual_seed(0) + + scale = 1.0 + attention_mask = torch.rand(batch_size, 1, 1, w) + attention_mask = (attention_mask > 0.5).float() + attention_mask = attention_mask.masked_fill(attention_mask == 0, torch.tensor(float("-inf"), dtype=torch.bfloat16)) + attention_mask = attention_mask.masked_fill(attention_mask == 1, 0) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + torch_input_tensor = torch_random((batch_size, 1, h, w), -1000, 1000, dtype=torch.bfloat16) + if not skip_scale_mask: + torch_output_tensor = torch_input_tensor * scale + attention_mask + else: + torch_output_tensor = torch_input_tensor + torch_output_tensor = F.softmax(torch_output_tensor, dim=-1, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + + if is_grayskull(): + compute_kernel_config = ttnn.GrayskullComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=math_approx, + ) + else: + compute_kernel_config = ttnn.WormholeComputeKernelConfig( + math_fidelity=ttnn.MathFidelity.HiFi4, + math_approx_mode=math_approx, + fp32_dest_acc_en=False, + packer_l1_acc=False, + ) + + if not skip_scale_mask: + output_tensor = ttnn.scale_mask_softmax( + input_tensor, scale, attention_mask_t, compute_kernel_config=compute_kernel_config, numeric_stable=True + ) + else: + output_tensor = ttnn.softmax( + input_tensor, dim=-1, compute_kernel_config=compute_kernel_config, numeric_stable=True + ) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) + + +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("h", [32, 128]) +@pytest.mark.parametrize("w", [1024, 1500]) +@pytest.mark.parametrize("skip_scale_mask", [True, False]) +@pytest.mark.parametrize("math_approx", [True, False]) +def test_softmax_stable_with_program_cache(device, batch_size, h, w, skip_scale_mask, math_approx, use_program_cache): + for _ in range(2): + run_softmax_stable_with_program_cache(device, batch_size, h, w, skip_scale_mask, math_approx) + # dummy tensor to change tensor alloc + dummy_shape = [1, 1, 32, 32] + py_dummy_tensor = torch.randn(dummy_shape) + tt_dummy_tensor = ttnn.from_torch( + py_dummy_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + assert device.num_program_cache_entries() == 1 + + +def run_softmax_sharded_stable(device, batch_size, num_heads, h, w, skip_scale_mask): + torch.manual_seed(0) + + grid_size = (batch_size, num_heads) + + scale = 1.0 + attention_mask = torch.rand(batch_size, 1, 1, w) + attention_mask = (attention_mask > 0.5).float() + attention_mask = attention_mask.masked_fill(attention_mask == 0, torch.tensor(float("-inf"), dtype=torch.bfloat16)) + attention_mask = attention_mask.masked_fill(attention_mask == 1, 0) + attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + + torch_input_tensor = torch_random((batch_size, num_heads, h, w), -1000, 1000, dtype=torch.bfloat16) + if not skip_scale_mask: + torch_output_tensor = torch_input_tensor * scale + attention_mask + else: + torch_output_tensor = torch_input_tensor + torch_output_tensor = F.softmax(torch_output_tensor, dim=-1, dtype=torch.bfloat16) + + memory_config = ttnn.create_sharded_memory_config( + torch_input_tensor.shape, + core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]), + strategy=ttnn.ShardStrategy.HEIGHT, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + ) + program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig( + compute_with_storage_grid_size=grid_size, + subblock_w=6, + block_h=h // 32, + block_w=w // 32, + ) + + input_tensor = ttnn.from_torch( + torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=memory_config + ) + if not skip_scale_mask: + output_tensor = ttnn.scale_mask_softmax_in_place( + input_tensor, scale, attention_mask_t, program_config=program_config, numeric_stable=True + ) + else: + output_tensor = ttnn.scale_mask_softmax_in_place( + input_tensor, program_config=program_config, numeric_stable=True + ) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) + + +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("num_heads", [4]) +@pytest.mark.parametrize("h", [384]) +@pytest.mark.parametrize("w", [384]) +@pytest.mark.parametrize("skip_scale_mask", [True, False]) +def test_softmax_sharded_stable_with_program_cache( + device, batch_size, num_heads, h, w, skip_scale_mask, use_program_cache +): + for _ in range(2): + run_softmax_sharded_stable(device, batch_size, num_heads, h, w, skip_scale_mask) + # dummy tensor to change tensor alloc + dummy_shape = [1, 1, 32, 32] + py_dummy_tensor = torch.randn(dummy_shape) + tt_dummy_tensor = ttnn.from_torch( + py_dummy_tensor, + dtype=ttnn.DataType.BFLOAT16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) + assert device.num_program_cache_entries() == 1 + + @pytest.mark.parametrize("batch_size", [1, 16]) @pytest.mark.parametrize("h", [32, 64]) @pytest.mark.parametrize("w", [32, 64]) diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp index 84659afd5b9..4587d75fced 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp @@ -13,7 +13,8 @@ #include "compute_kernel_api/softmax.h" #include "compute_kernel_api/reduce.h" -// #include "debug/dprint.h" +#include "debug/dprint.h" +#include "debug/dprint_tensix.h" ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } ALWI void REL() { release_dst(tt::DstMode::Half); } @@ -25,6 +26,46 @@ ALWI void REL() { release_dst(tt::DstMode::Half); } // The buffer for the att mask is currently sized as (1t,Wt) so we only reuse it for one HtWt-sized batch of x // then read another Wt tiles of mask for the next batch +void calc_numeric_stable(uint32_t Wt, uint32_t ndst, uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t cb_max, uint32_t cb_out) { + // calculate max val per row + ACQ(); + unpack_reconfig_data_format(cb_in, cb_bcast_scaler); + cb_reserve_back(cb_max, 1); + cb_wait_front(cb_bcast_scaler, 1); + reduce_init_delta(); + for (uint32_t wt = 0; wt < Wt; wt++) { + cb_wait_front(cb_in, wt+1); + constexpr uint32_t bcast_scaler0 = 0; + reduce_tile(cb_in, cb_bcast_scaler, wt, bcast_scaler0, 0); + } + reduce_revert_delta(); + pack_tile(0, cb_max); + cb_push_back(cb_max, 1); + REL(); + + // calculate x-max(x) + exp_tile_init(); + unpack_reconfig_data_format_srcb(cb_max); + cb_wait_front(cb_max, 1); + sub_bcast_cols_init_short(); + for (uint32_t wt = 0; wt < Wt; wt += ndst) { + ACQ(); + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + sub_tiles_bcast_cols(cb_in, cb_max, wt+wt8, 0, wt8); + } + cb_reserve_back(cb_out, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + exp_tile(wt8); // exp on DST[0] + pack_tile(wt8, cb_out); // reuse the exps buffer again, this time in a circular manner + } + cb_push_back(cb_out, ndst); + REL(); + } + cb_pop_front(cb_in, Wt); + cb_pop_front(cb_max, 1); + cb_wait_front(cb_out, Wt); +} + namespace NAMESPACE { void MAIN { @@ -50,7 +91,12 @@ void MAIN { constexpr auto cb_recipsumexps = tt::CB::c_intermed1; constexpr auto cb_in0 = tt::CB::c_in0; constexpr auto cb_out0 = tt::CB::c_out0; - + #ifdef NUMERIC_STABLE + constexpr auto cb_max = tt::CB::c_intermed2; + constexpr auto cb_x = tt::CB::c_intermed4; + #else + constexpr auto cb_x = cb_exps; + #endif cb_wait_front(cb_bcast_scaler, 1); // comes from the reader @@ -81,38 +127,50 @@ void MAIN { } unpack_reconfig_data_format(cb_scale_mask, cb_fused_attn); - exp_tile_init(); + #ifndef NUMERIC_STABLE + exp_tile_init(); + #endif + #ifdef CAUSAL_MASK - add_tiles_init(); + add_tiles_init(); #else - add_bcast_rows_init_short(); + add_bcast_rows_init_short(); #endif for (uint32_t wt = 0; wt < Wt; wt+=ndst) { ACQ(); cb_wait_front(cb_scale_mask, ndst); #ifdef CAUSAL_MASK - cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles - for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { - add_tiles(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x))) - } + cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + add_tiles(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x))) + } #else - if (wait_mask) { - cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles, only at first ht - } + if (wait_mask) { + cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles, only at first ht + } - for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { - add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x))) - } + for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { + add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x))) + } #endif cb_pop_front(cb_scale_mask, ndst); - cb_reserve_back(cb_exps, ndst); + cb_reserve_back(cb_x, ndst); for (uint32_t wt8 = 0; wt8 < ndst; wt8++) { - exp_tile(wt8); // exp on DST[0] - pack_tile(wt8, cb_exps); // reuse the exps buffer again, this time in a circular manner + #ifndef NUMERIC_STABLE + exp_tile(wt8); // exp on DST[0] + #endif + pack_tile(wt8, cb_x); // reuse the exps buffer again, this time in a circular manner } - cb_push_back(cb_exps, ndst); + cb_push_back(cb_x, ndst); REL(); } + + // add numeric_stable + // fuse exp with sub tiles + #ifdef NUMERIC_STABLE + calc_numeric_stable(Wt, ndst, cb_x, cb_bcast_scaler, cb_max, cb_exps); + #endif + #ifdef CAUSAL_MASK cb_pop_front(cb_fused_attn, Wt); #else @@ -132,7 +190,9 @@ void MAIN { unpack_reconfig_data_format(cb_in0, cb_in0); pack_reconfig_data_format(cb_exps); copy_tile_to_dst_init_short(); // need to copy from CB to DST to be able to run sfpu math - exp_tile_init(); + #ifndef NUMERIC_STABLE + exp_tile_init(); + #endif if (mask_padded_data) { for (uint32_t wt = 0; wt < Wt; wt+=ndst) { ACQ(); @@ -149,32 +209,46 @@ void MAIN { } cb_pop_front(cb_in0, ndst); - cb_reserve_back(cb_exps, ndst); + cb_reserve_back(cb_x, ndst); for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { - exp_tile(wt8); // exp on DST[0] - pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt] + #ifndef NUMERIC_STABLE + exp_tile(wt8); // exp on DST[0] + #endif + pack_tile(wt8, cb_x); // DST[0]->cb_id[wt] } - cb_push_back(cb_exps, ndst); + cb_push_back(cb_x, ndst); REL(); } + // add numeric_stable + // fuse exp with sub tiles + #ifdef NUMERIC_STABLE + calc_numeric_stable(Wt, ndst, cb_x, cb_bcast_scaler, cb_max, cb_exps); + #endif + } else { - for (uint32_t wt = 0; wt < Wt; wt+=ndst) { - ACQ(); - cb_wait_front(cb_in0, ndst); - for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { - copy_tile(cb_in0, wt8, wt8); // copy from c_in[0] to DST[0] - } - cb_pop_front(cb_in0, ndst); + // add numeric_stable + // fuse exp with sub tiles + #ifdef NUMERIC_STABLE + calc_numeric_stable(Wt, ndst, cb_in0, cb_bcast_scaler, cb_max, cb_exps); + #else + for (uint32_t wt = 0; wt < Wt; wt+=ndst) { + ACQ(); + cb_wait_front(cb_in0, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { + copy_tile(cb_in0, wt8, wt8); // copy from c_in[0] to DST[0] + } + cb_pop_front(cb_in0, ndst); - cb_reserve_back(cb_exps, ndst); - for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { - exp_tile(wt8); // exp on DST[0] - pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt] + cb_reserve_back(cb_exps, ndst); + for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) { + exp_tile(wt8); // exp on DST[0] + pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt] + } + cb_push_back(cb_exps, ndst); + REL(); } - cb_push_back(cb_exps, ndst); - REL(); - } + #endif } unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp index 1d4fcf0c3bd..d955ee4f300 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp @@ -13,9 +13,55 @@ #include "compute_kernel_api/softmax.h" #include "compute_kernel_api/reduce.h" +#include "debug/dprint.h" + ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } ALWI void REL() { release_dst(tt::DstMode::Half); } +template +ALWI void calc_numeric_stable(uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t cb_max, uint32_t cb_out) { + // calculate max val per row + ACQ(); + unpack_reconfig_data_format(cb_in, cb_bcast_scaler); + cb_reserve_back(cb_max, 1); + reduce_init_delta(); + cb_wait_front(cb_bcast_scaler, 1); + for (uint32_t w = 0; w < block_w; w++) { + constexpr uint32_t bcast_scaler0 = 0; + reduce_tile(cb_in, cb_bcast_scaler, w, bcast_scaler0, 0); + } + reduce_revert_delta(); + pack_tile(0, cb_max); + cb_push_back(cb_max, 1); + REL(); + + // calculate x-max(x) + exp_tile_init(); + unpack_reconfig_data_format_srcb(cb_max); + cb_wait_front(cb_max, 1); + sub_bcast_cols_init_short(); + uint32_t index_subblock_w_offset = 0; + for (uint32_t j = 0; j < num_subblocks_w; j++) { + ACQ(); + cb_reserve_back(cb_out, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + uint32_t index = w + index_subblock_w_offset; + sub_tiles_bcast_cols(cb_in, cb_max, index, 0, w); + } + cb_reserve_back(cb_out, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + exp_tile(w); + pack_tile(w, cb_out); + } + cb_push_back(cb_out, subblock_w); + REL(); + index_subblock_w_offset += subblock_w; + } + cb_pop_front(cb_in, block_w); + cb_pop_front(cb_max, 1); + cb_wait_front(cb_out, block_w); +} + namespace NAMESPACE { void MAIN { @@ -34,6 +80,12 @@ void MAIN { constexpr auto cb_recipsumexps = tt::CB::c_intermed1; constexpr auto cb_scale_mask = tt::CB::c_intermed2; constexpr auto cb_out0 = tt::CB::c_out0; + #ifdef NUMERIC_STABLE + constexpr auto cb_max = tt::CB::c_intermed3; + constexpr auto cb_x = tt::CB::c_intermed4; + #else + constexpr auto cb_x = cb_exps; + #endif constexpr int dst0 = 0; int index_subblock_w_offset = 0; @@ -45,7 +97,6 @@ void MAIN { unpack_reconfig_data_format(cb_in0, cb_fused_scale); pack_reconfig_data_format(cb_scale_mask); cb_wait_front(cb_fused_scale, 1); - // UNPACK(( DPRINT << TSLICE(cb_fused_scale, 0, SliceRange::h0_w0_32()) << ENDL() )); mul_tiles_bcast_scalar_init_short(); index_subblock_w_offset = 0; for (uint32_t j = 0; j < num_subblocks_w; j++) { @@ -78,7 +129,9 @@ void MAIN { add_bcast_rows_init_short(); #endif - exp_tile_init(); + #ifndef NUMERIC_STABLE + exp_tile_init(); + #endif for (uint32_t j = 0; j < num_subblocks_w; j++) { ACQ(); #ifdef CAUSAL_MASK @@ -92,46 +145,60 @@ void MAIN { add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, index, index, w); } #endif - cb_reserve_back(cb_exps, subblock_w); + cb_reserve_back(cb_x, subblock_w); for (uint32_t w = 0; w < subblock_w; w++) { - exp_tile(w); - pack_tile(w, cb_exps); + #ifndef NUMERIC_STABLE + exp_tile(w); + #endif + pack_tile(w, cb_x); } - cb_push_back(cb_exps, subblock_w); + cb_push_back(cb_x, subblock_w); REL(); index_subblock_w_offset += subblock_w; } cb_pop_front(cb_scale_mask, block_w); + // add numeric_stable + // fuse exp with sub tiles + #ifdef NUMERIC_STABLE + cb_wait_front(cb_x, block_w); + calc_numeric_stable(cb_x, cb_bcast_scaler, cb_max, cb_exps); + #endif + #ifdef CAUSAL_MASK cb_pop_front(cb_fused_attn, block_w); #endif unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); #else - unpack_reconfig_data_format(cb_in0, cb_in0); - pack_reconfig_data_format(cb_exps); - // exp(x) - index_subblock_w_offset = 0; - copy_tile_to_dst_init_short(); - exp_tile_init(); - for (uint32_t j = 0; j < num_subblocks_w; j++) { - ACQ(); - for (uint32_t w = 0; w < subblock_w; w++) { - index = w + index_subblock_w_offset; - copy_tile(cb_in0, index, w); - } - cb_reserve_back(cb_exps, subblock_w); - for (uint32_t w = 0; w < subblock_w; w++) { - exp_tile(w); - pack_tile(w, cb_exps); + + #ifdef NUMERIC_STABLE + calc_numeric_stable(cb_in0, cb_bcast_scaler, cb_max, cb_exps); + #else + unpack_reconfig_data_format(cb_in0, cb_in0); + pack_reconfig_data_format(cb_exps); + // exp(x) + index_subblock_w_offset = 0; + copy_tile_to_dst_init_short(); + exp_tile_init(); + for (uint32_t j = 0; j < num_subblocks_w; j++) { + ACQ(); + for (uint32_t w = 0; w < subblock_w; w++) { + index = w + index_subblock_w_offset; + copy_tile(cb_in0, index, w); + } + cb_reserve_back(cb_exps, subblock_w); + for (uint32_t w = 0; w < subblock_w; w++) { + exp_tile(w); + pack_tile(w, cb_exps); + } + cb_push_back(cb_exps, subblock_w); + REL(); + index_subblock_w_offset += subblock_w; } - cb_push_back(cb_exps, subblock_w); - REL(); - index_subblock_w_offset += subblock_w; - } - cb_pop_front(cb_in0, block_w); - unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); + cb_pop_front(cb_in0, block_w); + unpack_reconfig_data_format(cb_exps, cb_bcast_scaler); + #endif #endif // FUSED_SCALE_MASK // sum(exp(x)) diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp index e102095185f..1958009fea3 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/multi_core/softmax_op_multi_core.cpp @@ -33,7 +33,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( const std::optional mask, std::optional scale, bool causal_mask, - DeviceComputeKernelConfig compute_kernel_config + DeviceComputeKernelConfig compute_kernel_config, + bool numeric_stable ) { const auto shape = input_tensor.get_legacy_shape(); @@ -116,13 +117,15 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( uint32_t block_size = fp32_dest_acc_en ? find_max_divisor(Wt, 4) : find_max_divisor(Wt, 8); // These tile capacity counts for CBs need to match the number of tiles expected by the kernel (softmax.cpp) - uint32_t in0_t = block_size*2; + uint32_t in0_t = numeric_stable ? tt::div_up(Wt, block_size)*block_size : block_size*2; uint32_t out0_t = block_size*2; uint32_t im1_t = 1; // 1/sum(exp(x)) uint32_t in2_t = 1; // scaler for reduce coming from reader uint32_t in3_t = 1; // 1/sqrt() scaler tile cb for fused scale/mask/softmax variant uint32_t in4_t = tt::div_up(Wt, block_size)*block_size; // attention mask (N,C,32,W) - Wt is reused for each Ht, NC is cycled uint32_t in5_t = 1; + // numeric_stable cb max + uint32_t im2_t = 1; // cb_exps - keeps exps in tt::CB in L1 to avoid recomputing uint32_t im0_t = block_size*tt::div_up(Wt, block_size); @@ -188,6 +191,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( // if wtpc < Ht then since we pass tpc to the kernel as Ht, the broadcasts should be correct // if wtpc >= Ht then tpc should be a multiple of Ht + if (numeric_stable) { + softmax_defines["NUMERIC_STABLE"] = "1"; + } softmax_defines["EXP_APPROX"] = math_approx_mode ? "1" : "0"; auto softmax_kernels_id = CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp", all_device_cores, @@ -224,6 +230,16 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( } CircularBufferConfig c_in5_config = CircularBufferConfig(in5_t * mask_tile_size, {{tt::CB::c_in5, mask_cb_data_format}}).set_page_size(tt::CB::c_in5, mask_tile_size); cb_in5_id = CreateCircularBuffer( program, all_device_cores, c_in5_config); + std::optional cb_intermed2_id; + std::optional cb_intermed4_id; + if (numeric_stable) { + // cb_max + auto c_intermed2_config = CircularBufferConfig(im2_t * in0_tile_size, {{tt::CB::c_intermed2, in0_cb_data_format}}).set_page_size(tt::CB::c_intermed2, in0_tile_size); + cb_intermed2_id = CreateCircularBuffer( program, all_device_cores, c_intermed2_config ); + // cb_x + auto c_x_config = CircularBufferConfig(in0_t * in0_tile_size, {{tt::CB::c_intermed4, in0_cb_data_format}}).set_page_size(tt::CB::c_intermed4, in0_tile_size); + cb_intermed4_id = CreateCircularBuffer( program, all_device_cores, c_x_config); + } uint32_t src_addr = src0_buffer->address(); uint32_t mask_addr = mask.has_value() ? mask.value().buffer()->address() : 0; @@ -289,7 +305,10 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( cb_intermed3_id, cb_in3_id, cb_in4_id, - causal_mask + causal_mask, + numeric_stable, + cb_intermed2_id, + cb_intermed4_id ] ( const void* operation, @@ -325,12 +344,13 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( uint32_t block_size = find_max_divisor(Wt, 8); // These tile capacity counts for CBs need to match the number of tiles expected by the kernel (softmax.cpp) - uint32_t in0_t = block_size*2; + uint32_t in0_t = numeric_stable ? tt::div_up(Wt, block_size)*block_size : block_size*2; uint32_t out0_t = block_size*2; uint32_t im1_t = 1; // 1/sum(exp(x)) uint32_t in2_t = 1; // scaler for reduce coming from reader uint32_t in3_t = 1; // 1/sqrt() scaler tile cb for fused scale/mask/softmax variant uint32_t in4_t = tt::div_up(Wt, block_size)*block_size; // attention mask (N,C,32,W) - Wt is reused for each Ht, NC is cycled + uint32_t im2_t = 1; // cb_exps - keeps exps in tt::CB in L1 to avoid recomputing uint32_t im0_t = block_size*tt::div_up(Wt, block_size); @@ -364,6 +384,10 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( UpdateCircularBufferTotalSize(program, cb_in3_id.value(), in3_t * scalar_tile_size); UpdateCircularBufferTotalSize(program, cb_in4_id.value(), in4_t * mask_tile_size); } + if (numeric_stable) { + UpdateCircularBufferTotalSize(program, cb_intermed2_id.value(), im2_t * in0_tile_size); + UpdateCircularBufferTotalSize(program, cb_intermed4_id.value(), in0_t * in0_tile_size); + } uint32_t curr_row = 0; union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax @@ -452,7 +476,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( uint32_t subblock_wt, uint32_t block_ht, uint32_t block_wt, - DeviceComputeKernelConfig compute_kernel_config + DeviceComputeKernelConfig compute_kernel_config, + bool numeric_stable ) { //////////////////////////////////////////////////////////////////////////// // Device Setup @@ -568,6 +593,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( uint32_t im2_CB_size = block_wt * im_tile_size; // output buffer size uint32_t out_CB_size = block_wt * block_ht * out0_tile_size; + // numeric_stable cb max + uint32_t max_CB_size = 1 * in0_tile_size; + uint32_t x_CB_size = block_wt * in0_tile_size; //////////////////////////////////////////////////////////////////////////// // Application Setup @@ -649,6 +677,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( subblock_wt, num_subblocks_w, }; + if (numeric_stable) { + softmax_defines["NUMERIC_STABLE"] = "1"; + } softmax_defines["EXP_APPROX"] = math_approx_mode ? "1" : "0"; auto softmax_kernels_id = CreateKernel( program, "ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp", all_device_cores, @@ -704,6 +735,16 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( auto c_intermed1_config = CircularBufferConfig(im1_CB_size, {{tt::CB::c_intermed1, im_cb_data_format}}) .set_page_size(tt::CB::c_intermed1, im_tile_size); auto cb_intermed1_id = CreateCircularBuffer( program, all_device_cores, c_intermed1_config ); + if (numeric_stable) { + // cb_max + auto c_intermed3_config = CircularBufferConfig(max_CB_size, {{tt::CB::c_intermed3, in0_cb_data_format}}) + .set_page_size(tt::CB::c_intermed3, in0_tile_size); + auto cb_intermed3_id = CreateCircularBuffer( program, all_device_cores, c_intermed3_config ); + // cb_x + auto c_intermed4_config = CircularBufferConfig(x_CB_size, {{tt::CB::c_intermed4, in0_cb_data_format}}) + .set_page_size(tt::CB::c_intermed4, in0_tile_size); + auto cb_intermed4_id = CreateCircularBuffer( program, all_device_cores, c_intermed4_config ); + } // Runtime Args uint32_t mask_addr = mask.has_value() ? mask->buffer()->address() : 0; diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp index feb75819095..ff404ec8859 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.cpp @@ -136,10 +136,11 @@ operation::ProgramWithCallbacks Softmax::create_program( program_config.subblock_w, program_config.block_h, program_config.block_w, - this->compute_kernel_config); + this->compute_kernel_config, + this->numeric_stable); } else { - return scale_mask_softmax_multi_core(input_tensor, output_tensor, mask, this->scale, causal_mask, this->compute_kernel_config); + return scale_mask_softmax_multi_core(input_tensor, output_tensor, mask, this->scale, causal_mask, this->compute_kernel_config, this->numeric_stable); } }, this->program_config @@ -159,42 +160,42 @@ const operation::Hash Softmax::compute_program_hash( this->output_mem_config); } -Tensor softmax_in_place(Tensor& input_tensor, const SoftmaxProgramConfig& program_config, std::optional compute_kernel_config) { - return scale_mask_softmax_in_place(input_tensor, std::nullopt, std::nullopt, program_config, false, compute_kernel_config); +Tensor softmax_in_place(Tensor& input_tensor, const SoftmaxProgramConfig& program_config, std::optional compute_kernel_config, const bool numeric_stable) { + return scale_mask_softmax_in_place(input_tensor, std::nullopt, std::nullopt, program_config, false, compute_kernel_config, numeric_stable); } -Tensor scale_mask_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config, const bool is_causal_mask, std::optional compute_kernel_config) { +Tensor scale_mask_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config, const bool is_causal_mask, std::optional compute_kernel_config, const bool numeric_stable) { std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [scale, mask, program_config, is_causal_mask, compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + [scale, mask, program_config, is_causal_mask, compute_kernel_config, numeric_stable] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { auto& input_tensor = input_tensors.at(0); auto& mask = optional_input_tensors.at(0); auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); - return operation::run(Softmax{.scale=scale, .inplace=true, .output_mem_config=input_tensor.memory_config(), .program_config=program_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val}, {input_tensor}, {mask}); + return operation::run(Softmax{.scale=scale, .inplace=true, .output_mem_config=input_tensor.memory_config(), .program_config=program_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val, .numeric_stable=numeric_stable}, {input_tensor}, {mask}); }, {input_tensor}, dummy_output_tensors, {mask}); return input_tensor; } -Tensor scale_causal_mask_hw_dims_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config, std::optional compute_kernel_config) { +Tensor scale_causal_mask_hw_dims_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config, std::optional compute_kernel_config, const bool numeric_stable) { std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [scale, mask, program_config, compute_kernel_config](const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + [scale, mask, program_config, compute_kernel_config, numeric_stable](const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { auto& input_tensor = input_tensors.at(0); auto& mask = optional_input_tensors.at(0); auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); - return operation::run(Softmax{.scale=scale, .inplace=true, .output_mem_config=input_tensor.memory_config(), .program_config=program_config, .is_causal_mask=true, .compute_kernel_config=kernel_config_val, .is_scale_causal_mask_hw_dims_softmax=true}, {input_tensor}, {mask}); + return operation::run(Softmax{.scale=scale, .inplace=true, .output_mem_config=input_tensor.memory_config(), .program_config=program_config, .is_causal_mask=true, .compute_kernel_config=kernel_config_val, .is_scale_causal_mask_hw_dims_softmax=true, .numeric_stable=numeric_stable}, {input_tensor}, {mask}); }, {input_tensor}, dummy_output_tensors, {mask}); return input_tensor; } -Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config, std::optional compute_kernel_config) { - return scale_mask_softmax(input_tensor, std::nullopt, std::nullopt, output_mem_config, false, compute_kernel_config); +Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config, std::optional compute_kernel_config, const bool numeric_stable) { + return scale_mask_softmax(input_tensor, std::nullopt, std::nullopt, output_mem_config, false, compute_kernel_config, numeric_stable); } -Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config, const bool is_causal_mask, std::optional compute_kernel_config) { +Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config, const bool is_causal_mask, std::optional compute_kernel_config, const bool numeric_stable) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_with_autoformat( - [scale, mask, output_mem_config, is_causal_mask, compute_kernel_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { + [scale, mask, output_mem_config, is_causal_mask, compute_kernel_config, numeric_stable] (const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { auto& input_tensor = input_tensors.at(0); auto& mask = optional_input_tensors.at(0); tt::tt_metal::LegacyShape input_pad_shape = ttnn::operations::experimental::auto_format::AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); @@ -211,7 +212,7 @@ Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale mask_format_params = {.pad_shape=mask_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; } auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); - return operation::run_with_autoformat(Softmax{.scale=scale, .inplace=false, .output_mem_config=output_mem_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val}, {input_tensor}, {input_format_params}, {Layout::TILE}, {mask}, {mask_format_params}); + return operation::run_with_autoformat(Softmax{.scale=scale, .inplace=false, .output_mem_config=output_mem_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val, .numeric_stable=numeric_stable}, {input_tensor}, {input_format_params}, {Layout::TILE}, {mask}, {mask_format_params}); }, {input_tensor}, output_tensors, {mask}); return output_tensors.at(0); } diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp index caccc7837b7..b546ac0016b 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/softmax_op.hpp @@ -25,6 +25,7 @@ struct Softmax { const bool is_causal_mask; const DeviceComputeKernelConfig compute_kernel_config; const bool is_scale_causal_mask_hw_dims_softmax; + const bool numeric_stable; void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; @@ -46,7 +47,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_multi_core( const std::optional mask, std::optional scale, bool causal_mask, - DeviceComputeKernelConfig compute_kernel_config + DeviceComputeKernelConfig compute_kernel_config, + bool numeric_stable ); // hw_dims_only_causal_mask - represents if the causal mask is of shape [1, 1, h, w] @@ -62,28 +64,29 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core( uint32_t subblock_wt, uint32_t block_ht, uint32_t block_wt, - DeviceComputeKernelConfig compute_kernel_config + DeviceComputeKernelConfig compute_kernel_config, + bool numeric_stable ); // softmax -Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt); +Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt, const bool numeric_stable = false); // const ref prevents in-place -Tensor softmax_in_place(Tensor& input_tensor, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, std::optional compute_kernel_config = std::nullopt); +Tensor softmax_in_place(Tensor& input_tensor, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, std::optional compute_kernel_config = std::nullopt, const bool numeric_stable = false); // computes // tmp1 = bcast_hw_mul(scale, x) ; shape of scale is [1,1,32,32] // tmp2 = bcast_add_w->h(tmp1, mask) ; shape of attn mask is [1,N,32,W] // y = softmax(tmp2) ; r=result // If scale == 0.0f then just y = softmax(x) is computed -Tensor scale_mask_softmax_in_place(Tensor& input_tensor, std::optional scale = std::nullopt, std::optional mask = std::nullopt, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, const bool is_causal_mask = false, std::optional compute_kernel_config = std::nullopt); +Tensor scale_mask_softmax_in_place(Tensor& input_tensor, std::optional scale = std::nullopt, std::optional mask = std::nullopt, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, const bool is_causal_mask = false, std::optional compute_kernel_config = std::nullopt, const bool numeric_stable = false); // Experimental feature. Does the same same as above, with the following assumptions: // 1. Input must be sharded // 2. Scale must exist // 3. Attention mask must be interleaved and be of this shape [1, 1, H, W] // 4. Causal mask argument is set to true. -Tensor scale_causal_mask_hw_dims_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config = SoftmaxShardedMultiCoreProgramConfig{}, std::optional compute_kernel_config = std::nullopt); +Tensor scale_causal_mask_hw_dims_softmax_in_place(Tensor& input_tensor, std::optional scale, std::optional mask, const SoftmaxProgramConfig& program_config = SoftmaxShardedMultiCoreProgramConfig{}, std::optional compute_kernel_config = std::nullopt, const bool numeric_stable = false); -Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const bool is_causal_mask = false, std::optional compute_kernel_config = std::nullopt); +Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const bool is_causal_mask = false, std::optional compute_kernel_config = std::nullopt, const bool numeric_stable = false); } // namespace ttnn::operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp index 91a69740d72..e4774606a41 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp @@ -14,7 +14,8 @@ ttnn::Tensor ExecuteSoftmax::invoke( const ttnn::Tensor& input_tensor, const int dim_arg, const std::optional& memory_config, - const std::optional compute_kernel_config) { + const std::optional compute_kernel_config, + const bool numeric_stable) { auto input_shape = input_tensor.get_shape(); auto rank = input_shape.size(); @@ -26,7 +27,7 @@ ttnn::Tensor ExecuteSoftmax::invoke( auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); if (dim == rank - 1) { auto output_tensor = ttnn::operations::normalization::softmax( - input_tensor_4D, memory_config.value_or(input_tensor.memory_config()), compute_kernel_config); + input_tensor_4D, memory_config.value_or(input_tensor.memory_config()), compute_kernel_config, numeric_stable); return ttnn::reshape(output_tensor, input_shape); } else { auto dim_4D = dim + 4 - rank; @@ -41,26 +42,28 @@ ttnn::Tensor ExecuteScaleMaskSoftmax::invoke( const std::optional mask, const std::optional& memory_config, const bool is_causal_mask, - const std::optional compute_kernel_config) { + const std::optional compute_kernel_config, + const bool numeric_stable) { auto input_shape = input_tensor.get_shape(); auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); auto output_tensor = - ttnn::operations::normalization::scale_mask_softmax(input_tensor_4D, scale, mask, memory_config.value_or(input_tensor.memory_config()), is_causal_mask, compute_kernel_config); + ttnn::operations::normalization::scale_mask_softmax(input_tensor_4D, scale, mask, memory_config.value_or(input_tensor.memory_config()), is_causal_mask, compute_kernel_config, numeric_stable); return ttnn::reshape(output_tensor, input_shape); } ttnn::Tensor ExecuteSoftmaxInPlace::invoke( const ttnn::Tensor& input_tensor, const SoftmaxProgramConfig& program_config, - const std::optional compute_kernel_config) { + const std::optional compute_kernel_config, + const bool numeric_stable) { auto input_shape = input_tensor.get_shape(); auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); auto output_tensor = - ttnn::operations::normalization::softmax_in_place(input_tensor_4D, program_config, compute_kernel_config); + ttnn::operations::normalization::softmax_in_place(input_tensor_4D, program_config, compute_kernel_config, numeric_stable); return ttnn::reshape(output_tensor, input_shape); } @@ -70,13 +73,14 @@ ttnn::Tensor ExecuteScaleMaskSoftmaxInPlace::invoke( const std::optional mask, const SoftmaxProgramConfig& program_config, const bool is_causal_mask, - const std::optional compute_kernel_config) { + const std::optional compute_kernel_config, + const bool numeric_stable) { auto input_shape = input_tensor.get_shape(); auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); auto output_tensor = - ttnn::operations::normalization::scale_mask_softmax_in_place(input_tensor_4D, scale, mask, program_config, is_causal_mask, compute_kernel_config); + ttnn::operations::normalization::scale_mask_softmax_in_place(input_tensor_4D, scale, mask, program_config, is_causal_mask, compute_kernel_config, numeric_stable); return ttnn::reshape(output_tensor, input_shape); } @@ -85,13 +89,14 @@ ttnn::Tensor ExecuteScaleCausalMaskHWSoftmaxInPlace::invoke( const std::optional scale, const std::optional mask, const SoftmaxProgramConfig& program_config, - const std::optional compute_kernel_config) { + const std::optional compute_kernel_config, + const bool numeric_stable) { auto input_shape = input_tensor.get_shape(); auto input_tensor_4D = ttnn::unsqueeze_to_4D(input_tensor); auto output_tensor = - ttnn::operations::normalization::scale_causal_mask_hw_dims_softmax_in_place(input_tensor_4D, scale, mask, program_config, compute_kernel_config); + ttnn::operations::normalization::scale_causal_mask_hw_dims_softmax_in_place(input_tensor_4D, scale, mask, program_config, compute_kernel_config, numeric_stable); return ttnn::reshape(output_tensor, input_shape); } diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp index bbed0083953..ca3fd6b0a22 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.hpp @@ -18,7 +18,8 @@ struct ExecuteSoftmax { const ttnn::Tensor& input_tensor, const int dim_arg, const std::optional& memory_config = std::nullopt, - const std::optional compute_kernel_config = std::nullopt); + const std::optional compute_kernel_config = std::nullopt, + const bool numeric_stable = false); }; struct ExecuteScaleMaskSoftmax { @@ -29,7 +30,8 @@ struct ExecuteScaleMaskSoftmax { const std::optional mask = std::nullopt, const std::optional& memory_config = std::nullopt, const bool is_causal_mask = false, - const std::optional compute_kernel_config = std::nullopt); + const std::optional compute_kernel_config = std::nullopt, + const bool numeric_stable = false); }; struct ExecuteSoftmaxInPlace { @@ -38,7 +40,8 @@ struct ExecuteSoftmaxInPlace { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, - const std::optional compute_kernel_config = std::nullopt); + const std::optional compute_kernel_config = std::nullopt, + const bool numeric_stable = false); }; struct ExecuteScaleMaskSoftmaxInPlace { @@ -50,7 +53,8 @@ struct ExecuteScaleMaskSoftmaxInPlace { const std::optional mask = std::nullopt, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, const bool is_causal_mask = false, - const std::optional compute_kernel_config = std::nullopt); + const std::optional compute_kernel_config = std::nullopt, + const bool numeric_stable = false); }; struct ExecuteScaleCausalMaskHWSoftmaxInPlace { @@ -61,7 +65,8 @@ struct ExecuteScaleCausalMaskHWSoftmaxInPlace { const std::optional scale = std::nullopt, const std::optional mask = std::nullopt, const SoftmaxProgramConfig& program_config = SoftmaxDefaultProgramConfig{}, - const std::optional compute_kernel_config = std::nullopt); + const std::optional compute_kernel_config = std::nullopt, + const bool numeric_stable = false); }; } // namespace operations::normalization diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.cpp index b1375fe46ba..f2041096091 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax_pybind.cpp @@ -66,14 +66,16 @@ void bind_normalization_softmax_operation(py::module& module) { const ttnn::Tensor& input_tensor, const int8_t dim, const std::optional& memory_config, - const std::optional& compute_kernel_config) { - return self(input_tensor, dim, memory_config, compute_kernel_config); + const std::optional& compute_kernel_config, + const bool numeric_stable) { + return self(input_tensor, dim, memory_config, compute_kernel_config, numeric_stable); }, py::arg("input_tensor").noconvert(), py::arg("dim") = -1, py::kw_only(), py::arg("memory_config") = std::nullopt, - py::arg("compute_kernel_config").noconvert() = std::nullopt}); + py::arg("compute_kernel_config").noconvert() = std::nullopt, + py::arg("numeric_stable").noconvert() = false}); } void bind_normalization_scale_mask_softmax_operation(py::module& module) { @@ -110,8 +112,9 @@ void bind_normalization_scale_mask_softmax_operation(py::module& module) { const std::optional mask, const std::optional& memory_config, const bool is_causal_mask, - const std::optional& compute_kernel_config) { - return self(input_tensor, scale, mask, memory_config, is_causal_mask, compute_kernel_config); + const std::optional& compute_kernel_config, + const bool numeric_stable) { + return self(input_tensor, scale, mask, memory_config, is_causal_mask, compute_kernel_config, numeric_stable); }, py::arg("input_tensor").noconvert(), py::arg("scale").noconvert() = std::nullopt, @@ -119,7 +122,8 @@ void bind_normalization_scale_mask_softmax_operation(py::module& module) { py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("is_causal_mask") = false, - py::arg("compute_kernel_config") = std::nullopt}); + py::arg("compute_kernel_config") = std::nullopt, + py::arg("numeric_stable") = false}); } void bind_normalization_softmax_in_place_operation(py::module& module) { @@ -150,13 +154,15 @@ void bind_normalization_softmax_in_place_operation(py::module& module) { [] (const OperationType& self, const ttnn::Tensor& input_tensor, const SoftmaxProgramConfig& program_config, - const std::optional& compute_kernel_config) { - return self(input_tensor, program_config, compute_kernel_config); + const std::optional& compute_kernel_config, + const bool numeric_stable) { + return self(input_tensor, program_config, compute_kernel_config, numeric_stable); }, py::arg("input_tensor").noconvert(), py::kw_only(), py::arg("program_config") = SoftmaxDefaultProgramConfig{}, - py::arg("compute_kernel_config") = std::nullopt}); + py::arg("compute_kernel_config") = std::nullopt, + py::arg("numeric_stable") = false}); } void bind_normalization_scale_mask_softmax_in_place_operation(py::module& module) { @@ -190,8 +196,9 @@ void bind_normalization_scale_mask_softmax_in_place_operation(py::module& module const std::optional mask, const SoftmaxProgramConfig& program_config, const bool is_causal_mask, - const std::optional& compute_kernel_config) { - return self(input_tensor, scale, mask, program_config, is_causal_mask, compute_kernel_config); + const std::optional& compute_kernel_config, + const bool numeric_stable) { + return self(input_tensor, scale, mask, program_config, is_causal_mask, compute_kernel_config, numeric_stable); }, py::arg("input_tensor").noconvert(), py::arg("scale").noconvert() = std::nullopt, @@ -199,7 +206,8 @@ void bind_normalization_scale_mask_softmax_in_place_operation(py::module& module py::kw_only(), py::arg("program_config") = SoftmaxDefaultProgramConfig{}, py::arg("is_causal_mask") = false, - py::arg("compute_kernel_config") = std::nullopt}); + py::arg("compute_kernel_config") = std::nullopt, + py::arg("numeric_stable") = false}); } void bind_normalization_scale_causal_mask_hw_dims_softmax_in_place_operation(py::module& module) { @@ -232,15 +240,17 @@ void bind_normalization_scale_causal_mask_hw_dims_softmax_in_place_operation(py: const std::optional scale, const std::optional mask, const SoftmaxProgramConfig& program_config, - const std::optional& compute_kernel_config) { - return self(input_tensor, scale, mask, program_config, compute_kernel_config); + const std::optional& compute_kernel_config, + const bool numeric_stable) { + return self(input_tensor, scale, mask, program_config, compute_kernel_config, numeric_stable); }, py::arg("input_tensor").noconvert(), py::arg("scale").noconvert() = std::nullopt, py::arg("mask").noconvert() = std::nullopt, py::kw_only(), py::arg("program_config") = SoftmaxDefaultProgramConfig{}, - py::arg("compute_kernel_config") = std::nullopt}); + py::arg("compute_kernel_config") = std::nullopt, + py::arg("numeric_stable") = false}); } void bind_normalization_softmax(py::module& module) {