Skip to content

Commit

Permalink
Handle padded shards in ttnn.convert_to_chw
Browse files Browse the repository at this point in the history
  • Loading branch information
esmalTT committed Dec 11, 2024
1 parent 208db3e commit dbba7ca
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 22 deletions.
55 changes: 46 additions & 9 deletions tests/ttnn/unit_tests/operations/test_convert_to_chw.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,61 @@ def test_convert_to_chw(device, C, HW, core_grid):
[1, 1, 32, HW], core_grid, ttnn.ShardStrategy.WIDTH, ttnn.ShardOrientation.ROW_MAJOR
)
actual = ttnn.experimental.convert_to_chw(input_tensor, memory_config=output_memory_config)
actual = ttnn.to_torch(actual)

assert_with_pcc(expected, actual, 0.9999999)
assert_with_pcc(expected, ttnn.to_torch(actual), 1.0)

return actual


@skip_for_grayskull()
@skip_for_blackhole()
@pytest.mark.parametrize("C", [1, 2, 4])
@pytest.mark.parametrize(
"HW, core_grid, padded_sharded_dim",
(
(96, ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))}), 64),
(1056 * 160, ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 7))}), 2656),
),
)
def test_convert_to_chw_padded(device, C, HW, core_grid, padded_sharded_dim):
input_tensor = torch.randn([1, 1, HW, C], dtype=torch.bfloat16)
expected = input_tensor.transpose(2, 3)

input_shard_shape = (padded_sharded_dim, 32)
input_shard_spec = ttnn.ShardSpec(core_grid, input_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False)
input_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.L1, input_shard_spec)

output_shard_shape = (C, padded_sharded_dim)
output_shard_spec = ttnn.ShardSpec(core_grid, output_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, False)
output_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, output_shard_spec)

input_tensor = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, device, memory_config=input_mem_config)

actual = ttnn.experimental.convert_to_chw(input_tensor, memory_config=output_mem_config)

assert_with_pcc(expected, ttnn.to_torch(actual), 1.0)

return actual


@skip_for_grayskull()
@skip_for_blackhole()
def test_convert_to_chw_with_program_cache(device, use_program_cache):
C, HW = 8, 128
core_grid = ttnn.CoreGrid(x=2, y=1)
C, HW, core_grid = 2, 256, ttnn.CoreGrid(x=2, y=1)

C_padded, HW_padded, padded_sharded_dim = 4, 96, 64
core_grid_padded = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 0))})

