Skip to content

Commit

Permalink
#0: add multi-block support for MM1d
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Dec 10, 2024
1 parent d1a9014 commit 69c218a
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 124 deletions.
217 changes: 217 additions & 0 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,223 @@ def test_matmul_1d_tiny_tile(
assert device.num_program_cache_entries() == 1


def run_matmul_1d_multiple_output_blocks_per_core(
device,
m,
k,
n,
has_bias,
grid_size,
in_sharded,
out_sharded,
num_out_block_h,
num_out_block_w,
mcast_in0,
uneven_width,
):
if in_sharded or out_sharded:
fuse_batch = True
else:
fuse_batch = False

if out_sharded and num_out_block_w > 1:
pytest.skip("out sharded not support multiple blocks on w dim")

in0_shape = [1, 1, m, k]
in1_shape = [1, 1, k, n]
bias_shape = [1, 1, n]

num_cores = grid_size[0] * grid_size[1]

if mcast_in0:
in0_block_w = k // num_cores // 32
per_core_M = m // 32
per_core_N = n // num_cores // 32 + uneven_width
else:
in0_block_w = k // 32
per_core_M = (m // 32 // num_cores + uneven_width,)
per_core_N = (n // 32,)
out_block_h = per_core_M // num_out_block_h
out_block_w = per_core_N // num_out_block_w
out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w)

logger.info(f"m: {m}")
logger.info(f"k: {k}")
logger.info(f"n: {n}")
logger.info(f"in0_block_w: {in0_block_w}")
logger.info(f"per_core_M: {per_core_M}")
logger.info(f"per_core_N: {per_core_N}")
logger.info(f"out_block_h: {out_block_h}")
logger.info(f"out_block_w: {out_block_w}")
logger.info(f"out_subblock_h: {out_subblock_h}")
logger.info(f"out_subblock_w: {out_subblock_w}")

in0 = torch.randn(in0_shape).bfloat16().float()
in1 = torch.randn(in1_shape).bfloat16().float()

if in_sharded:
if mcast_in0:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
else:
in0_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.create_sharded_memory_config(
(1, 1, k, n),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
else:
in0_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
in0_t = ttnn.from_torch(
in0,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=in0_memory_config,
)
in1_t = ttnn.from_torch(
in1,
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=in1_memory_config,
)

if has_bias:
bias = torch.randn(bias_shape).bfloat16().float()
bias_padded = bias.unsqueeze(2)
bias_padded = torch.nn.functional.pad(bias_padded, (0, 0, 0, 32 - bias_padded.size(2)), "constant", 0)
bias_t = ttnn.from_torch(
bias_padded,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)

program_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=grid_size,
in0_block_w=in0_block_w,
out_subblock_h=out_subblock_h,
out_subblock_w=out_subblock_w,
out_block_h=out_block_h,
out_block_w=out_block_w,
per_core_M=per_core_M,
per_core_N=per_core_N,
fuse_batch=fuse_batch,
fused_activation=None,
mcast_in0=mcast_in0,
)

if is_grayskull():
compute_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
)
else:
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
if out_sharded:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
else:
out_mem_config = ttnn.L1_MEMORY_CONFIG

if has_bias:
output_t = ttnn.linear(
in0_t,
in1_t,
bias=bias_t,
program_config=program_config,
memory_config=out_mem_config,
dtype=ttnn.bfloat16,
compute_kernel_config=compute_kernel_config,
)
else:
output_t = ttnn.matmul(
in0_t,
in1_t,
program_config=program_config,
memory_config=out_mem_config,
dtype=ttnn.bfloat16,
compute_kernel_config=compute_kernel_config,
)
output_tensor = ttnn.to_torch(output_t)
pt_out = in0 @ in1
if has_bias:
pt_out += bias

assert_with_pcc(pt_out, output_tensor, 0.999)


@run_for_wormhole_b0()
@pytest.mark.parametrize("m", [256])
@pytest.mark.parametrize("k", [1024])
@pytest.mark.parametrize("n", [2048])
@pytest.mark.parametrize("has_bias", [False])
@pytest.mark.parametrize("grid_size", [(8, 2)])
@pytest.mark.parametrize("in_sharded", [True, False])
@pytest.mark.parametrize("out_sharded", [True, False])
@pytest.mark.parametrize("num_out_block_h", [1, 2])
@pytest.mark.parametrize("num_out_block_w", [1, 2])
@pytest.mark.parametrize("mcast_in0", [True, False])
@pytest.mark.parametrize("uneven_width", [0, 2])
def test_matmul_1d_multiple_output_blocks_per_core(
device,
m,
k,
n,
has_bias,
grid_size,
in_sharded,
out_sharded,
num_out_block_h,
num_out_block_w,
mcast_in0,
uneven_width,
use_program_cache,
):
for _ in range(2):
run_matmul_1d_multiple_output_blocks_per_core(
device,
m,
k,
n,
has_bias,
grid_size,
in_sharded,
out_sharded,
num_out_block_h,
num_out_block_w,
mcast_in0,
uneven_width,
)
# dummy tensor to change tensor alloc
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.from_torch(
py_dummy_tensor,
dtype=ttnn.DataType.BFLOAT16,
layout=ttnn.TILE_LAYOUT,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 1


# fmt: off
@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH")
@pytest.mark.parametrize("m_size,k_size,n_size", [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "compute_kernel_api/eltwise_unary/sfpu_split_includes.h"

#include "debug/dprint.h"

// Please update
// tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation_copy.cpp
// when making any changes to this file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "dataflow_api.h"
#include "hostdevcommon/common_values.hpp"

#include "debug/dprint.h"

void kernel_main() {
// in0 mcast args
const uint32_t in0_mcast_sender_noc_x = get_arg_val<uint32_t>(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "hostdevcommon/common_values.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp"

#include "debug/dprint.h"

void kernel_main() {
uint32_t rt_args_idx = 0;
// in0 tensor args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "hostdevcommon/common_values.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp"

#include "debug/dprint.h"

void kernel_main() {
// READER
uint32_t rt_args_idx = 0;
Expand Down
28 changes: 28 additions & 0 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ MatmulProgramConfig create_matmul_1d_systolic_array_program_config(
.in0_block_w = k_tiles_per_core,
.out_subblock_h = out_subblock_h,
.out_subblock_w = out_subblock_w,
.out_block_h = batch_and_m_tiles_per_core,
.out_block_w = n_tiles_per_core,
.per_core_M = batch_and_m_tiles_per_core,
.per_core_N = n_tiles_per_core,
.fuse_batch = true,
Expand Down Expand Up @@ -357,6 +359,8 @@ MatmulMultiCoreReuseMultiCast1DProgramConfig get_mcast_1d_config(
.in0_block_w = in0_block_w,
.out_subblock_h = out_subblock_h,
.out_subblock_w = out_subblock_w,
.out_block_h = per_core_M,
.out_block_w = per_core_N,
.per_core_M = per_core_M,
.per_core_N = per_core_N,
.fuse_batch = fuse_batch,
Expand Down Expand Up @@ -701,6 +705,8 @@ MatmulProgramConfig get_matmul_program_config(
.in0_block_w = in0_block_w,
.out_subblock_h = out_subblock_h,
.out_subblock_w = out_subblock_w,
.out_block_h = per_core_M,
.out_block_w = per_core_N,
.per_core_M = per_core_M,
.per_core_N = per_core_N,
.fuse_batch = true,
Expand Down Expand Up @@ -1182,6 +1188,26 @@ void Matmul::validate(
// TODO: For 1D and 2D mcasts, we don't check if tensor is single core or single row/col
// We can uplift these variants to skip mcasting to support single core (1D) or single row/col (2D)
if constexpr (std::is_same_v<ProgramConfigType, MatmulMultiCoreReuseMultiCast1DProgramConfig>) {
TT_FATAL(
program_config.per_core_M % program_config.out_block_h == 0,
"Error: incompatible values {} and {}",
program_config.per_core_M,
program_config.out_block_h);
TT_FATAL(
program_config.per_core_N % program_config.out_block_w == 0,
"Error: incompatible values {} and {}",
program_config.per_core_N,
program_config.out_block_w);
TT_FATAL(
program_config.out_block_h % program_config.out_subblock_h == 0,
"Error: incompatible values {} and {}",
program_config.out_block_h,
program_config.out_subblock_h);
TT_FATAL(
program_config.out_block_w % program_config.out_subblock_w == 0,
"Error: incompatible values {} and {}",
program_config.out_block_w,
program_config.out_subblock_w);
TT_FATAL(
!(program_config.mcast_in0 && program_config.gather_in0),
"Matmul1D does not support mcast_in0 and gather_in0 at the same time.");
Expand Down Expand Up @@ -1806,6 +1832,8 @@ operation::ProgramWithCallbacks Matmul::create_program(
program_config.in0_block_w,
program_config.out_subblock_h,
program_config.out_subblock_w,
program_config.out_block_h,
program_config.out_block_w,
program_config.per_core_M,
program_config.per_core_N,
program_config.fuse_batch,
Expand Down
4 changes: 4 additions & 0 deletions ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized(
uint32_t in0_block_w,
uint32_t out_subblock_h,
uint32_t out_subblock_w,
uint32_t out_block_h,
uint32_t out_block_w,
uint32_t per_core_M,
uint32_t per_core_N,
bool fuse_batch,
Expand Down Expand Up @@ -130,6 +132,8 @@ struct MatmulMultiCoreReuseMultiCast1DProgramConfig {
std::size_t in0_block_w;
std::size_t out_subblock_h;
std::size_t out_subblock_w;
std::size_t out_block_h;
std::size_t out_block_w;
std::size_t per_core_M;
std::size_t per_core_N;
bool fuse_batch;
Expand Down
Loading

0 comments on commit 69c218a

Please sign in to comment.