Skip to content

Commit

Permalink
Pass the max_batch_total_tokens to causal_lm
Browse files Browse the repository at this point in the history
refine the warmup

Signed-off-by: yuanwu <[email protected]>
  • Loading branch information
yuanwu2017 committed Oct 23, 2024
1 parent bab529c commit 67ee45a
Show file tree
Hide file tree
Showing 15 changed files with 155 additions and 114 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,5 @@ FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

#ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"]
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]
2 changes: 2 additions & 0 deletions backends/client/src/v2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
Expand Down Expand Up @@ -175,6 +176,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Expand Down
2 changes: 2 additions & 0 deletions backends/client/src/v2/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
Expand All @@ -114,6 +115,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})
Expand Down
2 changes: 2 additions & 0 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
Expand Down Expand Up @@ -203,6 +204,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Expand Down
2 changes: 2 additions & 0 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
Expand All @@ -114,6 +115,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})
Expand Down
5 changes: 5 additions & 0 deletions backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ impl BackendV3 {
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
Expand All @@ -51,6 +53,8 @@ impl BackendV3 {
prefix_caching,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
max_batch_total_tokens,
);
let batching_task_notifier = Arc::new(Notify::new());
Expand Down Expand Up @@ -152,6 +156,7 @@ pub(crate) async fn batching_task(
.await;
let mut waiting_tokens = 1;

tracing::error!("Enter cached batch loop");
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
Expand Down
2 changes: 2 additions & 0 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
Expand Down Expand Up @@ -203,6 +204,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Expand Down
2 changes: 2 additions & 0 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
Expand All @@ -115,6 +116,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})
Expand Down
3 changes: 3 additions & 0 deletions backends/v3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub async fn connect_backend(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
max_batch_size,
)
.await
Expand All @@ -114,6 +115,8 @@ pub async fn connect_backend(
let backend = BackendV3::new(
sharded_client,
waiting_served_ratio,
max_input_tokens as u32,
max_total_tokens as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
Expand Down
37 changes: 34 additions & 3 deletions backends/v3/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ impl Queue {
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
) -> Self {
// Create channel
Expand All @@ -61,6 +63,8 @@ impl Queue {
prefix_caching,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
max_batch_total_tokens,
queue_receiver,
));
Expand Down Expand Up @@ -114,6 +118,8 @@ async fn queue_task(
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
Expand All @@ -123,6 +129,8 @@ async fn queue_task(
prefix_caching,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
max_batch_total_tokens,
);

Expand Down Expand Up @@ -174,6 +182,15 @@ struct State {

/// Paged Attention Block Allocation
block_allocator: Option<BlockAllocator>,

/// Require padding
requires_padding: bool,

/// max input tokens
max_input_tokens: u32,

/// max total tokens,
max_total_tokens: u32,
}

impl State {
Expand All @@ -183,6 +200,8 @@ impl State {
prefix_caching: bool,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
) -> Self {
let block_allocator = (!requires_padding).then(|| {
Expand All @@ -202,6 +221,9 @@ impl State {
window_size,
speculate,
block_allocator,
requires_padding,
max_input_tokens,
max_total_tokens,
}
}

Expand Down Expand Up @@ -272,10 +294,19 @@ impl State {
None => {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
if self.requires_padding {
prefill_tokens = (batch.len() + 1) as u32 * self.max_input_tokens;
} else{
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch.len() + 1) as u32 * max_input_length;
}

if self.requires_padding {
decode_tokens = (batch.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
} else {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
}

decode_tokens += entry.request.stopping_parameters.max_new_tokens;
let total_tokens = prefill_tokens + decode_tokens + self.speculate;

if prefill_tokens > prefill_token_budget || total_tokens > token_budget {
Expand Down
1 change: 1 addition & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ message WarmupRequest {
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
uint32 max_batch_total_tokens = 5;
}

message WarmupResponse {
Expand Down
1 change: 1 addition & 0 deletions proto/v3/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ message WarmupRequest {
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
uint32 max_batch_total_tokens = 5;
}

message WarmupResponse {
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def serve(
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = "bfloat16" if dtype is None else dtype.value
logger.info(f"quantize={quantize}")
if dtype is not None and quantize not in {
None,
"bitsandbytes",
Expand Down
Loading

0 comments on commit 67ee45a

Please sign in to comment.