for _ in range(2):
test_convert_to_chw(device, C, HW, core_grid)
test_convert_to_chw(device, C, HW, core_grid)
dummy_shape = [1, 1, 128, 128]
a, b, c = None, None, None
for _ in range(8):
a = test_convert_to_chw_padded(device, C_padded, HW_padded, core_grid_padded, padded_sharded_dim)
b = test_convert_to_chw(device, C, HW, core_grid)
c = test_convert_to_chw_padded(device, C_padded, HW_padded, core_grid_padded, padded_sharded_dim)
dummy_shape = [1, 1, 256, 128]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = (
ttnn.Tensor(py_dummy_tensor, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)

assert device.num_program_cache_entries() == 1
assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ void ConvertToCHW::validate(const std::vector<Tensor>& input_tensors) const {
TT_FATAL(C <= TILE_HEIGHT, "C must be less than or equal to 32 (was {})", C);
TT_FATAL(HW % TILE_HEIGHT == 0, "HW must be divisible by tile size");

TT_FATAL(input.is_sharded(), "Input tensor must be sharded");

const auto& input_shard_spec = input.memory_config().shard_spec.value();
TT_FATAL(
input_shard_spec.shape[0] % TILE_HEIGHT == 0,
"Shard height must be divisible by tile size"); // input shards can be padded so HW may not match shard height

TT_FATAL(
this->memory_config.is_sharded() && this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED,
"Output tensor must be width sharded");
// TODO: Check that grids match
}

std::vector<tt::tt_metal::LegacyShape> ConvertToCHW::compute_output_shapes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,29 @@ operation::ProgramWithCallbacks multi_core_convert_to_chw(
const auto HW = input_shape[2];
const auto C = input_shape[3];

tt::log_debug(tt::LogType::LogOp, "Running op with HW={}, C={}, shard_shape={}", HW, C, a.shard_spec()->shape);

TT_ASSERT(C < TILE_HEIGHT, "C must be 32 or smaller");

const uint32_t c_tiles = 1; // assume C <= 32
const uint32_t hw_tiles = HW / TILE_HEIGHT;
const uint32_t total_tiles = hw_tiles * c_tiles;
const uint32_t total_tiles_per_core = hw_tiles * c_tiles / input_cores.size();
const uint32_t total_tiles = HW / TILE_HEIGHT; // assume C < 32
const uint32_t total_tiles_per_core = tt::div_up(total_tiles, input_cores.size());

tt::log_debug(
tt::LogType::LogOp, "Processing {} tiles per core ({} total tiles)", total_tiles_per_core, total_tiles);

const auto create_circular_buffer = [&program, &input_core_grid](
uint32_t index,
uint32_t num_tiles,
uint32_t tile_size,
uint32_t total_size,
uint32_t page_size,
const tt::DataFormat& format,
Buffer* buffer = nullptr) -> tt::tt_metal::CBHandle {
auto config = CircularBufferConfig(num_tiles * tile_size, {{index, format}}).set_page_size(index, tile_size);
tt::log_debug(
tt::LogType::LogOp,
"Creating CB at index {} with total size {} B and page size {} B",
index,
total_size,
page_size);
auto config = CircularBufferConfig(total_size, {{index, format}}).set_page_size(index, page_size);
if (buffer != nullptr) {
config = config.set_globally_allocated_address(*buffer);
}
Expand All @@ -47,16 +56,22 @@ operation::ProgramWithCallbacks multi_core_convert_to_chw(
const uint32_t intermediary_tile_size = tt::tt_metal::detail::TileSize(intermediary_format);

const uint32_t cb_in_id = tt::CB::c_in0;
const auto cb_in =
create_circular_buffer(cb_in_id, total_tiles_per_core, input_tile_size, input_format, a.buffer());
const uint32_t cb_in_total_size = total_tiles_per_core * input_tile_size;
const uint32_t cb_in_page_size = input_tile_size;
const auto cb_in = create_circular_buffer(cb_in_id, cb_in_total_size, cb_in_page_size, input_format, a.buffer());

const uint32_t cb_out_id = tt::CB::c_out0;
const uint32_t element_size = tt::datum_size(input_format);
const uint32_t cb_out_total_size = C * HW * element_size / input_cores.size();
const uint32_t cb_out_page_size = HW / input_cores.size();
const auto cb_out =
create_circular_buffer(cb_out_id, total_tiles_per_core, input_tile_size, input_format, output.buffer());
create_circular_buffer(cb_out_id, cb_out_total_size, cb_out_page_size, input_format, output.buffer());

const uint32_t cb_in_transpose_id = tt::CB::c_intermed0;
const auto cb_in_transpose =
create_circular_buffer(cb_in_transpose_id, 1, intermediary_tile_size, intermediary_format);
const uint32_t cb_in_transpose_total_size = intermediary_tile_size;
const uint32_t cb_in_transpose_page_size = intermediary_tile_size;
const auto cb_in_transpose = create_circular_buffer(
cb_in_transpose_id, cb_in_transpose_total_size, cb_in_transpose_page_size, intermediary_format);

std::vector<uint32_t> reader_compile_time_args = {cb_in_id};
std::vector<uint32_t> writer_compile_time_args = {cb_in_transpose_id, cb_out_id, C};
Expand Down

0 comments on commit dbba7ca

Please sign in to comment.