Skip to content

Commit

Permalink
chore(trtllm): validate there are enough GPus on the system for the d…
Browse files Browse the repository at this point in the history
…esired model
  • Loading branch information
mfuntowicz committed Oct 21, 2024
1 parent 98dcde0 commit 1b56a33
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion backends/trtllm/lib/backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,16 @@ huggingface::tgi::backends::TensorRtLlmBackend::TensorRtLlmBackend(
config(json::parse(std::ifstream(enginesFolder / "config.json"))),
executor(enginesFolder, tensorrt_llm::executor::ModelType::kDECODER_ONLY,
GetExecutorConfig(config, executorWorker.string())) {
SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get_ref<const std::string &>());

SPDLOG_INFO(FMT_STRING("Engine (version={})"), config["/version"_json_pointer].get<std::string_view>());

// Ensure we have enough GPUs on the system
const auto worldSize = config["/pretrained_config/mapping/world_size"_json_pointer].get<size_t>();
const auto numGpus = huggingface::hardware::cuda::GetNumDevices().value_or(0);
if (numGpus < worldSize) {
SPDLOG_CRITICAL(FMT_NOT_ENOUGH_GPUS, numGpus, worldSize);
// todo : raise exception to catch on rust side
}

// Cache variables
maxNumTokens = config["/build_config/max_num_tokens"_json_pointer].get<uint32_t>();
Expand Down

0 comments on commit 1b56a33

Please sign in to comment.