diff --git a/Dockerfile_intel b/Dockerfile_intel index 720d7bee012..bc9071b8121 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -100,7 +100,6 @@ WORKDIR /usr/src RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir -RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-optimized-pytorch.s3.cn-north-1.amazonaws.com.cn/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install triton-xpu==3.0.0b2 --no-cache-dir @@ -119,6 +118,9 @@ ENV CCL_ZE_IPC_EXCHANGE=sockets #ENV TORCH_LLM_ALLREDUCE=1 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout c3e14caf792ad04824dd921e2fc3f16fca0d462e +RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router @@ -222,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final -ENV ATTENTION=paged -ENV PREFIX_CACHING=0 -ENV PREFILL_CHUNKING=0 +ENV ATTENTION=flashdecoding-ipex +ENV PREFIX_CACHING=1 +ENV PREFILL_CHUNKING=1 ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fb6ba2b2554..84dd6450ad0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -144,7 +144,9 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } - let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { + let fallback_attention = if compute_capability.is_none() + || matches!(compute_capability, Some((major, _)) if major < 8) + { "paged" } else { "flashdecoding" diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 677f3f5647d..54422308f8a 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,9 +1,12 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.layers.attention.kv_cache import KVCache, KVScales -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) SUPPORTS_WINDOWING = False @@ -28,22 +31,38 @@ def attention( out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) + if ATTENTION == "flashdecoding-ipex": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query.contiguous() if query.device.type == "xpu" else query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + causal, + block_tables, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, + 0.0, + softmax_scale, + False, + causal, + False, + None, + ) return out @@ -64,20 +83,37 @@ def paged_attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) - input_lengths = seqlen.input_lengths + seqlen.cache_lengths - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - kv_cache.key, - kv_cache.value, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - BLOCK_SIZE, - max_s, - None, - ) + + if ATTENTION == "flashdecoding-ipex": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query.contiguous() if query.device.type == "xpu" else query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + True, + block_tables, + None, + ) + else: + input_lengths = seqlen.input_lengths + seqlen.cache_lengths + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + kv_cache.key, + kv_cache.value, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + BLOCK_SIZE, + max_s, + None, + ) return out diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 93d74732408..00308601838 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,7 +66,9 @@ def __init__( else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"}: + if ATTENTION in {"flashdecoding", "flashinfer"} or ( + ATTENTION == "flashdecoding-ipex" and device.type == "xpu" + ): self.kv_cache = ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), @@ -80,6 +82,7 @@ def __init__( ), ) elif SYSTEM == "ipex" and device == torch.device("cpu"): + # ipex cpu flashdecoding kernel and paged attention kernel share same layout self.kv_cache = ( torch.empty( (num_blocks, num_heads, BLOCK_SIZE, head_size), @@ -187,6 +190,12 @@ def store( shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value + elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu": + import intel_extension_for_pytorch as ipex + + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slots + ) else: paged_reshape_and_cache(key, value, key_cache, value_cache, slots) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index ce8791411f9..889de028dc7 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -14,13 +14,17 @@ } PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer"} +_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: +if PREFIX_CACHING and ATTENTION not in { + "flashinfer", + "flashdecoding", + "flashdecoding-ipex", +}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -28,12 +32,15 @@ assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 + # This is overridden by the cli BLOCK_SIZE: int if ATTENTION == "flashdecoding": BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1 +elif ATTENTION == "flashdecoding-ipex": + BLOCK_SIZE = 64 else: BLOCK_SIZE = 16 diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 805fd771058..af4d1f082dc 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -79,10 +79,13 @@ def __init__( "Prefill chunking will be turned off", ) support_chunking = False - if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: + if ( + ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"] + and support_chunking + ): log_master( logger.warning, - "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", + "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.", ) support_chunking = False