diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 8e30554475..83c9a63884 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -3,6 +3,12 @@ on: push: branches: - main + pull_request: + branches: + - main + paths: + - ./Dockerfile + - .github/workflows/docker.yaml workflow_dispatch: {} jobs: docker-build: @@ -13,10 +19,16 @@ jobs: include: - name: '1.13.1_cu117' base_image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 + dep_groups: '[gpu]' - name: '2.0.1_cu118' base_image: mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 + dep_groups: '[gpu]' - name: '2.1.0_cu121' base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 + dep_groups: '[gpu]' + - name: '2.1.0_cu121_flash2' + base_image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 + dep_groups: '[gpu-flash2]' steps: - name: Maximize Build Space on Worker @@ -52,13 +64,32 @@ jobs: GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7) echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV} + if [ "${{ github.event_name }}" == "push" ]; then + echo "Triggered by push event." + PROD_REPO="mosaicml/llm-foundry" + IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" + IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" + elif [ "${{ github.event_name }}" == "pull_request" ]; then + echo "Triggered by pull_request event." + STAGING_REPO="mosaicml/ci-staging" + IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" + IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache" + else + echo "Triggered by unknown event: ${{ github.event_name }}" + exit 1 + fi + + echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV} + echo "IMAGE_CACHE=${IMAGE_CACHE}" >> ${GITHUB_ENV} + - name: Build and Push the Docker Image uses: docker/build-push-action@v3 with: context: . - tags: mosaicml/llm-foundry:${{ matrix.name }}-latest, - mosaicml/llm-foundry:${{ matrix.name }}-${{ env.IMAGE_TAG }} + tags: ${{ env.IMAGE_TAG }} push: true - cache-from: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache - cache-to: type=registry,ref=mosaicml/llm-foundry:${{ matrix.name }}-buildcache,mode=max - build-args: BASE_IMAGE=${{ matrix.base_image }} + cache-from: type=registry,ref=${{ env.IMAGE_CACHE }} + cache-to: type=registry,ref=${{ env.IMAGE_CACHE }},mode=max + build-args: | + BASE_IMAGE=${{ matrix.base_image }} + DEP_GROUPS=${{ matrix.dep_groups }} diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 769b345e39..e16f2c8b40 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -18,6 +18,7 @@ jobs: uses: ./.github/workflows/pytest-gpu.yaml strategy: matrix: + # TODO: After the PR with the flash attention 2 images goes in, add the new unit test suite include: - name: 'gpu-latest' container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 diff --git a/Dockerfile b/Dockerfile index 0d75241068..6c283660c4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,9 +4,10 @@ ARG BASE_IMAGE FROM $BASE_IMAGE +ARG DEP_GROUPS # Install and uninstall foundry to cache foundry requirements -RUN git clone -b main https://github.com/mosaicml/llm-foundry.git && \ - pip install --no-cache-dir "./llm-foundry[gpu]" && \ - pip uninstall -y llm-foundry && \ - rm -rf llm-foundry +RUN git clone -b main https://github.com/mosaicml/llm-foundry.git +RUN pip install --no-cache-dir "./llm-foundry${DEP_GROUPS}" +RUN pip uninstall -y llm-foundry +RUN rm -rf llm-foundry diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index bea6284fb5..39fa7162ac 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -17,6 +17,22 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +def is_flash_v2_installed(): + try: + import flash_attn as flash_attn + except: + return False + return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + + +def is_flash_v1_installed(): + try: + import flash_attn as flash_attn + except: + return False + return version.parse(flash_attn.__version__) < version.parse('2.0.0') + + def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool: # disable causal when it is not needed @@ -197,7 +213,8 @@ def flash_attn_fn( try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: - raise RuntimeError('Please install flash-attn==1.0.3.post0') + raise RuntimeError( + 'Please install flash-attn==1.0.9 or flash-attn==2.3.2') check_valid_inputs(query, key, value) @@ -278,18 +295,35 @@ def flash_attn_fn( reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func( - query_unpad, - key_unpad, - value_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale=softmax_scale, - causal=reset_is_causal, - return_attn_probs=needs_weights) + if is_flash_v1_installed(): + output_unpad = flash_attn_interface.flash_attn_unpadded_func( + q=query_unpad, + k=key_unpad, + v=value_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights) + elif is_flash_v2_installed(): + output_unpad = flash_attn_interface.flash_attn_varlen_func( + q=query_unpad, + k=key_unpad, + v=value_unpad, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=reset_is_causal, + return_attn_probs=needs_weights) + else: + raise RuntimeError( + 'flash-attn==1.0.9 or flash-attn==2.3.2 is required.') output = bert_padding.pad_input( rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, @@ -321,7 +355,7 @@ def triton_flash_attn_fn( if version.parse(torch.__version__) < version.parse('2.0.0'): _installed = True # if torch1.13.1 revert to using triton flash attn from HazyResearch - # with flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202 + # with flash-attn==1.0.9 and triton==2.0.0.dev20221202 try: from flash_attn.flash_attn_triton import flash_attn_func except: diff --git a/setup.py b/setup.py index be5b6708a3..a686dd0808 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,12 @@ # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.9#subdirectory=csrc/xentropy', ] +extra_deps['gpu-flash2'] = [ + 'flash-attn==2.3.2', + 'mosaicml-turbo==0.0.4', + # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI + 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v2.3.2#subdirectory=csrc/xentropy', +] extra_deps['peft'] = [ 'loralib==0.1.1', # lora core @@ -107,7 +113,10 @@ ] extra_deps['all-cpu'] = set( dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key) -extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps) +extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps + if key != 'gpu-flash2') +extra_deps['all-flash2'] = set( + dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu') setup( name=_PACKAGE_NAME,