Skip to content

Commit

Permalink
#9472: Optimize sharded bcast h op
Browse files Browse the repository at this point in the history
  • Loading branch information
johanna-rock-tt committed Jun 28, 2024
1 parent 4677eda commit d3bebbb
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 17 deletions.
42 changes: 29 additions & 13 deletions tests/tt_eager/python_api_testing/unit_testing/misc/test_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,30 @@
"input_height, input_width, num_cores, shard_grid, shard_strategy",
(
(2048, 320, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(8192, 320, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(2048, 640, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(512, 640, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(2048, 1280, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(512, 1280, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(128, 1280, 40, (8, 5), ttnn.ShardStrategy.WIDTH),
(8192, 320, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(2048, 640, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(512, 1280, 40, (8, 5), ttnn.ShardStrategy.BLOCK),
(128, 1280, 32, (4, 8), ttnn.ShardStrategy.BLOCK),
(512, 1280, 64, (8, 8), ttnn.ShardStrategy.BLOCK),
),
)
@pytest.mark.parametrize(
"dtype",
"in0_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
)
@pytest.mark.parametrize(
"in1_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
)
@pytest.mark.parametrize(
"op",
[ttl.tensor.BcastOpMath.ADD, ttl.tensor.BcastOpMath.MUL],
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("in1_batch_size", [1, 2])
@pytest.mark.parametrize("in0_batch_size", [1, 2])
@pytest.mark.parametrize(
"orientation",
[ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR, ttnn.experimental.tensor.ShardOrientation.COL_MAJOR],
Expand All @@ -45,13 +52,15 @@ def test_bcast(
device,
use_program_cache,
orientation,
batch_size,
in0_batch_size,
in1_batch_size,
input_height,
input_width,
num_cores,
shard_grid,
shard_strategy,
dtype,
in0_dtype,
in1_dtype,
op,
):
torch.manual_seed(0)
Expand All @@ -67,14 +76,15 @@ def test_bcast(
if shard_grid[0] == 8 and orientation == ttnn.experimental.tensor.ShardOrientation.ROW_MAJOR
else shard_grid
)
input_shape = [batch_size, 1, input_height // batch_size, input_width]
input_shape = [in0_batch_size, 1, input_height, input_width]
input = torch.rand(input_shape, dtype=torch.bfloat16)

tt_input = input.reshape(1, 1, input_height, input_width)
input_tensor = ttnn.from_torch(
tt_input, device=device, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, dtype=dtype
input, device=device, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, dtype=in0_dtype
)
input_2d_height = (
input_tensor.get_legacy_shape()[0] * input_tensor.get_legacy_shape()[1] * input_tensor.get_legacy_shape()[2]
)
input_2d_height = input_tensor.get_legacy_shape()[2]
input_2d_width = input_tensor.get_legacy_shape()[3]
if shard_strategy == ttnn.ShardStrategy.BLOCK:
input_2d_height_padded = _nearest_y(input_2d_height, shard_grid[0] * 32)
Expand Down Expand Up @@ -108,15 +118,21 @@ def test_bcast(

tt_input = ttnn.to_memory_config(input_tensor, memory_config=in_sharded_mem_config)

b_weights_shape = [batch_size, 1, 1, input_width]
if in0_batch_size == 1 and in1_batch_size > 1:
input = input.reshape(in1_batch_size, 1, input_height // in1_batch_size, input_width)

b_weights_shape = [in1_batch_size, 1, 1, input_width]
B_pyt = torch.rand(size=b_weights_shape).bfloat16()
if op == ttl.tensor.BcastOpMath.ADD:
torch_ref_output = torch.add(input, B_pyt)
elif op == ttl.tensor.BcastOpMath.MUL:
torch_ref_output = torch.mul(input, B_pyt)

if in0_batch_size == 1 and in1_batch_size > 1:
torch_ref_output = torch_ref_output.reshape(1, 1, input_height, input_width)

B_pyt = B_pyt.reshape(b_weights_shape)
tt_weight = ttnn.from_torch(B_pyt, device=device, layout=ttnn.TILE_LAYOUT, dtype=dtype)
tt_weight = ttnn.from_torch(B_pyt, device=device, layout=ttnn.TILE_LAYOUT, dtype=in1_dtype)
tt_output = ttl.tensor.bcast(
tt_input,
tt_weight,
Expand Down
1 change: 1 addition & 0 deletions tt_eager/tt_dnn/op_library/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ set(TT_DNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/bcast/bcast_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_h/bcast_op_multi_core_h.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_h/bcast_op_sharded_h.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_h/bcast_op_sharded_h_optimised.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_w/bcast_op_multi_core_w.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bcast/multi_core_hw/bcast_op_multi_core_hw.cpp
${CMAKE_CURRENT_SOURCE_DIR}/bmm/bmm_op.cpp
Expand Down
9 changes: 8 additions & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ operation::ProgramWithCallbacks EltwiseBinaryBroadcast::create_program(const std
switch (parallelization_strategy){
case BcastOpParallelizationStrategy::MULTI_CORE_H_SHARDED:
return bcast_sharded_h(input_tensor_a, input_tensor_b, output_tensor, this->math_op);
case BcastOpParallelizationStrategy::MULTI_CORE_H_SHARDED_OPTIMISED:
return bcast_sharded_h_optimised(input_tensor_a, input_tensor_b, output_tensor, this->math_op);
case BcastOpParallelizationStrategy::MULTI_CORE_H:
return bcast_multi_core_h(input_tensor_a, input_tensor_b, output_tensor, this->math_op);
case BcastOpParallelizationStrategy::MULTI_CORE_W:
Expand Down Expand Up @@ -183,14 +185,19 @@ const operation::Hash EltwiseBinaryBroadcast::compute_program_hash(

BcastOpParallelizationStrategy EltwiseBinaryBroadcast::get_parallelization_strategy(const std::vector<Tensor> &input_tensors) const {
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);

uint32_t num_tiles = input_tensor_a.volume() / TILE_HW;
uint32_t Ht = input_tensor_a.get_legacy_shape()[-2] / TILE_HEIGHT;
uint32_t Wt = input_tensor_a.get_legacy_shape()[-1] / TILE_WIDTH;

if(this->dim == BcastOpDim::H){
if(input_tensor_a.is_sharded())
return BcastOpParallelizationStrategy::MULTI_CORE_H_SHARDED;
if (input_tensor_a.get_legacy_shape()[0] == input_tensor_b.get_legacy_shape()[0] || input_tensor_a.get_legacy_shape()[0] > 1 and input_tensor_b.get_legacy_shape()[0] == 1){
return BcastOpParallelizationStrategy::MULTI_CORE_H_SHARDED_OPTIMISED;
} else {
return BcastOpParallelizationStrategy::MULTI_CORE_H_SHARDED;
}
else
return BcastOpParallelizationStrategy::MULTI_CORE_H;
}
Expand Down
7 changes: 6 additions & 1 deletion tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum class BcastOpMath { ADD, SUB, MUL };
enum class BcastOpDim { H, W, HW };

// TODO: Accept parallelization
enum class BcastOpParallelizationStrategy { MULTI_CORE_H_SHARDED, MULTI_CORE_H, MULTI_CORE_W, MULTI_CORE_HW, SINGLE_CORE };
enum class BcastOpParallelizationStrategy { MULTI_CORE_H_SHARDED, MULTI_CORE_H_SHARDED_OPTIMISED, MULTI_CORE_H, MULTI_CORE_W, MULTI_CORE_HW, SINGLE_CORE };

operation::ProgramWithCallbacks bcast_multi_core_h(
const Tensor &input_tensor_a,
Expand All @@ -31,6 +31,11 @@ operation::ProgramWithCallbacks bcast_sharded_h(
const Tensor &input_tensor_b,
const Tensor &output_tensor,
BcastOpMath bcast_op);
operation::ProgramWithCallbacks bcast_sharded_h_optimised(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
const Tensor &output_tensor,
BcastOpMath bcast_op);
operation::ProgramWithCallbacks bcast_multi_core_w(
const Tensor &input_tensor_a,
const Tensor &input_tensor_b,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <cstdint>
#include "compute_kernel_api/bcast.h"


namespace NAMESPACE {
void MAIN {
constexpr uint32_t onetile = 1;
uint32_t NC = get_arg_val<uint32_t>(0);
uint32_t Ht = get_arg_val<uint32_t>(1);
uint32_t Wt = get_arg_val<uint32_t>(2);
uint32_t h_blk = get_arg_val<uint32_t>(3);
uint32_t batch_b = get_arg_val<uint32_t>(4);
uint32_t Ht_per_batch_b = get_arg_val<uint32_t>(5);

init_bcast<BCAST_LLKOP, BCAST_DIM>(tt::CB::c_in0, tt::CB::c_in1, tt::CB::c_out0);

cb_wait_front(tt::CB::c_in0, Wt*Ht);
cb_reserve_back(tt::CB::c_out0, Wt*Ht);
uint32_t b_offset = 0;
for (uint32_t bn = 0; bn < batch_b; bn++) {
for (uint32_t wt = 0; wt < Wt; wt++) {
cb_wait_front(tt::CB::c_in1, onetile);
for (uint32_t ht = 0; ht < Ht_per_batch_b; ht+=h_blk) {
acquire_dst(tt::DstMode::Half);
for (uint32_t htr = 0; htr<h_blk; htr++) {
uint32_t current_index = b_offset + (ht + htr) * Wt + wt;
BCAST_OP<BroadcastType::ROW>(tt::CB::c_in0, tt::CB::c_in1, current_index, 0, htr);
pack_tile<true>(htr, tt::CB::c_out0, current_index);
}
release_dst(tt::DstMode::Half);
}
cb_pop_front(tt::CB::c_in1, onetile);
}
b_offset += Ht_per_batch_b * Wt;
}
cb_pop_front(tt::CB::c_in0, Wt*Ht);
cb_push_back(tt::CB::c_out0, Wt*Ht);
}
} // NAMESPACE
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <stdint.h>
#include "dataflow_api.h"

void kernel_main() {
uint32_t src1_addr = get_arg_val<uint32_t>(0);
uint32_t Ht = get_arg_val<uint32_t>(1);
uint32_t Wt = get_arg_val<uint32_t>(2);
uint32_t offset = get_arg_val<uint32_t>(3);
uint32_t batch_offset = get_arg_val<uint32_t>(4); //if weight has multiple batches
uint32_t w_blk = get_arg_val<uint32_t>(5);
uint32_t batch_b = get_arg_val<uint32_t>(6);

//constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1;
constexpr bool src1_is_dram = get_compile_time_arg_val(1) == 1;
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0);

//constexpr uint32_t cb_id_in0 = 0;
constexpr uint32_t cb_id_in1 = 1;
constexpr uint32_t onetile = 1;

// single-tile ublocks
const uint32_t tile_bytes = get_tile_size(cb_id_in1);
const DataFormat data_format = get_dataformat(cb_id_in1);

const InterleavedAddrGenFast<src1_is_dram> s1 = {
.bank_base_address = src1_addr,
.page_size = tile_bytes,
.data_format = data_format
};


uint32_t l1_write_addr_in0;
uint32_t l1_write_addr_in1;

cb_push_back(cb_id_in0, Ht * Wt);
for (uint32_t b = 0; b < batch_b; b ++) {
for (uint32_t wt = 0; wt < Wt; wt += w_blk) {
cb_reserve_back(cb_id_in1, w_blk);
l1_write_addr_in1 = get_write_ptr(cb_id_in1);
for (uint32_t r = 0; r<w_blk; r++) {
noc_async_read_tile(offset + wt + r, s1, l1_write_addr_in1);
l1_write_addr_in1 += tile_bytes;
}
noc_async_read_barrier();
cb_push_back(cb_id_in1, w_blk);
}
offset += batch_offset;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b
TT_FATAL(out_shard_spec.num_cores() == ncores, "Output tensor should have same number of cores {} as input tensor {}", out_shard_spec.num_cores(), ncores);

DataFormat act_df = tt_metal::datatype_to_dataformat_converter(a.get_dtype());
DataFormat b_df = tt_metal::datatype_to_dataformat_converter(b.get_dtype());
DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype());

uint32_t input_tile_size = tt::tt_metal::detail::TileSize(act_df);
uint32_t input1_tile_size = tt::tt_metal::detail::TileSize(b_df);
uint32_t output_tile_size = tt::tt_metal::detail::TileSize(out_df);

TT_FATAL(input_tile_size == output_tile_size, "Input and output tile size should be same");
Expand Down Expand Up @@ -84,8 +86,8 @@ operation::ProgramWithCallbacks bcast_sharded_h(const Tensor &a, const Tensor &b

uint32_t num_input_tiles = (b.get_legacy_shape()[-1] * output.element_size() + TILE_HW - 1)/ TILE_HW;
uint32_t src1_cb_index = CB::c_in1;
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * aligned_input_tile_nbytes, {{src1_cb_index, act_df}})
.set_page_size(src1_cb_index, aligned_input_tile_nbytes);
tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * input1_tile_size, {{src1_cb_index, b_df}})
.set_page_size(src1_cb_index, input1_tile_size);
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, src1_cb_config);

auto src0_buffer = a.buffer();
Expand Down
Loading

0 comments on commit d3bebbb

Please sign in to comment.