Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash decoding kernel adding and prefill-chunking and prefix caching enabling in intel cpu/xpu #2815

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Dockerfile_intel
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ WORKDIR /usr/src
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir
RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir

RUN pip install triton-xpu==3.0.0b2 --no-cache-dir
Expand All @@ -119,6 +118,9 @@ ENV CCL_ZE_IPC_EXCHANGE=sockets
#ENV TORCH_LLM_ALLREDUCE=1
#ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0

RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout c3e14caf792ad04824dd921e2fc3f16fca0d462e
RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch

# Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router
Expand Down Expand Up @@ -222,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher

FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV PREFIX_CACHING=0
ENV PREFILL_CHUNKING=0
ENV ATTENTION=flashdecoding
sywangyi marked this conversation as resolved.
Show resolved Hide resolved
ENV PREFIX_CACHING=1
ENV PREFILL_CHUNKING=1
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]
98 changes: 67 additions & 31 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)

SUPPORTS_WINDOWING = False

Expand All @@ -28,22 +31,38 @@ def attention(
out = torch.empty_like(query)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
causal,
block_tables,
None,
)
else:
ipex.llm.functional.varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)

return out

Expand All @@ -64,20 +83,37 @@ def paged_attention(
raise NotImplementedError("softcap is not available in IPEX")

out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)

if ATTENTION == "flashdecoding":
ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
query.contiguous() if query.device.type == "xpu" else query,
kv_cache.key,
kv_cache.value,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
seqlen.max_q,
seqlen.max_k,
softmax_scale,
True,
block_tables,
None,
)
else:
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out


Expand Down
8 changes: 6 additions & 2 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(
else:
x = BLOCK_SIZE // element_size

if ATTENTION in {"flashdecoding", "flashinfer"}:
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and device == torch.device("cpu")
):
self.kv_cache = (
torch.empty(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
Expand Down Expand Up @@ -174,7 +176,9 @@ def store(
scalar=True,
)[0]

if ATTENTION in {"flashdecoding", "flashinfer"}:
if ATTENTION in {"flashdecoding", "flashinfer"} and not (
SYSTEM == "ipex" and key.device == torch.device("cpu")
):
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
Expand Down
7 changes: 6 additions & 1 deletion server/text_generation_server/models/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict, Optional

from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM

REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.environ["ATTENTION"]
Expand All @@ -28,10 +29,14 @@
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1


# This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
if SYSTEM == "ipex":
BLOCK_SIZE = 64
else:
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
Expand Down