Skip to content

Commit

Permalink
feat(trtllm): add stop words handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Oct 21, 2024
1 parent 5d2171a commit 9cf43a7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
20 changes: 11 additions & 9 deletions backends/trtllm/include/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#ifndef TGI_TRTLLM_BACKEND_H
#define TGI_TRTLLM_BACKEND_H

#include <array>
#include <cmath>
#include <filesystem>
#include <span>
Expand Down Expand Up @@ -72,6 +73,7 @@ namespace huggingface::tgi::backends {

/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;
std::list<std::vector<TokenId>> stopWords;

public:
explicit TensorRtLlmBackend(
Expand All @@ -85,20 +87,20 @@ namespace huggingface::tgi::backends {
* @param topK
* @param topP
* @param temperature
* @param repetition_penalty
* @param frequency_penalty
* @param repetitionPenalty
* @param frequencyPenalty
* @param seed
* @return Request id related to this generation for reference
*/
[[nodiscard]] RequestId Submit(
const std::vector<TokenId> &tokens,
const uint32_t maxNewTokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
uint32_t maxNewTokens,
int32_t topK,
float_t topP,
float_t temperature,
float_t repetitionPenalty,
float_t frequencyPenalty,
uint64_t seed
);

[[nodiscard]] std::vector<tle::Response> PullNewTokens();
Expand Down
40 changes: 32 additions & 8 deletions backends/trtllm/lib/backend.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <algorithm>
#include <cstdlib>
#include <fstream>

Expand Down Expand Up @@ -104,35 +105,58 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(

// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();

// Attempt to discover stopWords from the generation_config.json
if (auto generationConfigPath = enginesFolder / "generation_config.json"; exists(generationConfigPath)) {
const auto generationConfig = json::parse(std::ifstream(generationConfigPath));
if (const auto eosTokenIds = generationConfig["/eos_token_ids"_json_pointer]; eosTokenIds.is_array()) {
SPDLOG_INFO(FMT_STRING("Found {:d} EOS tokens"), eosTokenIds.size());
stopWords = std::list<decltype(stopWords)::value_type>(eosTokenIds.size());

std::transform(eosTokenIds.cbegin(), eosTokenIds.cend(), stopWords.begin(),
[](const auto tokenIdObj) -> decltype(stopWords)::value_type {
const auto tokenId = tokenIdObj.template get<tle::TokenIdType>();
return {tokenId};
});
}
} else {
SPDLOG_INFO("No EOS tokens found, generation_config.json doesn't exist");
stopWords = {};
}
}

[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
[[nodiscard("(generationConfigPath)Returned request id needs to be provided back to gather generated tokens")]]
tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const std::vector<tle::TokenIdType> &tokens,
const uint32_t maxNewTokens,
const int32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const float_t repetitionPenalty,
const float_t frequencyPenalty,
const uint64_t seed
) {
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
#ifndef NDEBUG
{
const auto &iterations = executor.getLatestIterationStats();
const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);


SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
}
#endif

const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
const auto sampling = GetSamplingConfig(topK, topP, temperature, repetitionPenalty, frequencyPenalty, seed);
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});

// Build the request
auto request = tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG};
request.setStopWords(stopWords);

// Submit to the executor for batching
return executor.enqueueRequest(request);
}

std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
Expand Down

0 comments on commit 9cf43a7

Please sign in to comment.