Skip to content

Commit

Permalink
chore(rebase): fix invalid references
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Oct 21, 2024
1 parent f631742 commit 2c8ecdb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 24 deletions.
15 changes: 6 additions & 9 deletions backends/trtllm/src/looper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ use async_trait::async_trait;
use cxx::UniquePtr;
use hashbrown::HashMap;
use log::warn;
use tokenizers::{Encoding, Tokenizer};
use tokenizers::Tokenizer;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::TryAcquireError;
use tokio::task::{spawn_blocking, JoinHandle};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info};
use tracing::{debug, error};

use text_generation_router::infer::InferError::{GenerationError, ValidationError};
use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStreamResponse};
Expand Down Expand Up @@ -160,7 +160,7 @@ fn executor_status_looper(
}
}

fn post_processor_looper<const max_num_tokens: usize>(
fn post_processor_looper<const MAX_NUM_TOKENS: usize>(
tokenizer: Tokenizer,
max_inflight_requests: usize,
mut decoded_tokens: UnboundedReceiver<(u64, InferResult<DecodedTokenContext>)>,
Expand All @@ -180,7 +180,7 @@ fn post_processor_looper<const max_num_tokens: usize>(
.entry(request_id)
.and_modify(|s| s.push(*&ctx.token.id))
.or_insert_with(|| {
let mut state = Vec::with_capacity(max_num_tokens);
let mut state = Vec::with_capacity(MAX_NUM_TOKENS);
state.push(*&ctx.token.id);
state
});
Expand Down Expand Up @@ -279,7 +279,6 @@ fn ensure_paths_exist<P: AsRef<Path>, PP: AsRef<Path>>(
unsafe impl Send for TensorRtLlmBackendImpl {}

pub struct TensorRtLlmBackendV2 {
tokenizer: Tokenizer,
executor_looper: JoinHandle<()>,
post_processor_looper: JoinHandle<()>,
executor: UnboundedSender<GenerationContext>,
Expand Down Expand Up @@ -314,13 +313,11 @@ impl TensorRtLlmBackendV2 {
});

// Post processor looper is responsible from receiving a bunch of tokens, decoding them and sending them back to the user
let tokenizer_ = tokenizer.clone();
let post_processor_looper = spawn_blocking(move || {
post_processor_looper::<512>(tokenizer_, max_inflight_requests, post_processor_receiver)
post_processor_looper::<256>(tokenizer, max_inflight_requests, post_processor_receiver)
});

Ok(TensorRtLlmBackendV2 {
tokenizer,
executor_looper,
post_processor_looper,
executor: executor_sender,
Expand All @@ -347,7 +344,7 @@ impl TensorRtLlmBackendV2 {
"TensorRT-LLM backend don't support multi-chunk".into(),
)),
1 => match request.inputs.first().expect("Single item-chunk") {
Chunk::Text(text) => Ok(()),
Chunk::Text(_) => Ok(()),
Chunk::Image(_) => Err(ValidationError(UnsupportedModality("image"))),
},
}
Expand Down
17 changes: 2 additions & 15 deletions backends/trtllm/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tracing::info;

use text_generation_backends_trtllm::errors::TensorRtLlmBackendError;
use text_generation_backends_trtllm::TensorRtLlmBackendV2;
use text_generation_router::server::{create_post_processor, get_base_tokenizer};
use text_generation_router::server::get_base_tokenizer;
use text_generation_router::usage_stats::UsageStatsLevel;
use text_generation_router::{server, HubTokenizerConfig};

Expand Down Expand Up @@ -189,20 +189,7 @@ async fn get_tokenizer(
HubTokenizerConfig::default()
});

tokenizer_filename.and_then(|filename| {
let mut tokenizer = Tokenizer::from_file(filename).ok();
if let Some(tokenizer) = &mut tokenizer {
if let Some(class) = &tokenizer_config.tokenizer_class {
if class == "LlamaTokenizer" || class == "LlamaTokenizerFast"{
if let Ok(post_processor) = create_post_processor(tokenizer, &tokenizer_config) {
tracing::info!("Overriding LlamaTokenizer with TemplateProcessing to follow python override defined in https://github.com/huggingface/transformers/blob/4aa17d00690b7f82c95bb2949ea57e22c35b4336/src/transformers/models/llama/tokenization_llama_fast.py#L203-L205");
tokenizer.with_post_processor(post_processor);
}
}
}
}
tokenizer
})
tokenizer_filename.and_then(|filename| Tokenizer::from_file(filename).ok())
}

#[tokio::main]
Expand Down

0 comments on commit 2c8ecdb

Please sign in to comment.