Skip to content

Commit

Permalink
Add MM 1d with multi-block support (#15864)
Browse files Browse the repository at this point in the history
### Problem description
Matmul1d with multi-block support (similar to the matmul2d support)

Current matmul 1d only supports single output block per core, we need to
change it so that large matmuls can use less L1.

What's changed

support multiple output blocks on both height and width for interleaved
output.
support only multi-block on height for sharded output.
support in0 with multi-block on height.
support in1/bias with multi-block on width.

so we will have 3 levels of blocks:
per_core_M
block_h
subblock_h
previously we missed the second one

### Checklist
- [x] Post commit CI
https://github.com/tenstorrent/tt-metal/actions/runs/12284509644
- [x] Blackhole Post commit
https://github.com/tenstorrent/tt-metal/actions/runs/12284513170
- [x] Model regression CI testing
https://github.com/tenstorrent/tt-metal/actions/runs/12285308255
- [x] Device performance regression CI testing
https://github.com/tenstorrent/tt-metal/actions/runs/12285312770/job/34283864283
  • Loading branch information
yugaoTT authored Dec 12, 2024
1 parent d43dc8d commit 2b6277e
Show file tree
Hide file tree
Showing 9 changed files with 490 additions and 141 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ INSTANTIATE_TEST_SUITE_P(
.in0_block_w = 2,
.out_subblock_h = 1,
.out_subblock_w = 1,
.out_block_h = 64,
.out_block_w = 2,
.per_core_M = 64,
.per_core_N = 2,
.fuse_batch = true,
Expand Down
226 changes: 226 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,232 @@ def test_matmul_1d_tiny_tile(
assert device.num_program_cache_entries() == 1


def run_matmul_1d_multiple_output_blocks_per_core(
device,
m,
k,
n,
has_bias,
grid_size,
in_sharded,
out_sharded,
num_out_block_h,
num_out_block_w,
mcast_in0,
uneven_width,
):
if in_sharded or out_sharded:
fuse_batch = True
else:
fuse_batch = False

if out_sharded and num_out_block_w > 1:
pytest.skip("out sharded not support multiple blocks on w dim")

if not mcast_in0:
tmp = m
m = n
n = tmp

in0_shape = [1, 1, m, k]
in1_shape = [1, 1, k, n]
bias_shape = [1, 1, n]

num_cores = grid_size[0] * grid_size[1]

if mcast_in0:
in0_block_w = k // num_cores // 32
per_core_M = m // 32
per_core_N = n // num_cores // 32 + uneven_width
else:
in0_block_w = k // 32 // 2 # test exracting shards
per_core_M = m // 32 // num_cores
per_core_N = n // 32
out_block_h = per_core_M // num_out_block_h
out_block_w = per_core_N // num_out_block_w
out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w)

logger.info(f"m: {m}")
logger.info(f"k: {k}")
logger.info(f"n: {n}")
logger.info(f"in0_block_w: {in0_block_w}")
logger.info(f"per_core_M: {per_core_M}")
logger.info(f"per_core_N: {per_core_N}")
logger.info(f"out_block_h: {out_block_h}")
logger.info(f"out_block_w: {out_block_w}")
logger.info(f"out_subblock_h: {out_subblock_h}")
logger.info(f"out_subblock_w: {out_subblock_w}")

in0 = torch.randn(in0_shape).bfloat16().float()
in1 = torch.randn(in1_shape).bfloat16().float()

if in_sharded:
if mcast_in0:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
else:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
else:
in0_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
in0_t = ttnn.from_torch(
in0,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=in0_memory_config,
)
in1_t = ttnn.from_torch(
in1,
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=in1_memory_config,
)

if has_bias:
bias = torch.randn(bias_shape).bfloat16().float()
bias_padded = bias.unsqueeze(2)
bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, 32 - bias_padded.size(2)), "constant", 0)
bias_t = ttnn.from_torch(
bias_padded,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
out_subblock_h=out_subblock_h,
out_subblock_w=out_subblock_w,
out_block_h=out_block_h,
out_block_w=out_block_w,
per_core_M=per_core_M,
per_core_N=per_core_N,
fuse_batch=fuse_batch,
fused_activation=None,
mcast_in0=mcast_in0,
)

if is_grayskull():
compute_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
)
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
if out_sharded:
if mcast_in0:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
else:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
else:
out_mem_config = ttnn.DRAM_MEMORY_CONFIG

if has_bias:
output_t = ttnn.linear(
in0_t,
in1_t,
bias=bias_t,
program_config=program_config,
memory_config=out_mem_config,
dtype=ttnn.bfloat16,
compute_kernel_config=compute_kernel_config,
)
else:
output_t = ttnn.matmul(
in0_t,
in1_t,
program_config=program_config,
memory_config=out_mem_config,
dtype=ttnn.bfloat16,
compute_kernel_config=compute_kernel_config,
)
output_tensor = ttnn.to_torch(output_t)
pt_out = in0 @ in1
if has_bias:
pt_out += bias

