Skip to content

Commit

Permalink
feat(trtllm): cache maxNumTokens to avoid calling JSON everytime
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Oct 21, 2024
1 parent 3174716 commit e6da212
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
19 changes: 13 additions & 6 deletions backends/trtllm/include/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ namespace huggingface::tgi::backends {
using TokenId = tle::TokenIdType;

const static auto OUTPUT_CONFIG = tle::OutputConfig(true, false, false, true, false);
constexpr auto FMT_EXECUTOR_STATS = FMT_STRING(
"Submitting inference [{}] to the executor ({:d} already in-flight)");
constexpr auto FMT_SAMPLING_CONFIG = FMT_STRING(
"Sampling: topK={:d}, topP={:.1f}, temperature={:.1f}, repetition_penalty={:.1f}, frequency_penalty={:.1f}, seed={:d}");

/**
* Initialize all the components required by TRTLLM.
Expand All @@ -50,12 +54,12 @@ namespace huggingface::tgi::backends {
* @return
*/
tle::SamplingConfig GetSamplingConfig(
const uint32_t topK,
const float_t topP,
const float_t temperature,
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed
uint32_t topK,
float_t topP,
float_t temperature,
float_t repetition_penalty,
float_t frequency_penalty,
uint64_t seed
) noexcept;

/**
Expand All @@ -66,6 +70,9 @@ namespace huggingface::tgi::backends {
const json config;
tle::Executor executor;

/** Frequently accessed variables cached here **/
uint32_t maxNumTokens;

public:
explicit TensorRtLlmBackend(
const std::filesystem::path &engineFolder,
Expand Down
31 changes: 14 additions & 17 deletions backends/trtllm/lib/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ tle::SamplingConfig huggingface::tgi::backends::GetSamplingConfig(
const float_t repetition_penalty,
const float_t frequency_penalty,
const uint64_t seed) noexcept {

return tle::SamplingConfig(
1, // TGI only use a single beam
topK,
Expand All @@ -100,6 +101,9 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());

// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
}

[[nodiscard("Returned request id needs to be provided back to gather generated tokens")]]
Expand All @@ -113,29 +117,22 @@ tle::IdType huggingface::tgi::backends::TensorRtLlmBackend::Submit(
const float_t frequency_penalty,
const uint64_t seed
) {
const auto maxNewTokensChecked = std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size()));
#ifndef NDEBUG
SPDLOG_DEBUG(
FMT_STRING("Submitting inference [{}] to the executor ({:d} already in-flight)"),
fmt::join(tokens, ", "),
executor.getLatestIterationStats().front().numActiveRequests
);
#endif
{
const auto &iterations = executor.getLatestIterationStats();
const auto &lastIteration = iterations.front();
SPDLOG_DEBUG(FMT_EXECUTOR_STATS, fmt::join(tokens, ", "), lastIteration.numActiveRequests);

const auto maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint64_t>();
const auto maxNewTokensChecked = static_cast<tle::SizeType32>(
std::min(maxNewTokens, static_cast<uint32_t>(maxNumTokens - tokens.size())));

#ifndef NDEBUG
SPDLOG_INFO(
FMT_STRING(
"Sampling config: topK={:d}, topP={:d}, temperature={:d}, repetition_penalty={:d}, frequency_penalty={:d}, seed={:d}"),
topK, topP, temperature, repetition_penalty, frequency_penalty, seed
)
SPDLOG_INFO(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
SPDLOG_DEBUG(FMT_SAMPLING_CONFIG, topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
SPDLOG_DEBUG(FMT_STRING("Asking for max_new_tokens={:d}"), maxNewTokensChecked);
}
#endif

const auto sampling = GetSamplingConfig(topK, topP, temperature, repetition_penalty, frequency_penalty, seed);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked, true, sampling, OUTPUT_CONFIG});
const auto maxNewTokensChecked_ = static_cast<tle::SizeType32>(maxNewTokensChecked);
return executor.enqueueRequest(tle::Request{tokens, maxNewTokensChecked_, true, sampling, OUTPUT_CONFIG});
}

std::vector<tle::Response> huggingface::tgi::backends::TensorRtLlmBackend::PullNewTokens() {
Expand Down

0 comments on commit e6da212

Please sign in to comment.