Skip to content

Commit

Permalink
Fix the issues of tgi-gaudi for v.2.3.1
Browse files Browse the repository at this point in the history
Signed-off-by: yuanwu <[email protected]>
  • Loading branch information
yuanwu2017 committed Oct 27, 2024
1 parent 7e282b4 commit 372e071
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 37 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ COPY benchmark benchmark
COPY router router
COPY backends backends
COPY launcher launcher

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder
Expand Down Expand Up @@ -44,6 +43,10 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest as base

ENV ATTENTION=default
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0

# Text Generation Inference base env
ENV HF_HOME=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
Expand Down
13 changes: 12 additions & 1 deletion backends/v2/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ impl BackendV2 {
pub(crate) fn new(
client: ShardedClient,
waiting_served_ratio: f32,
max_input_tokens: u32,
max_total_tokens: u32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
Expand All @@ -48,7 +50,16 @@ impl BackendV2 {
} else {
16
};
let queue = Queue::new(requires_padding, block_size, window_size, speculate);

let queue = Queue::new(
requires_padding,
block_size,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
);

let batching_task_notifier = Arc::new(Notify::new());

// Spawn batching background task that contains all the inference logic
Expand Down
2 changes: 2 additions & 0 deletions backends/v2/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl Client {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let mut n_tokens = 0;
Expand Down Expand Up @@ -174,6 +175,7 @@ impl Client {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Expand Down
2 changes: 2 additions & 0 deletions backends/v2/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl ShardedClient {
max_input_length: u32,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_batch_total_tokens: u32,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
let futures: Vec<_> = self
Expand All @@ -115,6 +116,7 @@ impl ShardedClient {
max_input_length,
max_prefill_tokens,
max_total_tokens,
max_batch_total_tokens,
max_batch_size,
))
})
Expand Down
3 changes: 3 additions & 0 deletions backends/v2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub async fn connect_backend(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_total_tokens.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))),
max_batch_size,
)
.await
Expand All @@ -112,6 +113,8 @@ pub async fn connect_backend(
let backend = BackendV2::new(
sharded_client,
waiting_served_ratio,
max_input_tokens as u32,
max_total_tokens as u32,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
Expand Down
31 changes: 26 additions & 5 deletions backends/v2/src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ impl Queue {
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
) -> Self {
// Create channel
let (queue_sender, queue_receiver) = mpsc::unbounded_channel();
Expand All @@ -53,6 +55,8 @@ impl Queue {
block_size,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
queue_receiver,
));

Expand Down Expand Up @@ -103,9 +107,18 @@ async fn queue_task(
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
mut receiver: mpsc::UnboundedReceiver<QueueCommand>,
) {
let mut state = State::new(requires_padding, block_size, window_size, speculate);
let mut state = State::new(
requires_padding,
block_size,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
);

while let Some(cmd) = receiver.recv().await {
match cmd {
Expand Down Expand Up @@ -153,6 +166,12 @@ struct State {

/// Speculation amount
speculate: u32,

/// max input tokens
max_input_tokens: u32,

/// max total tokens,
max_total_tokens: u32,
}

impl State {
Expand All @@ -161,6 +180,8 @@ impl State {
block_size: u32,
window_size: Option<u32>,
speculate: u32,
max_input_tokens: u32,
max_total_tokens: u32,
) -> Self {
Self {
entries: VecDeque::with_capacity(128),
Expand All @@ -170,6 +191,8 @@ impl State {
block_size,
window_size,
speculate,
max_input_tokens,
max_total_tokens,
}
}

Expand Down Expand Up @@ -224,7 +247,6 @@ impl State {
let mut batch_entries =
IntMap::with_capacity_and_hasher(self.entries.len(), BuildNoHashHasher::default());

let mut max_input_length = 0;
let mut prefill_tokens: u32 = 0;
let mut decode_tokens: u32 = 0;

Expand All @@ -241,8 +263,7 @@ impl State {
if self.requires_padding {
// We pad to max input length in the Python shards
// We need to take these padding tokens into the equation
max_input_length = max_input_length.max(entry.request.input_length);
prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length
prefill_tokens = (batch_requests.len() + 1) as u32 * self.max_input_tokens
} else {
// pad to block size
prefill_tokens += ((entry.request.input_length + self.block_size - 1)
Expand All @@ -251,7 +272,7 @@ impl State {
}

if self.requires_padding {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
decode_tokens = (batch_requests.len() + 1) as u32 * (self.max_total_tokens - self.max_input_tokens);
} else {
let max_new_tokens = match self.window_size {
None => entry.request.stopping_parameters.max_new_tokens,
Expand Down
1 change: 0 additions & 1 deletion backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ pub(crate) async fn batching_task(
.await;
let mut waiting_tokens = 1;

tracing::error!("Enter cached batch loop");
// We loop until we do not receive any cached batch from the inference server (== until
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
Expand Down
3 changes: 3 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub enum Attention {
Paged,
FlashDecoding,
FlashInfer,
Default,
}

impl Attention {
Expand All @@ -31,6 +32,7 @@ impl Attention {
Attention::FlashDecoding => 256,
Attention::FlashInfer => 1,
Attention::Paged => 16,
Attention::Default => 16,
}
}
}
Expand All @@ -52,6 +54,7 @@ impl std::str::FromStr for Attention {
"paged" => Ok(Attention::Paged),
"flashdecoding" => Ok(Attention::FlashDecoding),
"flashinfer" => Ok(Attention::FlashInfer),
"default" => Ok(Attention::Default),
_ => Err(ParseError),
}
}
Expand Down
24 changes: 5 additions & 19 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration,
# )
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
Expand Down Expand Up @@ -196,20 +196,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)

if model_type == "mllama":
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)

if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ def warmup(self, request) -> None:
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
self.limit_hpu_graph = True
try:
for batch_size in range(max_decode_batch_size, BATCH_BUCKET_SIZE, -BATCH_BUCKET_SIZE):
for batch_size in range(max_decode_batch_size, 0, -BATCH_BUCKET_SIZE):
batches= []
iters = math.floor(batch_size/max_prefill_batch_size)
DECODE_WARMUP_BATCH_SIZE_LIST.append(batch_size)
Expand Down
32 changes: 32 additions & 0 deletions server/text_generation_server/models/globals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
import torch
import os
from typing import Dict, Optional
from loguru import logger
from text_generation_server.utils.log import log_master

ATTENTION = os.environ["ATTENTION"]
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.environ["PREFIX_CACHING"].lower() in {
"1",
"true",
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer", "default"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")

if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")

MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1

# This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16

# This is overridden by the cli
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
Expand Down
10 changes: 3 additions & 7 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
from loguru import logger

logger.info(
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features

elif config.model_type == "paligemma":
Expand Down Expand Up @@ -373,9 +369,9 @@ def batch_tokenized_inputs(
(image_inputs["pixel_attention_mask"], dummy_attention), dim=0
)
if "image_sizes" in image_inputs:
dummy_shape = list(image_inputs['image_sizes'].shape)
dummy_shape[0] = missing_inputs
dummy_sizes = torch.randint(dummy_shape)
dummy_shape = list(list(image_inputs['image_sizes'])[0])
dummy_shape = missing_inputs*[dummy_shape]
dummy_sizes = torch.IntTensor(dummy_shape)
new_image_inputs["image_sizes"] = torch.cat(
(image_inputs["image_sizes"], dummy_sizes), dim=0
)
Expand Down
2 changes: 0 additions & 2 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch,
)
Expand All @@ -35,7 +34,6 @@
PaliGemmaBatch,
VlmCausalLMBatch,
IdeficsCausalLMBatch,
MllamaCausalLMBatch,
}
except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash.
Expand Down

0 comments on commit 372e071

Please sign in to comment.