Skip to content

Commit

Permalink
Add numeric stable option for softmax (#13068)
Browse files Browse the repository at this point in the history
#0: add numeric stable option for softmax
  • Loading branch information
yugaoTT authored Sep 26, 2024
1 parent 6decc7f commit ce89b5c
Show file tree
Hide file tree
Showing 9 changed files with 484 additions and 124 deletions.
156 changes: 155 additions & 1 deletion tests/ttnn/unit_tests/operations/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,164 @@

import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import skip_for_wormhole_b0
from models.utility_functions import skip_for_wormhole_b0, is_grayskull
from models.utility_functions import torch_random


@pytest.mark.parametrize(
"input_vector",
[[100.0, 101.0], [100.0, 1000.0], [-100.0, -101.0], [-1000.0, -100.0], [-100, -108, -99, -100, -101, -98]],
)
def test_softmax_stable_neg_values(device, input_vector):
torch.manual_seed(0)

torch_input_tensor = torch.tensor([[[input_vector]]], dtype=torch.bfloat16)
torch_output_tensor = F.softmax(torch_input_tensor, dim=-1, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.softmax(input_tensor, dim=-1, numeric_stable=True)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.999)


def run_softmax_stable_with_program_cache(device, batch_size, h, w, skip_scale_mask, math_approx):
torch.manual_seed(0)

scale = 1.0
attention_mask = torch.rand(batch_size, 1, 1, w)
attention_mask = (attention_mask > 0.5).float()
attention_mask = attention_mask.masked_fill(attention_mask == 0, torch.tensor(float("-inf"), dtype=torch.bfloat16))
attention_mask = attention_mask.masked_fill(attention_mask == 1, 0)
attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

torch_input_tensor = torch_random((batch_size, 1, h, w), -1000, 1000, dtype=torch.bfloat16)
if not skip_scale_mask:
torch_output_tensor = torch_input_tensor * scale + attention_mask
else:
torch_output_tensor = torch_input_tensor
torch_output_tensor = F.softmax(torch_output_tensor, dim=-1, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)

if is_grayskull():
compute_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=math_approx,
)
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi4,
math_approx_mode=math_approx,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)

if not skip_scale_mask:
output_tensor = ttnn.scale_mask_softmax(
input_tensor, scale, attention_mask_t, compute_kernel_config=compute_kernel_config, numeric_stable=True
)
else:
output_tensor = ttnn.softmax(
input_tensor, dim=-1, compute_kernel_config=compute_kernel_config, numeric_stable=True
)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.999)


@pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("h", [32, 128])
@pytest.mark.parametrize("w", [1024, 1500])
@pytest.mark.parametrize("skip_scale_mask", [True, False])
@pytest.mark.parametrize("math_approx", [True, False])
def test_softmax_stable_with_program_cache(device, batch_size, h, w, skip_scale_mask, math_approx, use_program_cache):
for _ in range(2):
run_softmax_stable_with_program_cache(device, batch_size, h, w, skip_scale_mask, math_approx)
# dummy tensor to change tensor alloc
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.from_torch(
py_dummy_tensor,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 1


def run_softmax_sharded_stable(device, batch_size, num_heads, h, w, skip_scale_mask):
torch.manual_seed(0)

grid_size = (batch_size, num_heads)

scale = 1.0
attention_mask = torch.rand(batch_size, 1, 1, w)
attention_mask = (attention_mask > 0.5).float()
attention_mask = attention_mask.masked_fill(attention_mask == 0, torch.tensor(float("-inf"), dtype=torch.bfloat16))
attention_mask = attention_mask.masked_fill(attention_mask == 1, 0)
attention_mask_t = ttnn.from_torch(attention_mask, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

torch_input_tensor = torch_random((batch_size, num_heads, h, w), -1000, 1000, dtype=torch.bfloat16)
if not skip_scale_mask:
torch_output_tensor = torch_input_tensor * scale + attention_mask
else:
torch_output_tensor = torch_input_tensor
torch_output_tensor = F.softmax(torch_output_tensor, dim=-1, dtype=torch.bfloat16)

memory_config = ttnn.create_sharded_memory_config(
torch_input_tensor.shape,
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
program_config = ttnn.SoftmaxShardedMultiCoreProgramConfig(
compute_with_storage_grid_size=grid_size,
subblock_w=6,
block_h=h // 32,
block_w=w // 32,
)

input_tensor = ttnn.from_torch(
torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, memory_config=memory_config
)
if not skip_scale_mask:
output_tensor = ttnn.scale_mask_softmax_in_place(
input_tensor, scale, attention_mask_t, program_config=program_config, numeric_stable=True
)
else:
output_tensor = ttnn.scale_mask_softmax_in_place(
input_tensor, program_config=program_config, numeric_stable=True
)
output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor, 0.999)


@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("num_heads", [4])
@pytest.mark.parametrize("h", [384])
@pytest.mark.parametrize("w", [384])
@pytest.mark.parametrize("skip_scale_mask", [True, False])
def test_softmax_sharded_stable_with_program_cache(
device, batch_size, num_heads, h, w, skip_scale_mask, use_program_cache
):
for _ in range(2):
run_softmax_sharded_stable(device, batch_size, num_heads, h, w, skip_scale_mask)
# dummy tensor to change tensor alloc
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.from_torch(
py_dummy_tensor,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 1


@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("h", [32, 64])
@pytest.mark.parametrize("w", [32, 64])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
#include "compute_kernel_api/softmax.h"
#include "compute_kernel_api/reduce.h"

// #include "debug/dprint.h"
#include "debug/dprint.h"
#include "debug/dprint_tensix.h"

ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }
Expand All @@ -25,6 +26,46 @@ ALWI void REL() { release_dst(tt::DstMode::Half); }
// The buffer for the att mask is currently sized as (1t,Wt) so we only reuse it for one HtWt-sized batch of x
// then read another Wt tiles of mask for the next batch

void calc_numeric_stable(uint32_t Wt, uint32_t ndst, uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t cb_max, uint32_t cb_out) {
// calculate max val per row
ACQ();
unpack_reconfig_data_format(cb_in, cb_bcast_scaler);
cb_reserve_back(cb_max, 1);
cb_wait_front(cb_bcast_scaler, 1);
reduce_init_delta<false, PoolType::MAX, ReduceDim::REDUCE_ROW>();
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(cb_in, wt+1);
constexpr uint32_t bcast_scaler0 = 0;
reduce_tile<PoolType::MAX, ReduceDim::REDUCE_ROW>(cb_in, cb_bcast_scaler, wt, bcast_scaler0, 0);
}
reduce_revert_delta<ReduceDim::REDUCE_ROW>();
pack_tile(0, cb_max);
cb_push_back(cb_max, 1);
REL();

// calculate x-max(x)
exp_tile_init<EXP_APPROX>();
unpack_reconfig_data_format_srcb(cb_max);
cb_wait_front(cb_max, 1);
sub_bcast_cols_init_short();
for (uint32_t wt = 0; wt < Wt; wt += ndst) {
ACQ();
for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
sub_tiles_bcast_cols(cb_in, cb_max, wt+wt8, 0, wt8);
}
cb_reserve_back(cb_out, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
pack_tile(wt8, cb_out); // reuse the exps buffer again, this time in a circular manner
}
cb_push_back(cb_out, ndst);
REL();
}
cb_pop_front(cb_in, Wt);
cb_pop_front(cb_max, 1);
cb_wait_front(cb_out, Wt);
}

namespace NAMESPACE {
void MAIN {

Expand All @@ -50,7 +91,12 @@ void MAIN {
constexpr auto cb_recipsumexps = tt::CB::c_intermed1;
constexpr auto cb_in0 = tt::CB::c_in0;
constexpr auto cb_out0 = tt::CB::c_out0;

#ifdef NUMERIC_STABLE
constexpr auto cb_max = tt::CB::c_intermed2;
constexpr auto cb_x = tt::CB::c_intermed4;
#else
constexpr auto cb_x = cb_exps;
#endif

cb_wait_front(cb_bcast_scaler, 1); // comes from the reader

Expand Down Expand Up @@ -81,38 +127,50 @@ void MAIN {
}
unpack_reconfig_data_format(cb_scale_mask, cb_fused_attn);

exp_tile_init<EXP_APPROX>();
#ifndef NUMERIC_STABLE
exp_tile_init<EXP_APPROX>();
#endif

#ifdef CAUSAL_MASK
add_tiles_init();
add_tiles_init();
#else
add_bcast_rows_init_short();
add_bcast_rows_init_short();
#endif
for (uint32_t wt = 0; wt < Wt; wt+=ndst) {
ACQ();
cb_wait_front(cb_scale_mask, ndst);
#ifdef CAUSAL_MASK
cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles
for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
add_tiles(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x)))
}
cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles
for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
add_tiles(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x)))
}
#else
if (wait_mask) {
cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles, only at first ht
}
if (wait_mask) {
cb_wait_front(cb_fused_attn, wt+ndst); // cumulative wait for up to Wt tiles, only at first ht
}

