Skip to content

Commit

Permalink
pulling in flux improvements from ShariqM fork (#904)
Browse files Browse the repository at this point in the history
* more comments, run inference twice to show latency difference.

* lock flash attention commit sha

* minor text fixes

* comment improving

* more improvements

* optional compilation, text updates, more links

* minor text fixes

* minor text fixes

---------

Co-authored-by: Shariq Mobin <[email protected]>
  • Loading branch information
charlesfrye and shariq-audiofocus authored Oct 1, 2024
1 parent 5648e6d commit cb91791
Showing 1 changed file with 239 additions and 53 deletions.
292 changes: 239 additions & 53 deletions 06_gpu_and_ml/stable_diffusion/flux.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,282 @@
# ---
# output-directory: "/tmp/flux"
# args: ["--no-compile"]
# ---
# # Run Flux.1 (Schnell) on Modal
#
# This example runs the popular [Flux.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) text-to-image model on Modal.
#
# Thanks to [@Arro](https://github.com/Arro) for the original contribution.

# # Run Flux fast with Flash Attention 3 and `torch.compile` on Hopper GPUs

# In this guide, we'll run Flux as fast as possible on Modal using open source tools.
# We'll use `torch.compile`, Flash Attention 3, and NVIDIA H100 GPUs.

# ## Setting up the image and dependencies

import time
from io import BytesIO
from pathlib import Path

import modal

VARIANT = "schnell" # or "dev", but note [dev] requires you to accept terms and conditions on HF
# We'll make use of the full [CUDA toolkit](https://modal.com/docs/guide/cuda)
# in this example, so we'll build our container image off of the `nvidia/cuda` base.

diffusers_commit_sha = "1fcb811a8e6351be60304b1d4a4a749c36541651"

flux_image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install(
"git",
"libglib2.0-0",
"libsm6",
"libxrender1",
"libxext6",
"ffmpeg",
"libgl1",
)
.run_commands(
f"pip install git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha} 'numpy<2'"
)
.pip_install(
"invisible_watermark==0.2.0",
"transformers==4.44.0",
"accelerate==0.33.0",
"safetensors==0.4.4",
"sentencepiece==0.2.0",
)
cuda_version = "12.4.0" # should be no greater than host CUDA version
flavor = "devel" # includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

cuda_dev_image = modal.Image.from_registry(
f"nvidia/cuda:{tag}", add_python="3.11"
).entrypoint([])

# Now we install most of our dependencies with `apt` and `pip`.
# For Hugging Face's [Diffusers](https://github.com/huggingface/diffusers) library
# and for Flash Attention 3, we install from GitHub source
# and so pin to a specific commit.

diffusers_commit_sha = "81cf3b2f155f1de322079af28f625349ee21ec6b"
flash_commit_sha = "53a4f341634fcbc96bb999a3c804c192ea14f2ea"

flux_image = cuda_dev_image.apt_install(
"git",
"libglib2.0-0",
"libsm6",
"libxrender1",
"libxext6",
"ffmpeg",
"libgl1",
).pip_install(
"invisible_watermark==0.2.0",
"transformers==4.44.0",
"accelerate==0.33.0",
"safetensors==0.4.4",
"sentencepiece==0.2.0",
"ninja==1.11.1.1",
"packaging==24.1",
"wheel==0.44.0",
"torch==2.4.1",
f"git+https://github.com/huggingface/diffusers.git@{diffusers_commit_sha}",
"numpy<2",
)
# ### Installing and compiling Flash Attention 3

# [Flash Attention (FA3)](https://github.com/Dao-AILab/flash-attention)
# is a library of optimized CUDA kernels that
# make attention blocks in Transformers go
# [brrrr](https://horace.io/brrr_intro.html) on Hopper GPUs

# Flash Attention kernels break attention operations into tiny blocks that can be
# computed in the high-bandwidth SRAM of the GPU's streaming multiprocessors
# with minimal communication to the lower-latency off-chip DRAM.
# Check out Aleksa Gordic's ELI5 [here](https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad) for more.

# Flash Attention 3 applies additional optimizations to make maximal use of specific features of NVIDIA's Hopper GPUs,
# like hardware-accelerated indexing and matrix-multiplication accelerators.
# Read Tri Dao's blog post [here](https://tridao.me/blog/2024/flash3/) for details.

# To use it, we clone the repo and build the library from source.

flux_fa3_image = flux_image.run_commands(
# build Flash Attention 3 from source and add it to PYTHONPATH
"ln -s /usr/bin/g++ /usr/bin/clang++", # use clang as cpp compiler
"git clone https://github.com/Dao-AILab/flash-attention.git",
f"cd flash-attention && git checkout {flash_commit_sha}",
"cd flash-attention/hopper && python setup.py install",
).env({"PYTHONPATH": "/root/flash-attention/hopper"})

app = modal.App("example-flux")
# Finally, we construct our Modal [App](https://modal.com/docs/reference/modal.App),
# set its default image to the one we just constructed,
# and import `FluxPipeline` for downloading and running Flux.1.

with flux_image.imports():
app = modal.App("example-flux", image=flux_fa3_image)

with flux_fa3_image.imports():
import torch
from diffusers import FluxPipeline

# ## Defining a parameterized `Model` inference class

# Next, we map the model's setup and inference code onto Modal.

# 1. We run any setup that can be persisted to disk in methods decorated with `@build`.
# In this example, that includes downloading the model weights.
# 2. We run any additional setup, like moving the model to the GPU, in methods decorated with `@enter`.
# We do our model optimizations in this step. For details, see the section on `torch.compile` below.
# 3. We run the actual inference in methods decorated with `@method`.

MINUTES = 60 # seconds
VARIANT = "schnell" # or "dev", but note [dev] requires you to accept terms and conditions on HF
NUM_INFERENCE_STEPS = 4 # use ~50 for [dev], smaller for [schnell]


@app.cls(
gpu=modal.gpu.A100(size="40GB"),
container_idle_timeout=100,
image=flux_image,
gpu="H100", # FA3 is tuned for Hopper
container_idle_timeout=20 * MINUTES,
volumes={ # add Volumes to store serializable compilation artifacts
"/root/.nv": modal.Volume.from_name("nv-cache", create_if_missing=True),
"/root/.triton": modal.Volume.from_name(
"triton-cache", create_if_missing=True
),
},
)
class Model:
@modal.build()
@modal.enter()
def enter(self):
compile: int = ( # see section on torch.compile below for details
modal.parameter()
)

def setup_model(self):
from huggingface_hub import snapshot_download
from transformers.utils import move_cache

snapshot_download(f"black-forest-labs/FLUX.1-{VARIANT}")

self.pipe = FluxPipeline.from_pretrained(
move_cache()

pipe = FluxPipeline.from_pretrained(
f"black-forest-labs/FLUX.1-{VARIANT}", torch_dtype=torch.bfloat16
)
self.pipe.to("cuda")
move_cache()

return pipe

@modal.build()
def build(self):
self.setup_model()

@modal.enter()
def enter(self):
pipe = self.setup_model()
pipe.to("cuda") # move model to GPU
self.pipe = optimize(pipe, compile=bool(self.compile))

@modal.method()
def inference(self, prompt):
print("Generating image...")
def inference(self, prompt: str) -> bytes:
print("🎨 generating image...")
out = self.pipe(
prompt,
output_type="pil",
num_inference_steps=4, # use a larger number if you are using [dev], smaller for [schnell]
num_inference_steps=NUM_INFERENCE_STEPS,
).images[0]
print("Generated.")

byte_stream = BytesIO()
out.save(byte_stream, format="JPEG")
return byte_stream.getvalue()


# ## Calling our inference function

# To generate an image we just need to call the `Model`'s `generate` method
# with `.remote` appended to it.
# You can call `.generate.remote` from any Python environment that has access to your Modal credentials.
# The local environment will get back the image as bytes.

# Here, we wrap the call in a Modal [`local_entrypoint`](https://modal.com/docs/reference/modal.App#local_entrypoint)
# so that it can be run with `modal run`:

# ```bash
# modal run flux.py
# ```

# By default, we call `generate` twice to demonstrate how much faster
# the inference is after cold start. In our tests, clients received images in about 1.5 seconds.
# We save the output bytes to a temporary file.


@app.local_entrypoint()
def main(
prompt: str = "a computer screen showing ASCII terminal art of the word 'Modal' in neon green. two programmers are pointing excitedly at the screen.",
prompt: str = "a computer screen showing ASCII terminal art of the"
" word 'Modal' in neon green. two programmers are pointing excitedly"
" at the screen.",
twice: bool = True,
compile: bool = False,
):
image_bytes = Model().inference.remote(prompt)
t0 = time.time()
image_bytes = Model(compile=compile).inference.remote(prompt)
print(f"🎨 first inference latency: {time.time() - t0:.2f} seconds")

if twice:
t0 = time.time()
image_bytes = Model(compile=compile).inference.remote(prompt)
print(f"🎨 second inference latency: {time.time() - t0:.2f} seconds")

output_path = Path("/tmp") / "flux" / "output.jpg"
output_path.parent.mkdir(exist_ok=True, parents=True)
print(f"🎨 saving output to {output_path}")
output_path.write_bytes(image_bytes)


# ## Speeding up Flux with `torch.compile`

# By default, we do some basic optimizations, like adjusting memory layout
# and re-expressing the attention head projections as a single matrix multiplication.
# But there are additional speedups to be had!

# PyTorch 2 added a compiler that optimizes the
# compute graphs created dynamically during PyTorch execution.
# This feature helps close the gap with the performance of static graph frameworks
# like TensorRT and TensorFlow.

# Here, we follow the suggestions from Hugging Face's
# [guide to fast diffusion inference](https://huggingface.co/docs/diffusers/en/tutorials/fast_diffusion),
# which we verified with our own internal benchmarks.
# Review that guide for detailed explanations of the choices made below.

# The resulting compiled Flux `schnell` deployment returns images to the client in under a second (~800 ms), according to our testing.
# _Super schnell_!

# Compilation takes up to twenty minutes and, at time of writing in October 2024,
# the compilation artifacts cannot be serialized,
# so compilation work must be re-executed every time a new container is started.
# That includes when scaling up an existing deployment or the first time a Function is invoked with `modal run`.

# You can turn on compilation with the `--compile` flag.
# Try it out with:

# ```bash
# modal run flux.py --compile
# ```

# The `compile` option is passed by a [`modal.parameter`](https://modal.com/docs/reference/modal.parameter#modalparameter) on our class.
# Each different choice for a `parameter` creates a [separate auto-scaling deployment](https://modal.com/docs/guide/parameterized-functions).
# That means your client can use arbitrary logic to decide whether to hit a compiled or eager endpoint.


def optimize(pipe, compile=True):
# fuse QKV projections in Transformer and VAE
pipe.transformer.fuse_qkv_projections()
pipe.vae.fuse_qkv_projections()

# switch memory layout to Torch's preferred, channels_last
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

if not compile:
return pipe

# set torch compile flags
config = torch._inductor.config
config.disable_progress = False # show progress bar
config.conv_1x1_as_mm = True # treat 1x1 convolutions as matrix muls
# adjust autotuning algorithm
config.coordinate_descent_tuning = True
config.coordinate_descent_check_all_directions = True
config.epilogue_fusion = False # do not fuse pointwise ops into matmuls

# tag the compute-intensive modules, the Transformer and VAE decoder, for compilation
pipe.transformer = torch.compile(
pipe.transformer, mode="max-autotune", fullgraph=True
)
pipe.vae.decode = torch.compile(
pipe.vae.decode, mode="max-autotune", fullgraph=True
)

# trigger torch compilation
print("🔦 running torch compiliation (may take up to 20 minutes)...")

pipe(
"dummy prompt to trigger torch compilation",
output_type="pil",
num_inference_steps=NUM_INFERENCE_STEPS, # use ~50 for [dev], smaller for [schnell]
).images[0]

dir = Path("/tmp/flux")
if not dir.exists():
dir.mkdir(exist_ok=True, parents=True)
print("🔦 finished torch compilation")

output_path = dir / "output.jpg"
print(f"Saving it to {output_path}")
with open(output_path, "wb") as f:
f.write(image_bytes)
return pipe

0 comments on commit cb91791

Please sign in to comment.