Skip to content

Commit

Permalink
#10548: Support tile layout for width/height-sharded concat (#13744)
Browse files Browse the repository at this point in the history
Co-authored-by: Artem Yerofieiev <[email protected]>
  • Loading branch information
jerrysky3 and ayerofieiev-tt authored Oct 22, 2024
1 parent 343a607 commit 0326577
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 29 deletions.
40 changes: 37 additions & 3 deletions tests/ttnn/unit_tests/operations/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ def test_concat(device, height, width, dim, async_mode):


@pytest.mark.parametrize(
"inputs, output_shard_shape, shard_grid, strategy, cache_mode",
"inputs, output_shard_shape, shard_grid, strategy, layout, cache_mode",
(
(
[((1, 1, 160, 32), (80, 32)), ((1, 1, 160, 32), (80, 32))],
(80, 64),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}),
ttnn.ShardStrategy.HEIGHT,
ttnn.ROW_MAJOR_LAYOUT,
False,
),
(
[((1, 1, 160, 32), (80, 32)), ((1, 1, 160, 16), (80, 16))],
(80, 48),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}),
ttnn.ShardStrategy.HEIGHT,
ttnn.ROW_MAJOR_LAYOUT,
False,
),
(
Expand All @@ -60,6 +62,7 @@ def test_concat(device, height, width, dim, async_mode):
}
),
ttnn.ShardStrategy.HEIGHT,
ttnn.ROW_MAJOR_LAYOUT,
False,
),
pytest.param(
Expand All @@ -72,6 +75,7 @@ def test_concat(device, height, width, dim, async_mode):
}
),
ttnn.ShardStrategy.HEIGHT,
ttnn.ROW_MAJOR_LAYOUT,
True,
marks=pytest.mark.xfail(reason="two tensors concat kernel doesn't work with program cache (#13466)"),
),
Expand All @@ -80,33 +84,63 @@ def test_concat(device, height, width, dim, async_mode):
(8, 48),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}),
ttnn.ShardStrategy.HEIGHT,
ttnn.ROW_MAJOR_LAYOUT,
False,
),
(
[((1, 1, 16, 16), (8, 16)), ((1, 1, 16, 16), (8, 16)), ((1, 1, 16, 16), (8, 16))],
(8, 48),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}),
ttnn.ShardStrategy.HEIGHT,
ttnn.ROW_MAJOR_LAYOUT,
True,
),
(
[((1, 1, 8, 64), (8, 16)), ((1, 1, 7, 64), (7, 16)), ((1, 1, 23, 64), (23, 16))],
(38, 16),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}),
ttnn.ShardStrategy.WIDTH,
ttnn.ROW_MAJOR_LAYOUT,
False,
),
(
[((1, 1, 8, 64), (8, 16)), ((1, 1, 7, 64), (7, 16)), ((1, 1, 23, 64), (23, 16))],
(38, 16),
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}),
ttnn.ShardStrategy.WIDTH,
ttnn.ROW_MAJOR_LAYOUT,
True,
),
(
[((1, 1, 256, 96), (64, 96)), ((1, 1, 256, 64), (64, 64)), ((1, 1, 256, 32), (64, 32))],
(64, 192),
ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1)),
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(2, 0)),
}
),
ttnn.ShardStrategy.HEIGHT,
ttnn.TILE_LAYOUT,
False,
),
(
[((1, 1, 32, 512), (32, 64)), ((1, 1, 64, 512), (64, 64)), ((1, 1, 96, 512), (96, 64))],
(192, 64),
ttnn.CoreRangeSet(
{
ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3)),
ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(2, 1)),
}
),
ttnn.ShardStrategy.WIDTH,
ttnn.TILE_LAYOUT,
False,
),
),
)
@pytest.mark.parametrize("async_mode", [True, False], ids=["async_on", "async_off"])
def test_sharded_concat(device, inputs, output_shard_shape, shard_grid, strategy, cache_mode, async_mode):
def test_sharded_concat(device, inputs, output_shard_shape, shard_grid, strategy, layout, cache_mode, async_mode):
device.enable_async(async_mode)
if cache_mode:
device.enable_program_cache()
Expand All @@ -124,7 +158,7 @@ def _gen_inputs(input_specs):
use_height_and_width_as_shard_shape=True,
)
torch_input_tensor = torch.rand(shape, dtype=torch.bfloat16)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device)
input_tensor = ttnn.to_memory_config(input_tensor, input_sharded_memory_config)
input_tensors.append((torch_input_tensor, input_tensor))
return input_tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) c
}
TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");
if (shard_first) {
TT_FATAL(in_ref.get_layout() == Layout::ROW_MAJOR, "Only row major supported for sharded concat.");
// TODO(jerrysky3): Remove this when we replace the two tensors concat kernel with the general one.
TT_FATAL(
input_tensors.size() > 2 || in_ref.get_layout() == Layout::ROW_MAJOR,
"Only row major supported for sharded two tensors concat.");
TT_FATAL(in_ref.shard_spec().has_value(), "Sharded tensors must have a shard spec.");
TT_FATAL(
in_ref.shard_spec().value().grid == first_input.shard_spec().value().grid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ operation::ProgramWithCallbacks s2s_rm_concat_two_tensors_multi_core(
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback};
}

// Concat sharded tensors into sharded output in row-major layout. Currently it only supports height-sharded width
// Concat sharded tensors into sharded output in row-major/tile layout. Currently it only supports height-sharded width
// concat or width-sharded height concat.
//
// It is done by copying each row of each input sharded tensor to the right offset in the sharded output tensor based on
Expand All @@ -164,9 +164,9 @@ operation::ProgramWithCallbacks s2s_rm_concat_two_tensors_multi_core(
// output. The memory address gap between neighbor input rows is exactly the output width. In height concat, all input
// rows are placed at column 0 but sequential rows in the output. The address gap between neighbor input rows is still
// the output width (which is equal to the input width).
operation::ProgramWithCallbacks s2s_rm_concat_multi_core(
operation::ProgramWithCallbacks s2s_concat_multi_core(
const std::vector<Tensor> &input_tensors, uint32_t dim, Tensor &output) {
TT_FATAL(dim == 2 || dim == 3, "Sharded concat RM only supports dim=2 or 3");
TT_FATAL(dim == 2 || dim == 3, "Sharded concat only supports dim=2 or 3");
const bool is_height_concat = dim == 2;

tt_metal::Program program = tt_metal::CreateProgram();
Expand All @@ -175,37 +175,46 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core(
const uint32_t num_input_tensors = input_tensors.size();
const uint32_t cb_dst_id = 16;
TT_FATAL(num_input_tensors <= cb_dst_id, "Not enough circular buffer for {} inputs.", num_input_tensors);
const tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());
const bool rm_layout = output.get_layout() == Layout::ROW_MAJOR;

// Assume inputs and output have the same element size and alignment.
const uint32_t element_size = input_tensors[0].element_size();
const uint32_t alignment = input_tensors[0].buffer()->alignment();

std::vector<uint32_t> all_stick_sizes;
all_stick_sizes.push_back(output.shard_spec().value().shape[1]);
std::transform(
input_tensors.begin(), input_tensors.end(), std::back_inserter(all_stick_sizes), [](const Tensor &tensor) {
return tensor.element_size() * tensor.shard_spec().value().shape[1];
});
const uint32_t page_size = find_greatest_common_page_size(all_stick_sizes, alignment);
const uint32_t elements_per_page = page_size / element_size;
uint32_t page_size;
uint32_t elements_per_page_width;
uint32_t elements_per_page_height;
if (rm_layout) {
std::vector<uint32_t> all_stick_sizes;
all_stick_sizes.push_back(output.shard_spec().value().shape[1]);
std::transform(
input_tensors.begin(), input_tensors.end(), std::back_inserter(all_stick_sizes), [](const Tensor &tensor) {
return tensor.element_size() * tensor.shard_spec().value().shape[1];
});
page_size = find_greatest_common_page_size(all_stick_sizes, alignment);
elements_per_page_width = page_size / element_size;
elements_per_page_height = 1;
} else {
page_size = tt_metal::detail::TileSize(cb_data_format);
elements_per_page_width = TILE_WIDTH;
elements_per_page_height = TILE_HEIGHT;
}

vector<CBHandle> cb_inputs(num_input_tensors);
vector<uint32_t> input_num_pages_per_stick(num_input_tensors);
vector<uint32_t> input_num_sticks(num_input_tensors);
vector<uint32_t> input_write_offsets(num_input_tensors);

const tt::DataFormat cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype());
// Assume inputs and output have the same sharding grid.
const auto all_cores = input_tensors[0].shard_spec().value().grid;

// Input CBs
uint32_t curr_input_write_offset = 0;
for (uint32_t input_id = 0; input_id < num_input_tensors; input_id++) {
const auto shard_spec = input_tensors[input_id].shard_spec().value();
const uint32_t shard_height = shard_spec.shape[0];
const uint32_t shard_width = shard_spec.shape[1];
input_num_pages_per_stick[input_id] = div_up(shard_width, elements_per_page);
input_num_sticks[input_id] = shard_height;
input_num_pages_per_stick[input_id] = div_up(shard_spec.shape[1], elements_per_page_width);
input_num_sticks[input_id] = div_up(shard_spec.shape[0], elements_per_page_height);
input_write_offsets[input_id] = curr_input_write_offset;

const uint32_t input_num_pages = input_num_pages_per_stick[input_id] * input_num_sticks[input_id];
Expand All @@ -221,12 +230,11 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core(

// Output CB
const auto output_shard_spec = output.shard_spec().value();
const uint32_t output_shard_height = output_shard_spec.shape[0];
const uint32_t output_shard_width = output_shard_spec.shape[1];
const uint32_t output_num_pages_per_stick = div_up(output_shard_width, elements_per_page);
const uint32_t output_num_pages_per_stick = div_up(output_shard_spec.shape[1], elements_per_page_width);
const uint32_t output_num_sticks = div_up(output_shard_spec.shape[0], elements_per_page_height);
const tt_metal::CircularBufferConfig output_cb_config =
tt_metal::CircularBufferConfig(
page_size * output_shard_height * output_num_pages_per_stick, {{cb_dst_id, cb_data_format}})
page_size * output_num_sticks * output_num_pages_per_stick, {{cb_dst_id, cb_data_format}})
.set_page_size(cb_dst_id, page_size)
.set_globally_allocated_address(*output.buffer());
auto cb_output = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config);
Expand All @@ -250,15 +258,13 @@ operation::ProgramWithCallbacks s2s_rm_concat_multi_core(

tt_metal::KernelHandle unary_reader_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/"
"reader_s2s_rm_tensor_concat.cpp",
"ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_tensor_concat.cpp",
all_cores,
tt_metal::ReaderDataMovementConfig(compile_time_args));

tt_metal::KernelHandle unary_writer_kernel_id = tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/"
"reader_s2s_rm_tensor_concat.cpp",
"ttnn/cpp/ttnn/operations/data_movement/concat/device/kernels/dataflow/reader_s2s_tensor_concat.cpp",
all_cores,
tt_metal::WriterDataMovementConfig(compile_time_args));

Expand Down Expand Up @@ -422,7 +428,7 @@ operation::ProgramWithCallbacks sharded_concat_multi_core(
// case.
return s2s_rm_concat_two_tensors_multi_core(input_tensors, dim, output);
} else {
return s2s_rm_concat_multi_core(input_tensors, dim, output);
return s2s_concat_multi_core(input_tensors, dim, output);
}
} else {
return s2i_rm_concat_multi_core(input_tensors, dim, output);
Expand Down

0 comments on commit 0326577

Please sign in to comment.