Skip to content

Commit

Permalink
chore(rebase): fix invalid references
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Oct 21, 2024
1 parent f5b9ee3 commit d73401a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 30 deletions.
12 changes: 4 additions & 8 deletions backends/trtllm/src/looper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ use std::path::Path;
use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use log::warn;
use tokenizers::{Encoding, Tokenizer};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info};
use tracing::{debug, error, warn};

use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
Expand Down Expand Up @@ -285,7 +284,6 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
unsafe impl Send for TensorRtLlmBackendImpl {}

pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer,
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
Expand Down Expand Up @@ -320,18 +318,16 @@ impl TensorRtLlmBackendV2 {
});

// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let tokenizer_ = tokenizer.clone();
let post_processor_looper = spawn_blocking(move || {
post_processor_looper(
tokenizer_,
tokenizer,
512,
max_inflight_requests,
post_processor_receiver,
)
});

Ok(TensorRtLlmBackendV2 {
tokenizer,
executor_looper,
post_processor_looper,
executor: executor_sender,
Expand All @@ -358,7 +354,7 @@ impl TensorRtLlmBackendV2 {
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(()),
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
Expand Down
27 changes: 5 additions & 22 deletions backends/trtllm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tracing::info;

use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::{create_post_processor, get_base_tokenizer};
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};

Expand Down Expand Up @@ -125,10 +125,10 @@ async fn get_tokenizer(
// Load tokenizer and model info
let (
tokenizer_filename,
config_filename,
_config_filename,
tokenizer_config_filename,
preprocessor_config_filename,
processor_config_filename,
_preprocessor_config_filename,
_processor_config_filename,
) = match api {
Type::None => (
Some(local_path.join("tokenizer.json")),
Expand Down Expand Up @@ -184,25 +184,8 @@ async fn get_tokenizer(
} else {
tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
};
let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
tracing::warn!("Could not find tokenizer config locally and no API specified");
HubTokenizerConfig::default()
});

tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
})
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
}

#[tokio::main]
Expand Down

0 comments on commit d73401a

Please sign in to comment.