Skip to content

Commit

Permalink
Fix transposed pdims during autotuning. (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
romerojosh authored Apr 20, 2024
1 parent b8ffecc commit cff5275
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions src/autotune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ void autotuneTransposeBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_d
if (!options->transpose_use_inplace_buffers[i]) need_data2 = true;
}

std::vector<int> pdim0_list;
std::vector<int> pdim1_list;
if (autotune_pdims) {
pdim0_list = getFactors(handle->nranks);
pdim1_list = getFactors(handle->nranks);
} else {
pdim0_list = {grid_desc->config.pdims[0]};
pdim1_list = {grid_desc->config.pdims[1]};
}

int32_t pdims_best[2]{grid_desc->config.pdims[0], grid_desc->config.pdims[1]};
Expand All @@ -149,9 +149,9 @@ void autotuneTransposeBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_d

int64_t data_sz = 0;
int64_t work_sz = 0;
for (auto& pdim0 : pdim0_list) {
grid_desc->config.pdims[0] = handle->nranks / pdim0;
grid_desc->config.pdims[1] = pdim0;
for (auto& pdim1 : pdim1_list) {
grid_desc->config.pdims[0] = handle->nranks / pdim1;
grid_desc->config.pdims[1] = pdim1;
grid_desc->pidx[0] = handle->rank / grid_desc->config.pdims[1];
grid_desc->pidx[1] = handle->rank % grid_desc->config.pdims[1];

Expand Down Expand Up @@ -529,11 +529,11 @@ void autotuneHaloBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_desc,
#endif
}

std::vector<int> pdim0_list;
std::vector<int> pdim1_list;
if (autotune_pdims) {
pdim0_list = getFactors(handle->nranks);
pdim1_list = getFactors(handle->nranks);
} else {
pdim0_list = {grid_desc->config.pdims[0]};
pdim1_list = {grid_desc->config.pdims[1]};
}

int32_t pdims_best[2]{grid_desc->config.pdims[0], grid_desc->config.pdims[1]};
Expand All @@ -546,9 +546,9 @@ void autotuneHaloBackend(cudecompHandle_t handle, cudecompGridDesc_t grid_desc,

int64_t data_sz = 0;
int64_t work_sz = 0;
for (auto& pdim0 : pdim0_list) {
grid_desc->config.pdims[0] = pdim0;
grid_desc->config.pdims[1] = handle->nranks / pdim0;
for (auto& pdim1 : pdim1_list) {
grid_desc->config.pdims[0] = handle->nranks / pdim1;
grid_desc->config.pdims[1] = pdim1;
grid_desc->pidx[0] = handle->rank / grid_desc->config.pdims[1];
grid_desc->pidx[1] = handle->rank % grid_desc->config.pdims[1];

Expand Down

0 comments on commit cff5275

Please sign in to comment.