From cb9179180454426f0ee69b949422c4a16225d2ba Mon Sep 17 00:00:00 2001 From: Charles Frye Date: Tue, 1 Oct 2024 18:26:51 -0400 Subject: [PATCH] pulling in flux improvements from ShariqM fork (#904) * more comments, run inference twice to show latency difference. * lock flash attention commit sha * minor text fixes * comment improving * more improvements * optional compilation, text updates, more links * minor text fixes * minor text fixes --------- Co-authored-by: Shariq Mobin --- 06_gpu_and_ml/stable_diffusion/flux.py | 292 ++++++++++++++++++++----- 1 file changed, 239 insertions(+), 53 deletions(-) diff --git a/06_gpu_and_ml/stable_diffusion/flux.py b/06_gpu_and_ml/stable_diffusion/flux.py index 91def06e4..ecb2bd5f9 100644 --- a/06_gpu_and_ml/stable_diffusion/flux.py +++ b/06_gpu_and_ml/stable_diffusion/flux.py @@ -1,96 +1,282 @@ # --- # output-directory: "/tmp/flux" +# args: ["--no-compile"] # --- -# # Run Flux.1 (Schnell) on Modal -# -# This example runs the popular [Flux.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) text-to-image model on Modal. -# -# Thanks to [@Arro](https://github.com/Arro) for the original contribution. + +# # Run Flux fast with Flash Attention 3 and `torch.compile` on Hopper GPUs + +# In this guide, we'll run Flux as fast as possible on Modal using open source tools. +# We'll use `torch.compile`, Flash Attention 3, and NVIDIA H100 GPUs. + +# ## Setting up the image and dependencies + +import time from io import BytesIO from pathlib import Path import modal -VARIANT = "schnell" # or "dev", but note [dev] requires you to accept terms and conditions on HF +# We'll make use of the full [CUDA toolkit](https://modal.com/docs/guide/cuda) +# in this example, so we'll build our container image off of the `nvidia/cuda` base. -diffusers_commit_sha = "1fcb811a8e6351be60304b1d4a4a749c36541651" - -flux_image = ( - modal.Image.debian_slim(python_version="3.12") - .apt_install( - "git", - "libglib2.0-0", - "libsm6", - "libxrender1", - "libxext6", - "ffmpeg", - "libgl1", - ) - .run_commands( - f"pip install git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha} 'numpy<2'" - ) - .pip_install( - "invisible_watermark==0.2.0", - "transformers==4.44.0", - "accelerate==0.33.0", - "safetensors==0.4.4", - "sentencepiece==0.2.0", - ) +cuda_version = "12.4.0" # should be no greater than host CUDA version +flavor = "devel" # includes full CUDA toolkit +operating_sys = "ubuntu22.04" +tag = f"{cuda_version}-{flavor}-{operating_sys}" + +cuda_dev_image = modal.Image.from_registry( + f"nvidia/cuda:{tag}", add_python="3.11" +).entrypoint([]) + +# Now we install most of our dependencies with `apt` and `pip`. +# For Hugging Face's [Diffusers](https://github.com/huggingface/diffusers) library +# and for Flash Attention 3, we install from GitHub source +# and so pin to a specific commit. + +diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b" +flash_commit_sha = "53a4f341634fcbc96bb999a3c804c192ea14f2ea" + +flux_image = cuda_dev_image.apt_install( + "git", + "libglib2.0-0", + "libsm6", + "libxrender1", + "libxext6", + "ffmpeg", + "libgl1", +).pip_install( + "invisible_watermark==0.2.0", + "transformers==4.44.0", + "accelerate==0.33.0", + "safetensors==0.4.4", + "sentencepiece==0.2.0", + "ninja==1.11.1.1", + "packaging==24.1", + "wheel==0.44.0", + "torch==2.4.1", + f"git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha}", + "numpy<2", ) +# ### Installing and compiling Flash Attention 3 + +# [Flash Attention (FA3)](https://github.com/Dao-AILab/flash-attention) +# is a library of optimized CUDA kernels that +# make attention blocks in Transformers go +# [brrrr](https://horace.io/brrr_intro.html) on Hopper GPUs + +# Flash Attention kernels break attention operations into tiny blocks that can be +# computed in the high-bandwidth SRAM of the GPU's streaming multiprocessors +# with minimal communication to the lower-latency off-chip DRAM. +# Check out Aleksa Gordic's ELI5 [here](https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad) for more. + +# Flash Attention 3 applies additional optimizations to make maximal use of specific features of NVIDIA's Hopper GPUs, +# like hardware-accelerated indexing and matrix-multiplication accelerators. +# Read Tri Dao's blog post [here](https://tridao.me/blog/2024/flash3/) for details. + +# To use it, we clone the repo and build the library from source. + +flux_fa3_image = flux_image.run_commands( + # build Flash Attention 3 from source and add it to PYTHONPATH + "ln -s /usr/bin/g++ /usr/bin/clang++", # use clang as cpp compiler + "git clone https://github.com/Dao-AILab/flash-attention.git", + f"cd flash-attention && git checkout {flash_commit_sha}", + "cd flash-attention/hopper && python setup.py install", +).env({"PYTHONPATH": "/root/flash-attention/hopper"}) -app = modal.App("example-flux") +# Finally, we construct our Modal [App](https://modal.com/docs/reference/modal.App), +# set its default image to the one we just constructed, +# and import `FluxPipeline` for downloading and running Flux.1. -with flux_image.imports(): +app = modal.App("example-flux", image=flux_fa3_image) + +with flux_fa3_image.imports(): import torch from diffusers import FluxPipeline +# ## Defining a parameterized `Model` inference class + +# Next, we map the model's setup and inference code onto Modal. + +# 1. We run any setup that can be persisted to disk in methods decorated with `@build`. +# In this example, that includes downloading the model weights. +# 2. We run any additional setup, like moving the model to the GPU, in methods decorated with `@enter`. +# We do our model optimizations in this step. For details, see the section on `torch.compile` below. +# 3. We run the actual inference in methods decorated with `@method`. + +MINUTES = 60 # seconds +VARIANT = "schnell" # or "dev", but note [dev] requires you to accept terms and conditions on HF +NUM_INFERENCE_STEPS = 4 # use ~50 for [dev], smaller for [schnell] + @app.cls( - gpu=modal.gpu.A100(size="40GB"), - container_idle_timeout=100, - image=flux_image, + gpu="H100", # FA3 is tuned for Hopper + container_idle_timeout=20 * MINUTES, + volumes={ # add Volumes to store serializable compilation artifacts + "/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True), + "/root/.triton": modal.Volume.from_name( + "triton-cache", create_if_missing=True + ), + }, ) class Model: - @modal.build() - @modal.enter() - def enter(self): + compile: int = ( # see section on torch.compile below for details + modal.parameter() + ) + + def setup_model(self): from huggingface_hub import snapshot_download from transformers.utils import move_cache snapshot_download(f"black-forest-labs/FLUX.1-{VARIANT}") - self.pipe = FluxPipeline.from_pretrained( + move_cache() + + pipe = FluxPipeline.from_pretrained( f"black-forest-labs/FLUX.1-{VARIANT}", torch_dtype=torch.bfloat16 ) - self.pipe.to("cuda") - move_cache() + + return pipe + + @modal.build() + def build(self): + self.setup_model() + + @modal.enter() + def enter(self): + pipe = self.setup_model() + pipe.to("cuda") # move model to GPU + self.pipe = optimize(pipe, compile=bool(self.compile)) @modal.method() - def inference(self, prompt): - print("Generating image...") + def inference(self, prompt: str) -> bytes: + print("🎨 generating image...") out = self.pipe( prompt, output_type="pil", - num_inference_steps=4, # use a larger number if you are using [dev], smaller for [schnell] + num_inference_steps=NUM_INFERENCE_STEPS, ).images[0] - print("Generated.") byte_stream = BytesIO() out.save(byte_stream, format="JPEG") return byte_stream.getvalue() +# ## Calling our inference function + +# To generate an image we just need to call the `Model`'s `generate` method +# with `.remote` appended to it. +# You can call `.generate.remote` from any Python environment that has access to your Modal credentials. +# The local environment will get back the image as bytes. + +# Here, we wrap the call in a Modal [`local_entrypoint`](https://modal.com/docs/reference/modal.App#local_entrypoint) +# so that it can be run with `modal run`: + +# ```bash +# modal run flux.py +# ``` + +# By default, we call `generate` twice to demonstrate how much faster +# the inference is after cold start. In our tests, clients received images in about 1.5 seconds. +# We save the output bytes to a temporary file. + + @app.local_entrypoint() def main( - prompt: str = "a computer screen showing ASCII terminal art of the word 'Modal' in neon green. two programmers are pointing excitedly at the screen.", + prompt: str = "a computer screen showing ASCII terminal art of the" + " word 'Modal' in neon green. two programmers are pointing excitedly" + " at the screen.", + twice: bool = True, + compile: bool = False, ): - image_bytes = Model().inference.remote(prompt) + t0 = time.time() + image_bytes = Model(compile=compile).inference.remote(prompt) + print(f"🎨 first inference latency: {time.time() - t0:.2f} seconds") + + if twice: + t0 = time.time() + image_bytes = Model(compile=compile).inference.remote(prompt) + print(f"🎨 second inference latency: {time.time() - t0:.2f} seconds") + + output_path = Path("/tmp") / "flux" / "output.jpg" + output_path.parent.mkdir(exist_ok=True, parents=True) + print(f"🎨 saving output to {output_path}") + output_path.write_bytes(image_bytes) + + +# ## Speeding up Flux with `torch.compile` + +# By default, we do some basic optimizations, like adjusting memory layout +# and re-expressing the attention head projections as a single matrix multiplication. +# But there are additional speedups to be had! + +# PyTorch 2 added a compiler that optimizes the +# compute graphs created dynamically during PyTorch execution. +# This feature helps close the gap with the performance of static graph frameworks +# like TensorRT and TensorFlow. + +# Here, we follow the suggestions from Hugging Face's +# [guide to fast diffusion inference](https://huggingface.co/docs/diffusers/en/tutorials/fast_diffusion), +# which we verified with our own internal benchmarks. +# Review that guide for detailed explanations of the choices made below. + +# The resulting compiled Flux `schnell` deployment returns images to the client in under a second (~800 ms), according to our testing. +# _Super schnell_! + +# Compilation takes up to twenty minutes and, at time of writing in October 2024, +# the compilation artifacts cannot be serialized, +# so compilation work must be re-executed every time a new container is started. +# That includes when scaling up an existing deployment or the first time a Function is invoked with `modal run`. + +# You can turn on compilation with the `--compile` flag. +# Try it out with: + +# ```bash +# modal run flux.py --compile +# ``` + +# The `compile` option is passed by a [`modal.parameter`](https://modal.com/docs/reference/modal.parameter#modalparameter) on our class. +# Each different choice for a `parameter` creates a [separate auto-scaling deployment](https://modal.com/docs/guide/parameterized-functions). +# That means your client can use arbitrary logic to decide whether to hit a compiled or eager endpoint. + + +def optimize(pipe, compile=True): + # fuse QKV projections in Transformer and VAE + pipe.transformer.fuse_qkv_projections() + pipe.vae.fuse_qkv_projections() + + # switch memory layout to Torch's preferred, channels_last + pipe.transformer.to(memory_format=torch.channels_last) + pipe.vae.to(memory_format=torch.channels_last) + + if not compile: + return pipe + + # set torch compile flags + config = torch._inductor.config + config.disable_progress = False # show progress bar + config.conv_1x1_as_mm = True # treat 1x1 convolutions as matrix muls + # adjust autotuning algorithm + config.coordinate_descent_tuning = True + config.coordinate_descent_check_all_directions = True + config.epilogue_fusion = False # do not fuse pointwise ops into matmuls + + # tag the compute-intensive modules, the Transformer and VAE decoder, for compilation + pipe.transformer = torch.compile( + pipe.transformer, mode="max-autotune", fullgraph=True + ) + pipe.vae.decode = torch.compile( + pipe.vae.decode, mode="max-autotune", fullgraph=True + ) + + # trigger torch compilation + print("🔦 running torch compiliation (may take up to 20 minutes)...") + + pipe( + "dummy prompt to trigger torch compilation", + output_type="pil", + num_inference_steps=NUM_INFERENCE_STEPS, # use ~50 for [dev], smaller for [schnell] + ).images[0] - dir = Path("/tmp/flux") - if not dir.exists(): - dir.mkdir(exist_ok=True, parents=True) + print("🔦 finished torch compilation") - output_path = dir / "output.jpg" - print(f"Saving it to {output_path}") - with open(output_path, "wb") as f: - f.write(image_bytes) + return pipe