Skip to content

Commit

Permalink
#0: fix in0 height sharding case
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Dec 10, 2024
1 parent c06d6e1 commit cddc7c4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 25 deletions.
42 changes: 28 additions & 14 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,8 +996,8 @@ def run_matmul_1d_multiple_output_blocks_per_core(
per_core_M = m // 32
per_core_N = n // num_cores // 32 + uneven_width
else:
in0_block_w = k // 32
per_core_M = m // 32 // num_cores + uneven_width
in0_block_w = k // 32 // 2 # test exracting shards
per_core_M = m // 32 // num_cores
per_core_N = n // 32
out_block_h = per_core_M // num_out_block_h
out_block_w = per_core_N // num_out_block_w
Expand All @@ -1017,13 +1017,21 @@ def run_matmul_1d_multiple_output_blocks_per_core(
in0 = torch.randn(in0_shape).bfloat16().float()
in1 = torch.randn(in1_shape).bfloat16().float()

if in_sharded and mcast_in0:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
if in_sharded:
if mcast_in0:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
else:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
else:
in0_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
Expand Down Expand Up @@ -1080,11 +1088,17 @@ def run_matmul_1d_multiple_output_blocks_per_core(
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
if out_sharded and mcast_in0:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
if out_sharded:
if mcast_in0:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
else:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
else:
out_mem_config = ttnn.DRAM_MEMORY_CONFIG

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,17 @@ void kernel_main() {
// In case we need to send multiple blocks per shard, in0 sharded cb is cb2 and we extract the sub-blocks to cb0
constexpr uint32_t shard_read_stride = shard_width_in_tiles * in0_single_tile_size_bytes;
constexpr uint32_t shard_read_width = in0_single_tile_size_bytes * in0_block_w;
constexpr uint32_t shard_num_tiles = shard_width_in_tiles * shard_height_in_tiles;
constexpr uint32_t in0_tensor_next_h_dim_block_stride_bytes =
in0_tensor_next_h_dim_block_stride * in0_single_tile_size_bytes;

uint64_t noc_shard_read_start_addr = 0;
uint32_t noc_shard_read_start_addr = 0;
if constexpr (extract_shard_sub_blocks) {
constexpr uint32_t cb_id_in2 = 2; // in0 sharded cb if extract_shard_sub_blocks
noc_shard_read_start_addr = get_noc_addr(get_read_ptr(cb_id_in2));
noc_shard_read_start_addr = get_read_ptr(cb_id_in2);
} else {
cb_reserve_back(cb_id_in0, in0_block_num_tiles);
cb_push_back(cb_id_in0, in0_block_num_tiles);
cb_reserve_back(cb_id_in0, shard_num_tiles);
cb_push_back(cb_id_in0, shard_num_tiles);
}
#else
constexpr DataFormat in0_data_format = get_dataformat(cb_id_in0);
Expand Down Expand Up @@ -113,9 +116,15 @@ void kernel_main() {
#endif

for (uint32_t b = 0; b < batch; ++b) {
#ifdef IN0_SHARDED
uint32_t in0_tensor_current_h_dim_block_start_addr = noc_shard_read_start_addr;
#endif
uint32_t in0_tensor_current_h_dim_block_tile_id = in0_tensor_start_tile_id;
for (uint32_t bh = 0; bh < num_blocks_h_dim; ++bh) {
for (uint32_t bw = 0; bw < num_blocks_w_dim; ++bw) {
#ifdef IN0_SHARDED
uint32_t in0_tensor_current_inner_dim_block_start_addr = in0_tensor_current_h_dim_block_start_addr;
#endif
uint32_t in0_tensor_current_inner_dim_block_start_tile_id = in0_tensor_current_h_dim_block_tile_id;
for (uint32_t block = 0; block < num_blocks_inner_dim; ++block) {
if constexpr (fuse_op) {
Expand Down Expand Up @@ -159,16 +168,16 @@ void kernel_main() {
in0_start_address = l1_write_addr_in0; // copy start address of block, to be used for mcasting
#endif

uint64_t noc_shard_read_addr = noc_shard_read_start_addr;
noc_shard_read_start_addr += shard_read_width;
uint64_t noc_shard_read_addr = get_noc_addr(in0_tensor_current_inner_dim_block_start_addr);

for (uint32_t i = 0; i < shard_height_in_tiles; i++) {
for (uint32_t i = 0; i < in0_block_h; i++) {
noc_async_read(noc_shard_read_addr, l1_write_addr_in0, shard_read_width);

l1_write_addr_in0 += shard_read_width;
noc_shard_read_addr += shard_read_stride;
}

in0_tensor_current_inner_dim_block_start_addr += shard_read_width;
noc_async_read_barrier();
}
#endif
Expand Down Expand Up @@ -216,6 +225,9 @@ void kernel_main() {
#endif
}
}
#ifdef IN0_SHARDED
in0_tensor_current_h_dim_block_start_addr += in0_tensor_next_h_dim_block_stride_bytes;
#endif
in0_tensor_current_h_dim_block_tile_id += in0_tensor_next_h_dim_block_stride;
}
in0_tensor_start_tile_id += MtKt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
uint32_t in0_block_tiles = in0_block_h * in0_block_w;
uint32_t in0_CB_tiles = in0_block_tiles;
if (in0_is_sharded) {
in0_CB_tiles = num_blocks * in0_CB_tiles * B;
in0_CB_tiles = num_blocks * per_core_M * in0_block_w * B;
} else if (B * num_blocks > 1) {
in0_CB_tiles = in0_CB_tiles * 2; // double buffer
}
Expand All @@ -1045,6 +1045,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
extract_shard_sub_blocks = true;
}
}
uint32_t in2_CB_tiles = in0_block_tiles;
uint32_t in2_CB_size = in2_CB_tiles * in0_single_tile_size;

uint32_t in1_block_tiles = out_block_w * in0_block_w;
uint32_t in1_CB_tiles = in1_block_tiles;
Expand Down Expand Up @@ -1384,7 +1386,7 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
CBHandle cb_src2 = 0;
if (in0_is_sharded and extract_shard_sub_blocks) { // in0_is_sharded is technically redundant
tt_metal::CircularBufferConfig src2_cb_config =
tt_metal::CircularBufferConfig(in0_CB_size, {{src2_cb_index, in0_data_format}})
tt_metal::CircularBufferConfig(in2_CB_size, {{src2_cb_index, in0_data_format}})
.set_page_size(src2_cb_index, in0_single_tile_size)
.set_globally_allocated_address(*in0_buffer)
.set_tile_dims(src2_cb_index, in0_tile);
Expand All @@ -1394,8 +1396,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
"CB {} :: PS = {}, NP = {}, TOTAL = {}",
src2_cb_index,
in0_single_tile_size,
in0_CB_size / in0_single_tile_size,
in0_CB_size);
in2_CB_size / in0_single_tile_size,
in2_CB_size);
}

uint32_t src1_cb_index = 1;
Expand Down

0 comments on commit cddc7c4

Please sign in to comment.