From 248ccab7e6546d45f111384a1aec09b3403dea2e Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 13 Sep 2023 15:11:37 -0500 Subject: [PATCH] Some additional kernel thread index refactoring. --- cpp/benchmarks/join/generate_input_tables.cuh | 17 ++++++---- .../type_dispatcher/type_dispatcher.cu | 32 +++++++++++-------- cpp/include/cudf/detail/copy_if_else.cuh | 17 +++++----- 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/cpp/benchmarks/join/generate_input_tables.cuh b/cpp/benchmarks/join/generate_input_tables.cuh index 84e607a9f28..ef2e6370760 100644 --- a/cpp/benchmarks/join/generate_input_tables.cuh +++ b/cpp/benchmarks/join/generate_input_tables.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -33,7 +34,7 @@ __global__ static void init_curand(curandState* state, int const nstates) { - int ithread = threadIdx.x + blockIdx.x * blockDim.x; + int ithread = cudf::detail::grid_1d::global_thread_id(); if (ithread < nstates) { curand_init(1234ULL, ithread, 0, state + ithread); } } @@ -45,13 +46,14 @@ __global__ static void init_build_tbl(key_type* const build_tbl, curandState* state, int const num_states) { - auto const start_idx = blockIdx.x * blockDim.x + threadIdx.x; - auto const stride = blockDim.x * gridDim.x; + auto const start_idx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); assert(start_idx < num_states); curandState localState = state[start_idx]; - for (size_type idx = start_idx; idx < build_tbl_size; idx += stride) { + for (thread_index_type tidx = start_idx; tidx < build_tbl_size; tidx += stride) { + auto const idx = static_cast(tidx); double const x = curand_uniform_double(&localState); build_tbl[idx] = static_cast(x * (build_tbl_size / multiplicity)); @@ -70,13 +72,14 @@ __global__ void init_probe_tbl(key_type* const probe_tbl, curandState* state, int const num_states) { - auto const start_idx = blockIdx.x * blockDim.x + threadIdx.x; - auto const stride = blockDim.x * gridDim.x; + auto const start_idx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); assert(start_idx < num_states); curandState localState = state[start_idx]; - for (size_type idx = start_idx; idx < probe_tbl_size; idx += stride) { + for (thread_index_type tidx = start_idx; tidx < probe_tbl_size; tidx += stride) { + auto const idx = static_cast(tidx); key_type val; double x = curand_uniform_double(&localState); diff --git a/cpp/benchmarks/type_dispatcher/type_dispatcher.cu b/cpp/benchmarks/type_dispatcher/type_dispatcher.cu index 3f985cffb1f..5a2cbe5a395 100644 --- a/cpp/benchmarks/type_dispatcher/type_dispatcher.cu +++ b/cpp/benchmarks/type_dispatcher/type_dispatcher.cu @@ -60,13 +60,15 @@ constexpr int block_size = 256; template __global__ void no_dispatching_kernel(T** A, cudf::size_type n_rows, cudf::size_type n_cols) { - using F = Functor; - cudf::size_type index = blockIdx.x * blockDim.x + threadIdx.x; - while (index < n_rows) { + using F = Functor; + auto tidx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); + while (tidx < n_rows) { + auto const index = static_cast(tid); for (int c = 0; c < n_cols; c++) { A[c][index] = F::f(A[c][index]); } - index += blockDim.x * gridDim.x; + tidx += stride; } } @@ -74,12 +76,14 @@ __global__ void no_dispatching_kernel(T** A, cudf::size_type n_rows, cudf::size_ template __global__ void host_dispatching_kernel(cudf::mutable_column_device_view source_column) { - using F = Functor; - T* A = source_column.data(); - cudf::size_type index = blockIdx.x * blockDim.x + threadIdx.x; - while (index < source_column.size()) { - A[index] = F::f(A[index]); - index += blockDim.x * gridDim.x; + using F = Functor; + T* A = source_column.data(); + auto tidx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); + while (tidx < source_column.size()) { + auto const index = static_cast(tid); + A[index] = F::f(A[index]); + tidx += stride; } } @@ -127,14 +131,14 @@ template __global__ void device_dispatching_kernel(cudf::mutable_table_device_view source) { cudf::size_type const n_rows = source.num_rows(); - cudf::size_type index = threadIdx.x + blockIdx.x * blockDim.x; - - while (index < n_rows) { + auto tidx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); + while (tidx < n_rows) { for (cudf::size_type i = 0; i < source.num_columns(); i++) { cudf::type_dispatcher( source.column(i).type(), RowHandle{}, source.column(i), index); } - index += blockDim.x * gridDim.x; + tidx += stride; } // while } diff --git a/cpp/include/cudf/detail/copy_if_else.cuh b/cpp/include/cudf/detail/copy_if_else.cuh index 04ad1f20196..48c70df1862 100644 --- a/cpp/include/cudf/detail/copy_if_else.cuh +++ b/cpp/include/cudf/detail/copy_if_else.cuh @@ -44,18 +44,19 @@ __launch_bounds__(block_size) __global__ mutable_column_device_view out, size_type* __restrict__ const valid_count) { - size_type const tid = threadIdx.x + blockIdx.x * block_size; - int const warp_id = tid / warp_size; + auto tidx = cudf::detail::grid_1d::global_thread_id(); + auto const stride = cudf::detail::grid_1d::grid_stride(); + int const warp_id = tidx / warp_size; size_type const warps_per_grid = gridDim.x * block_size / warp_size; // begin/end indices for the column data - size_type begin = 0; - size_type end = out.size(); + size_type const begin = 0; + size_type const end = out.size(); // warp indices. since 1 warp == 32 threads == sizeof(bitmask_type) * 8, // each warp will process one (32 bit) of the validity mask via // __ballot_sync() - size_type warp_begin = cudf::word_index(begin); - size_type warp_end = cudf::word_index(end - 1); + size_type const warp_begin = cudf::word_index(begin); + size_type const warp_end = cudf::word_index(end - 1); // lane id within the current warp constexpr size_type leader_lane{0}; @@ -65,8 +66,8 @@ __launch_bounds__(block_size) __global__ // current warp. size_type warp_cur = warp_begin + warp_id; - size_type index = tid; while (warp_cur <= warp_end) { + auto const index = static_cast(tidx); auto const opt_value = (index < end) ? (filter(index) ? lhs[index] : rhs[index]) : thrust::nullopt; if (opt_value) { out.element(index) = static_cast(*opt_value); } @@ -84,7 +85,7 @@ __launch_bounds__(block_size) __global__ // next grid warp_cur += warps_per_grid; - index += block_size * gridDim.x; + tidx += stride; } if (has_nulls) {