-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix illegal acces mean/stdev, sum add Kahan Summation #2223
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,30 +34,72 @@ RAFT_KERNEL sumKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) | |
IdxType thisRowId = threadIdx.x / ColsPerBlk; | ||
IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); | ||
IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); | ||
Type thread_data = Type(0); | ||
Type thread_sum = Type(0); | ||
const IdxType stride = RowsPerBlkPerIter * gridDim.x; | ||
for (IdxType i = rowId; i < N; i += stride) | ||
thread_data += (colId < D) ? data[i * D + colId] : Type(0); | ||
for (IdxType i = rowId; i < N; i += stride) { | ||
thread_sum += (colId < D) ? data[i * D + colId] : Type(0); | ||
} | ||
__shared__ Type smu[ColsPerBlk]; | ||
if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); | ||
__syncthreads(); | ||
raft::myAtomicAdd(smu + thisColId, thread_data); | ||
raft::myAtomicAdd(smu + thisColId, thread_sum); | ||
__syncthreads(); | ||
if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); | ||
} | ||
|
||
template <typename Type, typename IdxType, int TPB, int ColsPerBlk = 32> | ||
RAFT_KERNEL sumKahanKernelRowMajor(Type* mu, const Type* data, IdxType D, IdxType N) | ||
{ | ||
constexpr int RowsPerBlkPerIter = TPB / ColsPerBlk; | ||
IdxType thisColId = threadIdx.x % ColsPerBlk; | ||
IdxType thisRowId = threadIdx.x / ColsPerBlk; | ||
IdxType colId = thisColId + ((IdxType)blockIdx.y * ColsPerBlk); | ||
IdxType rowId = thisRowId + ((IdxType)blockIdx.x * RowsPerBlkPerIter); | ||
Type thread_sum = Type(0); | ||
Type thread_c = Type(0); | ||
const IdxType stride = RowsPerBlkPerIter * gridDim.x; | ||
for (IdxType i = rowId; i < N; i += stride) { | ||
// KahanBabushkaNeumaierSum | ||
const Type cur_value = (colId < D) ? data[i * D + colId] : Type(0); | ||
const Type t = thread_sum + cur_value; | ||
if (abs(thread_sum) >= abs(cur_value)) { | ||
thread_c += (thread_sum - t) + cur_value; | ||
} else { | ||
thread_c += (cur_value - t) + thread_sum; | ||
} | ||
thread_sum = t; | ||
} | ||
thread_sum += thread_c; | ||
__shared__ Type smu[ColsPerBlk]; | ||
if (threadIdx.x < ColsPerBlk) smu[threadIdx.x] = Type(0); | ||
__syncthreads(); | ||
raft::myAtomicAdd(smu + thisColId, thread_sum); | ||
__syncthreads(); | ||
if (threadIdx.x < ColsPerBlk && colId < D) raft::myAtomicAdd(mu + colId, smu[thisColId]); | ||
} | ||
|
||
template <typename Type, typename IdxType, int TPB> | ||
RAFT_KERNEL sumKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) | ||
RAFT_KERNEL sumKahanKernelColMajor(Type* mu, const Type* data, IdxType D, IdxType N) | ||
{ | ||
typedef cub::BlockReduce<Type, TPB> BlockReduce; | ||
__shared__ typename BlockReduce::TempStorage temp_storage; | ||
Type thread_data = Type(0); | ||
Type thread_sum = Type(0); | ||
Type thread_c = Type(0); | ||
IdxType colStart = N * blockIdx.x; | ||
for (IdxType i = threadIdx.x; i < N; i += TPB) { | ||
IdxType idx = colStart + i; | ||
thread_data += data[idx]; | ||
// KahanBabushkaNeumaierSum | ||
IdxType idx = colStart + i; | ||
const Type cur_value = data[idx]; | ||
const Type t = thread_sum + cur_value; | ||
if (abs(thread_sum) >= abs(cur_value)) { | ||
thread_c += (thread_sum - t) + cur_value; | ||
} else { | ||
thread_c += (cur_value - t) + thread_sum; | ||
} | ||
thread_sum = t; | ||
} | ||
Type acc = BlockReduce(temp_storage).Sum(thread_data); | ||
thread_sum += thread_c; | ||
Type acc = BlockReduce(temp_storage).Sum(thread_sum); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not compensated right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The BlockReduce is not, which is why the compensation is added to the value beforehand. |
||
if (threadIdx.x == 0) { mu[blockIdx.x] = acc; } | ||
} | ||
|
||
|
@@ -66,15 +108,21 @@ void sum(Type* output, const Type* input, IdxType D, IdxType N, bool rowMajor, c | |
{ | ||
static const int TPB = 256; | ||
if (rowMajor) { | ||
static const int RowsPerThread = 4; | ||
static const int ColsPerBlk = 32; | ||
static const int RowsPerBlk = (TPB / ColsPerBlk) * RowsPerThread; | ||
dim3 grid(raft::ceildiv(N, (IdxType)RowsPerBlk), raft::ceildiv(D, (IdxType)ColsPerBlk)); | ||
static const int ColsPerBlk = 8; | ||
static const int MinRowsPerThread = 16; | ||
static const int MinRowsPerBlk = (TPB / ColsPerBlk) * MinRowsPerThread; | ||
static const int MaxBlocksDimX = 8192; | ||
|
||
const IdxType grid_y = raft::ceildiv(D, (IdxType)ColsPerBlk); | ||
const IdxType grid_x = | ||
raft::min((IdxType)MaxBlocksDimX, raft::ceildiv(N, (IdxType)MinRowsPerBlk)); | ||
|
||
dim3 grid(grid_x, grid_y); | ||
RAFT_CUDA_TRY(cudaMemset(output, 0, sizeof(Type) * D)); | ||
sumKernelRowMajor<Type, IdxType, TPB, ColsPerBlk> | ||
sumKahanKernelRowMajor<Type, IdxType, TPB, ColsPerBlk> | ||
<<<grid, TPB, 0, stream>>>(output, input, D, N); | ||
} else { | ||
sumKernelColMajor<Type, IdxType, TPB><<<D, TPB, 0, stream>>>(output, input, D, N); | ||
sumKahanKernelColMajor<Type, IdxType, TPB><<<D, TPB, 0, stream>>>(output, input, D, N); | ||
} | ||
RAFT_CUDA_TRY(cudaPeekAtLastError()); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, we are still loosing accuracy here, because we cannot do atomic compensated summation. In a follow up PR, we should strive to improve this. A few notes:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you would need to make use of extra smem here
smu[ColsPerBlk * RowsPerBlkPerIter ]
then store the each outputs something likesmu[ thisColId * RowsPerBlkPerIter + thisRowId ] = thread_sum
, followed by per-thread working on summing upRowsPerBlkPerIter
from a single warp0 with kahan algo ifRowsPerBlkPerIter
is small and for largerRowsPerBlkPerIter
like 32 you can use shfl based reduction with kahan algo applied on each of its 5 iteration.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, within the block we can use a second shared memory atomicAdd to store the compensation. With the current blockdim we only have 4 threads adding their intermediate values. I tried that but decided to skip for now until addition across blocks is not compensated afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.