Skip to content

Commit

Permalink
Add subdevice support to multicore untilize (#16193)
Browse files Browse the repository at this point in the history
### What's changed
Added sub_core_grids as an arg to use specific cores in the multicore
version of the op

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12433212242
- [x] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
https://github.com/tenstorrent/tt-metal/actions/runs/12430594736
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sraizada-tt authored Dec 31, 2024
1 parent f2605cc commit 3949130
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,64 @@
from models.utility_functions import is_grayskull, skip_for_blackhole


@pytest.mark.parametrize(
"dtype",
(ttnn.bfloat16,),
ids=[
"bfloat16",
],
)
@pytest.mark.parametrize(
"nb, nc, nh, nw",
(
# llama shapes
(1, 1, 32, 128 * 1024),
),
)
def test_run_untilize_subcoregrid_test(dtype, nb, nc, nh, nw, device):
if is_grayskull():
pytest.skip("Skipping tests on Grayskull")
device.enable_async(True)
shape = [nb, nc, nh, nw]

torch.set_printoptions(precision=3, sci_mode=False, linewidth=3000, threshold=10000, edgeitems=128)

torch.manual_seed(10)

inp = torch.rand(*shape).bfloat16()

a = ttnn.Tensor(
inp.flatten().tolist(),
shape,
dtype,
ttnn.TILE_LAYOUT,
device,
)

out_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1)

b1 = ttnn.untilize(
a,
memory_config=out_mem_config,
use_multicore=True,
sub_core_grids=ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(3, 6)),
}
),
)
c1 = b1.cpu().to_torch()

untilized_inp = untilize(inp)

if dtype == ttnn.float32:
passing1, output = comp_pcc(untilized_inp, c1, 0.999999)
logger.info(output)
else:
passing1 = torch.equal(untilized_inp, c1)
assert passing1


@pytest.mark.parametrize(
"dtype",
(ttnn.bfloat16, ttnn.float32),
Expand Down Expand Up @@ -40,7 +98,6 @@
def test_run_untilize_test(dtype, nb, nc, nh, nw, device):
if is_grayskull() and dtype == ttnn.float32:
pytest.skip("Skipping float32 tests on Grayskull")

shape = [nb, nc, 32 * nh, 32 * nw]

torch.set_printoptions(precision=3, sci_mode=False, linewidth=3000, threshold=10000, edgeitems=128)
Expand Down Expand Up @@ -72,7 +129,6 @@ def test_run_untilize_test(dtype, nb, nc, nh, nw, device):
logger.info(output)
else:
passing1 = torch.equal(untilized_inp, c1)

assert passing1


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ operation::ProgramWithCallbacks Untilize::create_program(
auto& output_tensor = output_tensors.at(0);
if (this->use_multicore) {
return detail::untilize_multi_core(
input_tensor_a, output_tensor, this->use_pack_untilize, this->fp32_dest_acc_en);
input_tensor_a, output_tensor, this->use_pack_untilize, this->fp32_dest_acc_en, this->sub_core_grids);
} else {
return detail::untilize_single_core(
input_tensor_a, output_tensor, this->use_pack_untilize, this->fp32_dest_acc_en);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct Untilize {
const bool use_multicore;
const bool use_pack_untilize;
const bool fp32_dest_acc_en;
const std::optional<CoreRangeSet> sub_core_grids;

void validate(const std::vector<Tensor>& input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor>& input_tensors) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ttnn/operation.hpp"
#include "ttnn/operations/core/work_split/work_split_tilize.hpp"
#include "tt_metal/common/constants.hpp"
#include "tt_metal/common/work_split.hpp"
#include "tt_metal/detail/util.hpp"
#include "tt_metal/host_api.hpp"

Expand All @@ -29,6 +30,190 @@ uint32_t get_largest_divisor(uint32_t dividend, uint32_t starting_divisor, uint3
return 1;
}

operation::ProgramWithCallbacks untilize_multi_core_parallelize_column_subgrid(
const Tensor& a,
Tensor& output,
bool use_pack_untilize,
bool fp32_dest_acc_en,
const CoreRangeSet& sub_core_grids) {
tt::tt_metal::Program program{};

tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(a.get_dtype());
uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format);
tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype());
uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format);

Device* device = a.device();

uint32_t ntiles = a.volume() / TILE_HW;
uint32_t ncores = sub_core_grids.num_cores();
for (uint32_t core_id = ncores; core_id >= 1; core_id--) {
if (ntiles % ncores == 0) {
break;
} else {
ncores--;
}
}

TT_ASSERT(ntiles % (ncores) == 0);

uint32_t max_tiles = 1;
uint32_t ntiles_per_block = ntiles / ncores;
uint32_t stick_s = a.get_legacy_shape()[-1];
uint32_t ntiles_per_row = stick_s / TILE_WIDTH;
uint32_t stick_size = stick_s * output.element_size();
uint32_t ntiles_per_column = ntiles / ntiles_per_row;
uint32_t starting_tile = ntiles_per_block;
if (ntiles_per_row > max_tiles) {
starting_tile = max_tiles;
}
ntiles_per_block = get_largest_divisor(ntiles_per_row, starting_tile);
TT_ASSERT(
ntiles_per_row % ntiles_per_block == 0 and ntiles_per_block >= 1 and ntiles_per_block <= ntiles_per_row and
ntiles % ntiles_per_block == 0);

uint32_t nblocks = (ntiles / ntiles_per_block);
uint32_t block_size_nbytes = input_single_tile_size;

auto cores = corerange_to_cores(sub_core_grids, ncores, true);
auto all_cores = num_cores_to_corerangeset_in_subcoregrids(cores[0], ncores, sub_core_grids, true);
uint32_t nblocks_per_core = nblocks / ncores;

bool row_major = true;
bool src_block_sharded = false;
uint32_t num_rows_block = 0, block_row_size = 0, output_row_size = 0, last_block_row_size_unpadded = 0,
num_output_rows_unpadded = 0;
CoreCoord end_core;
std::vector<CoreCoord> cores_with_rtargs;

uint32_t num_input_tiles = ntiles_per_block * 2;
auto [src0_cb_index, cb_src0] = create_cb(
tt::CBIndex::c_0, program, all_cores, input_single_tile_size, num_input_tiles, input_cb_data_format, nullptr);

uint32_t num_output_tiles = ntiles_per_block * 2;
auto [output_cb_index, cb_output] = create_cb(
tt::CBIndex::c_16,
program,
all_cores,
output_single_tile_size,
num_output_tiles,
output_cb_data_format,
nullptr);

Buffer* src0_buffer = a.buffer();
Buffer* dst_buffer = output.buffer();
bool src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0;
std::vector<uint32_t> reader_ct_args = {(uint32_t)src0_is_dram};

auto reader_kernel_id = CreateKernel(
program,
"ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_start_id.cpp",
all_cores,
ReaderDataMovementConfig(reader_ct_args));

bool out_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0;
bool stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size);
uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(stick_size) : 0;
std::vector<uint32_t> writer_ct_args = {
(uint32_t)out_is_dram,
(uint32_t)stick_size_is_power_of_two,
(uint32_t)log2_stick_size,
};

