Skip to content

Commit

Permalink
#9553: Add prefix scan op for Mamba prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Jun 27, 2024
1 parent c3042ac commit 26a124c
Show file tree
Hide file tree
Showing 9 changed files with 601 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch

import tt_lib as ttl
import pytest
from loguru import logger

from models.utility_functions import tt2torch_tensor, comp_pcc, skip_for_grayskull


def sequential_prefix_scan(a, bx):
(_, _, L, EN) = bx.shape
hidden_states = torch.zeros((1, 1, L, EN), device=a.device)
for i in range(L):
hidden_states[:, :, i] = a[:, :, i] * hidden_states[:, :, i - 1] + bx[:, :, i]
return hidden_states


def run_ssm_prefix_scan(L: int, E: int, N: int, num_cores: int, dtype, device):
torch.manual_seed(0)

a = torch.randn((1, 1, L, E * N))
bx = torch.randn((1, 1, L, E * N))

expected = sequential_prefix_scan(a, bx)

compute_grid_size = device.compute_with_storage_grid_size()
shard_grid = ttl.tensor.CoreRangeSet(ttl.tensor.num_cores_to_corerange_set(num_cores, compute_grid_size, True))
shard_spec = ttl.tensor.ShardSpec(
shard_grid,
[L, E * N // num_cores],
ttl.tensor.ShardOrientation.ROW_MAJOR,
False,
)
memory_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.L1, shard_spec
)
a = ttl.tensor.Tensor(a, dtype).to(ttl.tensor.Layout.TILE).to(device, memory_config)
bx = ttl.tensor.Tensor(bx, dtype).to(ttl.tensor.Layout.TILE).to(device, memory_config)

actual = ttl.operations.primary.transformers.ssm_prefix_scan(
a, bx, output_mem_config=memory_config, output_dtype=dtype
)
assert list(actual.get_legacy_shape()) == list(expected.shape)
assert actual.dtype == dtype

actual = tt2torch_tensor(actual)

passing_pcc, output_pcc = comp_pcc(actual, expected, 0.999)
logger.debug(f"Out passing={passing_pcc}")
logger.debug(f"Output pcc={output_pcc}")

assert passing_pcc


@skip_for_grayskull("Grayskull not supported")
@pytest.mark.parametrize(
"dtype",
(ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B),
)
@pytest.mark.parametrize(
"L, E, N, num_cores",
(
(32, 32, 32, 1),
(32, 64, 32, 1),
(32, 2560, 32, 32),
(32, 5120, 32, 40),
# (32, 5120, 32, 64) -> 8x8 grid not supported on CI
),
)
def test_ssm_reduce(L: int, E: int, N: int, num_cores: int, dtype, device):
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device)


@skip_for_grayskull("Grayskull not supported")
def test_ssm_prefix_scan_with_program_cache(device, use_program_cache):
L, E, N = 32, 64, 32
num_cores = 1
dtype = ttl.tensor.DataType.BFLOAT8_B
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device)

dummy_memory_config = ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1)
dummy_shape = [1, 1, 128, 128]

for _ in range(2):
run_ssm_prefix_scan(L, E, N, num_cores, dtype, device)
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = (
ttl.tensor.Tensor(py_dummy_tensor, dtype).to(ttl.tensor.Layout.TILE).to(device, dummy_memory_config)
)

assert device.num_program_cache_entries() == 1
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_group_attn_matmul/multi_core_group_attn_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_eltwise_mul/multi_core_ssm_eltwise_mul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_1d_sum_reduce/multi_core_ssm_1d_sum_reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transformer_tms/multi_core_ssm_prefix_scan/multi_core_ssm_prefix_scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/run_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/split/split_tiled.cpp
${CMAKE_CURRENT_SOURCE_DIR}/split/split_last_dim_two_chunks_tiled.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>

#include "compute_kernel_api/eltwise_binary.h"
#include "compute_kernel_api/tile_move_copy.h"
#include "compute_kernel_api/tilize.h"
#include "compute_kernel_api/untilize.h"

