diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index d97fa69fa24..95ba16a9396 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -5,14 +5,13 @@ use std::path::Path; use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; -use log::warn; -use tokenizers::{Encoding, Tokenizer}; +use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; use tokio::task::{spawn_blocking, JoinHandle}; use tokio::time::Instant; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, info}; +use tracing::{debug, error, warn}; use text_generation_router::infer::InferError::{GenerationError, ValidationError}; use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse}; @@ -285,7 +284,6 @@ fn ensure_paths_exist, PP: AsRef>( unsafe impl Send for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2 { - tokenizer: Tokenizer, executor_looper: JoinHandle<()>, post_processor_looper: JoinHandle<()>, executor: UnboundedSender, @@ -320,10 +318,9 @@ impl TensorRtLlmBackendV2 { }); // Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user - let tokenizer_ = tokenizer.clone(); let post_processor_looper = spawn_blocking(move || { post_processor_looper( - tokenizer_, + tokenizer, 512, max_inflight_requests, post_processor_receiver, @@ -331,7 +328,6 @@ impl TensorRtLlmBackendV2 { }); Ok(TensorRtLlmBackendV2 { - tokenizer, executor_looper, post_processor_looper, executor: executor_sender, @@ -358,7 +354,7 @@ impl TensorRtLlmBackendV2 { "TensorRT-LLM backend don't support multi-chunk".into(), )), 1 => match request.inputs.first().expect("Single item-chunk") { - Chunk::Text(text) => Ok(()), + Chunk::Text(_) => Ok(()), Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))), }, } diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 3573fe4136c..ec54ccce722 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -8,7 +8,7 @@ use tracing::info; use text_generation_backends_trtllm::errors::TensorRtLlmBackendError; use text_generation_backends_trtllm::TensorRtLlmBackendV2; -use text_generation_router::server::{create_post_processor, get_base_tokenizer}; +use text_generation_router::server::get_base_tokenizer; use text_generation_router::usage_stats::UsageStatsLevel; use text_generation_router::{server, HubTokenizerConfig}; @@ -125,10 +125,10 @@ async fn get_tokenizer( // Load tokenizer and model info let ( tokenizer_filename, - config_filename, + _config_filename, tokenizer_config_filename, - preprocessor_config_filename, - processor_config_filename, + _preprocessor_config_filename, + _processor_config_filename, ) = match api { Type::None => ( Some(local_path.join("tokenizer.json")), @@ -184,25 +184,8 @@ async fn get_tokenizer( } else { tokenizer_config_filename.and_then(HubTokenizerConfig::from_file) }; - let tokenizer_config = tokenizer_config.unwrap_or_else(|| { - tracing::warn!("Could not find tokenizer config locally and no API specified"); - HubTokenizerConfig::default() - }); - tokenizer_filename.and_then(|filename| { - let mut tokenizer = Tokenizer::from_file(filename).ok(); - if let Some(tokenizer) = &mut tokenizer { - if let Some(class) = &tokenizer_config.tokenizer_class { - if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{ - if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) { - tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205"); - tokenizer.with_post_processor(post_processor); - } - } - } - } - tokenizer - }) + tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok()) } #[tokio::main]