assert_with_pcc(pt_out, output_tensor, 0.999)


@run_for_wormhole_b0()
@pytest.mark.parametrize("m", [256])
@pytest.mark.parametrize("k", [1024])
@pytest.mark.parametrize("n", [2048])
@pytest.mark.parametrize("has_bias", [False])
@pytest.mark.parametrize("grid_size", [(8, 2)])
@pytest.mark.parametrize("in_sharded", [True, False])
@pytest.mark.parametrize("out_sharded", [True, False])
@pytest.mark.parametrize("num_out_block_h", [1, 2])
@pytest.mark.parametrize("num_out_block_w", [1, 2])
@pytest.mark.parametrize("mcast_in0", [True, False])
@pytest.mark.parametrize("uneven_width", [0, 2])
def test_matmul_1d_multiple_output_blocks_per_core(
device,
m,
k,
n,
has_bias,
grid_size,
in_sharded,
out_sharded,
num_out_block_h,
num_out_block_w,
mcast_in0,
uneven_width,
use_program_cache,
):
for _ in range(2):
run_matmul_1d_multiple_output_blocks_per_core(
device,
m,
k,
n,
has_bias,
grid_size,
in_sharded,
out_sharded,
num_out_block_h,
num_out_block_w,
mcast_in0,
uneven_width,
)
# dummy tensor to change tensor alloc
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.from_torch(
py_dummy_tensor,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 1


# fmt: off
@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH")
@pytest.mark.parametrize("m_size,k_size,n_size", [
Expand Down
2 changes: 2 additions & 0 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co
.in0_block_w = conv_blocking_config.act_block_w_ntiles,
.out_subblock_h = conv_blocking_config.out_subblock_h_ntiles,
.out_subblock_w = conv_blocking_config.out_subblock_w_ntiles,
.out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT),
.out_block_w = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH),
.per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT),
.per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH),
.fuse_batch = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,17 @@ void kernel_main() {
// In case we need to send multiple blocks per shard, in0 sharded cb is cb2 and we extract the sub-blocks to cb0
constexpr uint32_t shard_read_stride = shard_width_in_tiles * in0_single_tile_size_bytes;
constexpr uint32_t shard_read_width = in0_single_tile_size_bytes * in0_block_w;
constexpr uint32_t shard_num_tiles = shard_width_in_tiles * shard_height_in_tiles;
constexpr uint32_t in0_tensor_next_h_dim_block_stride_bytes =
in0_tensor_next_h_dim_block_stride * in0_single_tile_size_bytes;

uint64_t noc_shard_read_start_addr = 0;
uint32_t noc_shard_read_start_addr = 0;
if constexpr (extract_shard_sub_blocks) {
constexpr uint32_t cb_id_in2 = 2; // in0 sharded cb if extract_shard_sub_blocks
noc_shard_read_start_addr = get_noc_addr(get_read_ptr(cb_id_in2));
noc_shard_read_start_addr = get_read_ptr(cb_id_in2);
} else {
cb_reserve_back(cb_id_in0, in0_block_num_tiles);
cb_push_back(cb_id_in0, in0_block_num_tiles);
cb_reserve_back(cb_id_in0, shard_num_tiles);
cb_push_back(cb_id_in0, shard_num_tiles);
}
#else
constexpr DataFormat in0_data_format = get_dataformat(cb_id_in0);
Expand Down Expand Up @@ -113,9 +116,15 @@ void kernel_main() {
#endif

for (uint32_t b = 0; b < batch; ++b) {
#ifdef IN0_SHARDED
uint32_t in0_tensor_current_h_dim_block_start_addr = noc_shard_read_start_addr;
#endif
uint32_t in0_tensor_current_h_dim_block_tile_id = in0_tensor_start_tile_id;
for (uint32_t bh = 0; bh < num_blocks_h_dim; ++bh) {
for (uint32_t bw = 0; bw < num_blocks_w_dim; ++bw) {
#ifdef IN0_SHARDED
uint32_t in0_tensor_current_inner_dim_block_start_addr = in0_tensor_current_h_dim_block_start_addr;
#endif
uint32_t in0_tensor_current_inner_dim_block_start_tile_id = in0_tensor_current_h_dim_block_tile_id;
for (uint32_t block = 0; block < num_blocks_inner_dim; ++block) {
if constexpr (fuse_op) {
Expand Down Expand Up @@ -159,16 +168,16 @@ void kernel_main() {
in0_start_address = l1_write_addr_in0; // copy start address of block, to be used for mcasting
#endif

uint64_t noc_shard_read_addr = noc_shard_read_start_addr;
noc_shard_read_start_addr += shard_read_width;
uint64_t noc_shard_read_addr = get_noc_addr(in0_tensor_current_inner_dim_block_start_addr);

for (uint32_t i = 0; i < shard_height_in_tiles; i++) {
for (uint32_t i = 0; i < in0_block_h; i++) {
noc_async_read(noc_shard_read_addr, l1_write_addr_in0, shard_read_width);

l1_write_addr_in0 += shard_read_width;
noc_shard_read_addr += shard_read_stride;
}

in0_tensor_current_inner_dim_block_start_addr += shard_read_width;
noc_async_read_barrier();
}
#endif
Expand Down Expand Up @@ -216,6 +225,9 @@ void kernel_main() {
#endif
}
}
#ifdef IN0_SHARDED
in0_tensor_current_h_dim_block_start_addr += in0_tensor_next_h_dim_block_stride_bytes;
#endif
in0_tensor_current_h_dim_block_tile_id += in0_tensor_next_h_dim_block_stride;
}
in0_tensor_start_tile_id += MtKt;
Expand Down
28 changes: 28 additions & 0 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ MatmulProgramConfig create_matmul_1d_systolic_array_program_config(
.in0_block_w = k_tiles_per_core,
.out_subblock_h = out_subblock_h,
.out_subblock_w = out_subblock_w,
.out_block_h = batch_and_m_tiles_per_core,
.out_block_w = n_tiles_per_core,
.per_core_M = batch_and_m_tiles_per_core,
.per_core_N = n_tiles_per_core,
.fuse_batch = true,
Expand Down Expand Up @@ -357,6 +359,8 @@ MatmulMultiCoreReuseMultiCast1DProgramConfig get_mcast_1d_config(
.in0_block_w = in0_block_w,
.out_subblock_h = out_subblock_h,
.out_subblock_w = out_subblock_w,
.out_block_h = per_core_M,
.out_block_w = per_core_N,
.per_core_M = per_core_M,
.per_core_N = per_core_N,
.fuse_batch = fuse_batch,
Expand Down Expand Up @@ -701,6 +705,8 @@ MatmulProgramConfig get_matmul_program_config(
.in0_block_w = in0_block_w,
.out_subblock_h = out_subblock_h,
.out_subblock_w = out_subblock_w,
.out_block_h = per_core_M,
.out_block_w = per_core_N,
.per_core_M = per_core_M,
.per_core_N = per_core_N,
.fuse_batch = true,
Expand Down Expand Up @@ -1182,6 +1188,26 @@ void Matmul::validate(
// TODO: For 1D and 2D mcasts, we don't check if tensor is single core or single row/col
// We can uplift these variants to skip mcasting to support single core (1D) or single row/col (2D)
if constexpr (std::is_same_v<ProgramConfigType, MatmulMultiCoreReuseMultiCast1DProgramConfig>) {
TT_FATAL(
program_config.per_core_M % program_config.out_block_h == 0,
"Error: incompatible values {} and {}",
program_config.per_core_M,
program_config.out_block_h);
TT_FATAL(
program_config.per_core_N % program_config.out_block_w == 0,
"Error: incompatible values {} and {}",
program_config.per_core_N,
program_config.out_block_w);
TT_FATAL(
program_config.out_block_h % program_config.out_subblock_h == 0,
"Error: incompatible values {} and {}",
program_config.out_block_h,
program_config.out_subblock_h);
TT_FATAL(
program_config.out_block_w % program_config.out_subblock_w == 0,
"Error: incompatible values {} and {}",
program_config.out_block_w,
program_config.out_subblock_w);
TT_FATAL(
!(program_config.mcast_in0 && program_config.gather_in0),
"Matmul1D does not support mcast_in0 and gather_in0 at the same time.");
Expand Down Expand Up @@ -1806,6 +1832,8 @@ operation::ProgramWithCallbacks Matmul::create_program(
program_config.in0_block_w,
program_config.out_subblock_h,
program_config.out_subblock_w,
program_config.out_block_h,
program_config.out_block_w,
program_config.per_core_M,
program_config.per_core_N,
program_config.fuse_batch,
Expand Down
4 changes: 4 additions & 0 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized(
uint32_t in0_block_w,
uint32_t out_subblock_h,
uint32_t out_subblock_w,
uint32_t out_block_h,
uint32_t out_block_w,
uint32_t per_core_M,
uint32_t per_core_N,
bool fuse_batch,
Expand Down Expand Up @@ -130,6 +132,8 @@ struct MatmulMultiCoreReuseMultiCast1DProgramConfig {
std::size_t in0_block_w;
std::size_t out_subblock_h;
std::size_t out_subblock_w;
std::size_t out_block_h;
std::size_t out_block_w;
std::size_t per_core_M;
std::size_t per_core_N;
bool fuse_batch;
Expand Down
Loading

0 comments on commit 2b6277e

Please sign in to comment.