Skip to content

Commit

Permalink
[PyTorch] Add sliding window support to FlashAttention (#551)
Browse files Browse the repository at this point in the history
* add sliding window to FA

Signed-off-by: Charlene Yang <[email protected]>

* fix forward logic

Signed-off-by: Charlene Yang <[email protected]>

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* change bert test to causal as unfused does not support padding

Signed-off-by: Charlene Yang <[email protected]>

* fix FlashAttention for v2-2.3 versions

Signed-off-by: Charlene Yang <[email protected]>

* verify FA swa works

Signed-off-by: Charlene Yang <[email protected]>

* fix mask related restrictions and duplicate code after merge

Signed-off-by: Charlene Yang <[email protected]>

* fix swa test

Signed-off-by: Charlene Yang <[email protected]>

* add docstring for get_swa func

Signed-off-by: Charlene Yang <[email protected]>

* move repeated code into a function

Signed-off-by: Charlene Yang <[email protected]>

* revert mask change

Signed-off-by: Charlene Yang <[email protected]>

* add determinism filter and fix FA warning message

Signed-off-by: Charlene Yang <[email protected]>

* add message for determinism filter

Signed-off-by: Charlene Yang <[email protected]>

* simplify check_set_window_size()

Signed-off-by: Charlene Yang <[email protected]>

* fix check_set_window_size in transformer layers

Signed-off-by: Charlene Yang <[email protected]>

* fix indent

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
  • Loading branch information
cyanguwa authored Dec 16, 2023
1 parent 4a147e0 commit 27aa609
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 38 deletions.
68 changes: 57 additions & 11 deletions tests/pytorch/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,16 @@ def _is_flash_attention_2_available() -> bool:

@functools.cache
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.0+ is available"""
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")

@functools.cache
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.3")

def _is_flash_attention_supported(config: ModelConfig) -> bool:
"""Check if FlashAttention supports a model configuration"""
if get_device_compute_capability() < (8, 0):
Expand Down Expand Up @@ -192,14 +198,26 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool:
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

def get_swa(seq_q, seq_kv, w=None):
"""Generate a random sliding window size (left, right) if w is None,
and create its equivalent attention mask in [seq_q, seq_kv] shape"""
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0])
ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1])
ml = ~ ml
return w, ml

@pytest.mark.skipif(_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("workspace_opt", [True, False])
@pytest.mark.parametrize("qkv_layout", [None])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout):
@pytest.mark.parametrize("swa", [False])
def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa):
"""Test DotProductAttention module"""

# Get configs
Expand All @@ -224,36 +242,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
config, dtype, qkv_layout=qkv_layout,
)
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")

# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "UnfusedDotProductAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
if swa:
config.attn_mask_type = attn_mask_type

# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FusedAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)

# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt,
dtype, config, "FlashAttention", ckpt_attn, qkv_layout, workspace_opt, swa,
)

if unfused_attn_supported and fused_attn_supported:
Expand All @@ -279,7 +304,7 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn, workspace
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None)
test_dot_product_attention(dtype, model_configs, model, True, True, None, False)

model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
Expand All @@ -303,7 +328,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
Expand Down Expand Up @@ -339,7 +364,22 @@ def test_dpa_mask(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False)

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}
@pytest.mark.skipif(not _is_flash_attention_2_3(), reason="Flash-attn 2.3+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_swa])
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True)

qkv_layouts = [
'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
Expand Down Expand Up @@ -367,7 +407,7 @@ def test_dpa_bias(dtype, model_configs, model):
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout)
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False)

def _run_dot_product_attention(
dtype: torch.dtype,
Expand All @@ -376,6 +416,7 @@ def _run_dot_product_attention(
ckpt_attn: bool,
qkv_layout: str,
workspace_opt: bool,
swa: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""

Expand Down Expand Up @@ -433,6 +474,10 @@ def _run_dot_product_attention(
.to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
else:
window_size, attention_mask = None, None

# Create input tensors
dim_to_num = {
Expand Down Expand Up @@ -515,6 +560,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:

# Run a forward and backward pass
out = block(inp[0], inp[1], inp[2],
window_size=window_size,
attention_mask=attention_mask,
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
Expand Down
Loading

0 comments on commit 27aa609

Please sign in to comment.