for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x)))
}
for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
add_tiles_bcast_rows(cb_scale_mask, cb_fused_attn, wt8, wt+wt8, wt8); // tile *= 1/(sum(exp(x)))
}
#endif
cb_pop_front(cb_scale_mask, ndst);
cb_reserve_back(cb_exps, ndst);
cb_reserve_back(cb_x, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; wt8++) {
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
pack_tile(wt8, cb_exps); // reuse the exps buffer again, this time in a circular manner
#ifndef NUMERIC_STABLE
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
#endif
pack_tile(wt8, cb_x); // reuse the exps buffer again, this time in a circular manner
}
cb_push_back(cb_exps, ndst);
cb_push_back(cb_x, ndst);
REL();
}

// add numeric_stable
// fuse exp with sub tiles
#ifdef NUMERIC_STABLE
calc_numeric_stable(Wt, ndst, cb_x, cb_bcast_scaler, cb_max, cb_exps);
#endif

#ifdef CAUSAL_MASK
cb_pop_front(cb_fused_attn, Wt);
#else
Expand All @@ -132,7 +190,9 @@ void MAIN {
unpack_reconfig_data_format(cb_in0, cb_in0);
pack_reconfig_data_format(cb_exps);
copy_tile_to_dst_init_short(); // need to copy from CB to DST to be able to run sfpu math
exp_tile_init<EXP_APPROX>();
#ifndef NUMERIC_STABLE
exp_tile_init<EXP_APPROX>();
#endif
if (mask_padded_data) {
for (uint32_t wt = 0; wt < Wt; wt+=ndst) {
ACQ();
Expand All @@ -149,32 +209,46 @@ void MAIN {
}
cb_pop_front(cb_in0, ndst);

cb_reserve_back(cb_exps, ndst);
cb_reserve_back(cb_x, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) {
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt]
#ifndef NUMERIC_STABLE
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
#endif
pack_tile(wt8, cb_x); // DST[0]->cb_id[wt]
}
cb_push_back(cb_exps, ndst);
cb_push_back(cb_x, ndst);
REL();
}

// add numeric_stable
// fuse exp with sub tiles
#ifdef NUMERIC_STABLE
calc_numeric_stable(Wt, ndst, cb_x, cb_bcast_scaler, cb_max, cb_exps);
#endif

} else {
for (uint32_t wt = 0; wt < Wt; wt+=ndst) {
ACQ();
cb_wait_front(cb_in0, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) {
copy_tile(cb_in0, wt8, wt8); // copy from c_in[0] to DST[0]
}
cb_pop_front(cb_in0, ndst);
// add numeric_stable
// fuse exp with sub tiles
#ifdef NUMERIC_STABLE
calc_numeric_stable(Wt, ndst, cb_in0, cb_bcast_scaler, cb_max, cb_exps);
#else
for (uint32_t wt = 0; wt < Wt; wt+=ndst) {
ACQ();
cb_wait_front(cb_in0, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) {
copy_tile(cb_in0, wt8, wt8); // copy from c_in[0] to DST[0]
}
cb_pop_front(cb_in0, ndst);

cb_reserve_back(cb_exps, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) {
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt]
cb_reserve_back(cb_exps, ndst);
for (uint32_t wt8 = 0; wt8 < ndst; ++wt8) {
exp_tile<EXP_APPROX>(wt8); // exp on DST[0]
pack_tile(wt8, cb_exps); // DST[0]->cb_id[wt]
}
cb_push_back(cb_exps, ndst);
REL();
}
cb_push_back(cb_exps, ndst);
REL();
}
#endif
}

unpack_reconfig_data_format(cb_exps, cb_bcast_scaler);
Expand Down
Loading

0 comments on commit ce89b5c

Please sign in to comment.