From d7c991b0d11dbe13d23247c768be7a0ddb7dc472 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 5 Nov 2024 00:48:23 -0800 Subject: [PATCH 1/5] flash decoding Signed-off-by: Wang, Yi A --- .../layers/attention/ipex.py | 98 +++++++++++++------ .../layers/attention/kv_cache.py | 4 +- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 677f3f5647d..9b37c1b5771 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,9 +1,12 @@ import intel_extension_for_pytorch as ipex import torch from text_generation_server.layers.attention.kv_cache import KVCache, KVScales -from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) SUPPORTS_WINDOWING = False @@ -28,22 +31,38 @@ def attention( out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - ipex.llm.functional.varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_q, - 0.0, - softmax_scale, - False, - causal, - False, - None, - ) + if ATTENTION == "flashdecoding": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + causal, + block_tables, + None, + ) + else: + ipex.llm.functional.varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, + 0.0, + softmax_scale, + False, + causal, + False, + None, + ) return out @@ -64,20 +83,37 @@ def paged_attention( raise NotImplementedError("softcap is not available in IPEX") out = torch.empty_like(query) - input_lengths = seqlen.input_lengths + seqlen.cache_lengths - ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( - out, - query, - kv_cache.key, - kv_cache.value, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - BLOCK_SIZE, - max_s, - None, - ) + + if ATTENTION == "flashdecoding": + ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + out, + query, + kv_cache.key, + kv_cache.value, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + seqlen.max_q, + seqlen.max_k, + softmax_scale, + True, + block_tables, + None, + ) + else: + input_lengths = seqlen.input_lengths + seqlen.cache_lengths + ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + kv_cache.key, + kv_cache.value, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + BLOCK_SIZE, + max_s, + None, + ) return out diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index cad1d98a0b8..191771ca64d 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,7 +66,7 @@ def __init__( else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"}: + if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": self.kv_cache = ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), @@ -174,7 +174,7 @@ def store( scalar=True, )[0] - if ATTENTION in {"flashdecoding", "flashinfer"}: + if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: From d04c86c76c37139e416d6848e68a04d10c645f70 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 24 Nov 2024 21:40:00 -0800 Subject: [PATCH 2/5] enable xpu flashdecoding Signed-off-by: Wang, Yi A --- server/text_generation_server/layers/attention/ipex.py | 4 ++-- .../text_generation_server/layers/attention/kv_cache.py | 8 ++++++-- server/text_generation_server/models/globals.py | 6 +++++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 9b37c1b5771..817dfbd3c32 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -34,7 +34,7 @@ def attention( if ATTENTION == "flashdecoding": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, - query, + query.contiguous() if query.device.type == "xpu" else query, kv_cache.key, kv_cache.value, seqlen.cu_seqlen_q, @@ -87,7 +87,7 @@ def paged_attention( if ATTENTION == "flashdecoding": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, - query, + query.contiguous() if query.device.type == "xpu" else query, kv_cache.key, kv_cache.value, seqlen.cu_seqlen_q, diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 191771ca64d..0ff6522ea86 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,7 +66,9 @@ def __init__( else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": + if ATTENTION in {"flashdecoding", "flashinfer"} and not ( + SYSTEM == "ipex" and device == torch.device("cpu") + ): self.kv_cache = ( torch.empty( (num_blocks, BLOCK_SIZE, num_heads, head_size), @@ -174,7 +176,9 @@ def store( scalar=True, )[0] - if ATTENTION in {"flashdecoding", "flashinfer"} and SYSTEM != "ipex": + if ATTENTION in {"flashdecoding", "flashinfer"} and not ( + SYSTEM == "ipex" and key.device == torch.device("cpu") + ): key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 4ac6a6b499f..3561f13f1cb 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,6 +4,7 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master +from text_generation_server.utils.import_utils import SYSTEM ATTENTION = os.environ["ATTENTION"] # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" @@ -27,9 +28,12 @@ assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM < 1 + # This is overridden by the cli BLOCK_SIZE: int -if ATTENTION == "flashdecoding": +if SYSTEM == "ipex": + BLOCK_SIZE = 16 +elif ATTENTION == "flashdecoding": BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1 From 4ee4ebc03bdfb950153bfc0fe80055603aa3f3c1 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 1 Dec 2024 18:55:05 -0800 Subject: [PATCH 3/5] set flashdecoding blocksize as 64 Signed-off-by: Wang, Yi A --- server/text_generation_server/models/globals.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 3561f13f1cb..8b4c6d4fea4 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -31,10 +31,11 @@ # This is overridden by the cli BLOCK_SIZE: int -if SYSTEM == "ipex": - BLOCK_SIZE = 16 -elif ATTENTION == "flashdecoding": - BLOCK_SIZE = 256 +if ATTENTION == "flashdecoding": + if SYSTEM == "ipex": + BLOCK_SIZE = 64 + else: + BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1 else: From fd4f861d2ccc73b34a6048135a42411395a77e47 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Mon, 9 Dec 2024 20:49:10 -0800 Subject: [PATCH 4/5] enable flashdecoding, prefill chunking and prefix caching Signed-off-by: Wang, Yi A --- Dockerfile_intel | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index e024f31a563..69c6b479b08 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -100,7 +100,6 @@ WORKDIR /usr/src RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchaudio-2.5.0a0%2B56bc006-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torchvision-0.20.0a0%2B8e8a208-cp311-cp311-linux_x86_64.whl --no-cache-dir -RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp311-cp311-linux_x86_64.whl --no-cache-dir RUN pip install triton-xpu==3.0.0b2 --no-cache-dir @@ -119,6 +118,9 @@ ENV CCL_ZE_IPC_EXCHANGE=sockets #ENV TORCH_LLM_ALLREDUCE=1 #ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0 +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout c3e14caf792ad04824dd921e2fc3f16fca0d462e +RUN cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=OFF BUILD_WITH_CPU=OFF USE_XETLA=ON python setup.py install && rm -rf /usr/src/intel-extension-for-pytorch + # Install benchmarker COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark # Install router @@ -222,9 +224,9 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final -ENV ATTENTION=paged -ENV PREFIX_CACHING=0 -ENV PREFILL_CHUNKING=0 +ENV ATTENTION=flashdecoding +ENV PREFIX_CACHING=1 +ENV PREFILL_CHUNKING=1 ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] From d6ac8cdf8180e8821966f2cdc102a4413e02f3ef Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 19 Dec 2024 18:33:25 -0800 Subject: [PATCH 5/5] add flashdecoding-ipex Signed-off-by: Wang, Yi A --- Dockerfile_intel | 2 +- launcher/src/main.rs | 4 +++- .../layers/attention/ipex.py | 4 ++-- .../layers/attention/kv_cache.py | 15 ++++++++++----- server/text_generation_server/models/globals.py | 16 +++++++++------- server/text_generation_server/models/model.py | 7 +++++-- 6 files changed, 30 insertions(+), 18 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 99c6ba8b45e..bc9071b8121 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final -ENV ATTENTION=flashdecoding +ENV ATTENTION=flashdecoding-ipex ENV PREFIX_CACHING=1 ENV PREFILL_CHUNKING=1 ENV CUDA_GRAPHS=0 diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fb6ba2b2554..84dd6450ad0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -144,7 +144,9 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } - let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { + let fallback_attention = if compute_capability.is_none() + || matches!(compute_capability, Some((major, _)) if major < 8) + { "paged" } else { "flashdecoding" diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 817dfbd3c32..54422308f8a 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -31,7 +31,7 @@ def attention( out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - if ATTENTION == "flashdecoding": + if ATTENTION == "flashdecoding-ipex": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, @@ -84,7 +84,7 @@ def paged_attention( out = torch.empty_like(query) - if ATTENTION == "flashdecoding": + if ATTENTION == "flashdecoding-ipex": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 969e41c0345..00308601838 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,8 +66,8 @@ def __init__( else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"} and not ( - SYSTEM == "ipex" and device == torch.device("cpu") + if ATTENTION in {"flashdecoding", "flashinfer"} or ( + ATTENTION == "flashdecoding-ipex" and device.type == "xpu" ): self.kv_cache = ( torch.empty( @@ -82,6 +82,7 @@ def __init__( ), ) elif SYSTEM == "ipex" and device == torch.device("cpu"): + # ipex cpu flashdecoding kernel and paged attention kernel share same layout self.kv_cache = ( torch.empty( (num_blocks, num_heads, BLOCK_SIZE, head_size), @@ -176,9 +177,7 @@ def store( scalar=True, )[0] - if ATTENTION in {"flashdecoding", "flashinfer"} and not ( - SYSTEM == "ipex" and key.device == torch.device("cpu") - ): + if ATTENTION in {"flashdecoding", "flashinfer"}: key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: @@ -191,6 +190,12 @@ def store( shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value + elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu": + import intel_extension_for_pytorch as ipex + + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slots + ) else: paged_reshape_and_cache(key, value, key_cache, value_cache, slots) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 5f3d8f35115..889de028dc7 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,6 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -from text_generation_server.utils.import_utils import SYSTEM REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"] @@ -15,13 +14,17 @@ } PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer"} +_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: +if PREFIX_CACHING and ATTENTION not in { + "flashinfer", + "flashdecoding", + "flashdecoding-ipex", +}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -33,12 +36,11 @@ # This is overridden by the cli BLOCK_SIZE: int if ATTENTION == "flashdecoding": - if SYSTEM == "ipex": - BLOCK_SIZE = 64 - else: - BLOCK_SIZE = 256 + BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1 +elif ATTENTION == "flashdecoding-ipex": + BLOCK_SIZE = 64 else: BLOCK_SIZE = 16 diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 805fd771058..af4d1f082dc 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -79,10 +79,13 @@ def __init__( "Prefill chunking will be turned off", ) support_chunking = False - if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: + if ( + ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"] + and support_chunking + ): log_master( logger.warning, - "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", + "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.", ) support_chunking = False