constexpr uint32_t NUM_TILES_IN_TILIZED_CHUNK = 32;

constexpr uint32_t cb_a_in = get_compile_time_arg_val(0);
constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1);

constexpr uint32_t cb_a_tilize_in = get_compile_time_arg_val(2);
constexpr uint32_t cb_bx_tilize_in = get_compile_time_arg_val(3);

constexpr uint32_t cb_h_prev = get_compile_time_arg_val(4);
constexpr uint32_t cb_ah = get_compile_time_arg_val(5);
constexpr uint32_t cb_h = get_compile_time_arg_val(6);

constexpr uint32_t cb_tilize_out = get_compile_time_arg_val(7);
constexpr uint32_t cb_out = get_compile_time_arg_val(8);

constexpr uint32_t cb_zeros = get_compile_time_arg_val(9);

// This function relies on untilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount
FORCE_INLINE void pack_block_rows_into_tiles(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) {
unpack_reconfig_data_format_srca(cb_in);
pack_reconfig_data_format(cb_out);

untilize_init_short(cb_in);

cb_wait_front(cb_in, num_tiles);
cb_reserve_back(cb_out, NUM_TILES_IN_TILIZED_CHUNK);

untilize_block(cb_in, NUM_TILES_IN_TILIZED_CHUNK, cb_out);

cb_push_back(cb_out, NUM_TILES_IN_TILIZED_CHUNK);
cb_pop_front(cb_in, num_tiles);

untilize_uninit(cb_in);
}

// This function relies on tilizing NUM_TILES_IN_TILIZED_CHUNK tiles so we pad up to that amount
FORCE_INLINE void pack_block_tiles_into_rows(uint32_t cb_in, uint32_t cb_out, uint32_t num_tiles) {
unpack_reconfig_data_format_srca(cb_in);
pack_reconfig_data_format(cb_out);

tilize_init_short(cb_in, NUM_TILES_IN_TILIZED_CHUNK);

cb_wait_front(cb_in, NUM_TILES_IN_TILIZED_CHUNK);
cb_reserve_back(cb_out, num_tiles);

tilize_block(cb_in, NUM_TILES_IN_TILIZED_CHUNK, cb_out);

cb_push_back(cb_out, num_tiles);
cb_pop_front(cb_in, NUM_TILES_IN_TILIZED_CHUNK);

tilize_uninit(cb_in);
}

FORCE_INLINE void mul(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) {
unpack_reconfig_data_format(cb_a, cb_b);
pack_reconfig_data_format(cb_out);

mul_tiles_init();

cb_wait_front(cb_a, 1);
cb_wait_front(cb_b, 1);
cb_reserve_back(cb_out, 1);

tile_regs_acquire();
mul_tiles(cb_a, cb_b, 0, 0, 0);
tile_regs_commit();
tile_regs_wait();
pack_tile(0, cb_out);
tile_regs_release();

cb_push_back(cb_out, 1);
cb_pop_front(cb_a, 1);
cb_pop_front(cb_b, 1);
}

FORCE_INLINE void sum(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) {
unpack_reconfig_data_format(cb_a, cb_b);
pack_reconfig_data_format(cb_out);

add_tiles_init();

cb_wait_front(cb_a, 1);
cb_wait_front(cb_b, 1);
cb_reserve_back(cb_out, 1);

tile_regs_acquire();
add_tiles(cb_a, cb_b, 0, 0, 0);
tile_regs_commit();
tile_regs_wait();
pack_tile(0, cb_out);
tile_regs_release();

cb_push_back(cb_out, 1);
cb_pop_front(cb_a, 1);
cb_pop_front(cb_b, 1);
}

