Skip to content

Commit

Permalink
update to torch 2.5, drop FA3 (#934)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesfrye authored Oct 19, 2024
1 parent a397c20 commit 9283487
Showing 1 changed file with 14 additions and 41 deletions.
55 changes: 14 additions & 41 deletions 06_gpu_and_ml/stable_diffusion/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# tags: ["use-case-image-video-3d", "featured"]
# ---

# # Run Flux fast with Flash Attention 3 and `torch.compile` on Hopper GPUs
# # Run Flux fast with `torch.compile` on Hopper GPUs

# In this guide, we'll run Flux as fast as possible on Modal using open source tools.
# We'll use `torch.compile`, Flash Attention 3, and NVIDIA H100 GPUs.
# We'll use `torch.compile` and NVIDIA H100 GPUs.

# ## Setting up the image and dependencies

Expand All @@ -31,11 +31,11 @@

# Now we install most of our dependencies with `apt` and `pip`.
# For Hugging Face's [Diffusers](https://github.com/huggingface/diffusers) library
# and for Flash Attention 3, we install from GitHub source
# and so pin to a specific commit.
# we install from GitHub source and so pin to a specific commit.

# PyTorch added [faster attention kernels for Hopper GPUs in version 2.5

diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b"
flash_commit_sha = "53a4f341634fcbc96bb999a3c804c192ea14f2ea"

flux_image = cuda_dev_image.apt_install(
"git",
Expand All @@ -51,53 +51,26 @@
"accelerate==0.33.0",
"safetensors==0.4.4",
"sentencepiece==0.2.0",
"ninja==1.11.1.1",
"packaging==24.1",
"wheel==0.44.0",
"torch==2.4.1",
"torch==2.5.0",
f"git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha}",
"numpy<2",
)
# ### Installing and compiling Flash Attention 3

# [Flash Attention (FA3)](https://github.com/Dao-AILab/flash-attention)
# is a library of optimized CUDA kernels that
# make attention blocks in Transformers go
# [brrrr](https://horace.io/brrr_intro.html) on Hopper GPUs

# Flash Attention kernels break attention operations into tiny blocks that can be
# computed in the high-bandwidth SRAM of the GPU's streaming multiprocessors
# with minimal communication to the lower-latency off-chip DRAM.
# Check out Aleksa Gordic's ELI5 [here](https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad) for more.

# Flash Attention 3 applies additional optimizations to make maximal use of specific features of NVIDIA's Hopper GPUs,
# like hardware-accelerated indexing and matrix-multiplication accelerators.
# Read Tri Dao's blog post [here](https://tridao.me/blog/2024/flash3/) for details.

# To use it, we clone the repo and build the library from source.

flux_fa3_image = flux_image.run_commands(
# build Flash Attention 3 from source and add it to PYTHONPATH
"ln -s /usr/bin/g++ /usr/bin/clang++", # use clang as cpp compiler
"git clone https://github.com/Dao-AILab/flash-attention.git",
f"cd flash-attention && git checkout {flash_commit_sha}",
"cd flash-attention/hopper && python setup.py install",
).env({"PYTHONPATH": "/root/flash-attention/hopper"})

# Later, we'll also use `torch.compile` to increase the speed further.
# This requires a few environment variables to be set.
# Torch compilation needs to be re-executed when each new container starts,
# So we turn on some extra caching to reduce compile times for later containers.

flux_fa3_image = flux_fa3_image.env(
flux_image = flux_image.env(
{"TORCHINDUCTOR_CACHE_DIR": "/root/.inductor-cache"}
).env({"TORCHINDUCTOR_FX_GRAPH_CACHE": "1"})

# Finally, we construct our Modal [App](https://modal.com/docs/reference/modal.App),
# set its default image to the one we just constructed,
# and import `FluxPipeline` for downloading and running Flux.1.

app = modal.App("example-flux", image=flux_fa3_image)
app = modal.App("example-flux", image=flux_image)

with flux_fa3_image.imports():
with flux_image.imports():
import torch
from diffusers import FluxPipeline

Expand All @@ -117,7 +90,7 @@


@app.cls(
gpu="H100", # FA3 is tuned for Hopper
gpu="H100", # fastest GPU on Modal
container_idle_timeout=20 * MINUTES,
timeout=60 * MINUTES, # leave plenty of time for compilation
volumes={ # add Volumes to store serializable compilation artifacts, see section on torch.compile below
Expand Down Expand Up @@ -188,7 +161,7 @@ def inference(self, prompt: str) -> bytes:
# ```

# By default, we call `generate` twice to demonstrate how much faster
# the inference is after cold start. In our tests, clients received images in about 1.5 seconds.
# the inference is after cold start. In our tests, clients received images in about 1.2 seconds.
# We save the output bytes to a temporary file.


Expand Down Expand Up @@ -231,7 +204,7 @@ def main(
# which we verified with our own internal benchmarks.
# Review that guide for detailed explanations of the choices made below.

# The resulting compiled Flux `schnell` deployment returns images to the client in under a second (~800 ms), according to our testing.
# The resulting compiled Flux `schnell` deployment returns images to the client in under a second (~700 ms), according to our testing.
# _Super schnell_!

# Compilation takes up to twenty minutes on first iteration.
Expand Down

0 comments on commit 9283487

Please sign in to comment.