diff --git a/06_gpu_and_ml/stable_diffusion/flux.py b/06_gpu_and_ml/stable_diffusion/flux.py index 1aed50974..40f8a0361 100644 --- a/06_gpu_and_ml/stable_diffusion/flux.py +++ b/06_gpu_and_ml/stable_diffusion/flux.py @@ -4,10 +4,10 @@ # tags: ["use-case-image-video-3d", "featured"] # --- -# # Run Flux fast with Flash Attention 3 and `torch.compile` on Hopper GPUs +# # Run Flux fast with `torch.compile` on Hopper GPUs # In this guide, we'll run Flux as fast as possible on Modal using open source tools. -# We'll use `torch.compile`, Flash Attention 3, and NVIDIA H100 GPUs. +# We'll use `torch.compile` and NVIDIA H100 GPUs. # ## Setting up the image and dependencies @@ -31,11 +31,11 @@ # Now we install most of our dependencies with `apt` and `pip`. # For Hugging Face's [Diffusers](https://github.com/huggingface/diffusers) library -# and for Flash Attention 3, we install from GitHub source -# and so pin to a specific commit. +# we install from GitHub source and so pin to a specific commit. + +# PyTorch added [faster attention kernels for Hopper GPUs in version 2.5 diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b" -flash_commit_sha = "53a4f341634fcbc96bb999a3c804c192ea14f2ea" flux_image = cuda_dev_image.apt_install( "git", @@ -51,43 +51,16 @@ "accelerate==0.33.0", "safetensors==0.4.4", "sentencepiece==0.2.0", - "ninja==1.11.1.1", - "packaging==24.1", - "wheel==0.44.0", - "torch==2.4.1", + "torch==2.5.0", f"git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha}", "numpy<2", ) -# ### Installing and compiling Flash Attention 3 - -# [Flash Attention (FA3)](https://github.com/Dao-AILab/flash-attention) -# is a library of optimized CUDA kernels that -# make attention blocks in Transformers go -# [brrrr](https://horace.io/brrr_intro.html) on Hopper GPUs - -# Flash Attention kernels break attention operations into tiny blocks that can be -# computed in the high-bandwidth SRAM of the GPU's streaming multiprocessors -# with minimal communication to the lower-latency off-chip DRAM. -# Check out Aleksa Gordic's ELI5 [here](https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad) for more. - -# Flash Attention 3 applies additional optimizations to make maximal use of specific features of NVIDIA's Hopper GPUs, -# like hardware-accelerated indexing and matrix-multiplication accelerators. -# Read Tri Dao's blog post [here](https://tridao.me/blog/2024/flash3/) for details. - -# To use it, we clone the repo and build the library from source. - -flux_fa3_image = flux_image.run_commands( - # build Flash Attention 3 from source and add it to PYTHONPATH - "ln -s /usr/bin/g++ /usr/bin/clang++", # use clang as cpp compiler - "git clone https://github.com/Dao-AILab/flash-attention.git", - f"cd flash-attention && git checkout {flash_commit_sha}", - "cd flash-attention/hopper && python setup.py install", -).env({"PYTHONPATH": "/root/flash-attention/hopper"}) # Later, we'll also use `torch.compile` to increase the speed further. -# This requires a few environment variables to be set. +# Torch compilation needs to be re-executed when each new container starts, +# So we turn on some extra caching to reduce compile times for later containers. -flux_fa3_image = flux_fa3_image.env( +flux_image = flux_image.env( {"TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache"} ).env({"TORCHINDUCTOR_FX_GRAPH_CACHE": "1"}) @@ -95,9 +68,9 @@ # set its default image to the one we just constructed, # and import `FluxPipeline` for downloading and running Flux.1. -app = modal.App("example-flux", image=flux_fa3_image) +app = modal.App("example-flux", image=flux_image) -with flux_fa3_image.imports(): +with flux_image.imports(): import torch from diffusers import FluxPipeline @@ -117,7 +90,7 @@ @app.cls( - gpu="H100", # FA3 is tuned for Hopper + gpu="H100", # fastest GPU on Modal container_idle_timeout=20 * MINUTES, timeout=60 * MINUTES, # leave plenty of time for compilation volumes={ # add Volumes to store serializable compilation artifacts, see section on torch.compile below @@ -188,7 +161,7 @@ def inference(self, prompt: str) -> bytes: # ``` # By default, we call `generate` twice to demonstrate how much faster -# the inference is after cold start. In our tests, clients received images in about 1.5 seconds. +# the inference is after cold start. In our tests, clients received images in about 1.2 seconds. # We save the output bytes to a temporary file. @@ -231,7 +204,7 @@ def main( # which we verified with our own internal benchmarks. # Review that guide for detailed explanations of the choices made below. -# The resulting compiled Flux `schnell` deployment returns images to the client in under a second (~800 ms), according to our testing. +# The resulting compiled Flux `schnell` deployment returns images to the client in under a second (~700 ms), according to our testing. # _Super schnell_! # Compilation takes up to twenty minutes on first iteration.