Skip to content

Commit

Permalink
Add support for FP8 KV cache scales (#2628)
Browse files Browse the repository at this point in the history
* Add support for FP8 KV cache scales

Since FP8 only has limited dynamic range, we can scale keys/values
before storing them into the cache (and unscale them in attention). To
avoid rescaling the cache as the absmax values change, good scales are
usually determined per layer using calibration calibration data and stored
in the checkpoint.

This change adds support for for using key-value scales and loading them
from checkpoints in the two most common formats:

- Separate per-layer `k_scale` and `v_scale` scalars.
- Per-layer `kv_scale` scalar (older format).

Currently, scales are only used with an `float8_e4m3fn` cache.

Besides adding support for key/value scales, the `fp8_quantize` function
is also extended to support quantization with a kernel vendored from
vLLM. This is slightly faster than the PyTorch implementation, but also
scales in FP32, potentially improving accuracy.

* Update FP8 KV cache test to use checkpoint with scales

* `can_scale`: check that the attention is flashinfer
  • Loading branch information
danieldk authored Oct 24, 2024
1 parent 14a0df3 commit eab07f7
Show file tree
Hide file tree
Showing 33 changed files with 486 additions and 155 deletions.
7 changes: 4 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,94 +11,94 @@
},
{
"id": 3923,
"logprob": -5.6328125,
"logprob": -6.1875,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"logprob": -0.93359375,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"logprob": -9.875,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"logprob": -1.1796875,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"logprob": -1.75,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"logprob": -1.109375,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"logprob": -0.005432129,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"logprob": -0.028808594,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"logprob": -0.013671875,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"logprob": -0.69921875,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"logprob": -0.0005874634,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"logprob": -0.026855469,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"logprob": -0.00020885468,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"logprob": -0.17773438,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"id": 18065,
"logprob": -0.703125,
"special": false,
"text": " is"
"text": " involves"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
"generated_text": " Deep learning is a subset of machine learning that involves"
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
Expand All @@ -11,47 +11,89 @@
},
{
"id": 374,
"logprob": -22.96875,
"logprob": -18.0,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"logprob": -11.75,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"logprob": -2.0625,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"logprob": -6.0,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"logprob": 0.0,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"id": 34564,
"logprob": -0.11279297,
"special": false,
"text": "Deep"
},
{
"id": 6975,
"logprob": -0.16015625,
"special": false,
"text": " "
"text": " learning"
},
{
"id": 128001,
"id": 320,
"logprob": -0.25195312,
"special": false,
"text": " ("
},
{
"id": 16931,
"logprob": -1.703125,
"special": false,
"text": "DL"
},
{
"id": 8,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
"special": false,
"text": ")"
},
{
"id": 374,
"logprob": -1.140625,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": 0.0,
"special": false,
"text": " a"
},
{
"id": 1207,
"logprob": -1.3125,
"special": false,
"text": " sub"
},
{
"id": 2630,
"logprob": 0.0,
"special": false,
"text": "field"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
"generated_text": "What is deep learning? \nDeep learning (DL) is a subfield"
}
Loading

0 comments on commit eab07f7

Please sign in to comment.