Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam search logit refactor #771

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]

### Added
- Batches retrieval of logits from the GPU when the --n-best flag is specified.
- Local/global sharding with MPI training via `--sharding local`
- fp16 support for factors.
- Correct training with fp16 via `--fp16`.
Expand Down
28 changes: 28 additions & 0 deletions src/tensors/gpu/algorithm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,33 @@ template void swap_ranges<float>(Ptr<Backend>, float*, float*, float*);
template void swap_ranges<double>(Ptr<Backend>, double*, double*, double*);
// clang-format on

template <typename T>
__global__ void ggatherIndices(float* d_out, T* d_in, size_t* indices, size_t indicesToGather) {
int index = threadIdx.x + blockDim.x * blockIdx.x;
if(index < indicesToGather) {
d_out[index] = static_cast<float>(d_in[indices[index]]);
}
}

void gatherIndices(Ptr<Backend> backend, float* d_out, float* d_in, size_t* d_indices, size_t indices_size) {
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
int threadsPerBlock = std::min(MAX_THREADS, (int)indices_size);
int blocks = (indices_size + threadsPerBlock - 1) / threadsPerBlock;
ggatherIndices<<<blocks, threadsPerBlock>>>(d_out, d_in, d_indices, indices_size);
CUDA_CHECK(cudaStreamSynchronize(0));
}

void gatherIndices(Ptr<Backend> backend, float* d_out, float16* d_in, size_t* d_indices, size_t indices_size) {
#if COMPILE_FP16
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
int threadsPerBlock = std::min(MAX_THREADS, (int)indices_size);
int blocks = (indices_size + threadsPerBlock - 1) / threadsPerBlock;
ggatherIndices<<<blocks, threadsPerBlock>>>(d_out, (__half*)d_in, d_indices, indices_size);
CUDA_CHECK(cudaStreamSynchronize(0));
#else
ABORT("FP16 not supported with current hardware or CUDA version");
#endif
}

} // namespace gpu
} // namespace marian
10 changes: 10 additions & 0 deletions src/tensors/gpu/algorithm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "tensors/backend.h"
#include "common/types.h"

namespace marian {
namespace gpu {
Expand All @@ -17,5 +23,9 @@ void setSparse(Ptr<marian::Backend> backend,
const std::vector<size_t>&,
const std::vector<float>&,
float*);

void gatherIndices(Ptr<marian::Backend> backend, float* d_out, float* d_in, size_t* d_indices, size_t indices_size);

void gatherIndices(Ptr<marian::Backend> backend, float* d_out, float16* d_in, size_t* d_indices, size_t indices_size);
} // namespace gpu
} // namespace marian
38 changes: 38 additions & 0 deletions src/tensors/tensor.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "common/definitions.h"
Expand Down Expand Up @@ -109,6 +114,39 @@ class TensorBase {
return TensorBase::New(mem, Shape{1, (int)size}, type(), backend_);
}

void gatherFromIndices(Tensor gatheredResults, Tensor flattenedIndices) {
ABORT_IF((flattenedIndices->type() != Type::uint64),
"Type of indices must be uint64");

ABORT_IF(gatheredResults->size() < flattenedIndices->size(),
"The result tensor is too small to hold all of the indexed values.");

ABORT_IF(gatheredResults->type() != Type::float32,
"The type of the result tensor must be float32.");

if(backend_->getDeviceId().type == DeviceType::cpu) {
float* gatheredResultsPtr = gatheredResults->data<float>();
size_t* flattenedIndicesPtr = flattenedIndices->data<size_t>();
float* dataToGather = data();

for(int i = 0; i < flattenedIndices->size(); ++i) {
gatheredResultsPtr[i] = dataToGather[flattenedIndicesPtr[i]];
}
}
#ifdef CUDA_FOUND
else {
if (type_ == Type::float32) {
return gpu::gatherIndices(backend_, gatheredResults->data<float>(), data<float>(), flattenedIndices->data<size_t>(), flattenedIndices->size());
} else if(type_ == Type::float16) {
return gpu::gatherIndices(backend_, gatheredResults->data<float>(), data<float16>(), flattenedIndices->data<size_t>(), flattenedIndices->size());
} else {
ABORT("INVALID TYPE FOR OP");
}
}
#endif

}

// @TODO: review if we can eliminate GPU-specific code here,
// potentially by moving this to non-class members.
template <typename T>
Expand Down
56 changes: 49 additions & 7 deletions src/translator/beam_search.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#include "translator/beam_search.h"
#include "tensors/tensor_allocator.h"

#include "data/factored_vocab.h"
#include "translator/helpers.h"
Expand Down Expand Up @@ -40,6 +46,12 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
}
}

