Skip to content

Commit

Permalink
misc(backend): indent
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Dec 13, 2024
1 parent ab6591e commit 1640da7
Showing 1 changed file with 42 additions and 34 deletions.
76 changes: 42 additions & 34 deletions backends/trtllm/csrc/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ namespace huggingface::tgi::backends::trtllm {
uint64_t seed;

constexpr explicit operator tle::SamplingConfig() const {
return tle::SamplingConfig {
1,
top_k,
top_p,
std::nullopt,
std::nullopt,
std::nullopt,
seed,
temperature,
std::nullopt,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty,
std::nullopt
return tle::SamplingConfig{
1,
top_k,
top_p,
std::nullopt,
std::nullopt,
std::nullopt,
seed,
temperature,
std::nullopt,
std::nullopt,
repetition_penalty,
std::nullopt,
frequency_penalty,
std::nullopt
};
}
};
Expand All @@ -67,10 +67,10 @@ namespace huggingface::tgi::backends::trtllm {
float_t temperature;
std::list<std::vector<int32_t>> stop_words;

constexpr explicit generation_config_t(const json &config):
top_p(config.value("top_p", 1.0f)), temperature( config.value("temperature", 1.0f)), stop_words(0) {
if(config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) {
const auto& eos_token_id = config["/eos_token_id"_json_pointer];
constexpr explicit generation_config_t(const json &config) :
top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) {
if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) {
const auto &eos_token_id = config["/eos_token_id"_json_pointer];
std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {
stop_words.emplace_back(1, token_id.template get<int32_t>());
});
Expand All @@ -97,13 +97,13 @@ namespace huggingface::tgi::backends::trtllm {
generation_config_t generation_config_;

public:
backend_workspace_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path):
engines_folder_(engines_folder),
executor_worker_path_(executor_worker_path),
config_(as_json(engines_folder / "config.json")),
generation_config_(as_json(engines_folder / "generation_config.json")) {};
backend_workspace_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) :
engines_folder_(engines_folder),
executor_worker_path_(executor_worker_path),
config_(as_json(engines_folder / "config.json")),
generation_config_(as_json(engines_folder / "generation_config.json")) {};

backend_workspace_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path):
backend_workspace_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) :
engines_folder_(engines_folder),
executor_worker_path_(executor_worker_path),
config_(as_json(engines_folder / "config.json")),
Expand All @@ -120,9 +120,9 @@ namespace huggingface::tgi::backends::trtllm {
* `generation_config.json` holding default generation parameters.
* @return `generation_config_t`
*/
[[nodiscard]] constexpr const generation_config_t& generation_config() const { return generation_config_; }
[[nodiscard]] constexpr const generation_config_t &generation_config() const { return generation_config_; }

/**
/**
* Factory method returning new `tensorrt_llm::executor::ParallelConfig` instance used
* to initialize `tensorrt_llm::executor::Executor` with multi-instance communication information
* @return `tensorrt_llm::executor::ParallelConfig` instance
Expand Down Expand Up @@ -159,8 +159,9 @@ namespace huggingface::tgi::backends::trtllm {

public:
backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path);

backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path)
: backend_t(engines_folder, executor_worker_path) {};
: backend_t(engines_folder, executor_worker_path) {};

/**
* Submit a new request to the executor
Expand All @@ -171,7 +172,8 @@ namespace huggingface::tgi::backends::trtllm {
*/
[[nodiscard("Discarded executor request_id needs to be assigned")]]
std::expected<request_id_t, backend_error_t>
submit(std::span<const token_id_t> token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept;
submit(std::span<const token_id_t> token_ids, generation_params_t generation_params,
sampling_params_t sampling_params) noexcept;

/**
* Query the number of tokens available across all in-flight generations
Expand All @@ -198,26 +200,32 @@ namespace huggingface::tgi::backends::trtllm {
* Create a TensorRT-LLM executor from a workspace
*/
const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor {
return { workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY, workspace.executor_config() };
return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY,
workspace.executor_config()};
};
}

/**
* Helper structures to define formatting strategies for various types in the backend
*/
template <> struct fmt::formatter<huggingface::tgi::backends::trtllm::generation_params_t>: formatter<string_view> {
auto format(huggingface::tgi::backends::trtllm::generation_params_t const& c, format_context& ctx) const -> format_context::iterator {
template<>
struct fmt::formatter<huggingface::tgi::backends::trtllm::generation_params_t> : formatter<string_view> {
auto format(huggingface::tgi::backends::trtllm::generation_params_t const &c,
format_context &ctx) const -> format_context::iterator {
return fmt::format_to(ctx.out(), "generation_params_t{{ max_new_tokens={:d} }}", c.max_new_tokens);
}
};

template <> struct fmt::formatter<huggingface::tgi::backends::trtllm::sampling_params_t>: formatter<string_view> {
auto format(huggingface::tgi::backends::trtllm::sampling_params_t const& c, format_context& ctx) const -> format_context::iterator {
template<>
struct fmt::formatter<huggingface::tgi::backends::trtllm::sampling_params_t> : formatter<string_view> {
auto format(huggingface::tgi::backends::trtllm::sampling_params_t const &c,
format_context &ctx) const -> format_context::iterator {
return fmt::format_to(
ctx.out(),
"sampling_params_t{{ top_k={:d}, top_p={:.3f}, repetition_penalty={:.3f}, frequency_penalty={:.3f}, temperature={:.3f}, seed={:d} }}",
c.top_k, c.top_p, c.repetition_penalty, c.frequency_penalty, c.temperature, c.seed
);
}
};

#endif

0 comments on commit 1640da7

Please sign in to comment.