diff --git a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp index 98a972b0fc7..61e2e5b28a6 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/device/pool_op.cpp @@ -33,7 +33,7 @@ void validate_pool2d( TensorMemoryLayout in_memory_layout = input.memory_config().memory_layout; if (in_memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { uint32_t num_shards_c = sliding_window_config.num_cores_c; - const tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape(); + const tt::tt_metal::SimpleShape input_shape = input.get_logical_shape(); TT_FATAL( input_shape[3] % num_shards_c == 0, "For width and block sharding, input channels ({}) should be divisible by num_shards ({})", @@ -60,8 +60,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs( // NOTE: Only for RM // NOTE2: Assuming { N, 1, H * W, C } // NOTE3: Assuming output data type is same as input - const auto input_shape = input.get_legacy_shape(); - + const auto input_shape = input.get_padded_shape(); // confirm that the output size supplied to the function matches uint32_t out_h = sliding_window_config.get_output_shape()[1]; uint32_t out_w = sliding_window_config.get_output_shape()[2]; @@ -73,15 +72,7 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs( uint32_t out_c_padded = tt::round_up(out_c, (out_c <= 16) ? 16 : tt::constants::TILE_WIDTH); uint32_t out_nhw = sliding_window_config.batch_size * out_h * out_w; - uint32_t out_nhw_padded = - tt::round_up(out_nhw, (is_out_tiled ? tt::constants::TILE_HEIGHT : 1) * sliding_window_config.num_cores_nhw); - - // {1, 1, N * H * W, C} - const ttnn::SmallVector out_dims({1, 1, out_nhw_padded, out_c_padded}); - const auto padding = Padding( - {{0, 0}, {0, 0}, {0, out_nhw_padded - out_nhw}, {0, out_c_padded - out_c}}, - Padding::PadValue::NegativeInfinity); - auto output_shape = Shape(tt::tt_metal::LegacyShape(out_dims, padding)); + ttnn::SimpleShape output_shape({sliding_window_config.batch_size, out_h, out_w, out_c_padded}); auto mem_config = out_mem_config; if (mem_config.shard_spec.has_value()) { @@ -91,13 +82,11 @@ Pool2D::spec_return_value_t Pool2D::compute_output_specs( TT_FATAL(ncores == sliding_window_config.num_cores_nhw, "Number of cores should match"); uint32_t out_nhw_per_core = output_shape[0] * output_shape[1] * output_shape[2] / ncores; CoreRangeSet shard_grid = sliding_window_config.core_range_set; - std::array shard_shape = {out_nhw_per_core, input.get_legacy_shape()[-1]}; + std::array shard_shape = {out_nhw_per_core, input.get_logical_shape()[-1]}; mem_config.shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR, false}; } - return TensorSpec( - output_shape.logical_shape(), - TensorLayout::fromLegacyPaddedShape(output_dtype, PageConfig(input.get_layout()), mem_config, output_shape)); + return TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input.get_layout()), mem_config)); } Pool2D::tensor_return_value_t Pool2D::create_output_tensors(