Skip to content

Commit

Permalink
#15621: CNN Op support for new Tensor infra.
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Dec 10, 2024
1 parent c2d7b09 commit aebfe75
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void validate_pool2d(
TensorMemoryLayout in_memory_layout = input.memory_config().memory_layout;
if (in_memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) {
uint32_t num_shards_c = sliding_window_config.num_cores_c;
const tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape();
const tt::tt_metal::SimpleShape input_shape = input.get_logical_shape();
TT_FATAL(
input_shape[3] % num_shards_c == 0,
"For width and block sharding, input channels ({}) should be divisible by num_shards ({})",
Expand All @@ -60,8 +60,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs(
// NOTE: Only for RM
// NOTE2: Assuming { N, 1, H * W, C }
// NOTE3: Assuming output data type is same as input
const auto input_shape = input.get_legacy_shape();

const auto input_shape = input.get_padded_shape();
// confirm that the output size supplied to the function matches
uint32_t out_h = sliding_window_config.get_output_shape()[1];
uint32_t out_w = sliding_window_config.get_output_shape()[2];
Expand All @@ -73,15 +72,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs(
uint32_t out_c_padded = tt::round_up(out_c, (out_c <= 16) ? 16 : tt::constants::TILE_WIDTH);
uint32_t out_nhw = sliding_window_config.batch_size * out_h * out_w;

uint32_t out_nhw_padded =
tt::round_up(out_nhw, (is_out_tiled ? tt::constants::TILE_HEIGHT : 1) * sliding_window_config.num_cores_nhw);

// {1, 1, N * H * W, C}
const ttnn::SmallVector<uint32_t> out_dims({1, 1, out_nhw_padded, out_c_padded});
const auto padding = Padding(
{{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}},
Padding::PadValue::NegativeInfinity);
auto output_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding));
ttnn::SimpleShape output_shape({sliding_window_config.batch_size, out_h, out_w, out_c_padded});

auto mem_config = out_mem_config;
if (mem_config.shard_spec.has_value()) {
Expand All @@ -91,13 +82,11 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs(
TT_FATAL(ncores == sliding_window_config.num_cores_nhw, "Number of cores should match");
uint32_t out_nhw_per_core = output_shape[0] * output_shape[1] * output_shape[2] / ncores;
CoreRangeSet shard_grid = sliding_window_config.core_range_set;
std::array<uint32_t, 2> shard_shape = {out_nhw_per_core, input.get_legacy_shape()[-1]};
std::array<uint32_t, 2> shard_shape = {out_nhw_per_core, input.get_logical_shape()[-1]};
mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false};
}

return TensorSpec(
output_shape.logical_shape(),
TensorLayout::fromLegacyPaddedShape(output_dtype, PageConfig(input.get_layout()), mem_config, output_shape));
return TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input.get_layout()), mem_config));
}

Pool2D::tensor_return_value_t Pool2D::create_output_tensors(
Expand Down

0 comments on commit aebfe75

Please sign in to comment.