FORCE_INLINE void copy(uint32_t cb_in, uint32_t cb_out) {
unpack_reconfig_data_format_srca(cb_in);
pack_reconfig_data_format(cb_out);

copy_tile_to_dst_init_short();

cb_wait_front(cb_in, 1);
cb_reserve_back(cb_out, 1);

tile_regs_acquire();
copy_tile(cb_in, 0, 0);
tile_regs_commit();
tile_regs_wait();
pack_tile(0, cb_out);
tile_regs_release();

// Don't pop the copied tile - caller can do it
cb_push_back(cb_out, 1);
}

FORCE_INLINE void setup_cb_zeros() {
cb_reserve_back(cb_zeros, 1);
cb_push_back(cb_zeros, 1);
}

FORCE_INLINE void fill_tile_zeros(uint32_t cb_id) { copy(cb_zeros, cb_id); }

FORCE_INLINE void compute_ht(uint32_t cb_a, uint32_t cb_bx, uint32_t cb_out, uint32_t num_tiles) {
for (uint32_t idx = 0; idx < num_tiles; idx++) {
mul(cb_a, cb_h_prev, cb_ah);
sum(cb_ah, cb_bx, cb_h);
copy(cb_h, cb_h_prev);
copy(cb_h, cb_out); // TODO: Get rid of this extraneous copy
cb_pop_front(cb_h, 1);
}
// Make sure to remove the last hidden state
cb_wait_front(cb_h_prev, 1);
cb_pop_front(cb_h_prev, 1);
}

namespace NAMESPACE {
void MAIN {
const uint32_t total_tiles = get_arg_val<uint32_t>(0);
const uint32_t total_tiles_per_row = get_arg_val<uint32_t>(1);
const uint32_t total_tiles_per_col = get_arg_val<uint32_t>(2);

const uint32_t num_tilize_per_row =
(total_tiles_per_row + NUM_TILES_IN_TILIZED_CHUNK - 1) / NUM_TILES_IN_TILIZED_CHUNK; // ceil(x/y)

untilize_init(cb_a_in);
binary_op_init_common(cb_a_in, cb_bx_in);

setup_cb_zeros();

// For each row of tiles we want to tilize chunks of 32 tiles to pack the rows into tiles
for (uint32_t row_idx = 0; row_idx < total_tiles_per_col; row_idx++) {
for (uint32_t tilized_chunk_idx = 0; tilized_chunk_idx < num_tilize_per_row; tilized_chunk_idx++) {
fill_tile_zeros(cb_h_prev);

// If we don't have a full chunk (NUM_TILES_IN_TILIZED_CHUNK tiles) we should figure out how many tiles we
// have left. This only runs 2-3 tiles per shard so no need to unroll.
const uint32_t remaining_tiles_in_chunk =
tilized_chunk_idx == num_tilize_per_row - 1 && total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK != 0
? total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK
: NUM_TILES_IN_TILIZED_CHUNK;

pack_block_rows_into_tiles(cb_a_in, cb_a_tilize_in, remaining_tiles_in_chunk);
pack_block_rows_into_tiles(cb_bx_in, cb_bx_tilize_in, remaining_tiles_in_chunk);

compute_ht(cb_a_tilize_in, cb_bx_tilize_in, cb_tilize_out, NUM_TILES_IN_TILIZED_CHUNK);

pack_block_tiles_into_rows(cb_tilize_out, cb_out, remaining_tiles_in_chunk);
}
}
}
} // namespace NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"

void kernel_main() {
uint32_t num_tiles_per_core = get_arg_val<uint32_t>(0);
constexpr uint32_t cb_a_in = get_compile_time_arg_val(0);
constexpr uint32_t cb_bx_in = get_compile_time_arg_val(1);

cb_push_back(cb_a_in, num_tiles_per_core);
cb_push_back(cb_bx_in, num_tiles_per_core);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "dataflow_api.h"

void kernel_main() {
uint32_t num_tiles_per_core = get_arg_val<uint32_t>(0);
constexpr uint32_t cb_out = get_compile_time_arg_val(0);
cb_wait_front(cb_out, num_tiles_per_core);
}
Loading

0 comments on commit 26a124c

Please sign in to comment.