Skip to content

Commit

Permalink
add flashdecoding-ipex
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi committed Dec 20, 2024
1 parent ac67673 commit d6ac8cd
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Dockerfile_intel
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher

FROM ${PLATFORM} AS final
ENV ATTENTION=flashdecoding
ENV ATTENTION=flashdecoding-ipex
ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0
Expand Down
4 changes: 3 additions & 1 deletion launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
}
}

let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) {
let fallback_attention = if compute_capability.is_none()
|| matches!(compute_capability, Some((major, _)) if major < 8)
{
"paged"
} else {
"flashdecoding"
Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def attention(
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
if ATTENTION == "flashdecoding":
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
Expand Down Expand Up @@ -84,7 +84,7 @@ def paged_attention(

out = torch.empty_like(query)

if ATTENTION == "flashdecoding":
if ATTENTION == "flashdecoding-ipex":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
Expand Down
15 changes: 10 additions & 5 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(
else:
x = BLOCK_SIZE // element_size

if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and device == torch.device("cpu")
if ATTENTION in {"flashdecoding", "flashinfer"} or (
ATTENTION == "flashdecoding-ipex" and device.type == "xpu"
):
self.kv_cache = (
torch.empty(
Expand All @@ -82,6 +82,7 @@ def __init__(
),
)
elif SYSTEM == "ipex" and device == torch.device("cpu"):
# ipex cpu flashdecoding kernel and paged attention kernel share same layout
self.kv_cache = (
torch.empty(
(num_blocks, num_heads, BLOCK_SIZE, head_size),
Expand Down Expand Up @@ -176,9 +177,7 @@ def store(
scalar=True,
)[0]

if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and key.device == torch.device("cpu")
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
Expand All @@ -191,6 +190,12 @@ def store(
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu":
import intel_extension_for_pytorch as ipex

ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slots
)
else:
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)

Expand Down
16 changes: 9 additions & 7 deletions server/text_generation_server/models/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Dict, Optional

from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM

REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
Expand All @@ -15,13 +14,17 @@
}
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer"}
_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")

if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
if PREFIX_CACHING and ATTENTION not in {
"flashinfer",
"flashdecoding",
"flashdecoding-ipex",
}:
raise RuntimeError("Prefix caching is only supported with flashinfer")

MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
Expand All @@ -33,12 +36,11 @@
# This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
if SYSTEM == "ipex":
BLOCK_SIZE = 64
else:
BLOCK_SIZE = 256
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
elif ATTENTION == "flashdecoding-ipex":
BLOCK_SIZE = 64
else:
BLOCK_SIZE = 16

Expand Down
7 changes: 5 additions & 2 deletions server/text_generation_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ def __init__(
"Prefill chunking will be turned off",
)
support_chunking = False
if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking:
if (
ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"]
and support_chunking
):
log_master(
logger.warning,
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.",
"Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.",
)
support_chunking = False

Expand Down

0 comments on commit d6ac8cd

Please sign in to comment.