Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mht-sharma committed Jan 3, 2025
1 parent fa14d71 commit 43370a1
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 98 deletions.
34 changes: 12 additions & 22 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,21 @@ def __init__(
):
"""Construct the key-value cache for a layer."""
if dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if (ATTENTION == "flashinfer" and SYSTEM == "cuda") or not (
ATTENTION == "paged" and SYSTEM == "rocm"
if not (
(ATTENTION == "flashinfer" and SYSTEM == "cuda")
or (ATTENTION == "paged" and SYSTEM == "rocm")
):
raise ValueError(
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCM"
"FP8 KV cache is currently only supported for flashinfer on CUDA and paged attention on ROCm. "
)
if SYSTEM == "rocm" and dtype == torch.float8_e5m2:
raise ValueError(
"float8_e5m2 FP8 KV cache is not supported on AMD Rocm"
"float8_e5m2 FP8 KV cache is not supported on AMD ROCm"
)

self.kv_cache_dtype_str = "auto"
self.kv_cache_dtype = "auto"
if SYSTEM == "rocm" and dtype == torch.float8_e4m3fn:
self.kv_cache_dtype_str = "fp8"
self.kv_cache_dtype = "fp8"
dtype = torch.uint8

element_size = torch.tensor([], dtype=dtype).element_size()
Expand Down Expand Up @@ -123,27 +124,16 @@ def can_scale(self, kv_scales: KVScales) -> bool:
self.dtype == torch.float8_e4m3fn
and ATTENTION == "flashinfer"
and SYSTEM == "cuda"
) or (
self.kv_cache_dtype == "fp8" and ATTENTION == "paged" and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
elif (
self.kv_cache_dtype_str == "fp8"
and ATTENTION == "paged"
and SYSTEM == "rocm"
):
log_once(
logger.info,
"Using FP8 KV cache scales",
)
log_once(logger.info, "Using FP8 KV cache scales")
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on flashinfer is supported",
"Ignoring FP8 KV cache scales, supported only for flashinfer on CUDA and paged attention on ROCm",
)
return False

Expand Down Expand Up @@ -213,7 +203,7 @@ def store(
key_cache,
value_cache,
slots,
self.kv_cache_dtype_str,
self.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
Expand Down
6 changes: 3 additions & 3 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def paged_attention(
block_size,
max_s,
None,
kv_cache.kv_cache_dtype_str,
kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
Expand Down Expand Up @@ -154,7 +154,7 @@ def paged_attention(
block_size,
max_s,
None,
kv_cache.kv_cache_dtype_str,
kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
Expand All @@ -174,7 +174,7 @@ def paged_attention(
block_size,
max_s,
None,
kv_cache.kv_cache_dtype_str,
kv_cache.kv_cache_dtype,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,10 @@ def forward(self, hidden_states, adapter_data):
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
output_shape = gate_up_states.shape[:-1] + (self.intermediate_size,)
out = torch.empty(
output_shape, dtype=gate_up_states.dtype, device=gate_up_states.device
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
ops.silu_and_mul(out, gate_up_states)
return self.down_proj(out, adapter_data)
# gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
# return self.down_proj(
# self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
# )


class FlashLlamaLayer(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,68 +520,28 @@ def forward(
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if (
torch.distributed.get_rank() == 0
and input_ids.shape[0] == 262144
and cu_seqlen_prefill is not None
):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)

hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)

prof.export_chrome_trace("/tgi/trace_mistral_prefill.json")
else:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)

hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)

hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
true_max_s,
prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits

0 comments on commit 43370a1

Please sign in to comment.