Skip to content

Commit

Permalink
#9589: use num_cores_to_core_range_set api
Browse files Browse the repository at this point in the history
  • Loading branch information
kpaigwar committed Jun 22, 2024
1 parent 844a0cb commit f546d4e
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions ttnn/cpp/ttnn/operations/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "tt_eager/tt_dnn/op_library/downsample/downsample_op.hpp"
#include "tt_metal/detail/reports/memory_reporter.hpp"
#include "ttnn/cpp/ttnn/op_library/to_dtype/to_dtype_op.hpp"
#include "tt_dnn/op_library/work_split.hpp"

using namespace tt;
namespace ttnn {
Expand Down Expand Up @@ -87,6 +88,8 @@ ParallelConfig determine_parallel_config(
auto compute_with_storage_grid_size = device.compute_with_storage_grid_size();
std::vector<uint32_t> device_grid_size = {
(uint32_t)compute_with_storage_grid_size.x, (uint32_t)compute_with_storage_grid_size.y};
CoreCoord device_grid_size_coord = {
(std::size_t)compute_with_storage_grid_size.x, (std::size_t)compute_with_storage_grid_size.y};
uint32_t max_num_cores = device_grid_size[0] * device_grid_size[1];

auto calculate_num_cores_nhw = [&]() {
Expand All @@ -99,17 +102,7 @@ ParallelConfig determine_parallel_config(

auto calculate_grid = [&](uint32_t num_cores_nhw) {
if (height_sharding) {
uint32_t cores_x_1 = num_cores_nhw >= device_grid_size[0] ? device_grid_size[0] : num_cores_nhw;
uint32_t cores_y_1 = num_cores_nhw > device_grid_size[0] ? (uint32_t)(num_cores_nhw / device_grid_size[0]) : 1;
TT_ASSERT(cores_y_1 <= device_grid_size[1], "Internal Error: Incorrect num_cores_nhw");
CoreRange core_range1 = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x_1 - 1, cores_y_1 - 1}));
CoreRangeSet grid = CoreRangeSet({core_range1});
if (num_cores_nhw >= device_grid_size[0] && num_cores_nhw % device_grid_size[0] != 0) {
uint32_t cores_x_2 = num_cores_nhw % device_grid_size[0];
TT_ASSERT(cores_y_1 + 1 <= device_grid_size[1], "Internal Error: Incorrect num_cores_nhw");
CoreRange core_range2 = CoreRange(CoreCoord({0, cores_y_1}), CoreCoord({cores_x_2 - 1, cores_y_1}));
grid = CoreRangeSet({core_range1, core_range2});
}
CoreRangeSet grid = num_cores_to_core_range_set(num_cores_nhw, device_grid_size_coord, true);
return grid;
} else {
uint32_t total_cores_for_channels =
Expand Down

0 comments on commit f546d4e

Please sign in to comment.