Skip to content

Commit

Permalink
Using both value from config as they might not be correct. (#2817)
Browse files Browse the repository at this point in the history
* Using both value from config as they might not be correct.

* Fixing max_position_embeddings for falcon.

* Simple attempt to fix the healthcheck block allocation.

* Much simpler solution.

* Default value for Backend start_health
  • Loading branch information
Narsil authored Dec 10, 2024
1 parent a2d878f commit 82c24f7
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 5 deletions.
4 changes: 4 additions & 0 deletions backends/v2/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ impl Backend for BackendV2 {
}
.is_ok()
}

fn start_health(&self) -> bool {
true
}
}

/// Batching logic
Expand Down
4 changes: 4 additions & 0 deletions backends/v3/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ impl Backend for BackendV3 {
}
.is_ok()
}

fn start_health(&self) -> bool {
true
}
}

/// Batching logic
Expand Down
6 changes: 3 additions & 3 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ impl Health for ShardedClient {
input_chunks: Some(Input {
chunks: vec![Chunk::Text("liveness".into()).into()],
}),
truncate: 10,
add_special_tokens: true,
truncate: 1,
add_special_tokens: false,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
Expand All @@ -241,7 +241,7 @@ impl Health for ShardedClient {
top_n_tokens: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
slots: vec![0],
cache_len: 0,
adapter_id: None,
chunk_len: None,
Expand Down
9 changes: 8 additions & 1 deletion router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ pub trait Backend {
) -> Result<UnboundedReceiverStream<Result<InferStreamResponse, InferError>>, InferError>;

async fn health(&self, current_health: bool) -> bool;

/// The state of the health on startup
/// Typically false, or true if the backend includes
/// a warmup phase.
fn start_health(&self) -> bool {
false
}
}

/// Inference struct
Expand Down Expand Up @@ -75,7 +82,7 @@ impl Infer {
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));

// Backend health
let backend_health = Arc::new(AtomicBool::new(false));
let backend_health = Arc::new(AtomicBool::new(backend.start_health()));

Self {
validation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.alibi = False
self.rotary = True
self.rope_theta = rope_theta
self.max_position_embeddings = 2048

self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
Expand Down
6 changes: 5 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,7 @@ def __init__(

self.num_layers = config.num_hidden_layers
self.num_heads = config.num_attention_heads // self.process_group.size()
self.config = config
# Validation is done in the model itself
if num_kv_heads is None:
num_kv_heads = getattr(config, "num_key_value_heads", None)
Expand Down Expand Up @@ -1594,7 +1595,10 @@ def warmup(
if max_total_tokens is None:
if get_support_chunking():
model_max_length = self.tokenizer.model_max_length
max_total_tokens = min(num_blocks * BLOCK_SIZE, model_max_length)
max_position_embeddings = self.config.max_position_embeddings
max_total_tokens = min(
num_blocks * BLOCK_SIZE, model_max_length, max_position_embeddings
)
else:
max_total_tokens = sum(batch.cache_lengths)

Expand Down

0 comments on commit 82c24f7

Please sign in to comment.