Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes flash_attn + cascade attention_code to decoder Transformer bloc… #71

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions Dockerfile.ewc_flash_attn
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-devel

ARG INJECT_MF_CERT

COPY mf.crt /usr/local/share/ca-certificates/mf.crt

# The following two lines are necessary to deal with the MITM sniffing proxy we have internally.
RUN ( test $INJECT_MF_CERT -eq 1 && update-ca-certificates ) || echo "MF certificate not injected"
# set apt to non interactive
ENV DEBIAN_FRONTEND=noninteractive
ENV MY_APT='apt -o "Acquire::https::Verify-Peer=false" -o "Acquire::AllowInsecureRepositories=true" -o "Acquire::AllowDowngradeToInsecureRepositories=true" -o "Acquire::https::Verify-Host=false"'

RUN $MY_APT update && $MY_APT install -y software-properties-common && add-apt-repository ppa:ubuntugis/ppa
RUN $MY_APT update && $MY_APT install -y curl gdal-bin libgdal-dev libgeos-dev git vim nano sudo libx11-dev tk python3-tk tk-dev libpng-dev libffi-dev dvipng texlive-latex-base texlive-latex-extra texlive-fonts-recommended cm-super openssh-server netcat libeccodes-dev libeccodes-tools openssh-server

ENV CPLUS_INCLUDE_PATH=/usr/include/gdal
ENV C_INCLUDE_PATH=/usr/include/gdal

ARG REQUESTS_CA_BUNDLE
ARG CURL_CA_BUNDLE

# Build eccodes, a recent version yields far better throughput according to our benchmarks
ARG ECCODES_VER=2.35.0
RUN curl -O https://confluence.ecmwf.int/download/attachments/45757960/eccodes-$ECCODES_VER-Source.tar.gz && tar -xzf eccodes-$ECCODES_VER-Source.tar.gz && mkdir build && cd build && cmake ../eccodes-$ECCODES_VER-Source -DENABLE_AEC=ON -DENABLE_NETCDF=ON -DENABLE_FORTRAN=OFF && make && ctest && make install && ldconfig

RUN pip install --upgrade pip
COPY requirements.txt /root/requirements.txt
RUN set -eux && pip install --default-timeout=100 -r /root/requirements.txt
RUN MAX_JOBS=8 pip install flash-attn --no-build-isolation
ARG USERNAME
ARG GROUPNAME
ARG USER_UID
ARG USER_GUID
ARG HOME_DIR
ARG NODE_EXTRA_CA_CERTS

RUN set -eux && groupadd --gid $USER_GUID $GROUPNAME \
# https://stackoverflow.com/questions/73208471/docker-build-issue-stuck-at-exporting-layers
&& mkdir -p $HOME_DIR && useradd -l --uid $USER_UID --gid $USER_GUID -s /bin/bash --home-dir $HOME_DIR --create-home $USERNAME \
&& chown $USERNAME:$GROUPNAME $HOME_DIR \
&& echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
&& chmod 0440 /etc/sudoers.d/$USERNAME \
&& echo "$USERNAME:$USERNAME" | chpasswd \
&& mkdir /run/sshd

RUN set -eux && pip install pyg-lib==0.4.0 torch-scatter==2.1.2 torch-sparse==0.6.18 torch-cluster==1.6.2\
torch-geometric==2.3.1 -f https://data.pyg.org/whl/torch-2.1.2+cpu.html

WORKDIR $HOME_DIR
RUN curl -fsSL https://code-server.dev/install.sh | sh





2 changes: 1 addition & 1 deletion config/models/unetrpp161024_linear_up.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"num_heads": 16,
"num_heads_encoder": 16,
"hidden_size": 1024,
"linear_upsampling": true
}
Expand Down
7 changes: 7 additions & 0 deletions config/models/unetrpp161024_linear_up_flash.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"num_heads_encoder": 16,
"hidden_size": 1024,
"linear_upsampling": true,
"attention_code": "flash"
}

2 changes: 1 addition & 1 deletion config/models/unetrpp8512.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"num_heads": 8,
"num_heads_encoder": 8,
"hidden_size": 512
}

2 changes: 1 addition & 1 deletion config/models/unetrpp8512_linear_up.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"num_heads": 8,
"num_heads_encoder": 8,
"hidden_size": 512,
"linear_upsampling": true
}
Expand Down
2 changes: 1 addition & 1 deletion config/models/unetrpp8512_linear_up_d2.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"num_heads": 8,
"num_heads_encoder": 8,
"hidden_size": 512,
"linear_upsampling": true,
"downsampling_rate": 2
Expand Down
16 changes: 16 additions & 0 deletions doc/installEWC.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## Installation instruction on EWC ECMWF A100 machines

This procedure uses MF's docker wrapper runai and assumes it is available on your PATH and that docker has been installed on your machine.

```bash
export RUNAI_DOCKER_FILENAME=Dockerfile.ewc_flash_attn
runai build
```

You should now be able to run a test training with the Dummy dataset using flash_attn and bf16 precision:

```bash
runai exec_gpu python bin/train.py --dataset dummy --model unetrpp --epochs 1 --batch_size 2 --model_conf config/models/unetrpp161024_linear_up_flash.json --precision bf16
```


