Skip to content

Commit

Permalink
Merge branch 'main' into hsadia/micro_sdpa_bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
h-sadia authored Jan 11, 2025
2 parents ae0dc9b + 38055e0 commit 0c7a877
Show file tree
Hide file tree
Showing 95 changed files with 1,888 additions and 791 deletions.
1 change: 1 addition & 0 deletions doc/graph/fusion_patterns/fusion_patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ ReduceProd | ReduceSum]
|:--------|:-----------------------------|
| Scaled Dot-Product Attention | Refer to @ref dev_guide_graph_sdpa for more details. |
| Grouped Query Attention | Refer to @ref dev_guide_graph_gqa for more details. |
| Scaled Dot-Product Attention with Compressed Key/Value | Refer to @ref dev_guide_graph_sdpa_compressed_kv for more details. |
| Gated Multi-Layer Perceptron (Gated-MLP) | Refer to @ref dev_guide_graph_gated_mlp for more details. |
| Convolution + BiasAdd\f$^?\f$ + BatchNormInference\f$^?\f$ + [Unary \| Binary]\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used in Convolution Neural Networks, for example ResNet, ResNext, SSD, etc. |
| ConvTranspose + BiasAdd\f$^?\f$ + [Unary \| Binary]\f$^{0-3}\f$\f$_{>out}\f$ | This pattern is widely used in Generative Adversarial Networks. |
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
119 changes: 119 additions & 0 deletions doc/graph/fusion_patterns/sdpa_with_compressed_kv.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
SDPA with Compressed Key and Value {#dev_guide_graph_sdpa_compressed_kv}
========================================================================

## Overview

int4 and int8 compressions for Key and Value are exploited in fused Scaled
Dot-Product Attention (SDPA)[1] to reduce the memory footprint of generative
inference of LLM, especially when KV cache mechanism is adopted. Specifically,
Key and Value tensors are stored using lower precision data types like int4 and
int8 to reduce memory usage, and are subsequently de-quantized to wider floating
point data types such as f16 and bf16 for computation.

Note that grouped quantization is required to improve the model accuracy,
especially for int4 data types. In this case, group size is needed as an
attribute for quantization, which indicates the number of elements that share
the same scaling factor and zero-points in each quantization group.

The notations used in this topic are:

- N: The mini-batch size.
- H: The head number.
- S: The sequence length.
- D: The size of each head.
- G: The group size.

## SDPA Pattern

The SDPA pattern with compressed Key and Value is defined as a directional
acyclic graph (DAG) using oneDNN Graph API. oneDNN extends
[SDPA pattern](@ref dev_guide_graph_sdpa) to support the following three kinds
of compressed SDPA patterns:

1. SDPA with compressed Key and Value.
2. SDPA with floating-point Key and compressed Value.
3. SDPA with compressed Key and floating-point Value.

The floating-point data types include f32, f16 and bf16, and the compressed
data type refers to low-precision integral data types, including int4 (u4/s4)
and int8 (u8/s8) data types.

In oneDNN Graph API, we support quantization through a pattern with quantization
operations such as [DynamicDequantize](@ref dev_guide_op_dynamicdequantize) and
[DynamicQuantize](@ref dev_guide_op_dynamicquantize). The supported pattern is
as follows. The blue nodes are required while the brown nodes are optional.

![compressed SDPA pattern](images/compressed_sdpa_pattern.png)

Compared to a typical SDPA pattern, there are a few differences:

1. Two additional DynamicDequantize operations are applied to the input Key and
Value to convert the integral values to floating-point values.
2. Apart from the Query, Key and Value inputs, the pattern requires additional
quantization information such as scale and zero-points for the dequantization of
Key and Value tensors. Currently, oneDNN only supports grouped quantization
on one dimension; specifically, the shapes of scale and zero-points for Key and
Value de-quantization should be (N, H, S, D/G).
3. Additionally, the `group_shape` attribute of the quantization operations must
be specified as (1, 1, 1, G) for Key and Value dequantization.

## Data Types

oneDNN supports the following combinations of data types for Query, Key, Value,
output, scale for Key, zero-points for Key, scale for Value and zero-points for
Value:

| Query | Key | Scale_K | Zp_K | Value | Scale_V | Zp_V | Output |
|:--------|:--------|:--------|:----------------|:-------|:--------|:----------------|:-------|
| dt_fp | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_fp |
| dt_fp | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_fp | N/A | N/A | dt_fp |
| dt_fp | dt_fp | N/A | N/A | dt_int | dt_fp | u4,s4,u8,s8,s32 | dt_fp |

Notes:
- dt_fp can be: f16, bf16 or f32.
- dt_int can be: u8, s8, u4 or s4.
- zero-point inputs are optional.

You can specify the data type via the input and output data type fields of
logical tensors for each operation. The definition of the data types and support
status on different CPU and GPU platforms follow the general description in
@ref dev_guide_data_types.

### Floating-point Math Mode

You should set the floating-point math mode
(@ref dev_guide_attributes_fpmath_mode) when using SDPA with compressed Key and
Value. Generally, the math mode should align with the data type of the Query,
which indicates the computation data type. Additionally, the second boolean
flag, `apply_to_int`, should be set to true. You can configure these attribute
values using the `set_fpmath_mode` API
(@ref dnnl::graph::graph::set_fpmath_mode) on the graph object.

## Implementation Limitations

- oneDNN primitive-based SDPA with compressed Key and Value is implemented as
a reference implementation on both Intel Architecture Processors and Intel
Graphics Products. The reference implementation requires memory to store the
intermediate results of the dot products between Query and Key which takes
\f$O(S^2)\f$ memory. It may lead to Out-of-Memory error when computing long
sequence length inputs on platforms with limited memory.
- The compressed SDPA patterns functionally support all input shapes meeting
the shape requirements of each operation in the graph.
- CPU
- oneDNN does not provide optimized implementation on CPU currently. All
executions will be implemented with the primitive-based reference
computation.
- GPU
- Optimized implementation is available for 4D Q/K/V tensors with the shape
defined as (N, H, S, D) for Query and Value, (N, H, D, S) for Key,
(N, H, D/G, S) for scales and zero-points of Key (if available) and
(N, H, S, D/G) for scales and zero-points of Value (if available).
- Optimized implementation is available for compressed SDPA with `f16`
computation data type on Intel Graphics Products with Intel(R) Xe Matrix
Extensions (Intel(R) XMX) support.
- If int4 zero-points are specified, optimized implementation will be only
available when the group size equals 16.

## References

[1] Attention is all you need, https://arxiv.org/abs/1706.03762v7
7 changes: 4 additions & 3 deletions src/common/memory_desc.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -471,8 +471,9 @@ status_t memory_desc_permute_axes(memory_desc_t &out_memory_desc,
VCHECK_MEMORY(
!memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides(),
invalid_arguments, VERBOSE_UNSUPPORTED_MEM_STRIDE);
VCHECK_MEMORY(in_memory_desc.extra.flags == 0, invalid_arguments,
VERBOSE_UNSUPPORTED_MD_FLAG, "extra");
VCHECK_MEMORY(
check_md_extra_flags_compensation_gpu(in_memory_desc.extra.flags),
invalid_arguments, VERBOSE_UNSUPPORTED_MD_FLAG, "extra");

// verify that perm is indeed a permutation of [0 .. ndims)
unsigned occurrence_mask = 0;
Expand Down
55 changes: 47 additions & 8 deletions src/common/memory_desc.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2024-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,21 +56,30 @@ const rnn_packed_memory_format_t ldio_p = rnn_packed_memory_format_t::ldio_p;
// TODO: convert to 'enum class'.
// Flags for memory special features
enum memory_extra_flags_t {
dnnl_memory_extra_flag_none = 0x0U,
dnnl_memory_extra_flag_none = 0u,
// Indicates the weights have an additional buffer, that depends on the
// @p compensation_mask.
//
// For instance, in 4D case with the compensation mask equals (1 << 0)
// the additional buffer would consist of OC values:
// O[oc : 0,OC] =
// -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
dnnl_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
dnnl_memory_extra_flag_scale_adjust = 0x2U,
dnnl_memory_extra_flag_rnn_u8s8_compensation = 0x4U,
dnnl_memory_extra_flag_compensation_conv_s8s8 = 1u,
dnnl_memory_extra_flag_scale_adjust = 2u,
dnnl_memory_extra_flag_rnn_u8s8_compensation = 4u,
dnnl_memory_extra_flag_gpu_rnn_u8s8_compensation
= dnnl_memory_extra_flag_rnn_u8s8_compensation,
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 0x8U,
dnnl_memory_extra_flag_rnn_s8s8_compensation = 0x16U,
dnnl_memory_extra_flag_compensation_conv_asymmetric_src = 8u,
dnnl_memory_extra_flag_rnn_s8s8_compensation = 16u,
// This flag has to be kept separate from *compensation_conv_asymmetric_src
// since the GPU precompute algorithm is incompatible with that of the CPU
dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src = 32u,
// This flag depends on *compensation_gpu_conv_asymmetric_src and is used
// when precompute is to be performed for a backward-by-data convolution
dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_bwd = 64u,
// This flag depends on *compensation_gpu_conv_asymmetric_src and is used
// when IC and OC are swapped to reinterpret a deconv as a BWD_D conv
dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_swap = 128u,
};

// Create aliases for extra flags to preserve the old behavior.
Expand All @@ -87,8 +96,23 @@ const memory_extra_flags_t rnn_s8s8_compensation
= dnnl_memory_extra_flag_rnn_s8s8_compensation;
const memory_extra_flags_t compensation_conv_asymmetric_src
= dnnl_memory_extra_flag_compensation_conv_asymmetric_src;
const memory_extra_flags_t compensation_gpu_conv_asymmetric_src
= dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src;
const memory_extra_flags_t compensation_gpu_conv_asymmetric_src_bwd
= dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_bwd;
const memory_extra_flags_t compensation_gpu_conv_asymmetric_src_swap
= dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src_swap;
} // namespace memory_extra_flags

inline bool check_md_extra_flags_compensation_gpu(uint64_t flags) {
using namespace memory_extra_flags;
const uint64_t c = compensation_gpu_conv_asymmetric_src;
const uint64_t b = compensation_gpu_conv_asymmetric_src_bwd;
const uint64_t s = compensation_gpu_conv_asymmetric_src_swap;
return (flags == none) || (flags == c) || (flags == (c | b))
|| (flags == (c | b | s));
}

// Generic description of blocked data layout for most memory formats.
struct blocking_desc_t {
// The strides between the outermost blocks.
Expand Down Expand Up @@ -208,7 +232,12 @@ struct memory_extra_desc_t {
: flags(0)
, compensation_mask(0)
, scale_adjust(0.0f)
, asymm_compensation_mask(0) {}
, asymm_compensation_mask(0)
, idhw {0, 0, 0}
, odhw {0, 0, 0}
, pdhw {0, 0, 0}
, ddhw {0, 0, 0}
, dst_size(0) {}
// The flags contain arbitrary extra information, such as compensation.
// @sa dnnl_memory_extra_flags_t
uint64_t flags;
Expand All @@ -218,6 +247,16 @@ struct memory_extra_desc_t {
float scale_adjust;
// Compensation mask for asymmetric quantization
int asymm_compensation_mask;
// Precomp GPU ZP convolution input spatials
dim_t idhw[3];
// Precomp GPU ZP convolution output spatials
dim_t odhw[3];
// Precomp GPU ZP convolution padding spatials
dim_t pdhw[3];
// Precomp GPU ZP convolution dilation spatials
dim_t ddhw[3];
// Precomp GPU ZP convolution destination size
dim_t dst_size;
};

status_t DNNL_API memory_desc_init_by_tag(memory_desc_t &memory_desc, int ndims,
Expand Down
36 changes: 18 additions & 18 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -149,30 +149,28 @@ struct memory_desc_wrapper : public c_compatible {
size_t additional_buffer_data_size(uint64_t flag_select) const {
using namespace memory_extra_flags;
if (flag_select & compensation_conv_s8s8) return sizeof(int32_t);
if ((flag_select & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(flag_select))
return sizeof(float);
if (flag_select & rnn_u8s8_compensation) return sizeof(float);
if (flag_select & compensation_conv_asymmetric_src)
return sizeof(int32_t);
if (flag_select & compensation_gpu_conv_asymmetric_src)
return sizeof(int32_t);
return 0;
}

/** return true if memory format has additional buffer */
bool is_additional_buffer() const {
using namespace memory_extra_flags;
// Currently compensation is not required for rnn_s8s8_compensation,
// but it has common bit with rnn_u8s8_compensation constant so we have
// to exclude rnn_s8s8_compensation case explicitly
return ((extra().flags
& (compensation_conv_s8s8 | rnn_u8s8_compensation
| compensation_conv_asymmetric_src))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
extra().flags));
return extra().flags
& (compensation_conv_s8s8 | rnn_u8s8_compensation
| compensation_gpu_conv_asymmetric_src
| compensation_conv_asymmetric_src);
}

/** returns the size required for a particular extra memory buffer */
size_t additional_buffer_size(memory_extra_flags_t flag) const {
using namespace memory_extra_flags;
const auto flags = extra().flags;
if (!(flags & flag)) return 0;

const auto ndims = this->ndims();
const auto &pdims = padded_dims();
Expand All @@ -186,21 +184,21 @@ struct memory_desc_wrapper : public c_compatible {
return (size_t)prod * buff_data_size;
};

if (extra().flags & compensation_conv_s8s8) {
if (flag == compensation_conv_s8s8) {
return calculate_size(extra().compensation_mask,
additional_buffer_data_size(flag));
}

if ((extra().flags & rnn_u8s8_compensation)
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
extra().flags)) {
if (flag == rnn_u8s8_compensation) {
return calculate_size(extra().compensation_mask,
additional_buffer_data_size(flag));
}
if (extra().flags & compensation_conv_asymmetric_src) {
if (flag == compensation_conv_asymmetric_src) {
return calculate_size(extra().asymm_compensation_mask,
additional_buffer_data_size(flag));
}
if (flag == compensation_gpu_conv_asymmetric_src) {
return extra().dst_size;
}

return 0;
}
Expand All @@ -220,6 +218,8 @@ struct memory_desc_wrapper : public c_compatible {
buff_size += additional_buffer_size(compensation_conv_s8s8);
buff_size += additional_buffer_size(rnn_u8s8_compensation);
buff_size += additional_buffer_size(compensation_conv_asymmetric_src);
buff_size
+= additional_buffer_size(compensation_gpu_conv_asymmetric_src);
return buff_size;
}

Expand Down
17 changes: 12 additions & 5 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,9 @@ size_t get_md_hash(const memory_desc_t &md) {

if (md.extra.flags != dnnl_memory_extra_flag_none) {
seed = hash_combine(seed, md.extra.flags);
if ((md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
md.extra.flags)) {
if (md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
seed = hash_combine(seed, md.extra.compensation_mask);
}

Expand All @@ -206,6 +204,15 @@ size_t get_md_hash(const memory_desc_t &md) {
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
seed = hash_combine(seed, md.extra.asymm_compensation_mask);
}

if (md.extra.flags
& dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src) {
seed = get_array_hash(seed, md.extra.idhw, 3);
seed = get_array_hash(seed, md.extra.odhw, 3);
seed = get_array_hash(seed, md.extra.pdhw, 3);
seed = get_array_hash(seed, md.extra.ddhw, 3);
seed = hash_combine(seed, md.extra.dst_size);
}
}
// Combined hash for a memory descriptor
return seed;
Expand Down
18 changes: 11 additions & 7 deletions src/common/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,22 +120,26 @@ void serialize_md(serialization_stream_t &sstream, const memory_desc_t &md) {

if (md.extra.flags != dnnl_memory_extra_flag_none) {
sstream.write(&md.extra.flags);
if ((md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation))
&& !types::extra_flag_rnn_s8s8_compensation_is_set(
md.extra.flags)) {
if (md.extra.flags
& (dnnl_memory_extra_flag_compensation_conv_s8s8
| dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
sstream.write(&md.extra.compensation_mask);
}

if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
sstream.write(&md.extra.scale_adjust);
}

if (md.extra.flags
& dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
sstream.write(&md.extra.asymm_compensation_mask);
}
if (md.extra.flags
& dnnl_memory_extra_flag_compensation_gpu_conv_asymmetric_src) {
sstream.write(md.extra.idhw, 3);
sstream.write(md.extra.odhw, 3);
sstream.write(md.extra.pdhw, 3);
sstream.write(md.extra.ddhw, 3);
sstream.write(&md.extra.dst_size);
}
}
}

Expand Down
Loading

0 comments on commit 0c7a877

Please sign in to comment.