diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 5d23b3be567..969bdcb72c7 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -6,13 +6,13 @@ use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; use log::warn; -use tokenizers::{Encoding, Tokenizer}; +use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; use tokio::task::{spawn_blocking, JoinHandle}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, info}; +use tracing::{debug, error}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -160,7 +160,7 @@ fn executor_status_looper( } } -fn post_processor_looper( +fn post_processor_looper( tokenizer: Tokenizer, max_inflight_requests: usize, mut decoded_tokens: UnboundedReceiver<(u64, InferResult)>, @@ -180,7 +180,7 @@ fn post_processor_looper( .entry(request_id) .and_modify(|s| s.push(*&ctx.token.id)) .or_insert_with(|| { - let mut state = Vec::with_capacity(max_num_tokens); + let mut state = Vec::with_capacity(MAX_NUM_TOKENS); state.push(*&ctx.token.id); state }); @@ -279,7 +279,6 @@ fn ensure_paths_exist, PP: AsRef>( unsafe impl Send for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2 { - tokenizer: Tokenizer, executor_looper: JoinHandle<()>, post_processor_looper: JoinHandle<()>, executor: UnboundedSender, @@ -314,13 +313,11 @@ impl TensorRtLlmBackendV2 { }); // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user - let tokenizer_ = tokenizer.clone(); let post_processor_looper = spawn_blocking(move || { - post_processor_looper::<512>(tokenizer_, max_inflight_requests, post_processor_receiver) + post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver) }); Ok(TensorRtLlmBackendV2 { - tokenizer, executor_looper, post_processor_looper, executor: executor_sender, @@ -347,7 +344,7 @@ impl TensorRtLlmBackendV2 { "TensorRT-LLM backend don't support multi-chunk".into(), )), 1 => match request.inputs.first().expect("Single item-chunk") { - Chunk::Text(text) => Ok(()), + Chunk::Text(_) => Ok(()), Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), }, } diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 3573fe4136c..d99dd2a0e87 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -8,7 +8,7 @@ use tracing::info; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackendV2; -use text_generation_router::server::{create_post_processor, get_base_tokenizer}; +use text_generation_router::server::get_base_tokenizer; use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::{server, HubTokenizerConfig}; @@ -189,20 +189,7 @@ async fn get_tokenizer( HubTokenizerConfig::default() }); - tokenizer_filename.and_then(|filename| { - let mut tokenizer = Tokenizer::from_file(filename).ok(); - if let Some(tokenizer) = &mut tokenizer { - if let Some(class) = &tokenizer_config.tokenizer_class { - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ - if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { - tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - tokenizer.with_post_processor(post_processor); - } - } - } - } - tokenizer - }) + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) } #[tokio::main]