Skip to content

Commit

Permalink
Add support for FP8 KV cache scales
Browse files Browse the repository at this point in the history
Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.
  • Loading branch information
danieldk committed Oct 21, 2024
1 parent 5e0fb46 commit a4cb3d3
Show file tree
Hide file tree
Showing 29 changed files with 332 additions and 49 deletions.
7 changes: 4 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
Expand Down
26 changes: 13 additions & 13 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26"

marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache
from .kv_cache import KVCache, get_kv_scales

__all__ = [
"attention",
"get_kv_scales",
"paged_attention",
"SUPPORTS_WINDOWING",
"KVCache",
Expand Down
14 changes: 13 additions & 1 deletion server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
Expand All @@ -8,6 +8,7 @@
from text_generation_server.layers.attention import Seqlen
from typing import Optional


major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
Expand All @@ -21,6 +22,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
Expand All @@ -46,6 +49,8 @@ def paged_attention(
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE

can_scale = kv_cache.can_scale(kv_scales)

# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
Expand All @@ -59,6 +64,8 @@ def paged_attention(
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)
elif ATTENTION == "flashdecoding":
max_q = 1
Expand Down Expand Up @@ -204,13 +211,16 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
can_scale = kv_cache.can_scale(kv_scales)

if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
Expand All @@ -226,6 +236,8 @@ def attention(
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
k_scale=kv_scales.key_scale_cpu if can_scale else 1.0,
v_scale=kv_scales.value_scale_cpu if can_scale else 1.0,
)

# If we are using flashdecoding or paged, we always use flash-attn for
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/layers/attention/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def use_decode_state(
num_kv_heads: int,
head_size: int,
page_size: int,
kv_cache_dtype: torch.dtype,
dtype: torch.dtype,
window_left: int,
):
Expand Down Expand Up @@ -240,7 +241,7 @@ def use_decode_state(
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
data_type=kv_cache_dtype,
q_data_type=dtype,
window_left=window_left,
)
Expand Down
5 changes: 4 additions & 1 deletion server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
Expand All @@ -14,6 +14,7 @@ def attention(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
Expand Down Expand Up @@ -55,6 +56,8 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
):
if softcap is not None:
Expand Down
93 changes: 90 additions & 3 deletions server/text_generation_server/layers/attention/kv_cache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,38 @@
from typing import Tuple
from dataclasses import dataclass, field

from loguru import logger
import torch

from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights


@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""

key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)

def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")

self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()


class KVCache:
Expand Down Expand Up @@ -76,6 +106,29 @@ def __init__(
),
)

def can_scale(self, kv_scales: KVScales) -> bool:
"""Check if the cache can be scaled by the given scales."""
if kv_scales.key_scale_cpu == 1.0 and kv_scales.value_scale_cpu == 1.0:
return False
elif self.dtype == torch.float8_e4m3fn and SYSTEM == "cuda":
log_once(
logger.info,
"Using FP8 KV cache scales",
)
return True
else:
# We have scales, but not the correct FP8 cache type, so warn once.
log_once(
logger.info,
"Ignoring FP8 KV cache scales, only float8_e4m3fn KV cache on CUDA is supported",
)
return False

@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype

@property
def key(self):
"""Get the key cache."""
Expand All @@ -94,17 +147,33 @@ def store(
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""

key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]

if self.can_scale(kv_scales):
if kv_scales.key_scale_cpu != 1.0:
key = fp8_quantize(
key.float(),
scale=kv_scales.key_scale,
qdtype=self.dtype,
scalar=True,
)[0]
if kv_scales.value_scale_cpu != 1.0:
value = fp8_quantize(
value.float(),
scale=kv_scales.value_scale,
qdtype=self.dtype,
scalar=True,
)[0]

if ATTENTION in {"flashdecoding", "flashinfer"}:
# TODO: add scale
key = key.to(key_cache.dtype)
value = value.to(value_cache.dtype)
if key_cache.dtype in {torch.float8_e5m2, torch.float8_e4m3fn}:
if key_cache.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
# Torch index_put does not support float8_{e5m2,e4m3fn} yet, so
# put as raw data instead.
key_cache = key_cache.view(torch.uint8)
Expand Down Expand Up @@ -151,5 +220,23 @@ def paged_reshape_and_cache(
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supported"
)


def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""

key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale

return KVScales(key_scale=key_scale, value_scale=value_scale)
Loading

0 comments on commit a4cb3d3

Please sign in to comment.