diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 67105057709..a0823c8d07e 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -53,20 +53,21 @@ def __init__( ): """Construct the key-value cache for a layer.""" if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}: - if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not ( - ATTENTION == "paged" and SYSTEM == "rocm" + if not ( + (ATTENTION == "flashinfer" and SYSTEM == "cuda") + or (ATTENTION == "paged" and SYSTEM == "rocm") ): raise ValueError( - "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM" + "FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. " ) if SYSTEM == "rocm" and dtype == torch.float8_e5m2: raise ValueError( - "float8_e5m2 FP8 KV cache is not supported on AMD Rocm" + "float8_e5m2 FP8 KV cache is not supported on AMD ROCm" ) - self.kv_cache_dtype_str = "auto" + self.kv_cache_dtype = "auto" if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn: - self.kv_cache_dtype_str = "fp8" + self.kv_cache_dtype = "fp8" dtype = torch.uint8 element_size = torch.tensor([], dtype=dtype).element_size() @@ -123,27 +124,16 @@ def can_scale(self, kv_scales: KVScales) -> bool: self.dtype == torch.float8_e4m3fn and ATTENTION == "flashinfer" and SYSTEM == "cuda" + ) or ( + self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm" ): - log_once( - logger.info, - "Using FP8 KV cache scales", - ) - return True - elif ( - self.kv_cache_dtype_str == "fp8" - and ATTENTION == "paged" - and SYSTEM == "rocm" - ): - log_once( - logger.info, - "Using FP8 KV cache scales", - ) + log_once(logger.info, "Using FP8 KV cache scales") return True else: # We have scales, but not the correct FP8 cache type, so warn once. log_once( logger.info, - "Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported", + "Ignoring FP8 KV cache scales, supported only for flashinfer on CUDA and paged attention on ROCm", ) return False @@ -213,7 +203,7 @@ def store( key_cache, value_cache, slots, - self.kv_cache_dtype_str, + self.kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index bc790f06ee3..a5ab0ae9676 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -119,7 +119,7 @@ def paged_attention( block_size, max_s, None, - kv_cache.kv_cache_dtype_str, + kv_cache.kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, ) @@ -154,7 +154,7 @@ def paged_attention( block_size, max_s, None, - kv_cache.kv_cache_dtype_str, + kv_cache.kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, ) @@ -174,7 +174,7 @@ def paged_attention( block_size, max_s, None, - kv_cache.kv_cache_dtype_str, + kv_cache.kv_cache_dtype, kv_scales.key_scale_cpu, kv_scales.value_scale_cpu, None, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 53df59dfca1..10309006af9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -398,16 +398,10 @@ def forward(self, hidden_states, adapter_data): return self.down_proj(out, adapter_data) else: gate_up_states = self.gate_up_proj(hidden_states, adapter_data) - output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,) - out = torch.empty( - output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data ) - ops.silu_and_mul(out, gate_up_states) - return self.down_proj(out, adapter_data) - # gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - # return self.down_proj( - # self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data - # ) class FlashLlamaLayer(nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 1bc6c7d4766..a45dd1e615e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -520,68 +520,28 @@ def forward( lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - if ( - torch.distributed.get_rank() == 0 - and input_ids.shape[0] == 262144 - and cu_seqlen_prefill is not None - ): - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - ) as prof: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) - - hidden_states = self.model( - input_ids, - position_ids, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - seqlen, - max_s, - true_max_s, - prefill_cache_indices, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - - prof.export_chrome_trace("/tgi/trace_mistral_prefill.json") - else: - true_max_s = max_s - if prefill_cache_indices is not None: - # Slots also need to be sliced as it has the same size as the whole kv tensor - slots = slots[prefill_cache_indices] - elif self.max_past is not None: - # Clamp in decode mode as paged attention requires clamped values whereas the flash attention - # kernel requires the true values - seqlen = seqlen.clamp(max=self.max_past_tensor) - - hidden_states = self.model( - input_ids, - position_ids, - cu_seqlen_prefill, - kv_cache, - block_tables, - slots, - seqlen, - max_s, - true_max_s, - prefill_cache_indices, - ) - if lm_head_indices is not None: - hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) + true_max_s = max_s + if prefill_cache_indices is not None: + # Slots also need to be sliced as it has the same size as the whole kv tensor + slots = slots[prefill_cache_indices] + elif self.max_past is not None: + # Clamp in decode mode as paged attention requires clamped values whereas the flash attention + # kernel requires the true values + seqlen = seqlen.clamp(max=self.max_past_tensor) + + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + seqlen, + max_s, + true_max_s, + prefill_cache_indices, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) return logits