diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index 710f805e604..f20a83e0fde 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -400,6 +400,8 @@ constexpr bool implements_get_parallelization_strategy() { return std::experimental::is_detected_v; } +} // namespace detail + template auto default_create_output_tensors( const ConcreteOperation& operation, @@ -427,8 +429,6 @@ auto default_create_output_tensors( return output_tensors; } -} // namespace detail - template struct DeviceOperation final { using storage_t = std::array; @@ -628,7 +628,7 @@ struct DeviceOperation final { "create_output_tensors"); return operation.create_output_tensors(input_tensors); } else if constexpr (detail::implements_compute_output_specs()) { - return detail::default_create_output_tensors(operation, input_tensors, output_tensors); + return default_create_output_tensors(operation, input_tensors, output_tensors); } else { static_assert( tt::stl::concepts::always_false_v, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 079427ad854..e32b1232ae8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -186,27 +186,24 @@ void AllGather::validate(const std::vector& input_tensors) const { } } -std::vector AllGather::compute_output_shapes(const std::vector& input_tensors) const { - auto shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape() - shape[this->dim] *= this->ring_size; - return std::vector(input_tensors.size(), shape); -} +std::vector AllGather::compute_output_specs(const std::vector& input_tensors) const { + auto output_shape = input_tensors[0].get_padded_shape(); // TODO: Replace with get_logical_shape() + output_shape[this->dim] *= this->ring_size; -std::vector AllGather::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors[0]; - auto tile = input_tensor.get_tensor_spec().tile(); + TensorSpec spec( + output_shape, + TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config)); if (this->output_mem_config.is_sharded()) { - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - input_tensor.get_dtype(), - input_tensor.get_layout(), - input_tensor.device(), - this->output_mem_config, - tile)}; - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, input_tensor.get_dtype(), input_tensor.get_layout(), this->output_mem_config, tile); + return {TensorSpec( + output_shape, + TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config))}; } + return std::vector(input_tensors.size(), spec); +} + +std::vector AllGather::create_output_tensors(const std::vector& input_tensors) const { + return operation::default_create_output_tensors(*this, input_tensors, {}); } operation::ProgramWithCallbacks AllGather::create_program( diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index abc697dfab5..5665d0a63fa 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -132,7 +132,7 @@ struct AllGather { const ccl::Topology topology; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_specs(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index e09aa621dd5..4f3d8d5c2ef 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -119,7 +119,7 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const sliding_window_config, parallelization_config.num_cores_nhw, out_block_h_ntiles); - uint32_t out_width_ntiles = this->compute_output_shapes(input_tensors).at(0)[-1] / TILE_WIDTH; + uint32_t out_width_ntiles = this->compute_output_specs(input_tensors).at(0).padded_shape()[-1] / TILE_WIDTH; if(this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { TT_FATAL(per_core_out_matrix_width_ntiles == out_width_ntiles, "Error"); TT_FATAL(this->block_config.out_subblock_w_ntiles == out_width_ntiles || this->block_config.out_subblock_h_ntiles == 1, "Error"); @@ -136,22 +136,13 @@ void OptimizedConvNew::validate(const std::vector& input_tensors, const } } -std::vector OptimizedConvNew::compute_output_shapes(const std::vector& input_tensors) const { +std::vector OptimizedConvNew::compute_output_specs(const std::vector& input_tensors) const { const auto& input_tensor_a_shape = this->input_tensor_shape; uint32_t batch_size = input_tensor_a_shape[0]; - uint32_t conv_activation_h = input_tensor_a_shape[1]; - uint32_t conv_activation_w = input_tensor_a_shape[2]; - // TODO: clean up here - uint32_t filter_h = (uint32_t)sliding_window_config.window_hw.first; // filter_h - uint32_t filter_w = (uint32_t)sliding_window_config.window_hw.second; // filter_W - uint32_t stride_h = (uint32_t)sliding_window_config.stride_hw.first; - uint32_t stride_w = (uint32_t)sliding_window_config.stride_hw.second; - uint32_t pad_h = (uint32_t)sliding_window_config.pad_hw.first; - uint32_t pad_w = (uint32_t)sliding_window_config.pad_hw.second; - auto output_shape = sliding_window_config.get_output_shape(); - uint32_t conv_output_h = output_shape[1]; - uint32_t conv_output_w = output_shape[2]; + auto sliding_window_output_shape = sliding_window_config.get_output_shape(); + uint32_t conv_output_h = sliding_window_output_shape[1]; + uint32_t conv_output_w = sliding_window_output_shape[2]; // Tiled output shape is padded shape. Padded to tile shape. auto shape_w = batch_size * conv_output_h * conv_output_w; @@ -160,16 +151,10 @@ std::vector OptimizedConvNew::compute_output_shapes(c auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); auto output_padding = Padding( {{0, 0}, {0, 0}, {0, (padded_shape_w - shape_w)}, {0, (padded_shape_c - shape_c)}}, Padding::PadValue::Zero); - auto output_tensor_shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, padded_shape_w, padded_shape_c}, output_padding)); - return {output_tensor_shape.value}; -} + auto output_shape = tt::tt_metal::LegacyShape({1, 1, padded_shape_w, padded_shape_c}, output_padding); -std::vector OptimizedConvNew::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - const auto& weight_tensor = input_tensors.at(1); auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; if (this->memory_config.is_sharded()) { - auto output_shape = this->compute_output_shapes(input_tensors).at(0); if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; uint32_t num_cores; @@ -188,7 +173,7 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectormemory_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)}; + return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))}; } else if(this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { uint32_t total_height_tiles = tt::tt_metal::compute_volume(output_shape) / output_shape[-1] / TILE_HEIGHT; std::array shard_shape = {tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH}; @@ -196,15 +181,14 @@ std::vector OptimizedConvNew::create_output_tensors(const std::vectormemory_config.shard_spec.value().orientation}; auto mem_config = this->memory_config; mem_config.shard_spec = shard_spec; - return{create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), mem_config)}; - + return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), mem_config, ttnn::Shape(output_shape)))}; } else if (this->memory_config.memory_layout == TensorMemoryLayout::BLOCK_SHARDED) { - return {create_device_tensor(output_shape, this->dtype, output_layout, input_tensor.device(), this->memory_config)}; + return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))}; } else { TT_THROW("Unsupported shard scheme"); } } - return operation::generic_create_output_tensors(*this, input_tensors, this->dtype, output_layout, this->memory_config); + return {TensorSpec(output_shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(dtype, PageConfig(output_layout), memory_config, ttnn::Shape(output_shape)))}; } operation::ProgramWithCallbacks OptimizedConvNew::create_program(const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index a39e97f4fac..17d22b48a4b 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -108,8 +108,7 @@ struct OptimizedConvNew { use_non_tile_height(use_non_tile_height) {} void validate(const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, const std::vector>& optional_input_tensors, std::vector &output_tensors) const; operation::OpPerformanceModel create_op_performance_model(const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index 44ba8c8c9e1..96b80a12e3b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -66,16 +66,16 @@ void AllGatherMatmul::validate( } } -std::vector AllGatherMatmul::compute_output_shapes(const std::vector& input_tensors) const { +std::vector AllGatherMatmul::compute_output_specs(const std::vector& input_tensors) const { // All Gather shape - ttnn::SimpleShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0]; - ttnn::SimpleShape datacopy_output_shape = all_gather_output_shape; + ttnn::TensorSpec all_gather_output_shape = this->all_gather_struct.compute_output_specs({input_tensors[0]})[0]; + ttnn::TensorSpec datacopy_output_shape = all_gather_output_shape; // Matmul shape - ttnn::SimpleShape matmul_output_shapes = - this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0]; + ttnn::TensorSpec matmul_output_specs = + this->matmul_struct.compute_output_specs({input_tensors[1], input_tensors[2]})[0]; - return {all_gather_output_shape, matmul_output_shapes, datacopy_output_shape}; + return {all_gather_output_shape, matmul_output_specs, datacopy_output_shape}; } std::vector AllGatherMatmul::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp index bb1359e0e61..c8af6cc9dd7 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp @@ -42,7 +42,7 @@ struct AllGatherMatmul { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index c4745856b80..f3d0b732e62 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1562,9 +1562,11 @@ void Matmul::validate( chosen_program_config); } -std::vector Matmul::compute_output_shapes(const std::vector& input_tensors) const { - const ttnn::SimpleShape input_shape_a = input_tensors.at(0).get_logical_shape(); - const ttnn::SimpleShape input_shape_b = input_tensors.at(1).get_logical_shape(); +std::vector Matmul::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + const auto& input_tensor_b = input_tensors.at(1); + const ttnn::SimpleShape input_shape_a = input_tensor_a.get_logical_shape(); + const ttnn::SimpleShape input_shape_b = input_tensor_b.get_logical_shape(); const uint32_t a_rank = input_shape_a.rank(); const uint32_t b_rank = input_shape_b.rank(); const uint32_t out_rank = std::max(a_rank, b_rank); @@ -1579,12 +1581,7 @@ std::vector Matmul::compute_output_shapes(const std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { - const auto& input_tensor_a = input_tensors.at(0); - const auto& input_tensor_b = input_tensors.at(1); auto in0_tile_shape = input_tensor_a.get_tensor_spec().tile().get_tile_shape(); auto in1_tile_shape = input_tensor_b.get_tensor_spec().tile().get_tile_shape(); auto output_tile = this->output_tile.value(); @@ -1594,14 +1591,14 @@ std::vector Matmul::create_output_tensors(const std::vector& inp if (this->output_mem_config.is_sharded()) { MatmulProgramConfig chosen_program_config = get_program_config(input_tensor_a, input_tensor_b, this); return std::visit( - [&](const auto& program_config) -> std::vector { + [&](const auto& program_config) -> std::vector { using ProgramConfigType = std::decay_t; if constexpr (std::is_same_v) { uint32_t M = - (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] - : input_tensor_a.get_legacy_shape()[-2]) / + (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1] + : input_tensor_a.get_padded_shape()[-2]) / in0_tile_shape[0]; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; + uint32_t N = input_tensor_b.get_padded_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; @@ -1622,19 +1619,14 @@ std::vector Matmul::create_output_tensors(const std::vector& inp ShardOrientation::ROW_MAJOR}; mem_config.shard_spec = shard_spec; } - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->output_dtype.value(), - output_layout, - input_tensor_a.device(), - mem_config, - output_tile)}; + return {TensorSpec( + output_shape, + TensorLayout(output_dtype.value(), PageConfig(output_layout, output_tile), mem_config))}; } else if constexpr (std::is_same_v< ProgramConfigType, MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig>) { - uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; - uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; - auto input_tensor_b_shape = input_tensor_b.get_legacy_shape(); + uint32_t M = input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1] / in0_tile_shape[0]; + uint32_t N = input_tensor_b.get_padded_shape()[-1] / in1_tile_shape[1]; uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; @@ -1645,7 +1637,6 @@ std::vector Matmul::create_output_tensors(const std::vector& inp uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; - uint32_t num_blocks_total = num_blocks_y * num_blocks_x; uint32_t num_cores = num_blocks_x * num_blocks_y; auto end_core = input_tensor_a.shard_spec()->grid.bounding_box().end_coord; auto grid_size = CoreCoord{end_core.x + 1, end_core.y + 1}; @@ -1656,13 +1647,9 @@ std::vector Matmul::create_output_tensors(const std::vector& inp ShardOrientation::ROW_MAJOR}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->output_dtype.value(), - output_layout, - input_tensor_a.device(), - mem_config, - output_tile)}; + return {TensorSpec( + output_shape, + TensorLayout(output_dtype.value(), PageConfig(output_layout, output_tile), mem_config))}; } else if constexpr (std::is_same_v) { uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; @@ -1675,8 +1662,6 @@ std::vector Matmul::create_output_tensors(const std::vector& inp uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; - uint32_t num_blocks_total = num_blocks_y * num_blocks_x; - uint32_t num_cores = num_blocks_x * num_blocks_y; CoreRangeSet all_cores; ShardOrientation shard_orientation; if (program_config.transpose_mcast) { @@ -1690,13 +1675,9 @@ std::vector Matmul::create_output_tensors(const std::vector& inp all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, shard_orientation}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->output_dtype.value(), - output_layout, - input_tensor_a.device(), - mem_config, - output_tile)}; + return {TensorSpec( + output_shape, + TensorLayout(output_dtype.value(), PageConfig(output_layout, output_tile), mem_config))}; } else if constexpr (std::is_same_v) { uint32_t M = input_tensor_a.volume() / input_tensor_a.get_legacy_shape()[-1] / in0_tile_shape[0]; uint32_t N = input_tensor_b.get_legacy_shape()[-1] / in1_tile_shape[1]; @@ -1709,7 +1690,6 @@ std::vector Matmul::create_output_tensors(const std::vector& inp uint32_t num_blocks_y = (M - 1) / per_core_M + 1; uint32_t num_blocks_x = (N - 1) / per_core_N + 1; - uint32_t num_blocks_total = num_blocks_y * num_blocks_x; uint32_t num_cores = num_blocks_x * num_blocks_y; ShardOrientation shard_orientation = ShardOrientation::COL_MAJOR; if (input_tensor_a.is_sharded()) { @@ -1726,13 +1706,9 @@ std::vector Matmul::create_output_tensors(const std::vector& inp all_cores, {per_core_M * in0_tile_shape[0], per_core_N * in1_tile_shape[1]}, shard_orientation}; auto mem_config = this->output_mem_config; mem_config.shard_spec = shard_spec; - return {create_device_tensor( - this->compute_output_shapes(input_tensors).at(0), - this->output_dtype.value(), - output_layout, - input_tensor_a.device(), - mem_config, - output_tile)}; + return {TensorSpec( + output_shape, + TensorLayout(output_dtype.value(), PageConfig(output_layout, output_tile), mem_config))}; } else { TT_FATAL( in0_tile_shape[0] == TILE_HEIGHT and in0_tile_shape[1] == TILE_WIDTH, @@ -1753,8 +1729,12 @@ std::vector Matmul::create_output_tensors(const std::vector& inp chosen_program_config); } - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype.value(), Layout::TILE, this->output_mem_config, output_tile); + return {TensorSpec( + output_shape, TensorLayout(output_dtype.value(), PageConfig(Layout::TILE, output_tile), output_mem_config))}; +} + +std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { + return operation::default_create_output_tensors(*this, input_tensors, {}); } operation::ProgramWithCallbacks Matmul::create_program( diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index a4b41cb6519..9ef2eaeafb9 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -178,7 +178,7 @@ struct Matmul { void validate( const std::vector& input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp index e89e4c6fa6f..6a8422d4873 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.cpp @@ -30,8 +30,7 @@ void HaloDeviceOperation::validate(const std::vector& input_tensors) con TT_FATAL(input_tensor.shard_spec().has_value(), "Shard spec should not be empty"); } -std::vector HaloDeviceOperation::compute_output_shapes( - const std::vector& input_tensors) const { +std::vector HaloDeviceOperation::compute_output_specs(const std::vector& input_tensors) const { const auto& input = input_tensors.at(0); const auto& input_shape = input.get_legacy_shape(); tt::tt_metal::LegacyShape output_shape = input_shape; @@ -50,14 +49,9 @@ std::vector HaloDeviceOperation::compute_output_shape log_debug(tt::LogOp, "max_out_nsticks_per_core: {}", max_out_nsticks_per_core_); log_debug(tt::LogOp, "num_cores_nhw: {}", config_.num_cores_nhw); - return {output_shape}; -} - -std::vector HaloDeviceOperation::create_output_tensors(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); DataType output_dtype = input_tensor.get_dtype() == DataType::BFLOAT8_B ? DataType::BFLOAT16 : input_tensor.get_dtype(); - auto output_shape = this->compute_output_shapes(input_tensors).at(0); TT_FATAL( input_tensor.memory_config().memory_layout == output_memory_config_.memory_layout, @@ -77,7 +71,10 @@ std::vector HaloDeviceOperation::create_output_tensors(const std::vector out_mem_config.shard_spec->shape[0] = tt::div_up(output_shape[0] * output_shape[2], config_.num_cores_nhw); out_mem_config.shard_spec->shape[1] = input_tensor.memory_config().shard_spec->shape[1]; out_mem_config.shard_spec->halo = true; - return {create_device_tensor(output_shape, output_dtype, Layout::ROW_MAJOR, input_tensor.device(), out_mem_config)}; + return {TensorSpec( + output_shape.logical_shape(), + TensorLayout::fromLegacyPaddedShape( + output_dtype, PageConfig(Layout::ROW_MAJOR), out_mem_config, ttnn::Shape(output_shape)))}; } operation::ProgramWithCallbacks HaloDeviceOperation::create_program( diff --git a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp index d9801d306b8..88e502181f1 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/halo/device/halo_device_operation.hpp @@ -27,8 +27,7 @@ struct HaloDeviceOperation { bool is_out_tiled_; void validate(const std::vector& input_tensors) const; - std::vector compute_output_shapes(const std::vector& input_tensors) const; - std::vector create_output_tensors(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; // const operation::Hash compute_program_hash(const std::vector &input_tensors) const;