diff --git a/Dockerfile_intel b/Dockerfile_intel index 99c6ba8b45e..bc9071b8121 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -224,7 +224,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final -ENV ATTENTION=flashdecoding +ENV ATTENTION=flashdecoding-ipex ENV PREFIX_CACHING=1 ENV PREFILL_CHUNKING=1 ENV CUDA_GRAPHS=0 diff --git a/launcher/src/main.rs b/launcher/src/main.rs index fb6ba2b2554..84dd6450ad0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -144,7 +144,9 @@ fn resolve_attention(config: &Option, lora_adapters: &Option) -> } } - let fallback_attention = if matches!(compute_capability, Some((major, _)) if major < 8) { + let fallback_attention = if compute_capability.is_none() + || matches!(compute_capability, Some((major, _)) if major < 8) + { "paged" } else { "flashdecoding" diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 817dfbd3c32..54422308f8a 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -31,7 +31,7 @@ def attention( out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. - if ATTENTION == "flashdecoding": + if ATTENTION == "flashdecoding-ipex": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, @@ -84,7 +84,7 @@ def paged_attention( out = torch.empty_like(query) - if ATTENTION == "flashdecoding": + if ATTENTION == "flashdecoding-ipex": ipex.llm.modules.PagedAttention.flash_attn_varlen_func( out, query.contiguous() if query.device.type == "xpu" else query, diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 969e41c0345..00308601838 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -66,8 +66,8 @@ def __init__( else: x = BLOCK_SIZE // element_size - if ATTENTION in {"flashdecoding", "flashinfer"} and not ( - SYSTEM == "ipex" and device == torch.device("cpu") + if ATTENTION in {"flashdecoding", "flashinfer"} or ( + ATTENTION == "flashdecoding-ipex" and device.type == "xpu" ): self.kv_cache = ( torch.empty( @@ -82,6 +82,7 @@ def __init__( ), ) elif SYSTEM == "ipex" and device == torch.device("cpu"): + # ipex cpu flashdecoding kernel and paged attention kernel share same layout self.kv_cache = ( torch.empty( (num_blocks, num_heads, BLOCK_SIZE, head_size), @@ -176,9 +177,7 @@ def store( scalar=True, )[0] - if ATTENTION in {"flashdecoding", "flashinfer"} and not ( - SYSTEM == "ipex" and key.device == torch.device("cpu") - ): + if ATTENTION in {"flashdecoding", "flashinfer"}: key = key.to(key_cache.dtype) value = value.to(value_cache.dtype) if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: @@ -191,6 +190,12 @@ def store( shape = key_cache.shape key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value + elif ATTENTION == "flashdecoding-ipex" and key.device.type == "xpu": + import intel_extension_for_pytorch as ipex + + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slots + ) else: paged_reshape_and_cache(key, value, key_cache, value_cache, slots) diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 5f3d8f35115..889de028dc7 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -4,7 +4,6 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -from text_generation_server.utils.import_utils import SYSTEM REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"} ATTENTION = os.environ["ATTENTION"] @@ -15,13 +14,17 @@ } PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "1").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -_expected = {"paged", "flashdecoding", "flashinfer"} +_expected = {"paged", "flashdecoding", "flashdecoding-ipex", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: +if PREFIX_CACHING and ATTENTION not in { + "flashinfer", + "flashdecoding", + "flashdecoding-ipex", +}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None @@ -33,12 +36,11 @@ # This is overridden by the cli BLOCK_SIZE: int if ATTENTION == "flashdecoding": - if SYSTEM == "ipex": - BLOCK_SIZE = 64 - else: - BLOCK_SIZE = 256 + BLOCK_SIZE = 256 elif ATTENTION == "flashinfer": BLOCK_SIZE = 1 +elif ATTENTION == "flashdecoding-ipex": + BLOCK_SIZE = 64 else: BLOCK_SIZE = 16 diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 805fd771058..af4d1f082dc 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -79,10 +79,13 @@ def __init__( "Prefill chunking will be turned off", ) support_chunking = False - if ATTENTION not in ["flashinfer", "flashdecoding"] and support_chunking: + if ( + ATTENTION not in ["flashinfer", "flashdecoding", "flashdecoding-ipex"] + and support_chunking + ): log_master( logger.warning, - "Prefill chunking is only supported with `flashinfer` or `flashdecoding` attention types.", + "Prefill chunking is only supported with `flashinfer` or `flashdecoding` or `flashdecoding-ipex` attention types.", ) support_chunking = False