auto writer_kernel_id = CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/dataflow/"
"writer_unary_stick_layout_split_rows_interleaved_parallel_columns.cpp",
all_cores,
WriterDataMovementConfig(writer_ct_args));

/** compute
*/
std::vector<uint32_t> compute_args = {
(uint32_t)nblocks_per_core, // per_core_block_cnt
(uint32_t)ntiles_per_block, // per_block_ntiles
};

std::string compute_kernel(
"ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/pack_untilize.cpp");
if (ntiles_per_block > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize) {
log_debug(tt::LogOp, "Using slow untilize.");
compute_kernel =
std::string("ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize.cpp");
} else {
log_debug(tt::LogOp, "Using fast pack untilize.");
}

auto untilize_kernel_id = CreateKernel(
program,
compute_kernel,
all_cores,
ComputeConfig{.fp32_dest_acc_en = fp32_dest_acc_en, .compile_args = compute_args});

uint32_t tile_start_id = 0;
uint32_t offset_within_stick = 0;

auto nsticks_per_core = ntiles_per_column * TILE_HEIGHT;

for (uint32_t i = 0; i < cores.size(); i++) {
CoreCoord core = cores[i];

// reader runtime args
auto ntiles_per_core = ntiles_per_block * nblocks_per_core;
const std::array reader_rt_args = {
src0_buffer->address(), // src_addr
ntiles_per_core, // ntiles
tile_start_id // start_id
};

const std::array writer_rt_args = {
dst_buffer->address(), // dst_addr
nsticks_per_core, // nsticks
stick_size, // block_size_nbytes
ntiles_per_core, // ntiles_per_core
TILE_WIDTH * output.element_size(), // tile_width_size
std::uint32_t{0}, // start stick id = 0, since parallelizing on height
offset_within_stick};

tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_rt_args);
tt::tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_rt_args);
cores_with_rtargs.push_back(core);
tile_start_id += ntiles_per_core;
offset_within_stick += ntiles_per_core * TILE_WIDTH * output.element_size();
}