// Hold the flattened logit indices for each state so we can batch retrieval later. Additionally, store the original batch index to we can update the hypothesis in new beams
std::vector<size_t> origBatchIndices;
std::vector<size_t> oldBeamHypIndices;
std::vector<size_t> newBeamHypIndices;
std::vector<std::vector<uint64_t>> flattenedLogitIndices(states.size());

for(size_t i = 0; i < nBestKeys.size(); ++i) { // [currentDimBatch, beamSize] flattened
// Keys encode batchIdx, beamHypIdx, and word index in the entire beam.
// They can be between 0 and (vocabSize * nBestBeamSize * batchSize)-1.
Expand Down Expand Up @@ -123,23 +135,22 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current

// Set score breakdown for n-best lists
if(options_->get<bool>("n-best")) {
auto breakDown = beam[beamHypIdx]->getScoreBreakdown();
ABORT_IF(factoredVocab && factorGroup > 0 && !factoredVocab->canExpandFactoredWord(word, factorGroup),
"A word without this factor snuck through to here??");
breakDown.resize(states.size(), 0); // at start, this is empty, so this will set the initial score to 0
for(size_t j = 0; j < states.size(); ++j) {
for(uint64_t j = 0; j < states.size(); ++j) {
auto lval = states[j]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab]
// The flatting happens based on actual (current) batch size and batch index computed with batch-pruning as we are looking into the pruned tensor
size_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'
uint64_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'

// @TODO: use a function on shape() to index, or new method val->at({i1, i2, i3, i4}) with broadcasting
ABORT_IF(lval->shape() != Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}) &&
(beamHypIdx == 0 && lval->shape() != Shape({1, 1, (int)currentDimBatch, (int)vocabSize})),
"Unexpected shape of logits?? {} != {}", lval->shape(), Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}));

breakDown[j] += lval->get(flattenedLogitIndex);
flattenedLogitIndices[j].push_back(flattenedLogitIndex);
}
hyp->setScoreBreakdown(breakDown);
newBeamHypIndices.push_back(newBeam.size());
origBatchIndices.push_back(origBatchIdx);
oldBeamHypIndices.push_back(beamHypIdx);
}

// Set alignments
Expand All @@ -151,6 +162,36 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
newBeam.push_back(hyp);
}

// We need to set the score breakdown outside of the main loop to batch requests. This avoids issuing several 4 byte memcpys when using the GPU backend.
if(options_->get<bool>("n-best")) {
Tensor indices;
Tensor logitsTensor;
allocator_->allocate(indices, {(int)flattenedLogitIndices[0].size()}, Type::uint64);
allocator_->allocate(logitsTensor, indices->shape(), Type::float32);
std::vector<float> logits(flattenedLogitIndices[0].size());

for(size_t state = 0; state < states.size(); ++state) {
auto lval = states[state]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab]
indices->set(flattenedLogitIndices[state]);
lval->gatherFromIndices(logitsTensor, indices);
logitsTensor->get(logits);

for(int i = 0; i < flattenedLogitIndices[state].size(); ++i) {
const auto originalBatchIdx = origBatchIndices[i];
const auto beamHypIdx = oldBeamHypIndices[i];
const auto& beam = beams[originalBatchIdx];
auto& newBeam = newBeams[originalBatchIdx];

auto breakDown = beam[beamHypIdx]->getScoreBreakdown();
breakDown.resize(states.size(), 0); // at start, this is empty, so this will set the initial score to 0
breakDown[state] += logits[i];
newBeam[newBeamHypIndices[i]]->setScoreBreakdown(breakDown);
}
}
allocator_->free(indices);
allocator_->free(logitsTensor);
}

// if factored vocab and this is not the first factor, we need to
// also propagate factored hypotheses that do not get expanded in this step because they don't have this factor
if (factorGroup > 0) {
Expand Down Expand Up @@ -261,6 +302,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
const auto trgUnkId = trgVocab_->getUnkId();

auto getNBestList = createGetNBestListFn(beamSize_, origDimBatch, graph->getDeviceId());
allocator_ = graph->getTensorAllocator();

for(auto scorer : scorers_) {
scorer->clear(graph);
Expand Down
7 changes: 7 additions & 0 deletions src/translator/beam_search.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Part of this file was contributed by NVIDIA under license:
* Copyright (C) 2020 NVIDIA Corporation
* SPDX-License-Identifier: MIT
*/

#pragma once

#include "marian.h"
Expand All @@ -6,12 +11,14 @@

namespace marian {

class TensorAllocator;
class BeamSearch {
private:
Ptr<Options> options_;
std::vector<Ptr<Scorer>> scorers_;
size_t beamSize_;
Ptr<const Vocab> trgVocab_;
Ptr<TensorAllocator> allocator_;

const float INVALID_PATH_SCORE;
const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
Expand Down