Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make flash attention configurable #60

Merged
merged 94 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
539e8a2
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
3317138
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
3186a8e
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
a86c9a8
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
105443f
feat: add softcap
theissenhelen Sep 20, 2024
e82a59e
test: add softcap
theissenhelen Sep 20, 2024
e648eb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
6271cd8
feat: flash attention lazy import
theissenhelen Sep 23, 2024
d4940e7
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
9ff6cb9
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
bbd89dc
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
91533c6
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
0eb5c50
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
c04e641
feat: get alibi_slopes
theissenhelen Oct 2, 2024
6523b47
docs: update docstrings
theissenhelen Oct 3, 2024
22623cc
fix: bias shape
theissenhelen Oct 3, 2024
ed07e34
fix: softcap optional
theissenhelen Oct 3, 2024
c841324
fix: import annotations from future
theissenhelen Oct 3, 2024
6c12dda
fix: annotation error
theissenhelen Oct 3, 2024
b7b8f2e
docs: update changelog
theissenhelen Oct 3, 2024
df353d9
fix: type annotation
theissenhelen Oct 7, 2024
fc335c7
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
663fea0
feat: FlashMultiHeadSelfAttention
theissenhelen Sep 17, 2024
a8b3f9d
Chore/multiple fixes ci precommit (#41)
theissenhelen Sep 18, 2024
6595ca1
11 add configurability to dropout in multiheadselfattention module (#12)
theissenhelen Sep 18, 2024
0c55a9c
chore!: drop support for scaled_dot_product_attention
theissenhelen Sep 20, 2024
ea665be
feat: add softcap
theissenhelen Sep 20, 2024
ffa2d99
test: add softcap
theissenhelen Sep 20, 2024
7c2d634
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2024
d2ed932
feat: flash attention lazy import
theissenhelen Sep 23, 2024
3295159
feat: make alibi slopes configurable
theissenhelen Sep 27, 2024
ebde686
chore(deps): add flash-attn
theissenhelen Sep 27, 2024
5102d9a
feat: use scaled_dot_product as default
theissenhelen Oct 1, 2024
3abc286
feat: make alibi_slope cinfigurable in block, chunk processor
theissenhelen Oct 1, 2024
673a25d
chore(deps): remove flash-attn
theissenhelen Oct 1, 2024
f606058
feat: get alibi_slopes
theissenhelen Oct 2, 2024
ef34771
docs: update docstrings
theissenhelen Oct 3, 2024
5136fb3
fix: bias shape
theissenhelen Oct 3, 2024
892c269
fix: softcap optional
theissenhelen Oct 3, 2024
4c42171
fix: import annotations from future
theissenhelen Oct 3, 2024
4bdf464
fix: annotation error
theissenhelen Oct 3, 2024
5a670b2
docs: update changelog
theissenhelen Oct 3, 2024
34db6e4
fix: type annotation
theissenhelen Oct 7, 2024
d424c75
feat: catch low flash-attn version
theissenhelen Oct 7, 2024
222b7d8
feat: attention wrapper
theissenhelen Oct 25, 2024
c2aca14
fix: remove duplicate version check
theissenhelen Oct 25, 2024
b75d225
merge conflict
cathalobrien Nov 1, 2024
147e772
added flex attn wrapper
cathalobrien Nov 1, 2024
f0c24e8
fix: alibi_slopes unassigned
theissenhelen Nov 6, 2024
3c4572b
adding causal wip
cathalobrien Nov 6, 2024
fb731f7
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Nov 8, 2024
f0308f2
added flex attn module
cathalobrien Nov 12, 2024
6dee265
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2024
7fb0b62
Bump min torch version to be able to use Flex Attn
cathalobrien Nov 12, 2024
739aa65
added input parameter checks
cathalobrien Nov 12, 2024
2a2ed11
precommit fix
cathalobrien Nov 12, 2024
fa1474c
merge
cathalobrien Nov 12, 2024
a703688
fix: typo
theissenhelen Nov 26, 2024
f1be563
test: adjust tests
theissenhelen Nov 27, 2024
0dda5d6
fix: no self.use_alibi_slopes
theissenhelen Nov 27, 2024
12facf0
fix: use_alibi_slope default to false
theissenhelen Nov 28, 2024
60e32f1
feat: Add sliding window support for TorchAttention via mask
japols Dec 9, 2024
07d9684
fix: set default flash_attention
japols Dec 10, 2024
9a1827a
fix: pytest
japols Dec 10, 2024
ca8c9fa
fix: tests
japols Dec 13, 2024
ac897ea
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Dec 16, 2024
7ec8142
docs: improve docstrings in MultiHeadSelfAttention
theissenhelen Dec 18, 2024
972d3c5
fix: error instead of SystemExit
theissenhelen Dec 18, 2024
e89fd2e
chore: refactor SDPAAttention update_mask method
theissenhelen Dec 18, 2024
2d122df
feat: add missing pytest.ini
theissenhelen Dec 18, 2024
d4510f6
chore: remove explicit float typing
theissenhelen Dec 19, 2024
6057004
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Dec 19, 2024
8656cae
support running without window size
cathalobrien Dec 19, 2024
705fc6b
Merge commit '8656cae768f8f8359772cdab0b06e0933b8a0bdf' into feature/…
theissenhelen Jan 6, 2025
aa1abe7
test: sepa:rate test for sdpa and flex attention
theissenhelen Jan 13, 2025
eb5ed7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
a10d832
added asserts and tests for flex attn
cathalobrien Jan 13, 2025
a010860
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 13, 2025
c0f462c
fix: embed_dim / num_heads >=16
theissenhelen Jan 14, 2025
752de28
test: fix tests to account for embed_dim constraints
theissenhelen Jan 14, 2025
3818719
fix tests
cathalobrien Jan 17, 2025
c2f8890
Merge branch 'feature/44-make-flash-attention-configurable' of github…
cathalobrien Jan 17, 2025
2d8b775
chore: remove debugging code
theissenhelen Jan 20, 2025
7665c7f
consitency change
cathalobrien Jan 22, 2025
230f044
chore(configs): add attention_implementation
theissenhelen Jan 24, 2025
603ab17
Update models/src/anemoi/models/layers/attention.py
theissenhelen Jan 29, 2025
c2925fa
Update models/src/anemoi/models/layers/attention.py
theissenhelen Jan 29, 2025
88dc6d5
fix: address comments
theissenhelen Jan 29, 2025
1d65779
chore: remove flex_attention
theissenhelen Jan 29, 2025
46ba31c
Merge branch 'main' into feature/44-make-flash-attention-configurable
theissenhelen Jan 29, 2025
e5f0f49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2025
8e7a93c
test: fix merge
theissenhelen Jan 29, 2025
8ef5575
fix test to address breaking change from torch 2.6
anaprietonem Jan 30, 2025
d20d3ed
remove flex_attention references
anaprietonem Jan 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphs/tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def test_generate_graph(self, config_file: tuple[Path, str], mock_grids_path: tu

if graph_path is not None:
assert graph_path.exists()
graph_saved = torch.load(graph_path)
graph_saved = torch.load(graph_path, weights_only=False)
assert graph.node_types == graph_saved.node_types
assert graph.edge_types == graph_saved.edge_types
9 changes: 9 additions & 0 deletions models/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,24 @@ Keep it human-readable, your future self will thank you!

### Added

- CI workflow to update the changelog on release
- add configurability of flash attention (#47)
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy


### Changed

- Update CI to inherit from common infrastructue reusable workflows
- run downstream-ci only when src and tests folders have changed
- New error messages for wrongs graphs.
- Feature: Change model to be instantiatable in the interface, addressing [#28](https://github.com/ecmwf/anemoi-models/issues/28) through [#45](https://github.com/ecmwf/anemoi-models/pulls/45)
- Bugfixes for CI

### Removed

Expand Down
2 changes: 1 addition & 1 deletion models/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dependencies = [
"anemoi-utils>=0.1.9",
"einops>=0.6.1",
"hydra-core>=1.3",
"torch>=2.2",
"torch>=2.5",
"torch-geometric>=2.3,<2.5",
]
optional-dependencies.all = [ ]
Expand Down
7 changes: 7 additions & 0 deletions models/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[pytest]
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
markers =
data_dependent: marks tests depending on data (deselect with '-m "not data_dependent"')
auth: marks tests that require authentication (deselect with '-m "not auth"')
gpu: marks tests that require a GPU (deselect with '-m "not gpu"')

tmp_path_retention_policy = none
237 changes: 211 additions & 26 deletions models/src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,19 @@
# nor does it submit to any jurisdiction.


from __future__ import annotations

import logging
import math
from typing import Optional

import einops
import torch
from packaging import version
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup

try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from torch.nn.functional import scaled_dot_product_attention as attn_func

_FLASH_ATTENTION_AVAILABLE = False
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence
from anemoi.utils.config import DotDict
Expand All @@ -33,7 +29,12 @@


class MultiHeadSelfAttention(nn.Module):
"""Multi Head Self Attention Pytorch Layer."""
"""Multi Head Self Attention Pytorch Layer

allows for three different attention implementations:
- scaled dot product attention, see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
- flash attention, see https://github.com/Dao-AILab/flash-attention
"""

def __init__(
self,
Expand All @@ -44,32 +45,89 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
attention_implementation: str = "flash_attention",
softcap: Optional[float] = None,
use_alibi_slopes: bool = False,
):
"""Initialize MultiHeadSelfAttention.

For the flash attention implementation, two additional parameters are available: softcap, use_alibi_slopes

softcap: Softcapping prevents the logits from growing excessively large

use_alibi_slopes: Adds bias of `(-alibi_slope * |i + seqlen_k - seqlen_q - j|)` to the attention score of
query i and key j, where alibi_slope is calculated using get_alibi_slopes

Parameters
----------
num_heads : int
number of heads
embed_dim : int
embedding dimension
bias : bool, optional
bias, by default False
is_causal : bool, optional
apply causal attention mask, by default False
window_size : Optional[int], optional
window_size, by default None
dropout_p : float, optional
dropout probability, by default 0.0
attention_implementation: str, optional
A predefined string which selects which underlying attention
implementation, by default "flash_attention"
softcap : float, optional
Anything > 0 activates softcapping attention, by default None
use_alibi_slopes : bool, optional
Adds bias
HCookie marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()

assert (
embed_dim % num_heads == 0
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"

self.attention_implementation = attention_implementation
self.use_alibi_slopes = use_alibi_slopes

self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.window_size = window_size
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap

self.set_attention_function()

if self.use_alibi_slopes:
self.alibi_slopes = get_alibi_slopes(num_heads)
assert self.alibi_slopes.shape[0] == num_heads, "Error: Number of alibi_slopes must match number of heads"
else:
self.alibi_slopes = None

linear = layer_kernels["Linear"]
self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

self.projection = linear(embed_dim, embed_dim, bias=True)

def set_attention_function(self):
attn_funcs = {
"flash_attention": FlashAttentionWrapper,
"scaled_dot_product_attention": SDPAAttentionWrapper,
}
assert (
self.attention_implementation in attn_funcs
), f"{self.attention_implementation} not supported. \
Please change model.processor.attention_implementation to one of: {attn_funcs.keys()}"
LOGGER.info(f"Using {self.attention_implementation}")

# initalise the attn func here
self.attention = attn_funcs[self.attention_implementation]()

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:

query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
Expand All @@ -92,24 +150,151 @@ def forward(
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
out = self.attention(
query,
key,
value,
batch_size,
causal=False,
window_size=self.window_size,
dropout_p=dropout_p,
softcap=self.softcap,
alibi_slopes=self.alibi_slopes,
)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)

return out


class SDPAAttentionWrapper(nn.Module):
"""Wrapper for Pytorch scaled dot product attention"""

def __init__(self):
super().__init__()

from torch.nn.functional import scaled_dot_product_attention

self.attention = scaled_dot_product_attention
self.mask = None
self.window_size = None

def update_mask(self, seq_len, window_size: int, device: str):

self.mask = (
torch.abs(
torch.arange(seq_len, device=device).unsqueeze(0) - torch.arange(seq_len, device=device).unsqueeze(1)
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
<= window_size
)

def forward(
self,
query,
key,
value,
batch_size: int,
causal=False,
window_size=None,
dropout_p=0.0,
softcap=None,
alibi_slopes=None,
):
if softcap is not None:
NotImplementedError(
"Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap."
)
if alibi_slopes is not None:
NotImplementedError(
"Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
)

sequence_len = query.shape[-2]

if window_size is not None and (self.mask is None or tuple(self.mask.shape) != (sequence_len, sequence_len)):
self.update_mask(sequence_len, window_size=window_size, device=query.device)

with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]):
out = self.attention(
query,
key,
value,
is_causal=False,
attn_mask=self.mask,
is_causal=causal,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format
)

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
return out

out = self.projection(out)

class FlashAttentionWrapper(nn.Module):
"""Wrapper for Flash attention."""

def __init__(self):
super().__init__()
try:
import flash_attn
except ImportError:
raise ImportError("Error: Flash-attn not installed. Please install flash-attn to use Flash Attention")

if version.parse(flash_attn.__version__) < version.parse("2.6.0"):
theissenhelen marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("Error: Flash-attn version is too low. Update to 2.6.0 or higher.")
else:
self.attention = flash_attn.flash_attn_func

def forward(
self,
query,
key,
value,
batch_size: int,
causal: bool = False,
window_size: int = None,
dropout_p: float = 0.0,
softcap: Optional[float] = None,
alibi_slopes: torch.Tensor = None,
):
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)

alibi_slopes = alibi_slopes.repeat(batch_size, 1).to(query.device) if alibi_slopes is not None else None

out = self.attention(
query,
key,
value,
causal=False,
window_size=(window_size, window_size),
dropout_p=dropout_p,
softcap=softcap,
alibi_slopes=alibi_slopes,
)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
return out


def get_alibi_slopes(num_heads: int) -> Tensor:
"""Calculates linearly decreasing slopes for alibi attention.

Parameters
----------
num_heads : int
number of attention heads

Returns
-------
Tensor
aLiBi slopes
"""
n = 2 ** math.floor(math.log2(num_heads))
slope_0 = 2 ** (-8 / n)
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
if n < num_heads:
slope_hat_0 = 2 ** (-4 / n)
alibi_slopes_hat = torch.pow(slope_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2))
alibi_slopes = torch.cat([alibi_slopes, alibi_slopes_hat])
return alibi_slopes
6 changes: 6 additions & 0 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def __init__(
window_size: int,
layer_kernels: DotDict,
dropout_p: float = 0.0,
attention_implementation: str = "flash_attention",
softcap: float = None,
use_alibi_slopes: bool = None,
):
super().__init__()

Expand All @@ -91,6 +94,9 @@ def __init__(
is_causal=False,
dropout_p=dropout_p,
layer_kernels=layer_kernels,
attention_implementation=attention_implementation,
softcap=softcap,
use_alibi_slopes=use_alibi_slopes,
)

self.mlp = nn.Sequential(
Expand Down
Loading