-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#9553: Add prefix scan op for Mamba prefill
- Loading branch information
Showing
9 changed files
with
601 additions
and
0 deletions.
There are no files selected for viewing
95 changes: 95 additions & 0 deletions
95
tests/tt_eager/python_api_testing/unit_testing/misc/test_ssm_prefix_scan.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
import tt_lib as ttl | ||
import pytest | ||
from loguru import logger | ||
|
||
from models.utility_functions import tt2torch_tensor, comp_pcc, skip_for_grayskull | ||
|
||
|
||
def sequential_prefix_scan(a, bx): | ||
(_, _, L, EN) = bx.shape | ||
hidden_states = torch.zeros((1, 1, L, EN), device=a.device) | ||
for i in range(L): | ||
hidden_states[:, :, i] = a[:, :, i] * hidden_states[:, :, i - 1] + bx[:, :, i] | ||
return hidden_states | ||
|
||
|
||
def run_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device): | ||
torch.manual_seed(0) | ||
|
||
a = torch.randn((1, 1, L, E * N)) | ||
bx = torch.randn((1, 1, L, E * N)) | ||
|
||
expected = sequential_prefix_scan(a, bx) | ||
|
||
compute_grid_size = device.compute_with_storage_grid_size() | ||
shard_grid = ttl.tensor.CoreRangeSet(ttl.tensor.num_cores_to_corerange_set(num_cores, compute_grid_size, True)) | ||
shard_spec = ttl.tensor.ShardSpec( | ||
shard_grid, | ||
[L, E * N // num_cores], | ||
ttl.tensor.ShardOrientation.ROW_MAJOR, | ||
False, | ||
) | ||
memory_config = ttl.tensor.MemoryConfig( | ||
ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.L1, shard_spec | ||
) | ||
a = ttl.tensor.Tensor(a, dtype).to(ttl.tensor.Layout.TILE).to(device, memory_config) | ||
bx = ttl.tensor.Tensor(bx, dtype).to(ttl.tensor.Layout.TILE).to(device, memory_config) | ||
|
||
actual = ttl.operations.primary.transformers.ssm_prefix_scan( | ||
a, bx, output_mem_config=memory_config, output_dtype=dtype | ||
) | ||
assert list(actual.get_legacy_shape()) == list(expected.shape) | ||
assert actual.dtype == dtype | ||
|
||
actual = tt2torch_tensor(actual) | ||
|
||
passing_pcc, output_pcc = comp_pcc(actual, expected, 0.999) | ||
logger.debug(f"Out passing={passing_pcc}") | ||
logger.debug(f"Output pcc={output_pcc}") | ||
|
||
assert passing_pcc | ||
|
||
|
||
@skip_for_grayskull("Grayskull not supported") | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
(ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B), | ||
) | ||
@pytest.mark.parametrize( | ||
"L, E, N, num_cores", | ||
( | ||
(32, 32, 32, 1), | ||
(32, 64, 32, 1), | ||
(32, 2560, 32, 32), | ||
(32, 5120, 32, 40), | ||
# (32, 5120, 32, 64) -> 8x8 grid not supported on CI | ||
), | ||
) | ||
def test_ssm_reduce(L: int, E: int, N: int, num_cores: int, dtype, device): | ||
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device) | ||
|
||
|
||
@skip_for_grayskull("Grayskull not supported") | ||
def test_ssm_prefix_scan_with_program_cache(device, use_program_cache): | ||
L, E, N = 32, 64, 32 | ||
num_cores = 1 | ||
dtype = ttl.tensor.DataType.BFLOAT8_B | ||
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device) | ||
|
||
dummy_memory_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1) | ||
dummy_shape = [1, 1, 128, 128] | ||
|
||
for _ in range(2): | ||
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device) | ||
py_dummy_tensor = torch.randn(dummy_shape) | ||
tt_dummy_tensor = ( | ||
ttl.tensor.Tensor(py_dummy_tensor, dtype).to(ttl.tensor.Layout.TILE).to(device, dummy_memory_config) | ||
) | ||
|
||
assert device.num_program_cache_entries() == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
184 changes: 184 additions & 0 deletions
184
tt_eager/tt_dnn/op_library/transformer_tms/kernels/compute/ssm_prefix_scan.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <cstdint> | ||
|
||
#include "compute_kernel_api/eltwise_binary.h" | ||
#include "compute_kernel_api/tile_move_copy.h" | ||
#include "compute_kernel_api/tilize.h" | ||
#include "compute_kernel_api/untilize.h" | ||
|
||
constexpr uint32_t NUM_TILES_IN_TILIZED_CHUNK = 32; | ||
|
||
constexpr uint32_t cb_a_in = get_compile_time_arg_val(0); | ||
constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1); | ||
|
||
constexpr uint32_t cb_a_tilize_in = get_compile_time_arg_val(2); | ||
constexpr uint32_t cb_bx_tilize_in = get_compile_time_arg_val(3); | ||
|
||
constexpr uint32_t cb_h_prev = get_compile_time_arg_val(4); | ||
constexpr uint32_t cb_ah = get_compile_time_arg_val(5); | ||
constexpr uint32_t cb_h = get_compile_time_arg_val(6); | ||
|
||
constexpr uint32_t cb_tilize_out = get_compile_time_arg_val(7); | ||
constexpr uint32_t cb_out = get_compile_time_arg_val(8); | ||
|
||
constexpr uint32_t cb_zeros = get_compile_time_arg_val(9); | ||
|
||
// This function relies on untilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount | ||
FORCE_INLINE void pack_block_rows_into_tiles(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) { | ||
unpack_reconfig_data_format_srca(cb_in); | ||
pack_reconfig_data_format(cb_out); | ||
|
||
untilize_init_short(cb_in); | ||
|
||
cb_wait_front(cb_in, num_tiles); | ||
cb_reserve_back(cb_out, NUM_TILES_IN_TILIZED_CHUNK); | ||
|
||
untilize_block(cb_in, NUM_TILES_IN_TILIZED_CHUNK, cb_out); | ||
|
||
cb_push_back(cb_out, NUM_TILES_IN_TILIZED_CHUNK); | ||
cb_pop_front(cb_in, num_tiles); | ||
|
||
untilize_uninit(cb_in); | ||
} | ||
|
||
// This function relies on tilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount | ||
FORCE_INLINE void pack_block_tiles_into_rows(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) { | ||
unpack_reconfig_data_format_srca(cb_in); | ||
pack_reconfig_data_format(cb_out); | ||
|
||
tilize_init_short(cb_in, NUM_TILES_IN_TILIZED_CHUNK); | ||
|
||
cb_wait_front(cb_in, NUM_TILES_IN_TILIZED_CHUNK); | ||
cb_reserve_back(cb_out, num_tiles); | ||
|
||
tilize_block(cb_in, NUM_TILES_IN_TILIZED_CHUNK, cb_out); | ||
|
||
cb_push_back(cb_out, num_tiles); | ||
cb_pop_front(cb_in, NUM_TILES_IN_TILIZED_CHUNK); | ||
|
||
tilize_uninit(cb_in); | ||
} | ||
|
||
FORCE_INLINE void mul(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) { | ||
unpack_reconfig_data_format(cb_a, cb_b); | ||
pack_reconfig_data_format(cb_out); | ||
|
||
mul_tiles_init(); | ||
|
||
cb_wait_front(cb_a, 1); | ||
cb_wait_front(cb_b, 1); | ||
cb_reserve_back(cb_out, 1); | ||
|
||
tile_regs_acquire(); | ||
mul_tiles(cb_a, cb_b, 0, 0, 0); | ||
tile_regs_commit(); | ||
tile_regs_wait(); | ||
pack_tile(0, cb_out); | ||
tile_regs_release(); | ||
|
||
cb_push_back(cb_out, 1); | ||
cb_pop_front(cb_a, 1); | ||
cb_pop_front(cb_b, 1); | ||
} | ||
|
||
FORCE_INLINE void sum(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) { | ||
unpack_reconfig_data_format(cb_a, cb_b); | ||
pack_reconfig_data_format(cb_out); | ||
|
||
add_tiles_init(); | ||
|
||
cb_wait_front(cb_a, 1); | ||
cb_wait_front(cb_b, 1); | ||
cb_reserve_back(cb_out, 1); | ||
|
||
tile_regs_acquire(); | ||
add_tiles(cb_a, cb_b, 0, 0, 0); | ||
tile_regs_commit(); | ||
tile_regs_wait(); | ||
pack_tile(0, cb_out); | ||
tile_regs_release(); | ||
|
||
cb_push_back(cb_out, 1); | ||
cb_pop_front(cb_a, 1); | ||
cb_pop_front(cb_b, 1); | ||
} | ||
|
||
FORCE_INLINE void copy(uint32_t cb_in, uint32_t cb_out) { | ||
unpack_reconfig_data_format_srca(cb_in); | ||
pack_reconfig_data_format(cb_out); | ||
|
||
copy_tile_to_dst_init_short(); | ||
|
||
cb_wait_front(cb_in, 1); | ||
cb_reserve_back(cb_out, 1); | ||
|
||
tile_regs_acquire(); | ||
copy_tile(cb_in, 0, 0); | ||
tile_regs_commit(); | ||
tile_regs_wait(); | ||
pack_tile(0, cb_out); | ||
tile_regs_release(); | ||
|
||
// Don't pop the copied tile - caller can do it | ||
cb_push_back(cb_out, 1); | ||
} | ||
|
||
FORCE_INLINE void setup_cb_zeros() { | ||
cb_reserve_back(cb_zeros, 1); | ||
cb_push_back(cb_zeros, 1); | ||
} | ||
|
||
FORCE_INLINE void fill_tile_zeros(uint32_t cb_id) { copy(cb_zeros, cb_id); } | ||
|
||
FORCE_INLINE void compute_ht(uint32_t cb_a, uint32_t cb_bx, uint32_t cb_out, uint32_t num_tiles) { | ||
for (uint32_t idx = 0; idx < num_tiles; idx++) { | ||
mul(cb_a, cb_h_prev, cb_ah); | ||
sum(cb_ah, cb_bx, cb_h); | ||
copy(cb_h, cb_h_prev); | ||
copy(cb_h, cb_out); // TODO: Get rid of this extraneous copy | ||
cb_pop_front(cb_h, 1); | ||
} | ||
// Make sure to remove the last hidden state | ||
cb_wait_front(cb_h_prev, 1); | ||
cb_pop_front(cb_h_prev, 1); | ||
} | ||
|
||
namespace NAMESPACE { | ||
void MAIN { | ||
const uint32_t total_tiles = get_arg_val<uint32_t>(0); | ||
const uint32_t total_tiles_per_row = get_arg_val<uint32_t>(1); | ||
const uint32_t total_tiles_per_col = get_arg_val<uint32_t>(2); | ||
|
||
const uint32_t num_tilize_per_row = | ||
(total_tiles_per_row + NUM_TILES_IN_TILIZED_CHUNK - 1) / NUM_TILES_IN_TILIZED_CHUNK; // ceil(x/y) | ||
|
||
untilize_init(cb_a_in); | ||
binary_op_init_common(cb_a_in, cb_bx_in); | ||
|
||
setup_cb_zeros(); | ||
|
||
// For each row of tiles we want to tilize chunks of 32 tiles to pack the rows into tiles | ||
for (uint32_t row_idx = 0; row_idx < total_tiles_per_col; row_idx++) { | ||
for (uint32_t tilized_chunk_idx = 0; tilized_chunk_idx < num_tilize_per_row; tilized_chunk_idx++) { | ||
fill_tile_zeros(cb_h_prev); | ||
|
||
// If we don't have a full chunk (NUM_TILES_IN_TILIZED_CHUNK tiles) we should figure out how many tiles we | ||
// have left. This only runs 2-3 tiles per shard so no need to unroll. | ||
const uint32_t remaining_tiles_in_chunk = | ||
tilized_chunk_idx == num_tilize_per_row - 1 && total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK != 0 | ||
? total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK | ||
: NUM_TILES_IN_TILIZED_CHUNK; | ||
|
||
pack_block_rows_into_tiles(cb_a_in, cb_a_tilize_in, remaining_tiles_in_chunk); | ||
pack_block_rows_into_tiles(cb_bx_in, cb_bx_tilize_in, remaining_tiles_in_chunk); | ||
|
||
compute_ht(cb_a_tilize_in, cb_bx_tilize_in, cb_tilize_out, NUM_TILES_IN_TILIZED_CHUNK); | ||
|
||
pack_block_tiles_into_rows(cb_tilize_out, cb_out, remaining_tiles_in_chunk); | ||
} | ||
} | ||
} | ||
} // namespace NAMESPACE |
14 changes: 14 additions & 0 deletions
14
tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/reader_ssm_prefix_scan.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
uint32_t num_tiles_per_core = get_arg_val<uint32_t>(0); | ||
constexpr uint32_t cb_a_in = get_compile_time_arg_val(0); | ||
constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1); | ||
|
||
cb_push_back(cb_a_in, num_tiles_per_core); | ||
cb_push_back(cb_bx_in, num_tiles_per_core); | ||
} |
11 changes: 11 additions & 0 deletions
11
tt_eager/tt_dnn/op_library/transformer_tms/kernels/dataflow/writer_ssm_prefix_scan.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "dataflow_api.h" | ||
|
||
void kernel_main() { | ||
uint32_t num_tiles_per_core = get_arg_val<uint32_t>(0); | ||
constexpr uint32_t cb_out = get_compile_time_arg_val(0); | ||
cb_wait_front(cb_out, num_tiles_per_core); | ||
} |
Oops, something went wrong.