Skip to content

Commit

Permalink
Add images with flash attention 2 (#651)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 10, 2023
1 parent 1bf0a93 commit ba6b880
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 24 deletions.
41 changes: 36 additions & 5 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ on:
push:
branches:
- main
pull_request:
branches:
- main
paths:
- ./Dockerfile
- .github/workflows/docker.yaml
workflow_dispatch: {}
jobs:
docker-build:
Expand All @@ -13,10 +19,16 @@ jobs:
include:
- name: '1.13.1_cu117'
base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
dep_groups: '[gpu]'
- name: '2.0.1_cu118'
base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04
dep_groups: '[gpu]'
- name: '2.1.0_cu121'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
dep_groups: '[gpu]'
- name: '2.1.0_cu121_flash2'
base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
dep_groups: '[gpu-flash2]'

steps:
- name: Maximize Build Space on Worker
Expand Down Expand Up @@ -52,13 +64,32 @@ jobs:
GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
if [ "${{ github.event_name }}" == "push" ]; then
echo "Triggered by push event."
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
elif [ "${{ github.event_name }}" == "pull_request" ]; then
echo "Triggered by pull_request event."
STAGING_REPO="mosaicml/ci-staging"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache"
else
echo "Triggered by unknown event: ${{ github.event_name }}"
exit 1
fi
echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
echo "IMAGE_CACHE=${IMAGE_CACHE}" >> ${GITHUB_ENV}
- name: Build and Push the Docker Image
uses: docker/build-push-action@v3
with:
context: .
tags: mosaicml/llm-foundry:${{ matrix.name }}-latest,
mosaicml/llm-foundry:${{ matrix.name }}-${{ env.IMAGE_TAG }}
tags: ${{ env.IMAGE_TAG }}
push: true
cache-from: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache
cache-to: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache,mode=max
build-args: BASE_IMAGE=${{ matrix.base_image }}
cache-from: type=registry,ref=${{ env.IMAGE_CACHE }}
cache-to: type=registry,ref=${{ env.IMAGE_CACHE }},mode=max
build-args: |
BASE_IMAGE=${{ matrix.base_image }}
DEP_GROUPS=${{ matrix.dep_groups }}
1 change: 1 addition & 0 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
uses: ./.github/workflows/pytest-gpu.yaml
strategy:
matrix:
# TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite
include:
- name: 'gpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
Expand Down
9 changes: 5 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
ARG BASE_IMAGE
FROM $BASE_IMAGE

ARG DEP_GROUPS

# Install and uninstall foundry to cache foundry requirements
RUN git clone -b main https://github.com/mosaicml/llm-foundry.git && \
pip install --no-cache-dir "./llm-foundry[gpu]" && \
pip uninstall -y llm-foundry && \
rm -rf llm-foundry
RUN git clone -b main https://github.com/mosaicml/llm-foundry.git
RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}"
RUN pip uninstall -y llm-foundry
RUN rm -rf llm-foundry
62 changes: 48 additions & 14 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def is_flash_v2_installed():
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')


def is_flash_v1_installed():
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
# disable causal when it is not needed
Expand Down Expand Up @@ -197,7 +213,8 @@ def flash_attn_fn(
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
raise RuntimeError('Please install flash-attn==1.0.3.post0')
raise RuntimeError(
'Please install flash-attn==1.0.9 or flash-attn==2.3.2')

check_valid_inputs(query, key, value)

Expand Down Expand Up @@ -278,18 +295,35 @@ def flash_attn_fn(

reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)

output_unpad = flash_attn_interface.flash_attn_unpadded_func(
query_unpad,
key_unpad,
value_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
if is_flash_v1_installed():
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
elif is_flash_v2_installed():
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
v=value_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=reset_is_causal,
return_attn_probs=needs_weights)
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')

output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
Expand Down Expand Up @@ -321,7 +355,7 @@ def triton_flash_attn_fn(
if version.parse(torch.__version__) < version.parse('2.0.0'):
_installed = True
# if torch1.13.1 revert to using triton flash attn from HazyResearch
# with flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202
# with flash-attn==1.0.9 and triton==2.0.0.dev20221202
try:
from flash_attn.flash_attn_triton import flash_attn_func
except:
Expand Down
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
extra_deps['gpu-flash2'] = [
'flash-attn==2.3.2',
'mosaicml-turbo==0.0.4',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]

extra_deps['peft'] = [
'loralib==0.1.1', # lora core
Expand All @@ -107,7 +113,10 @@
]
extra_deps['all-cpu'] = set(
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
if key != 'gpu-flash2')
extra_deps['all-flash2'] = set(
dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu')

setup(
name=_PACKAGE_NAME,
Expand Down

0 comments on commit ba6b880

Please sign in to comment.