auto override_runtime_arguments_callback = [reader_kernel_id = reader_kernel_id,
writer_kernel_id = writer_kernel_id,
cb_src0 = cb_src0,
cb_output = cb_output,
cores_with_rtargs](
const void* operation,
Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto src_buffer = input_tensors.at(0).buffer();
auto dst_buffer = output_tensors.at(0).buffer();
{
auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id);
for (const CoreCoord& core : cores_with_rtargs) {
auto& runtime_args = runtime_args_by_core[core.x][core.y];
runtime_args[0] = src_buffer->address();
}
}

{
auto& runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id);
for (const CoreCoord& core : cores_with_rtargs) {
auto& runtime_args = runtime_args_by_core[core.x][core.y];
runtime_args[0] = dst_buffer->address();
}
}
};

return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
}

operation::ProgramWithCallbacks untilize_multi_core_parallelize_column(
const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) {
tt::tt_metal::Program program{};
Expand All @@ -41,12 +226,13 @@ operation::ProgramWithCallbacks untilize_multi_core_parallelize_column(
Device* device = a.device();

auto grid_size = device->compute_with_storage_grid_size();

uint32_t ntiles = a.volume() / TILE_HW;
uint32_t ncores_x = grid_size.x;
uint32_t ncores_y = grid_size.y;
// uint32_t ncores_x = 2;

ncores_x = get_largest_divisor(ntiles, ncores_x);
uint32_t ncores_y = grid_size.y;
// uint32_t ncores_y = 1;
ncores_y = get_largest_divisor(ntiles, ncores_y, ncores_x);

Expand Down Expand Up @@ -260,7 +446,11 @@ operation::ProgramWithCallbacks untilize_multi_core_parallelize_column(
}

operation::ProgramWithCallbacks untilize_multi_core(
const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) {
const Tensor& a,
Tensor& output,
bool use_pack_untilize,
bool fp32_dest_acc_en,
const std::optional<CoreRangeSet>& sub_core_grids) {
tt::tt_metal::Program program{};

bool src_sharded = a.memory_config().is_sharded();
Expand Down Expand Up @@ -289,7 +479,12 @@ operation::ProgramWithCallbacks untilize_multi_core(
if (!src_sharded and !out_sharded) {
uint32_t ntiles_height = ntiles / ntiles_per_block;
if (ntiles_height == 1) {
return untilize_multi_core_parallelize_column(a, output, use_pack_untilize, fp32_dest_acc_en);
if (sub_core_grids.has_value()) {
return untilize_multi_core_parallelize_column_subgrid(
a, output, use_pack_untilize, fp32_dest_acc_en, sub_core_grids.value());
} else {
return untilize_multi_core_parallelize_column(a, output, use_pack_untilize, fp32_dest_acc_en);
}
} else {
return untilize_single_core(a, output, use_pack_untilize, fp32_dest_acc_en);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@

namespace ttnn::operations::data_movement::detail {


tt::tt_metal::operation::ProgramWithCallbacks untilize_multi_core(
const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en);
const Tensor& a,
Tensor& output,
bool use_pack_untilize,
bool fp32_dest_acc_en,
const std::optional<CoreRangeSet>& sub_core_grids);

tt::tt_metal::operation::ProgramWithCallbacks untilize_single_core(
const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en);
Expand Down
11 changes: 7 additions & 4 deletions ttnn/cpp/ttnn/operations/data_movement/untilize/untilize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ ttnn::Tensor ExecuteUntilize::invoke(
const ttnn::Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config,
bool use_multicore,
bool use_pack_untilize) {
bool use_pack_untilize,
const std::optional<CoreRangeSet>& sub_core_grids) {
bool fp32_dest_acc_en =
input_tensor.get_dtype() ==
DataType::UINT32; // MT: Currently only uint32 is moved to DST directly, fp32 is converted to fp16b
Expand All @@ -51,7 +52,8 @@ ttnn::Tensor ExecuteUntilize::invoke(
memory_config.value_or(input_tensor.memory_config()),
use_multicore,
use_pack_untilize,
fp32_dest_acc_en},
fp32_dest_acc_en,
sub_core_grids},
{input_tensor},
{},
{},
Expand All @@ -65,8 +67,9 @@ ttnn::Tensor ExecuteUntilize::invoke(
const ttnn::Tensor& input_tensor,
const std::optional<MemoryConfig>& memory_config,
bool use_multicore,
bool use_pack_untilize) {
return invoke(DefaultQueueId, input_tensor, memory_config, use_multicore, use_pack_untilize);
bool use_pack_untilize,
const std::optional<CoreRangeSet>& sub_core_grids) {
return invoke(DefaultQueueId, input_tensor, memory_config, use_multicore, use_pack_untilize, sub_core_grids);
}

} // namespace ttnn::operations::data_movement
Loading

0 comments on commit 3949130

Please sign in to comment.