Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trt support for BF16 #195

Open
wants to merge 279 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
279 commits
Select commit Hold shift + click to select a range
a4b0f13
Add link to fine-tunes collection on Replicate (#130)
zeke Aug 29, 2024
ed51d5e
Add Torch CUDA sync to fix timing code in cli.py (#147)
neilmovva Sep 13, 2024
bc22ee3
Update API interface for FLUX.1.1 [pro]
jenuk Oct 3, 2024
c5ebf2b
CLI: `/n` is for steps, not seeds (#169)
thibautRe Oct 3, 2024
933e54a
Update README.md
timudk Oct 7, 2024
d171e39
Update README.md
timudk Oct 8, 2024
a94a546
Remove unused import and extraneous `f` prefix (#171)
esadek Oct 8, 2024
16fc5e2
update readme for 1.1
timudk Oct 8, 2024
f8747c2
add question logger
andompesta Oct 1, 2024
bf71e81
add cli input for TRT support
andompesta Oct 1, 2024
223017d
initial support for TRT engine builder
andompesta Oct 1, 2024
fd7057e
add onnx export functions taken from
andompesta Oct 1, 2024
0713049
base class for convert to onnx
andompesta Oct 1, 2024
d0ba09a
add missing dependencies
andompesta Oct 1, 2024
959eaa7
fix imports
andompesta Oct 1, 2024
50149bd
moved to wrappers package
andompesta Oct 1, 2024
71c1b0d
moved to wrapper package and renamed into base wrapper
andompesta Oct 1, 2024
407a3ac
implement load engines function
andompesta Oct 1, 2024
ca691f1
remove old wrapper
andompesta Oct 1, 2024
78fbb25
add additional parameters to base class constructor
andompesta Oct 1, 2024
3548085
implemented CLIP wrapper
andompesta Oct 1, 2024
6b01977
remove model as a property
andompesta Oct 1, 2024
2477223
enable float16 optimization
andompesta Oct 1, 2024
969e06d
reorder arguments
andompesta Oct 1, 2024
f8a183b
first wrapper fir onnx build
andompesta Oct 1, 2024
5572ae7
add load_engines with minimal parameters
andompesta Oct 1, 2024
713a66e
fix `get_sample_input` interface format and add `get_model_to_trace`
andompesta Oct 1, 2024
b4bebd2
fix stage name
andompesta Oct 1, 2024
e56078c
ad imports for wrappers
andompesta Oct 1, 2024
f1fa53b
set assert error message for missing stage
andompesta Oct 2, 2024
0afcc8e
add assert to validate only 1 dtype is active
andompesta Oct 2, 2024
9b600b3
add `set_model_to_dtype` function to set the correct dtype
andompesta Oct 2, 2024
da78bef
call `set_model_to_dtype` instead of doing it manually
andompesta Oct 2, 2024
2c7201f
T5wrapper as a copy of CLIPwrapper
andompesta Oct 2, 2024
d2b2567
rename `get_model_to_trace` to `get_model`
andompesta Oct 2, 2024
d6c18c8
`do_constant_folding` should not be configurable
andompesta Oct 2, 2024
89f616e
removed custom configurations
andompesta Oct 2, 2024
44a3280
fix import
andompesta Oct 2, 2024
27c4268
add t5wrapper to possible imports
andompesta Oct 2, 2024
c35d4d7
use T5Wrapper
andompesta Oct 2, 2024
4e273e1
fluxwrapper as a copy of t5wrapper
andompesta Oct 2, 2024
e111c72
implement flux exporter
andompesta Oct 2, 2024
ec75372
use flux exporter
andompesta Oct 2, 2024
6acb1af
vae as copy of clip wrapper
andompesta Oct 2, 2024
6c941e6
rename and fix imports
andompesta Oct 2, 2024
1e11bc1
removed `embedding_dim` as not used
andompesta Oct 2, 2024
37ddc7e
remove embedding_dim and `forward` function point to `decode`
andompesta Oct 2, 2024
fa22eac
get_output_names = latent
andompesta Oct 2, 2024
eb4e7df
fix interface of `get_sample_input`
andompesta Oct 2, 2024
bf1cca6
save configuration parameters
andompesta Oct 2, 2024
358c8a5
ae wrapper implemented
andompesta Oct 2, 2024
381267d
fix import
andompesta Oct 2, 2024
a8af1d8
add AEWrapper step
andompesta Oct 2, 2024
a47608c
from set_model_to_dtype to prepare_model
andompesta Oct 3, 2024
ea420c5
fix eval mode during inference
andompesta Oct 3, 2024
af2f48b
fix clip onnx export. Now it trace ony the needed outputs
andompesta Oct 3, 2024
e6b66bb
fix t5 wrapper
andompesta Oct 3, 2024
cb188d8
reorder input name flux
andompesta Oct 3, 2024
54002de
fix flux input format for text_ids and guidance
andompesta Oct 4, 2024
1cdc0a8
fix Flux imports and scale of inputs to prevent nan
andompesta Oct 4, 2024
c1c3a8d
add torch inference while tracing
andompesta Oct 4, 2024
bb0cc66
fix casting problem in onnx trace
andompesta Oct 4, 2024
21ec7d9
solve optimization problem by removing cleanup steps
andompesta Oct 4, 2024
d6f5e2f
rename to notes
andompesta Oct 4, 2024
577ba49
prevent nan due to large inputs
andompesta Oct 4, 2024
dfd06fc
provide base implementation of `get_model`
andompesta Oct 4, 2024
54b2ceb
format
andompesta Oct 4, 2024
d7ccef4
add trt export step
andompesta Oct 6, 2024
505411b
add engine class for trt build
andompesta Oct 6, 2024
0154232
add `get_input_profile` and `get_minmax_dims` abstract methods
andompesta Oct 6, 2024
cc2d921
add `build_strongly_typed` attributed
andompesta Oct 6, 2024
a2fb731
implement `get_minmax_dims` and `get_input_profile`
andompesta Oct 6, 2024
0096f7a
remove `static_shape` from `get_sample_input`
andompesta Oct 6, 2024
dfb6ded
remove static sharpe and batch flags
andompesta Oct 7, 2024
50100b5
add typing
andompesta Oct 7, 2024
ea240be
remove static shape and batch flags
andompesta Oct 7, 2024
0c3720c
offload to cpu
andompesta Oct 7, 2024
30f0140
enable device offloading while tracing
andompesta Oct 7, 2024
f2b357a
check cuda is avaiable while building engines
andompesta Oct 7, 2024
a30ec20
clip trt engine build
andompesta Oct 7, 2024
dbeeed9
add pinned transformer dependency
andompesta Oct 9, 2024
0682915
fix nan with onnx and trt when executed on CUDA
andompesta Oct 9, 2024
bef25e0
AE need to be traced in TF32 not FP16
andompesta Oct 9, 2024
c028d8d
add `get_shape_dict` abstract method and device as a property
andompesta Oct 9, 2024
8208e4c
AE should be traced in TF32
andompesta Oct 9, 2024
816ff12
AE explicitly on TF32 and reactivate full pipeline
andompesta Oct 9, 2024
3a341f8
add input provile to flux to enable trt engine build
andompesta Oct 9, 2024
7aa6956
format and add input_profile to t5 for TRT build
andompesta Oct 9, 2024
e68a993
add `TransformersModelWrapper`
andompesta Oct 9, 2024
ea581b7
add TransformersModelWrapper support
andompesta Oct 9, 2024
7e883d5
add `get_shape_dict` interface
andompesta Oct 9, 2024
5080d86
add TransformersModelWrapper support
andompesta Oct 9, 2024
e2b65c4
add shape_dict interface
andompesta Oct 9, 2024
87413e2
t5 in TF32 for numerical reasons
andompesta Oct 11, 2024
8629e50
remove unused options
andompesta Oct 11, 2024
5e711c7
remove unused code
andompesta Oct 11, 2024
02235dc
add `get_shape_dict`
andompesta Oct 11, 2024
6c3c4db
remove custom optimization
andompesta Oct 11, 2024
4b8a973
add garbage collector
andompesta Oct 14, 2024
8e4b103
return error
andompesta Oct 14, 2024
8f45f81
create wrapper specific to Onnx export operatio
andompesta Oct 14, 2024
3af1a33
user OnnxWrapper
andompesta Oct 14, 2024
fe024b8
create base wrapper for trt engines
andompesta Oct 14, 2024
68060bd
moved to engine package
andompesta Oct 14, 2024
0f8d8b3
moved to engine package
andompesta Oct 14, 2024
49dc6d1
forbit relative import of trt-builder
andompesta Oct 14, 2024
098391b
remove wrapper and create BaseExporter or BaseEngine
andompesta Oct 14, 2024
bf9c4cb
models not stored in builder class
andompesta Oct 14, 2024
0ee9104
_prepare_model_configs as pure function
andompesta Oct 14, 2024
c7136f8
_get_onnx_exporters as a private method to get onnx exporters
andompesta Oct 14, 2024
ee72695
remove unused dependencies
andompesta Oct 14, 2024
ecf6c4f
from onnxwrapper to onnxengine
andompesta Oct 14, 2024
2a14000
trt engine class
andompesta Oct 14, 2024
c791c53
add `calculate_max_device_memory` to TRTBuilder
andompesta Oct 14, 2024
ce343dc
`get_shape_dict` moved to trt-engine interface
andompesta Oct 14, 2024
66ca1ce
add common inference code
andompesta Oct 14, 2024
7400072
autoencder inference wrapper
andompesta Oct 14, 2024
aa0d474
add requirements.txt
Oct 16, 2024
d676a18
support guidance for ev model
andompesta Oct 16, 2024
550f660
ad support for trt based on evn variables
andompesta Oct 16, 2024
fa5993b
format flux
andompesta Oct 16, 2024
bdbbb19
remove stream from constructor
andompesta Oct 16, 2024
f1d86f6
fix iterate over onnx-exporters
andompesta Oct 16, 2024
f065b09
flux is not strongly type
andompesta Oct 16, 2024
c57410a
move back for numerical stability
andompesta Oct 16, 2024
69f4dca
add logging
andompesta Oct 16, 2024
cc12a14
fix dtype casting for bfloat16
andompesta Oct 17, 2024
961259e
fix default value
andompesta Oct 17, 2024
6e1ca02
add version before merge
andompesta Oct 18, 2024
7217a7b
hacky get it building the engines
ducktrA Oct 15, 2024
c5481a1
requirements.txt
ducktrA Oct 17, 2024
54674c3
adding a seperate _engine.py file for all the flux, t5 and clip engine
ducktrA Oct 18, 2024
37003c7
boilerroom and plating. getting parameters handle into setting up the…
ducktrA Oct 18, 2024
fd33eb5
remove _version.py from git
andompesta Oct 18, 2024
99e72e9
create base mixin class to share parameters
andompesta Oct 18, 2024
6678a3b
clipmixin parameters
andompesta Oct 18, 2024
395541d
remove parameters as are part of mixin class
andompesta Oct 18, 2024
315dd9d
clip engine and exporter use common mixin for managing parameters
andompesta Oct 18, 2024
7cdbb03
use mixin cass to build engine from exporter
andompesta Oct 18, 2024
55497eb
ae-mixin for shared parameters
andompesta Oct 18, 2024
5917f38
flux exporter and engine unified by mixin class
andompesta Oct 21, 2024
7c156cd
formatting
andompesta Oct 21, 2024
92f13f8
add common `get_latent_dims` method
andompesta Oct 21, 2024
f5acd54
add `get_latent_dims` common method
andompesta Oct 21, 2024
8b182cc
T5 based on mixin class
andompesta Oct 21, 2024
11570dc
build strongly typed flux
andompesta Oct 21, 2024
a9acfa0
enable load with shared device memory
andompesta Oct 21, 2024
c6e94a6
remove boilderpart code to create engines
andompesta Oct 21, 2024
7b07602
add tokenizer to trt engine
andompesta Oct 22, 2024
2dc2460
use static shape for reduce memory consumption
andompesta Oct 22, 2024
40de55c
implemnet tokenizer into t5 engine
andompesta Oct 22, 2024
c8273c7
mix max_batch size to 8
andompesta Oct 22, 2024
b96fd96
add licence
andompesta Oct 22, 2024
6743bb7
add licence
andompesta Oct 22, 2024
852b444
enable trt runtime tracking
andompesta Oct 22, 2024
95f7822
add static-batch and static-shape options
andompesta Oct 22, 2024
8ac3f84
add cuda steam to load method
andompesta Oct 22, 2024
f93fc87
add inference code
andompesta Oct 22, 2024
528621a
add inference code
andompesta Oct 22, 2024
23e1236
enable static shape
andompesta Oct 22, 2024
dc326df
add `static_shape` option to reduce memory and `_build_engine` as sta…
andompesta Oct 22, 2024
7e3fe14
add `should_be_dtype` filed to handle output type conversion
andompesta Oct 22, 2024
41f18e7
from trtbuilder to trt_manager
andompesta Oct 22, 2024
12dee48
from TRTBuilder to TRTManager
andompesta Oct 23, 2024
45997a9
AE engine interface
andompesta Oct 23, 2024
bb9f468
`trt_to_torch_dtype_dict` as property
andompesta Oct 23, 2024
2bde369
clip engine inference
andompesta Oct 23, 2024
359572e
implement flux trt engine inference process
andompesta Oct 23, 2024
e3f0fd9
add scale_factor and shift_factor
andompesta Oct 23, 2024
d91bbde
removed `should_be_dtype`
andompesta Oct 23, 2024
df245db
removed `should_be_dtype`
andompesta Oct 23, 2024
33bc095
remove `should_be_dtype` from t5
andompesta Oct 23, 2024
c330491
add scale and shift factor
andompesta Oct 23, 2024
90b4f11
`max_batch` to 8
andompesta Oct 23, 2024
17c1f7d
implement `TRTManager`
andompesta Oct 23, 2024
811f2ff
from ae to vae to match DD
andompesta Oct 25, 2024
f4ae3ca
remove autocast
andompesta Oct 25, 2024
0fe7c84
`pooled_embeddings` to match DD naming for clip
andompesta Oct 25, 2024
f71091a
rename `flux` to `transformer` engine
andompesta Oct 25, 2024
4055a3e
from flux to transformer mixin
andompesta Oct 25, 2024
2b2bb5b
from flux to transforemer exporter
andompesta Oct 25, 2024
b088430
fix trtmanger with naming
andompesta Oct 25, 2024
82d658d
fix inputs names and dimentions. Nota that `img_ids` and `txt_ids` ar…
andompesta Oct 25, 2024
3708773
fix shape of inputs according to `text_maxlen` and batch_size
andompesta Oct 25, 2024
7737426
reduce max_batch
andompesta Oct 27, 2024
917d8ff
fix stage naming
andompesta Oct 27, 2024
6473ca1
add support for DD model
andompesta Oct 27, 2024
6d39ad5
add support for DD models
andompesta Oct 27, 2024
753129b
fix dtype configuration
andompesta Oct 28, 2024
149c27c
fix enginge dtype
andompesta Oct 28, 2024
55568bf
trensformers inference interface to match DD
andompesta Oct 28, 2024
4872169
vae inference script dtype mapping
andompesta Oct 28, 2024
41ee44c
remove dtype checks as multiples can be actives
andompesta Oct 28, 2024
a31161d
by default tf32 always active
andompesta Oct 28, 2024
3b91c51
fix trt enginges names
andompesta Nov 11, 2024
4ebca7d
add wrapper for fluxmodel to match DD onnx configuration
andompesta Nov 11, 2024
3e9f64f
add autocast back in to match DD setup
andompesta Nov 11, 2024
bb82e4b
fix dependencies for trt support
andompesta Nov 14, 2024
830358e
support trt
andompesta Nov 14, 2024
cdce3a3
add explicit kwargs
andompesta Nov 14, 2024
b789e05
vscode setup
andompesta Nov 14, 2024
8b07e6e
add setup instructions for trt
andompesta Nov 14, 2024
5ffd6d6
`trt` dependencies not part of `all`
andompesta Nov 14, 2024
766d878
from onnx_exporter to exporter
andompesta Nov 14, 2024
6d83690
hide onnx parameters
andompesta Nov 14, 2024
2458486
from onnx-exporter to exporter
andompesta Nov 14, 2024
80a52d7
exporter responsible to build trt engine and onnx exportr
andompesta Nov 14, 2024
adf2d46
hide onnx parameter
andompesta Nov 14, 2024
e82311f
remove build function from engine class
andompesta Nov 14, 2024
17f6562
remove unused import
andompesta Nov 14, 2024
2512bb2
remove space
andompesta Nov 14, 2024
86614a3
manage t5 and vae separately
andompesta Nov 14, 2024
f14de69
disable autocast
andompesta Nov 14, 2024
3410d34
stronglytyped t5
andompesta Nov 14, 2024
2422538
fix input type and max image size
andompesta Nov 14, 2024
9bef65b
max image size
andompesta Nov 14, 2024
a3bd8fc
T5 not strongly typed
andompesta Nov 14, 2024
e615fa0
testing
andompesta Nov 14, 2024
611efed
fix torch sycronize problem
andompesta Nov 14, 2024
13b1016
don't build already present engines
andompesta Nov 14, 2024
01b508c
remove torch save
andompesta Nov 14, 2024
f57b5a5
removed onnx dependencies
andompesta Nov 14, 2024
9cffa24
add trt dependencies
andompesta Nov 14, 2024
63e29cc
remove trt dependencies from toml
andompesta Nov 14, 2024
c978cc3
rename requirements and fix readme
andompesta Nov 14, 2024
3087c60
remove unused files
andompesta Nov 14, 2024
5c2cba1
fix import format
andompesta Nov 14, 2024
08fbb60
remove comments
andompesta Nov 14, 2024
1b4a41a
add gitignore
andompesta Nov 15, 2024
a404144
reset dependencies
andompesta Nov 15, 2024
a8b8478
add hidden setup files
andompesta Nov 15, 2024
8fa1d22
solve ruff check
andompesta Nov 15, 2024
3f20508
fix imports with rufs
andompesta Nov 15, 2024
7662313
run ruff formatter
andompesta Nov 15, 2024
4691502
update gitignore
andompesta Nov 15, 2024
deb5633
simplify dependencies
andompesta Nov 18, 2024
1de2799
remove gitignore
andompesta Nov 18, 2024
64cbb8f
add cli formatting
andompesta Nov 18, 2024
fd1455e
fix import orders
andompesta Nov 18, 2024
095ee89
Merge pull request #1 from andompesta/add-trt-support-push
andompesta Nov 18, 2024
3d3741e
simplify dependencies
andompesta Nov 18, 2024
f31ffd4
solve vae quality issue
andompesta Nov 26, 2024
728c018
Merge branch 'main' of https://github.com/black-forest-labs/flux
andompesta Nov 26, 2024
1cd9476
Merge branch 'main' into add-trt-support
andompesta Nov 26, 2024
bee6c45
Merge branch 'main' into add-trt-support-cli-conflict
andompesta Nov 26, 2024
f80058f
fix ruff format
andompesta Nov 26, 2024
079778f
fix merge changes
andompesta Nov 26, 2024
a5986b5
format and sort src/flux/cli
andompesta Nov 26, 2024
c7fdb64
fix merge conflicts
andompesta Nov 26, 2024
74c4c7a
Merge pull request #2 from andompesta/add-trt-support-cli-conflict
andompesta Nov 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ source .venv/bin/activate
pip install -e ".[all]"
```

## Local installation with TRT support

```bash
docker pull nvcr.io/nvidia/pytorch:24.10-py3
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux
docker run --rm -it --gpus all -v $PWD:/workspace/flux nvcr.io/nvidia/pytorch:24.10-py3 /bin/bash
# inside container
cd /workspace/flux
pip install -e ".[all]"
pip install -r trt_requirements.txt
```

### Models

We are offering an extensive suite of models. For more information about the invidual models, please refer to the link under **Usage**.
Expand All @@ -40,6 +53,57 @@ We are offering an extensive suite of models. For more information about the inv

The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.

We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:

```bash
python demo_gr.py --name flux-schnell --device cuda
```

Options:

- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
- `--offload`: Offload model to CPU when not in use
- `--share`: Create a public link to your demo

To run the demo with the dev model and create a public link:

```bash
python demo_gr.py --name flux-dev --share
```

## Diffusers integration

`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:

```shell
pip install git+https://github.com/huggingface/diffusers.git
```

Then you can use `FluxPipeline` to run the model

```python
import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=4, #use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")
```

To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation

## API usage

Our API offers access to our models. It is documented here:
Expand Down
3 changes: 0 additions & 3 deletions demo_gr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

NSFW_THRESHOLD = 0.85


def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
Expand All @@ -24,7 +23,6 @@ def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier


class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool):
self.device = torch.device(device)
Expand Down Expand Up @@ -153,7 +151,6 @@ def generate_image(
exif_data[ExifTags.Base.Model] = self.model_name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt

img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)

return img, str(opts.seed), filename, None
Expand Down
12 changes: 9 additions & 3 deletions src/flux/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def request(self):
)
result = response.json()
if response.status_code != 200:
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
raise ApiException(
status_code=response.status_code, detail=result.get("detail")
)
self.request_id = response.json()["id"]

def retrieve(self) -> dict:
Expand All @@ -168,13 +170,17 @@ def retrieve(self) -> dict:
)
result = response.json()
if "status" not in result:
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
raise ApiException(
status_code=response.status_code, detail=result.get("detail")
)
elif result["status"] == "Ready":
self.result = result["result"]
elif result["status"] == "Pending":
time.sleep(0.5)
else:
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
raise ApiException(
status_code=200, detail=f"API returned status '{result['status']}'"
)
return self.result

@property
Expand Down
58 changes: 58 additions & 0 deletions src/flux/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from glob import iglob

import torch
from cuda import cudart
from fire import Fire
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.trt.trt_manager import TRTManager
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image

NSFW_THRESHOLD = 0.85
Expand Down Expand Up @@ -108,6 +110,8 @@ def main(
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
trt: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
Expand All @@ -126,6 +130,8 @@ def main(
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
trt: use TensorRT backend for optimized inference
kwargs: additional arguments for TensorRT support
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

Expand Down Expand Up @@ -158,6 +164,58 @@ def main(
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)

if trt:
# offload to CPU to save memory
ae = ae.cpu()
model = model.cpu()
clip = clip.cpu()
t5 = t5.cpu()

torch.cuda.empty_cache()

trt_ctx_manager = TRTManager(
bf16=True,
device=torch_device,
)

engines = trt_ctx_manager.load_engines(
models={
"clip": clip,
"transformer": model,
"t5": t5,
"vae": ae,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
onnx_dir=os.environ.get("ONNX_DIR", "./onnx"),
opt_image_height=height,
opt_image_width=width,
)

torch.cuda.synchronize()

trt_ctx_manager.init_runtime()
stream = cudart.cudaStreamCreate()[1]

for engine in engines.values():
engine.load(stream)

calculate_max_device_memory = trt_ctx_manager.calculate_max_device_memory(engines)
_, shared_device_memory = cudart.cudaMalloc(calculate_max_device_memory)

for engine_name, engine in engines.items():
engine.activate(shared_device_memory)
shape_dict = engine.get_shape_dict(
batch_size=1,
image_height=height,
image_width=width,
)
engine.allocate_buffers(shape_dict, device=torch_device)

ae = engines["vae"]
model = engines["transformer"]
clip = engines["clip"]
t5 = engines["t5"]

rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
Expand Down
2 changes: 1 addition & 1 deletion src/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:

def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
Expand Down
6 changes: 6 additions & 0 deletions src/flux/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

def forward(self, z: Tensor) -> Tensor:
# get dtype for proper tracing
upscale_dtype = next(self.up.parameters()).dtype

# z to block_in
h = self.conv_in(z)

Expand All @@ -243,6 +246,8 @@ def forward(self, z: Tensor) -> Tensor:
h = self.mid.attn_1(h)
h = self.mid.block_2(h)

# cast to proper dtype
h = h.to(upscale_dtype)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
Expand Down Expand Up @@ -277,6 +282,7 @@ def forward(self, z: Tensor) -> Tensor:
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.params = params
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
Expand Down
Empty file added src/flux/trt/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions src/flux/trt/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flux.trt.engine.base_engine import BaseEngine
from flux.trt.engine.clip_engine import CLIPEngine
from flux.trt.engine.t5_engine import T5Engine
from flux.trt.engine.transformer_engine import TransformerEngine
from flux.trt.engine.vae_engine import VAEEngine

__all__ = [
"BaseEngine",
"CLIPEngine",
"TransformerEngine",
"T5Engine",
"VAEEngine",
]
Loading