52 changes: 40 additions & 12 deletions py4cast/models/vision/unetrpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def __init__(
dropout_rate: faction of the input units to drop.
pos_embed: bool argument to determine if positional embedding is used.
proj_size: size of the projection space for Spatial Attention.
use_scaled_dot_product_CA : bool argument to determine if torch's scaled_dot_product_attenton
is used for Channel Attention.
attention_code: type of attention implementation to use. See EPA for more details.
"""

super().__init__()
Expand Down Expand Up @@ -247,6 +246,7 @@ def __init__(
raise NotImplementedError(
"Attention code should be one of 'torch', 'flash' or 'manual'"
)
self.attention_code = attention_code
if attention_code == "flash":
from flash_attn import flash_attn_func

Expand Down Expand Up @@ -274,6 +274,7 @@ def __init__(
self.attn_drop_2 = nn.Dropout(spatial_attn_drop)

def forward(self, x):
# TODO: fully optimize this function for each attention code
B, N, C = x.shape

qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)
Expand All @@ -298,7 +299,29 @@ def forward(self, x):
q_shared = torch.nn.functional.normalize(q_shared, dim=-1).type_as(q_shared)
k_shared = torch.nn.functional.normalize(k_shared, dim=-1).type_as(k_shared)
if self.use_scaled_dot_product_CA:
x_CA = self.attn_func(q_shared, k_shared, v_CA, dropout_p=self.attn_drop.p)
if self.attention_code == "torch":
x_CA = self.attn_func(
q_shared, k_shared, v_CA, dropout_p=self.attn_drop.p
)
elif self.attention_code == "flash":
# flash attention expects inputs of shape (batch_size, seqlen, nheads, headdim)
# so we need to permute the dimensions from (batch, head, channels, spatial_dim) to (batch, channels, head, spatial_dim)
q_shared = q_shared.permute(0, 2, 1, 3)
k_shared = k_shared.permute(0, 2, 1, 3)
v_CA = v_CA.permute(0, 2, 1, 3)

x_CA = self.attn_func(
q_shared, k_shared, v_CA, dropout_p=self.attn_drop.p
)

# flash attention returns the output in the same shape as the input
# so we need to permute it back
x_CA = x_CA.permute(0, 2, 1, 3)

# we permute back the inputs
q_shared = q_shared.permute(0, 2, 1, 3)
k_shared = k_shared.permute(0, 2, 1, 3)

else:
attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature
attn_CA = attn_CA.softmax(dim=-1)
Expand Down Expand Up @@ -446,7 +469,7 @@ def __init__(
conv_decoder: bool = False,
linear_upsampling: bool = False,
proj_size: int = 64,
use_scaled_dot_product_CA: bool = True,
attention_code: str = "torch",
) -> None:
"""
Args:
Expand Down Expand Up @@ -522,7 +545,7 @@ def __init__(
)
else:
stage_blocks = []
for j in range(depth):
for _ in range(depth):
stage_blocks.append(
TransformerBlock(
input_size=out_size,
Expand All @@ -531,7 +554,7 @@ def __init__(
dropout_rate=0.1,
pos_embed=True,
proj_size=proj_size,
attention_code=use_scaled_dot_product_CA,
attention_code=attention_code,
)
)
self.decoder_block.append(nn.Sequential(*stage_blocks))
Expand All @@ -557,7 +580,8 @@ def forward(self, inp, skip):
@dataclass
class UNETRPPSettings:
hidden_size: int = 256
num_heads: int = 4
num_heads_encoder: int = 4
num_heads_decoder: int = 4
pos_embed: str = "perceptron"
norm_name: Union[Tuple, str] = "instance"
dropout_rate: float = 0.0
Expand Down Expand Up @@ -660,7 +684,7 @@ def __init__(
h_size,
),
depths=settings.depths,
num_heads=settings.num_heads,
num_heads=settings.num_heads_encoder,
spatial_dims=settings.spatial_dims,
in_channels=num_input_features,
downsampling_rate=settings.downsampling_rate,
Expand All @@ -686,7 +710,8 @@ def __init__(
out_size=no_pixels // 16,
linear_upsampling=settings.linear_upsampling,
proj_size=settings.decoder_proj_size,
use_scaled_dot_product_CA=settings.attention_code,
attention_code=settings.attention_code,
num_heads=settings.num_heads_decoder,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=settings.spatial_dims,
Expand All @@ -698,7 +723,8 @@ def __init__(
out_size=no_pixels // 4,
linear_upsampling=settings.linear_upsampling,
proj_size=settings.decoder_proj_size,
use_scaled_dot_product_CA=settings.attention_code,
attention_code=settings.attention_code,
num_heads=settings.num_heads_decoder,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=settings.spatial_dims,
Expand All @@ -710,7 +736,8 @@ def __init__(
out_size=no_pixels,
linear_upsampling=settings.linear_upsampling,
proj_size=settings.decoder_proj_size,
use_scaled_dot_product_CA=settings.attention_code,
attention_code=settings.attention_code,
num_heads=settings.num_heads_decoder,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=settings.spatial_dims,
Expand All @@ -723,7 +750,8 @@ def __init__(
conv_decoder=True,
linear_upsampling=settings.linear_upsampling,
proj_size=settings.decoder_proj_size,
use_scaled_dot_product_CA=settings.attention_code,
attention_code=settings.attention_code,
num_heads=settings.num_heads_decoder,
)
self.out1 = UnetOutBlock(
spatial_dims=settings.spatial_dims,
Expand Down
4 changes: 2 additions & 2 deletions reformat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ set -eux

id
pwd
isort --profile black $1
black $1
python -m isort --profile black $1
python -m black $1
Loading